In [6]:
import pandas as pd

# Read three CSV files into separate DataFrames
deb = pd.read_csv('../data/tests_results/classification_gaia_deb.csv')
ogle = pd.read_csv('../data/tests_results/classification_gaia_ogle.csv')
wumacat = pd.read_csv('../data/tests_results/classification_tess_WUMaCat.csv')

In [85]:
deb.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 52 entries, 0 to 51
Data columns (total 10 columns):
 #   Column           Non-Null Count  Dtype 
---  ------           --------------  ----- 
 0   Name             52 non-null     object
 1   Gaia             52 non-null     int64 
 2   binary_tess_Res  52 non-null     object
 3   spot_Tess_Res    52 non-null     object
 4   binary_tess_ViT  52 non-null     object
 5   spot_tess_ViT    52 non-null     object
 6   binary_gaia_Res  52 non-null     object
 7   spot_Gaia_Res    52 non-null     object
 8   binary_gaia_ViT  52 non-null     object
 9   spot_Gaia_ViT    52 non-null     object
dtypes: int64(1), object(9)
memory usage: 4.2+ KB


In [34]:
def calculate_correct_classification_percentage(df, class_name='det'):
    """
    Calculates the percentage of correct classifications for a specified class in binary classification columns.

    Args:
        df (pd.DataFrame): The input DataFrame containing classification results.
        class_name (str): The name of the class to check for correct classifications (default: 'det').

    Returns:
        dict: A dictionary where keys are column names and values are the corresponding percentages.
    """
    results = {}
    total_objects = len(df)

    for col in df.columns:
        if 'binary' in col:
            correct_classifications = df[col][df[col] == class_name].count()
            percentage = (correct_classifications / total_objects) * 100
            results[col] = percentage

    return results

deb_classification_percentages = calculate_correct_classification_percentage(deb, class_name='det')

print("Deb Classification Percentages:")
for col, percentage in deb_classification_percentages.items():
    print(f"{col}: {percentage:.2f}%")

Deb Classification Percentages:
 binary_tess_Res: 100.00%
 binary_tess_ViT: 100.00%
  binary_gaia_Res: 100.00%
 binary_gaia_ViT: 100.00%


In [None]:
wumacat_classification_percentages = calculate_correct_classification_percentage(wumacat, class_name='over')

print("WU UMa Classification Percentages:")
for col, percentage in wumacat_classification_percentages.items():
    print(f"{col}: {percentage:.2f}%")

WU UMa Classification Percentages:
 binary_tess_Res: 100.00%
 binary_tess_ViT: 100.00%
  binary_gaia_Res: 96.67%
 binary_gaia_ViT: 91.11%


In [89]:
wumacat[['binary_gaia_ViT']].value_counts()

binary_gaia_ViT
over               82
det                 8
Name: count, dtype: int64

In [None]:
def clean_dataframe(df):
    """
    Cleans the DataFrame by stripping whitespace from object columns and removing spaces from column names.

    Args:
        df (pd.DataFrame): The input DataFrame.

    Returns:
        pd.DataFrame: The cleaned DataFrame.
    """
    # Remove spaces from column names
    df.columns = df.columns.str.replace(' ', '')

    # Strip whitespace from object columns
    for col in df.columns:
        if df[col].dtype == 'object':
            df[col] = df[col].str.strip()
    return df

ogle = clean_dataframe(ogle)
wumacat = clean_dataframe(wumacat)
deb = clean_dataframe(deb)

                  Name                 Gaia orig_ogle_class binary_I_Res  \
0  OGLE-BLG-ECL-002011  6028823779367951744             det          det   
1  OGLE-BLG-ECL-004840  4107331719835398656             det          det   
2  OGLE-BLG-ECL-005098  4059230147580348160             det          det   
3  OGLE-BLG-ECL-005728  4107530701320089728            over         over   
4  OGLE-BLG-ECL-010040  4109951241823489408             det          det   

  spot_I_Res binary_I_ViT spot_I_ViT binary_gaia_Res spot_Gaia_Res  \
0          s          det          s             det             n   
1          n          det          s             det             s   
2          s          det          s             det             s   
3          s         over          n            over             n   
4          n         over          n             det             n   

  binary_gaia_ViT spot_Gaia_ViT  
0             det             n  
1             det             s  
2             det   

In [59]:
print(round(ogle.groupby('orig_ogle_class')[['binary_gaia_Res']].value_counts(normalize=True) * 100, 1))
print(round(ogle.groupby('orig_ogle_class')[['binary_gaia_ViT']].value_counts(normalize=True) * 100, 1))

orig_ogle_class  binary_gaia_Res
det              det                96.8
                 over                3.2
over             over               97.4
                 det                 2.6
Name: proportion, dtype: float64
orig_ogle_class  binary_gaia_ViT
det              det                93.5
                 over                6.5
over             over               96.1
                 det                 3.9
Name: proportion, dtype: float64


In [61]:
print(round(ogle.groupby('orig_ogle_class')[['binary_I_Res']].value_counts(normalize=True) * 100, 1))
print(round(ogle.groupby('orig_ogle_class')[['binary_I_ViT']].value_counts(normalize=True) * 100, 1))

orig_ogle_class  binary_I_Res
det              det             86.3
                 over            13.7
over             over            96.1
                 det              3.9
Name: proportion, dtype: float64
orig_ogle_class  binary_I_ViT
det              det             82.3
                 over            17.7
over             over            94.7
                 det              5.3
Name: proportion, dtype: float64


In [83]:
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
import numpy as np

import sklearn.metrics
import matplotlib.pyplot as plt

# Columns to iterate through
predicted_label_cols = ['binary_I_Res', 'binary_I_ViT', 'binary_gaia_Res', 'binary_gaia_ViT']
titles_to_print = ['ResNet_I', 'ViT_I', 'ResNet_Gaia', 'ViT_Gaia']

# True labels
true_label_col = 'orig_ogle_class'
true_labels = ogle[true_label_col]

# Initialize LabelEncoder
le = LabelEncoder()

# Fit and transform the true labels
le.fit(true_labels)
true_labels_encoded = le.transform(true_labels)

# Set annotation keywords
annot_kws = {'size': 20}

# Loop through each predicted label column and create a confusion matrix
for i, predicted_label_col in enumerate(predicted_label_cols):
    # Predicted labels
    predicted_labels = ogle[predicted_label_col]
    predicted_labels_encoded = le.transform(predicted_labels)

    # Create confusion matrix
    cm = sklearn.metrics.confusion_matrix(true_labels_encoded, predicted_labels_encoded)

    # Calculate custom metrics (rates)
    class_totals = cm.sum(axis=1, keepdims=True)
    custom_cm = cm / class_totals

    # Calculate overall accuracy
    overall_accuracy = np.diag(cm) / class_totals.flatten()

    # Calculate misclassification rates
    misclassification_rates = 1 - overall_accuracy

    # Create a new figure for each plot
    plt.figure(figsize=(10, 8))

    # Define the colormap based on whether 'ViT' is in the column name
    cmap = "Reds" if 'ViT' in predicted_label_col else "Blues"

    # Plotting the confusion matrix
    ax = sns.heatmap(custom_cm, annot=True, fmt=".2f", cmap=cmap, cbar=False,
                xticklabels=le.classes_, yticklabels=le.classes_,
                annot_kws=annot_kws, square=True)

    # Add title for custom metrics
    plt.xlabel('Predicted class', fontsize=20)
    plt.ylabel('True class', fontsize=20)
    plt.title(f'{titles_to_print[i]}', fontsize=22)
    plt.tick_params(axis='both', which='major', labelsize=18)

    # Save the plot to a file
    plt.tight_layout()
    plt.savefig(f'confusion_matrix_{titles_to_print[i]}.png')
    plt.close()  # Close the figure to free memory

    print(f"Misclassification rates for {predicted_label_col}:", misclassification_rates)
    print(f"Overall accuracy for {predicted_label_col}:", overall_accuracy)

print("Confusion matrices saved as PNG files.")


Misclassification rates for binary_I_Res: [0.13709677 0.03947368]
Overall accuracy for binary_I_Res: [0.86290323 0.96052632]
Misclassification rates for binary_I_ViT: [0.17741935 0.05263158]
Overall accuracy for binary_I_ViT: [0.82258065 0.94736842]
Misclassification rates for binary_gaia_Res: [0.03225806 0.02631579]
Overall accuracy for binary_gaia_Res: [0.96774194 0.97368421]
Misclassification rates for binary_gaia_ViT: [0.06451613 0.03947368]
Overall accuracy for binary_gaia_ViT: [0.93548387 0.96052632]
Confusion matrices saved as PNG files.
