This notebook prototypes the abilitity to initialize an embedding network with a trained linear probe.

### Load libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pathlib
import sys

import git.repo
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch import nn

GIT_ROOT = pathlib.Path(
    str(git.repo.Repo(".", search_parent_directories=True).working_tree_dir)
)
sys.path.append(str(GIT_ROOT))

from src.pretrain import finetune, gen_embeddings, probe_embeddings
from src.pretrain.datasets.embedding import EmbeddingDataset
from src.pretrain.datasets.vision import cifar10, imagenette, svhn
from src.pretrain.models.vision import laion_clip, msft_beit, openai_clip
from src.pretrain.probes import fc_probe, linear_probe

### Construct configs

In [3]:
dataset_cfg = cifar10.CIFAR10()
embedder_cfg = openai_clip.OpenaiClipConfig(id="openai/ViT-B/16")

embedding_cfg = gen_embeddings.Config(
    dataset_cfg=dataset_cfg, embedder_cfg=embedder_cfg
)
probe_cfg = probe_embeddings.Config(
    embedder_cfg=embedder_cfg, dataset_cfg=dataset_cfg
)
finetune_cfg = finetune.Config(
    embedder_cfg=embedder_cfg, dataset_cfg=dataset_cfg
)

### Load network

In [4]:
embedder = embedder_cfg.get_model().float()
model = fc_probe.FCProbeConfig(n_layers=1, n_classes=10).get_fc_probe(embedder)
model.cuda();

### Load dataset (raw and embeddings)

In [5]:
ds_test = dataset_cfg.get_test_ds(embedder.preprocess)
loader_test = finetune_cfg.get_loader(ds_test, eval_mode=True)

# Embeddings
eds = EmbeddingDataset.load_from_file(embedding_cfg.full_save_path).astype(
    np.float32
)
print(eds.xs_train.shape, eds.xs_test.shape, eds.xs_test.dtype)

Files already downloaded and verified
(50000, 512) (10000, 512) float32


### Compute baseline accuracy of model

In [6]:
model.eval()
test_dict, _ = finetune.evaluate(
    model=model,
    loader=loader_test,
    cfg=finetune_cfg,
)
test_dict

  0%|          | 0/20 [00:00<?, ?it/s]

{'acc': WBMetric(data=0.06460000001192093, summary='max'),
 'loss': WBMetric(data=2.3414895526885986, summary='min')}

### Compute optimal probe

In [7]:
res_dict: dict

In [8]:
%%capture
# We capture output to hide some annoying warnings
res_dict = linear_probe.run_experiment(
    ds=eds,
    c=100,
    use_gpu=True,
    return_clf_params=True,
)

In [9]:
res_dict["acc"], res_dict["xent"], res_dict["xent_orig"]

(0.945, 0.3735188, 0.31929624)

### Modify model to have optimal probe

In [10]:
# https://discuss.pytorch.org/t/how-do-i-pass-numpy-array-to-conv2d-weight-for-initialization/56595/3

assert np.all(res_dict["classes"] == np.arange(len(res_dict["classes"])))

readout_lyr: nn.Linear
readout_lyr, = model.probe # type: ignore
with torch.no_grad():    
    readout_lyr.weight.copy_(torch.from_numpy(res_dict["clf_coef"]).T)
    readout_lyr.bias.copy_(torch.from_numpy(res_dict["clf_intercept"]))

In [11]:
# Now compute resulting accuracy
model.eval()
test_dict, _ = finetune.evaluate(
    model=model,
    loader=loader_test,
    cfg=finetune_cfg,
)
test_dict

  0%|          | 0/20 [00:00<?, ?it/s]

{'acc': WBMetric(data=0.9452000002861023, summary='max'),
 'loss': WBMetric(data=0.31933134059906004, summary='min')}