# Testing different readout methods in the few-shot context

In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.core.pylabtools import figsize

import seaborn as sns

import numpy as np
import pandas as pd

from tqdm import tqdm
import torch

import os
os.environ["HF_DATASETS_DISABLE_PROGRESS_BAR"] = "0"

In [2]:
from datasets import load_dataset

def load_cdfsl_dataset(name):
    """
    Loads CD-FSL datasets using the most stable configurations 
    to avoid legacy script and config errors.
    """
    match name:
        case "EuroSAT":
            return load_dataset("timm/eurosat-rgb", split="train")
        
        case "ISIC":
            return load_dataset("marmal88/skin_cancer", split="train")
        
        case "PlantVillage":
            return load_dataset("mohanty/PlantVillage", "default", split="train")
        
        case "ChestX":
            return load_dataset("g-ronimo/NIH-Chest-X-ray-dataset_10k",  split="train")
        case _:
            raise ValueError(f"Unknown dataset: {name}")


def n_way_k_shot_sample(ds, k, seed=None):
    """n-way k-shot subsample from the dataset"""
    if seed is not None:
        raise NotImplementedError('No seeding')
    labels = torch.tensor(ds['label']).unique()

    counts_remaining = {label.item(): k for label in labels}
    
    perms = torch.randperm( len(ds))
    inds = []

    for index in perms:
        if sum(counts_remaining.values()) <= 0:
            break

        label = ds[index.item()]['label']
        if counts_remaining[label] > 0:
            counts_remaining[label] -= 1
            inds.append(index)

    return ds.select(inds)



In [9]:
from transformers import Trainer, TrainingArguments
import evaluate
import numpy as np

def accuracy(model, ds):
    metric = evaluate.load("accuracy")
    
    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)
    
    eval_args = TrainingArguments(
        output_dir="./results",
        per_device_eval_batch_size=64,
        do_train=False,
        do_eval=True,
        report_to="none", # Keeps it quiet
    )
    
    trainer = Trainer(
        model=model,
        args=eval_args,
        eval_dataset=ds.rename_columns({'input': 'x'}),
        compute_metrics=compute_metrics,
        
    )
    
    results = trainer.evaluate()
    return results

## How do randomized readouts compare to baseline readouts for FSL?

Initial run, single sample: they are not as good. (Preliminary data, but...) The randomized readouts are not inheritly well-suited to FSL. It's possible that they need different hyper-parameters, but those are hard to find without over-fitting on the test set.

Could be worth running additional test, but I' not trying to over-fit.

#### Setup datasets

In [4]:
from src.model.setup import load_processed_dataset

model_name = "facebook/dinov2-base"

ds_raw = load_cdfsl_dataset("EuroSAT")
ds = load_processed_dataset(model_name, '', full_dataset=ds_raw)


In [5]:
# subsample
ds_subset = n_way_k_shot_sample(ds, 5)
ds_subset.set_format('pt')
ds.set_format('pt')

### Baseline

train: 100%, test: 72%

In [6]:
from src.model.harness import ModelWrapper, TrainConfig

train_cfg = TrainConfig(epochs=100, steps_per_epoch=1, weight_decay=1e-2)
model = ModelWrapper(model_name, num_classes=10, train_cfg=train_cfg)

Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
from src.train.training import train_readout

dl_train = torch.utils.data.DataLoader(ds_subset, batch_size=50)
dl_val = torch.utils.data.DataLoader(ds_subset, batch_size=50)  # re-using the train dataloader)
train_readout(model, dl_train, dl_val, run_name='fs_base')

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mlrast[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name                | Type                         | Params | Mode 
-----------------------------------------------------------------------------
0 | model               | Dinov2ForImageClassification | 86.6 M | eval 
1 | classification_loss | CrossEntropyLoss             | 0      | train
-----------------------------------------------------------------------------
15.4 K    Trainable params
86.6 M    Non-trainable params
86.6 M    Total params
346.383   Total estimated model params size (MB)
1         Modules in train mode
226       Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 50. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_li

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved. New best score: 0.040


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.080


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.060 >= min_delta = 0.0. New best score: 0.140


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.120 >= min_delta = 0.0. New best score: 0.260


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.260 >= min_delta = 0.0. New best score: 0.520


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.180 >= min_delta = 0.0. New best score: 0.700


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.120 >= min_delta = 0.0. New best score: 0.820


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.840


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.080 >= min_delta = 0.0. New best score: 0.920


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.960


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.980


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 1.000


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Monitored metric val/accuracy did not improve in the last 5 records. Best score: 1.000. Signaling Trainer to stop.


0,1
epoch,▁▁▂▂▃▃▄▄▅▅▅▆▆▇▇██
trainer/global_step,▁▁▂▂▃▃▄▄▅▅▅▆▆▇▇██
val/accuracy,▁▁▂▃▄▆▇▇▇████████
val/loss,██▇▆▅▄▃▂▂▂▂▁▁▁▁▁▁

0,1
epoch,16.0
trainer/global_step,16.0
val/accuracy,1.0
val/loss,0.07756


In [8]:
# load classifiers
from safetensors import safe_open
raw = safe_open('fs_base/readout/model.safetensors', framework='pt')
model.model.load_state_dict({key: raw.get_tensor(key) for key in raw.keys()}, strict=False )

_IncompatibleKeys(missing_keys=['dinov2.embeddings.cls_token', 'dinov2.embeddings.mask_token', 'dinov2.embeddings.position_embeddings', 'dinov2.embeddings.patch_embeddings.projection.weight', 'dinov2.embeddings.patch_embeddings.projection.bias', 'dinov2.encoder.layer.0.norm1.weight', 'dinov2.encoder.layer.0.norm1.bias', 'dinov2.encoder.layer.0.attention.attention.query.weight', 'dinov2.encoder.layer.0.attention.attention.query.bias', 'dinov2.encoder.layer.0.attention.attention.key.weight', 'dinov2.encoder.layer.0.attention.attention.key.bias', 'dinov2.encoder.layer.0.attention.attention.value.weight', 'dinov2.encoder.layer.0.attention.attention.value.bias', 'dinov2.encoder.layer.0.attention.output.dense.weight', 'dinov2.encoder.layer.0.attention.output.dense.bias', 'dinov2.encoder.layer.0.layer_scale1.lambda1', 'dinov2.encoder.layer.0.norm2.weight', 'dinov2.encoder.layer.0.norm2.bias', 'dinov2.encoder.layer.0.mlp.fc1.weight', 'dinov2.encoder.layer.0.mlp.fc1.bias', 'dinov2.encoder.layer

In [9]:
accuracy(model, ds)



{'eval_loss': 0.77437824010849,
 'eval_model_preparation_time': 0.0017,
 'eval_accuracy': 0.7508641975308642,
 'eval_runtime': 164.3294,
 'eval_samples_per_second': 98.582,
 'eval_steps_per_second': 1.546}

### Randomized readouts

train: 100%, test: 69%

In [10]:
from src.model.CLS_token_probing import SimpleReadOutAttachment

train_cfg = TrainConfig(epochs=100, steps_per_epoch=1, weight_decay=1e-2)
readout = SimpleReadOutAttachment(11)
model = ModelWrapper(model_name, num_classes=10, readout_module=readout, train_cfg=train_cfg)

/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'readout_module' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['readout_module'])`.
Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
from src.train.training import train_readout

dl_train = torch.utils.data.DataLoader(ds_subset, batch_size=50)
dl_val = torch.utils.data.DataLoader(ds_subset, batch_size=50)  # re-using the train dataloader)
train_readout(model, dl_train, dl_val, run_name='fs_random')

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



  | Name                | Type                         | Params | Mode 
-----------------------------------------------------------------------------
0 | model               | Dinov2ForImageClassification | 86.6 M | eval 
1 | classification_loss | CrossEntropyLoss             | 0      | train
2 | readout             | SimpleReadOutAttachment      | 9.2 M  | train
-----------------------------------------------------------------------------
2.1 M     Trainable params
86.6 M    Non-trainable params
88.7 M    Total params
354.711   Total estimated model params size (MB)
15        Modules in train mode
226       Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved. New best score: 0.000


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.100 >= min_delta = 0.0. New best score: 0.100


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.080 >= min_delta = 0.0. New best score: 0.180


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.060 >= min_delta = 0.0. New best score: 0.240


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.280


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.320


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.060 >= min_delta = 0.0. New best score: 0.380


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.080 >= min_delta = 0.0. New best score: 0.460


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.120 >= min_delta = 0.0. New best score: 0.580


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.080 >= min_delta = 0.0. New best score: 0.660


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.100 >= min_delta = 0.0. New best score: 0.760


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.080 >= min_delta = 0.0. New best score: 0.840


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.060 >= min_delta = 0.0. New best score: 0.900


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.920


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.060 >= min_delta = 0.0. New best score: 0.980


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 1.000


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Monitored metric val/accuracy did not improve in the last 5 records. Best score: 1.000. Signaling Trainer to stop.


0,1
epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
trainer/global_step,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇███
val/accuracy,▁▁▂▂▂▂▃▃▃▄▄▄▅▆▆▆▆▇▇▇▇▇▇▇███████████
val/loss,██▆▅▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,34.0
trainer/global_step,34.0
val/accuracy,1.0
val/loss,0.18398


In [13]:
# load the readout
model.add_readout(SimpleReadOutAttachment.from_pretrained('fs_random/readout/'))

Loading weights from local directory


In [14]:
accuracy(model, ds)



{'eval_loss': 0.9830414652824402,
 'eval_model_preparation_time': 0.0014,
 'eval_accuracy': 0.6937037037037037,
 'eval_runtime': 174.4564,
 'eval_samples_per_second': 92.86,
 'eval_steps_per_second': 1.456}

Initial results suggest that learning random readout generalizes less well than learned probes.

### Training random only

High patience.

Train: 72%, test: 40%

In [14]:
from src.model.harness import ModelWrapper, TrainConfig

from src.model.CLS_token_probing import SimpleReadOutAttachment

train_cfg = TrainConfig(epochs=100, steps_per_epoch=1, weight_decay=1e-2)
readout = SimpleReadOutAttachment(11, train_probe=False)
model = ModelWrapper(model_name, num_classes=10, readout_module=readout, train_cfg=train_cfg)

/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/utilities/parsing.py:208: Attribute 'readout_module' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['readout_module'])`.
Some weights of Dinov2ForImageClassification were not initialized from the model checkpoint at facebook/dinov2-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
from src.train.training import train_readout

dl_train = torch.utils.data.DataLoader(ds_subset, batch_size=50)
dl_val = torch.utils.data.DataLoader(ds_subset, batch_size=50)  # re-using the train dataloader)
train_readout(model, dl_train, dl_val, run_name='fs_random', patience=20)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



  | Name                | Type                         | Params | Mode 
-----------------------------------------------------------------------------
0 | model               | Dinov2ForImageClassification | 86.6 M | eval 
1 | classification_loss | CrossEntropyLoss             | 0      | train
2 | readout             | SimpleReadOutAttachment      | 9.2 M  | train
-----------------------------------------------------------------------------
1.3 M     Trainable params
87.3 M    Non-trainable params
88.7 M    Total params
354.711   Total estimated model params size (MB)
15        Modules in train mode
226       Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/Users/luke/.local/defaultPythonEnv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved. New best score: 0.000


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.120 >= min_delta = 0.0. New best score: 0.120


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.160


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.180


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.200


Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.080 >= min_delta = 0.0. New best score: 0.280


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.100 >= min_delta = 0.0. New best score: 0.380


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.060 >= min_delta = 0.0. New best score: 0.440


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.480


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.500


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.540


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.040 >= min_delta = 0.0. New best score: 0.580


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.600


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.620


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.020 >= min_delta = 0.0. New best score: 0.640


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Metric val/accuracy improved by 0.080 >= min_delta = 0.0. New best score: 0.720


Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Monitored metric val/accuracy did not improve in the last 20 records. Best score: 0.720. Signaling Trainer to stop.


0,1
epoch,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇█
lr-Adam,▁
train/loss,▁
trainer/global_step,▁▁▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
val/accuracy,▁▁▂▃▃▄▄▄▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██▇████▇███
val/loss,█▇▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,78.0
lr-Adam,0.00054
train/loss,1.51212
trainer/global_step,78.0
val/accuracy,0.66
val/loss,1.43353


In [16]:
# load the readout
model.add_readout(SimpleReadOutAttachment.from_pretrained('fs_random/readout/', train_probe=False))

Loading weights from local directory


In [17]:
accuracy(model, ds)



{'eval_loss': 1.8679229021072388,
 'eval_model_preparation_time': 0.0014,
 'eval_accuracy': 0.41444444444444445,
 'eval_runtime': 173.545,
 'eval_samples_per_second': 93.348,
 'eval_steps_per_second': 1.464}

Not training the linear probe at all:
- low patience: even worse readouts. The training barely broke 50%, test ~33%
- patience 20: train hit 72%, test: 40%

## Next steps:

1. Multiple readouts w/ combinations?
2. Self-calibrating readout?