# Testing different readout methods in the few-shot context

In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.core.pylabtools import figsize

import seaborn as sns

import numpy as np
import pandas as pd

from tqdm import tqdm
import torch

import os
os.environ["HF_DATASETS_DISABLE_PROGRESS_BAR"] = "0"

In [2]:
from datasets import load_dataset

def load_cdfsl_dataset(name):
    """
    Loads CD-FSL datasets using the most stable configurations 
    to avoid legacy script and config errors.
    """
    match name:
        case "EuroSAT":
            return load_dataset("timm/eurosat-rgb", split="train")
        
        case "ISIC":
            return load_dataset("marmal88/skin_cancer", split="train")
        
        case "PlantVillage":
            return load_dataset("mohanty/PlantVillage", "default", split="train")
        
        case "ChestX":
            return load_dataset("g-ronimo/NIH-Chest-X-ray-dataset_10k",  split="train")
        case _:
            raise ValueError(f"Unknown dataset: {name}")


def n_way_k_shot_sample(ds, k, seed=None):
    """n-way k-shot subsample from the dataset"""
    if seed is not None:
        raise NotImplementedError('No seeding')
    labels = torch.tensor(ds['label']).unique()

    counts_remaining = {label.item(): k for label in labels}
    
    perms = torch.randperm( len(ds))
    inds = []

    for index in perms:
        if sum(counts_remaining.values()) <= 0:
            break

        label = ds[index.item()]['label']
        if counts_remaining[label] > 0:
            counts_remaining[label] -= 1
            inds.append(index)

    return ds.select(inds)



In [3]:
from src.model.setup import image_model_setup

import error: No module named 'triton'


W0217 14:45:24.041000 9762 torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.


In [4]:
model_name = "facebook/dinov2-base"
ds_raw = load_cdfsl_dataset( "EuroSAT")


model, ds, _ = image_model_setup(model_name, '', 10, full_dataset=ds_raw)

Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
ds_subset = n_way_k_shot_sample(ds, 5)
ds_subset.set_format('pt')

ds.set_format('pt')

In [6]:
from transformers import Trainer, TrainingArguments
import evaluate
import numpy as np

def accuracy(model, ds):
    metric = evaluate.load("accuracy")
    
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)
    
    eval_args = TrainingArguments(
        output_dir="./results",
        per_device_eval_batch_size=64,
        do_train=False,
        do_eval=True,
        report_to="none", # Keeps it quiet
    )
    
    trainer = Trainer(
        model=model,
        args=eval_args,
        eval_dataset=ds.rename_columns({'input': 'pixel_values'}), # Your tensor-ready dataset
        compute_metrics=compute_metrics,
        
    )
    
    results = trainer.evaluate()
    return results

In [7]:
results = accuracy(model.model, ds)



In [9]:
print(results)