# Testing different readout methods in the few-shot context

In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.core.pylabtools import figsize

import seaborn as sns

import numpy as np
import pandas as pd

from tqdm import tqdm
import torch

import os
os.environ["HF_DATASETS_DISABLE_PROGRESS_BAR"] = "0"

In [2]:
from datasets import load_dataset

def load_cdfsl_dataset(name):
    """
    Loads CD-FSL datasets using the most stable configurations 
    to avoid legacy script and config errors.
    """
    match name:
        case "EuroSAT":
            return load_dataset("timm/eurosat-rgb", split="train")
        
        case "ISIC":
            return load_dataset("marmal88/skin_cancer", split="train")
        
        case "PlantVillage":
            return load_dataset("mohanty/PlantVillage", "default", split="train")
        
        case "ChestX":
            return load_dataset("g-ronimo/NIH-Chest-X-ray-dataset_10k",  split="train")
        case _:
            raise ValueError(f"Unknown dataset: {name}")


def n_way_k_shot_sample(ds, k, seed=None):
    """n-way k-shot subsample from the dataset"""
    if seed is not None:
        raise NotImplementedError('No seeding')
    labels = torch.tensor(ds['label']).unique()

    counts_remaining = {label.item(): k for label in labels}
    
    perms = torch.randperm( len(ds))
    inds = []

    for index in perms:
        if sum(counts_remaining.values()) <= 0:
            break

        label = ds[index.item()]['label']
        if counts_remaining[label] > 0:
            counts_remaining[label] -= 1
            inds.append(index)

    return ds.select(inds)



In [3]:
from transformers import Trainer, TrainingArguments
import evaluate
import numpy as np

def accuracy(model, ds):
    metric = evaluate.load("accuracy")
    
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)
    
    eval_args = TrainingArguments(
        output_dir="./results",
        per_device_eval_batch_size=64,
        do_train=False,
        do_eval=True,
        report_to="none", # Keeps it quiet
    )
    
    trainer = Trainer(
        model=model,
        args=eval_args,
        eval_dataset=ds.rename_columns({'input': 'pixel_values'}), # Your tensor-ready dataset
        compute_metrics=compute_metrics,
        
    )
    
    results = trainer.evaluate()
    return results

W0220 10:04:06.098000 55369 torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.


import error: No module named 'triton'


## How do randomized readouts compare to baseline readouts for FSL?

### Baseline

In [4]:
from src.model.setup import image_model_setup
from src.model.CLS_token_probing import ModuleSpecificDecoder, SimpleReadOutAttachment
model_name = "facebook/dinov2-base"
dataset_name = "temp_dataset_subsample"

ds_raw = load_cdfsl_dataset("EuroSAT")
model, ds, _ = image_model_setup(model_name, '', 10, full_dataset=ds_raw)

Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
model.get_classifier_module()

Linear(in_features=1536, out_features=10, bias=True)

In [5]:
ds_subset = n_way_k_shot_sample(ds, 5)
ds_subset.set_format('pt')

ds.set_format('pt')

In [6]:
model.freeze_backbone(freeze_classifier=False)

In [7]:
from src.model.harness import TrainConfig

In [8]:
model.train_cfg = TrainConfig(epochs=100, steps_per_epoch=1, weight_decay=1e-2) # doesn't save hyper-parameters, TODO

In [9]:
from src.train.training import train_readout

In [10]:
dl_train = torch.utils.data.DataLoader(ds_subset, batch_size=50)
dl_val = torch.utils.data.DataLoader(ds_subset, batch_size=50)  # re-using the train dataloader)
train_readout(model, dl_train, dl_val)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mlrast[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name                | Type                         | Params | Mode 
-----------------------------------------------------------------------------
0 | model               | Dinov2ForImageClassification | 86.6 M | eval 
1 | classification_loss | CrossEntropyLoss             | 0      | train
-----------------------------------------------------------------------------
15.4 K    Trainable params
86.6 M    Non-trainable params
86.6 M    Total params
346.383   Total estimated model params size (MB)
1         Modules in train mode
226       Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 50. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_li

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved. New best score: 0.220


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.060 >= min_delta = 0.0. New best score: 0.280


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.260 >= min_delta = 0.0. New best score: 0.540


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.160 >= min_delta = 0.0. New best score: 0.700


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.200 >= min_delta = 0.0. New best score: 0.900


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.920


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.960


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.980


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 1.000


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Monitored metric val/accuracy did not improve in the last 5 records. Best score: 1.000. Signaling Trainer to stop.


[34m[1mwandb[0m: [32m[41mERROR[0m Control-C detected -- Run data was not synced


KeyboardInterrupt: 

In [12]:
from src.model.harness import ModelWrapper

In [13]:
model = ModelWrapper.load_from_checkpoint('middle_decoders/nevajzf3/checkpoints/best.ckpt')

Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
from pytorch_lightning import Trainer

In [15]:
trainer = Trainer()

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [16]:
trainer.test(model, dl_val)

/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |                                                | 0/? [00:00<?, ?it/s]

[{'test/accuracy': 1.0}]

In [17]:
dl_test = torch.utils.data.DataLoader(ds, batch_size=64)

In [18]:
trainer.test(model, dl_test)

/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |                                                | 0/? [00:00<?, ?it/s]

/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 64. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 8. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'test/accuracy': 0.7203086614608765}]

Ok, these settings seem pretty good.

In [20]:
from src.model.setup import image_model_setup
from src.model.CLS_token_probing import ModuleSpecificDecoder, SimpleReadOutAttachment
model_name = "facebook/dinov2-base"
dataset_name = "temp_dataset_subsample"

ds_raw = load_cdfsl_dataset("EuroSAT")
model, ds, _ = image_model_setup(model_name, '', 10, full_dataset=ds_raw)

readout = SimpleReadOutAttachment(11)
model.add_readout(readout)
model.freeze_backbone(freeze_classifier=True)

Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [23]:
dl_train = torch.utils.data.DataLoader(ds_subset, batch_size=50)
dl_val = torch.utils.data.DataLoader(ds_subset, batch_size=50)  # re-using the train dataloader)

model.train_cfg = TrainConfig(epochs=100, steps_per_epoch=1, weight_decay=1e-2)
train_readout(model, dl_train, dl_val)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory ./middle_decoders/3b80yeye/checkpoints exists and is not empty.

  | Name                | Type                         | Params | Mode 
-----------------------------------------------------------------------------
0 | model               | Dinov2ForImageClassification | 86.6 M | eval 
1 | classification_loss | CrossEntropyLoss             | 0      | train
2 | readout             | SimpleReadOutAttachment      | 9.2 M  | train
---------------------

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved. New best score: 0.020


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.040


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.080 >= min_delta = 0.0. New best score: 0.120


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.100 >= min_delta = 0.0. New best score: 0.220


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.240


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.280


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.080 >= min_delta = 0.0. New best score: 0.360


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.160 >= min_delta = 0.0. New best score: 0.520


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.120 >= min_delta = 0.0. New best score: 0.640


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.660


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.180 >= min_delta = 0.0. New best score: 0.840


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.860


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.900


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.080 >= min_delta = 0.0. New best score: 0.980


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 1.000


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Monitored metric val/accuracy did not improve in the last 5 records. Best score: 1.000. Signaling Trainer to stop.


TypeError: asdict() should be called on dataclass instances

In [28]:
readout = SimpleReadOutAttachment.load_from_checkpoint('middle_decoders/3b80yeye/checkpoints/last-v1.ckpt', layer_ind=11)

TypeError: asdict() should be called on dataclass instances