# Imports

In [None]:
import os
os.chdir('../../vlm_toolbox/')

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

In [None]:
import gc
import warnings

import numpy as np
import pandas as pd
import seaborn as sns
import torch
from matplotlib import pyplot as plt

from config.annotations import AnnotationsConfig
from config.enums import (
    CLIPBackbones,
    ImageDatasets,
    LossType,
    Metrics,
    ModelType,
    PrecisionDtypes,
    Setups,
    Trainers,
)
from config.setup import Setup
from metric.classification import ClassificationMetricEvaluator
from metric.visualization.accuracy import plot_model_accuracy
from pipeline.pipeline import Pipeline

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
warnings.filterwarnings('ignore')

In [None]:
def flush():
    gc.collect()
    torch.cuda.empty_cache()

# Util

In [None]:
def compute_per_class_statistics(df, top_k):
    total_samples_per_class = df['class_id'].value_counts().rename_axis('class_id').reset_index(name='total_samples')
    final_df = total_samples_per_class[['class_id']].sort_values(by='class_id').reset_index(drop=True)
    
    for k in range(top_k):
        pred_col = f'pred@{k+1}_label_id'
        conf_col = f'pred@{k+1}_prob'

        # Group by class_id and predicted class label
        grouped = df.groupby(['class_id', pred_col]).agg(
            frequency=(pred_col, 'count'),
            average_confidence=(conf_col, 'mean')
        ).reset_index()
        grouped = pd.merge(grouped, total_samples_per_class, on='class_id')
        grouped['normalized_frequency'] = grouped['frequency'] / grouped['total_samples']
        

        grouped = (
            grouped
            .rename(
                columns={
                    pred_col: f'pred_class_id@{k+1}',
                    'average_confidence': f'confidence@{k+1}',
                    'normalized_frequency': f'frequency@{k+1}'
                }
            )
            [['class_id', f'pred_class_id@{k+1}', f'frequency@{k+1}', f'confidence@{k+1}']]
        
        )
        pivot_df = grouped.loc[grouped.groupby('class_id')[f'frequency@{k+1}'].idxmax()].sort_values(by='class_id').reset_index(drop=True)
       
        final_df = pd.concat([final_df, pivot_df.drop(['class_id'], axis=1)], axis=1)

    final_df['is_correct'] = final_df['class_id'] == final_df['pred_class_id@1']
    return final_df

def display_statistics(per_class_acc_df, col, y, group_by, title, class_cnt=10):
    worst_df = per_class_acc_df.sort_values(by=[col], ascending=True).head(class_cnt).copy()
    top_df = per_class_acc_df.sort_values(by=[col], ascending=False).head(class_cnt).copy()
    gap_size = 5
    all_dummy_dfs = pd.DataFrame()
    for i in range(1, gap_size+1):
        dummy_data = {col: [None] * gap_size, y: ["." * i] * gap_size, group_by: [None] * gap_size}
        dummy_df = pd.DataFrame(dummy_data)
        all_dummy_dfs = pd.concat([all_dummy_dfs, dummy_df])
    
    final_df = pd.concat([worst_df, all_dummy_dfs, top_df], ignore_index=True)
    plt.figure(figsize=(17, int(8 * class_cnt / 10)))
    barplot = sns.barplot(
        data=final_df,
        x=col,
        y=y,
        hue=group_by,
        orient="h",
        saturation=1,
        width=0.75,
        dodge=False
    )
    plt.title(title)
    plt.tight_layout()
    plt.show()

# Config

In [None]:
PREPROCESS_BATCH_SIZE = 512
RANDOM_STATE = 42
TOP_K = 5

### Setup

In [None]:
columns = ['phylum', 'class', 'order', 'family', 'genus', 'specific_epithet']
setup = Setup(
    dataset_name=ImageDatasets.INATURALIST,
    backbone_name=CLIPBackbones.CLIP_VIT_B_16,
    trainer_name=Trainers.COOP,
    setup_type=Setups.FULL,
    model_type=ModelType.FEW_SHOT,
    num_epochs=100,
    train_batch_size=1024,
    n_shots=16,
    validation_size=0.15,
    label_column_name=columns[0],
    annotations_key_value_criteria={'kingdom': ['Animalia']},
    precision_dtype=PrecisionDtypes.FP16,
    loss_type=LossType.LABEL_SMOOTHING_LOSS,
)

In [None]:
metrics_dict = ClassificationMetricEvaluator.load(setup)
metrics_dict.keys()

# Load Metrics

In [None]:
pipeline = Pipeline(setup=setup)
pipeline.setup_labels()
pipeline._initialize_metric_evaluator()
metric_evaluator = pipeline.metric_evaluator
label_handler = pipeline.label_handler

In [None]:
metric_evaluator.register_metrics([
    Metrics.BALANCED_ACCURACY,
    Metrics.COHEN_KAPPA,
    Metrics.M_CORR_COEFF,
])

In [None]:
annotations_config = AnnotationsConfig.get_config(dataset_name=setup.dataset_name)
labels = label_handler.get_labels()
prompts_df = label_handler.get_prompts_df()
class_ids = label_handler.get_class_ids()
class_id_label_id_adj_matrix = label_handler.get_class_id_label_id_adj_matrix()
label_id_prompt_id_mapping = label_handler.get_label_id_prompt_id_mapping()
classes_df = label_handler.get_classes_df()

In [None]:
per_sample_acc_df = metrics_dict['per_sample']
per_sample_acc_df['is_correct'] = per_sample_acc_df['correct_pred_rank'] == 1

per_class_acc_df = pd.DataFrame()

for class_id, group in per_sample_acc_df.groupby('actual_label_id'):
    class_accuracies = {'label_id': class_id}
    for k in range(1, min(metric_evaluator.top_k, 1) + 1):
        if k == 1:
            group_metrics = metric_evaluator.get_metrics(predictions_df=group, main_metric_only=False, top_k=1)
            class_accuracies.update(group_metrics.iloc[0].to_dict())
        else:
            top_k = group['correct_pred_rank'].apply(lambda x: x <= k and x != -1).mean()
            class_accuracies[f'accuracy'] = top_k
        class_accuracies['top_k'] = int(class_accuracies['top_k'])
        class_accuracies['group_cnt'] = len(group)
    per_class_acc_df = pd.concat([per_class_acc_df, pd.DataFrame([class_accuracies])], ignore_index=True)

overall_acc_df = metric_evaluator.get_metrics(predictions_df=per_sample_acc_df, main_metric_only=False)

In [None]:
overall_acc_df

In [None]:
dict(zip(per_class_acc_df['label_id'], per_class_acc_df['group_cnt']))

# Visualize Statistics

In [None]:
plot_model_accuracy(overall_acc_df)

In [None]:
plt.figure(figsize=(15, 8))
g = sns.histplot(per_sample_acc_df, x='pred@1_prob', hue='is_correct', stat='probability')
g.set_yscale("log")
plt.xlabel('Confidence', fontsize=13)
plt.ylabel('Density', fontsize=13)
plt.title(f'Samples\' Top-1 Prediction\'s Confidence Histogram', fontsize=14)
plt.show()

In [None]:
weights = np.ones_like(per_class_acc_df['accuracy'].to_numpy()) / per_class_acc_df.shape[0]
plt.figure(figsize=(9, 6))
bins = 50

plt.hist([per_class_acc_df['top_1_accuracy']], bins=bins, label=['Top-1'], alpha=0.5, weights=weights)
plt.hist([per_class_acc_df['top_3_accuracy']], bins=bins, label=['Top-3'], alpha=0.5, weights=weights)
plt.hist([per_class_acc_df['top_5_accuracy']], bins=bins, label=['Top-5'], alpha=0.5, weights=weights)

plt.legend(loc='upper right')
plt.title("Top5 & Top-3 & Top-1 Acc. Per Class")
plt.tight_layout()
plt.show()

## Class-wise Accuracies

In [None]:
coarse_grained_col = 'label'

In [None]:
display_statistics(
    per_class_acc_df,
    'top_1_accuracy',
    'label',
    coarse_grained_col,
    'Top-1 Worst & Best Acc. Performance',
    class_cnt=10,
)

In [None]:
display_statistics(
    per_class_acc_df,
    'top_3_accuracy',
    'label',
    coarse_grained_col,
    'Top-3 Worst & Best Acc. Performance',
    class_cnt=10,
)

In [None]:
display_statistics(
    per_class_acc_df,
    'top_5_accuracy',
    'label',
    coarse_grained_col,
    'Top-5 Worst & Best Acc. Performance',
    class_cnt=10,
)