# Imports

In [None]:
import os
    
os.chdir('../../vlm_toolbox')

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

In [None]:
import os
import pandas as pd
import numpy as np
import torch
import warnings
from matplotlib import pyplot as plt

from config.enums import (
    CLIPBackbones,
    ImageDatasets,
    LossType,
    ModelType,
    Stages,
    PrecisionDtypes,
    Setups,
    Metrics,
    Trainers,
    SamplingStrategy,
    SamplingType,
)
from config.logging import LoggerFactory
from config.setup import Setup
from metric.visualization.accuracy import plot_model_accuracy
from pipeline.pipeline import Pipeline

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings('ignore')

# Config

In [None]:
logger = LoggerFactory.create_logger("coop_finetuning_logger", notebook=False)

In [None]:
Setup.list_available_experiment_results()

In [None]:
setup = Setup(
    dataset_name=ImageDatasets.IMAGENET_1K,
    backbone_name=CLIPBackbones.CLIP_VIT_B_16,
    trainer_name=Trainers.COOP,
    model_type=ModelType.FEW_SHOT,
    setup_type=Setups.FULL,
    num_epochs=200,
    # train_split=Stages.EVAL,
    train_batch_size=128,
    eval_batch_size=1024,
    validation_size=0.15,
    label_column_name='coarse',
    n_shots=1,
    top_k=np.inf,
    precision_dtype=PrecisionDtypes.FP16,
    sampling_type=SamplingType.UNDER_SAMPLING,
    sampling_strategy=SamplingStrategy.RANDOM_UNDER_SAMPLING,
    # load_from_checkpoint=True,
    # loss_type=LossType.ENLARGED_LARGE_MARGIN_LOSS,
    # coarse_column_name='coarse',
    # enable_novelty=True,
    # top_k=67,
)
setup.get_relative_save_path()

### Device

In [None]:
DEVICE_TYPE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE = torch.device(DEVICE_TYPE)
DEVICE

# Evaluation

In [None]:
pipeline = Pipeline(setup, device_type=DEVICE_TYPE, logger=logger)

In [None]:
pipeline.setup_data()

In [None]:
print(str(pipeline.model))

In [None]:
pipeline.setup_model()

In [None]:
pipeline.setup.set_num_epochs(5)

In [None]:
pipeline.train(collate_all_m2_samples=True)

In [None]:
pipeline.evaluate()

In [None]:
saved_dirs_dict = pipeline.save(
    save_predictions=True,
)
print(saved_dirs_dict)

In [None]:
main_metric_df = pipeline.get_metrics()
plot_model_accuracy(main_metric_df, top_k=5)
plt.show()