This notebook demonstrates how to initialize an embedding network with a trained linear probe.

### Load libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pathlib
import sys

import git.repo
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.utils.data

GIT_ROOT = pathlib.Path(
    str(git.repo.Repo(".", search_parent_directories=True).working_tree_dir)
)
sys.path.append(str(GIT_ROOT))

from src.pretrain import finetune
from src.pretrain.datasets.vision import cifar10, imagenette, svhn
from src.pretrain.models.vision import laion_clip, msft_beit, openai_clip
from src.pretrain.probes import fc_probe

### Load network and train dataset

In [3]:
dataset_cfg = cifar10.CIFAR10()
embedder_cfg = openai_clip.OpenaiClipConfig(id="openai/ViT-B/16")
finetune_cfg = finetune.Config(
    embedder_cfg=embedder_cfg, dataset_cfg=dataset_cfg
)

embedder = embedder_cfg.get_model().float()
ds_train = torch.utils.data.Subset( # type: ignore
    dataset_cfg.get_train_ds(embedder.preprocess),
    indices=range(5000),  # to make testing faster
)
model = (
    fc_probe.FCProbeConfig(n_layers=1, n_classes=10)
    .get_fc_probe(embedder)
    .cuda()
)

Files already downloaded and verified


### Compute baseline accuracy of model

In [4]:
model.eval()
test_dict, _ = finetune.evaluate(
    model=model,
    loader=finetune_cfg.get_loader(ds_train, eval_mode=True),
    cfg=finetune_cfg,
)
test_dict

  0%|          | 0/10 [00:00<?, ?it/s]

{'acc': WBMetric(data=0.07639999985098839, summary='max'),
 'loss': WBMetric(data=2.3347501899719236, summary='min')}

### Modify model to have optimal probe

In [5]:
finetune.init_model_with_trained_linear_probe(
    model=model,
    ds=ds_train,
    cfg=finetune_cfg,
    verbose=True,
)

  0%|          | 0/10 [00:00<?, ?it/s]

Linear probe results:
acc: 1.0
xent_orig: 0.0011216763
xent: -0.0


In [6]:
# Now compute resulting accuracy
model.eval()
test_dict, _ = finetune.evaluate(
    model=model,
    loader=finetune_cfg.get_loader(ds_train, eval_mode=True),
    cfg=finetune_cfg,
)
test_dict

  0%|          | 0/10 [00:00<?, ?it/s]

{'acc': WBMetric(data=1.0, summary='max'),
 'loss': WBMetric(data=0.0011196479812264442, summary='min')}