# Imports

In [None]:
import os
os.chdir('../../vlm_toolbox/')

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

In [None]:
import gc
import warnings

import torch
from matplotlib import pyplot as plt

from config.enums import ImageDatasets, ModelType, Setups, Trainers
from config.setup import Setup
from metric.accuracy import AccuracyMetricEvaluator
from metric.visualization.accuracy import plot_model_accuracy

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings('ignore')

In [None]:
def flush():
    gc.collect()
    torch.cuda.empty_cache()

# Config

In [None]:
cols = ['phylum', 'class', 'order', 'family', 'genus', 'specific_epithet']

In [None]:
for i in range(len(cols)):
    clip_setup = Setup(
        dataset_name=ImageDatasets.INATURALIST,
        trainer_name=Trainers.CLIP,
        setup_type=Setups.EVAL_ONLY,
        label_column_name=cols[i],
        annotations_key_value_criteria={'kingdom': ['Animalia']},
    )
    baseline_coop_setup = Setup(
        dataset_name=ImageDatasets.INATURALIST,
        trainer_name=Trainers.COOP,
        n_shots=16,
        setup_type=Setups.EVAL_ONLY,
        model_type=ModelType.ZERO_SHOT,
        label_column_name=cols[i],
        annotations_key_value_criteria={'kingdom': ['Animalia']},
        enable_novelty=True,
    )
    coop_setup = Setup(
        dataset_name=ImageDatasets.INATURALIST,
        trainer_name=Trainers.COOP,
        n_shots=16,
        label_column_name=cols[i],
        annotations_key_value_criteria={'kingdom': ['Animalia']},
        enable_novelty=True,
    )
    
    
    clip_metrics = AccuracyMetricEvaluator.load(clip_setup)['overall']
    baseline_coop_metrics = AccuracyMetricEvaluator.load(baseline_coop_setup)['overall']
    baseline_coop_metrics['trainer_name'] = 'baseline_coop'
    coop_metrics = AccuracyMetricEvaluator.load(coop_setup)['overall']
    
    col_name = clip_setup.get_label_column_name() or 'default'
    plot_model_accuracy([clip_metrics, baseline_coop_metrics, coop_metrics], title=f"Overall Perfmance on the '{col_name}' Column - Dataset: {clip_setup.get_dataset_name()}")
    plt.show()