This notebook demonstrates how to initialize an embedding network with a trained linear probe.

### Load libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import dataclasses
import pathlib
import sys

import git.repo
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.utils.data

GIT_ROOT = pathlib.Path(
    str(git.repo.Repo(".", search_parent_directories=True).working_tree_dir)
)
sys.path.append(str(GIT_ROOT))

from src.pretrain import finetune, gen_embeddings
from src.pretrain.datasets.embedding import EmbeddingDataset
from src.pretrain.datasets.vision import cifar10, imagenette, svhn
from src.pretrain.models.vision import laion_clip, msft_beit, openai_clip
from src.pretrain.probes import fc_probe, linear_probe

### Load network and train dataset

In [3]:
torch.manual_seed(0)

dataset_cfg = cifar10.CIFAR10()
embedder_cfg = openai_clip.OpenaiClipConfig(id="openai/ViT-B/16")
finetune_cfg = finetune.Config(
    embedder_cfg=embedder_cfg, dataset_cfg=dataset_cfg
)

embedder = embedder_cfg.get_model().float()
# ds_train = torch.utils.data.Subset( # type: ignore
#     dataset_cfg.get_train_ds(embedder.preprocess),
#     indices=range(1000),  # to make testing faster
# )
ds_train, _, _ = torch.utils.data.random_split(
    dataset_cfg.get_train_ds(embedder.preprocess),
    [1000, 10000, 39000],
)
ds_test = dataset_cfg.get_test_ds(embedder.preprocess)
model = (
    fc_probe.FCProbeConfig(n_layers=1, n_classes=10)
    .get_fc_probe(embedder)
    .cuda()
)

Files already downloaded and verified
Files already downloaded and verified


### Compute baseline accuracy of model

In [4]:
model.eval()
test_dict, _ = finetune.evaluate(
    model=model,
    loader=finetune_cfg.get_loader(ds_test, eval_mode=True),
    cfg=finetune_cfg,
)
test_dict

  0%|          | 0/20 [00:00<?, ?it/s]

{'acc': WBMetric(data=0.09580000004768371, summary='max'),
 'loss': WBMetric(data=2.3757041038513185, summary='min')}

### Modify model to have optimal probe

In [5]:
sub_eds, clf = finetune.init_model_with_trained_linear_probe(
    model=model,
    ds=ds_train,
    cfg=finetune_cfg,
    verbose=True,
)

  0%|          | 0/2 [00:00<?, ?it/s]

[W] [19:15:29.152145] L-BFGS stopped, because the line search failed to advance (step delta = 0.000000)
Linear probe results:
acc: 1.0
xent_orig: 0.00097038585
xent: -0.0


In [7]:
# Now compute resulting accuracy
model.eval()
test_dict, _ = finetune.evaluate(
    model=model,
    loader=finetune_cfg.get_loader(ds_test, eval_mode=True),
    cfg=finetune_cfg,
)
test_dict

  0%|          | 0/20 [00:00<?, ?it/s]

{'acc': WBMetric(data=0.9365000004768371, summary='max'),
 'loss': WBMetric(data=0.24160788896083832, summary='min')}

### Compare to pure linear probe baseline

In [6]:
embedding_cfg = gen_embeddings.Config(
    dataset_cfg=dataset_cfg,
    embedder_cfg=embedder_cfg,
)
eds = EmbeddingDataset.load_from_file(embedding_cfg.full_save_path).astype(
    np.float32
)

eds.xs_train.shape, eds.xs_test.shape, eds.xs_test.dtype

((50000, 512), (10000, 512), dtype('float32'))

In [7]:
clf.score(eds.xs_test, eds.ys_test)

0.9363999962806702

### Compare decision function

In [8]:
# EDS_NEW is with a new conda environment and possible new cuda version.
eds_new = gen_embeddings.embed_dataset(
    model.embedder,
    ds=torch.utils.data.Subset(
        dataset_cfg.get_test_ds(model.embedder.preprocess),
        indices=range(2),
    ),
    cfg=gen_embeddings.Config(embedder_cfg=embedder_cfg, dataset_cfg=dataset_cfg),
)

print("Embedding comparison:")
print(eds_new[0][:2])
print(eds.xs_test[:2])

Files already downloaded and verified


  0%|          | 0/1 [00:00<?, ?it/s]

Embedding comparison:
[[ 0.77331966 -0.29994226  0.07011063 ...  0.19855547  0.48564512
   0.08899247]
 [ 0.50367385 -1.0752835   0.0752296  ...  0.24002843  0.15829396
   0.09126556]]
[[ 0.7758789  -0.30029297  0.06982422 ...  0.19909668  0.48583984
   0.08843994]
 [ 0.50341797 -1.0722656   0.0770874  ...  0.23986816  0.15856934
   0.09106445]]


In [10]:
model = model.train()
for imgs, _ in dataclasses.replace(finetune_cfg, eval_batch_size=2).get_loader(
    ds_test, eval_mode=True
):
    with torch.no_grad():
        embeddings = model.embedder.get_embeddings(imgs.cuda())
        logits = model(imgs.cuda())
    break

print("Embedding comparison:")
print(embeddings.cpu().numpy())
print(eds.xs_test[:2])
print()
print("Logit comparison:")
print(logits.cpu().numpy())
print(clf.decision_function(eds.xs_test[:2]).T)

Embedding comparison:
[[ 0.77331966 -0.29994226  0.07011063 ...  0.19855547  0.48564512
   0.08899247]
 [ 0.50367385 -1.0752835   0.0752296  ...  0.24002843  0.15829396
   0.09126556]]
[[ 0.7758789  -0.30029297  0.06982422 ...  0.19909668  0.48583984
   0.08843994]
 [ 0.50341797 -1.0722656   0.0770874  ...  0.23986816  0.15856934
   0.09106445]]

Logit comparison:
[[-3.5558755  -2.8956425  -0.860857    9.480637   -3.1567836   5.441544
   3.006999   -2.4137075  -2.6379287  -3.4146023 ]
 [-0.867033    0.7218502  -0.24550343  1.3061492  -5.9057727  -9.009769
  -4.9407964  -2.07026    17.649624    2.373362  ]]
[[-3.5421736  -2.8947928  -0.8484082   9.486113   -3.1623018   5.4385223
   3.0049486  -2.4247503  -2.6465964  -3.4168882 ]
 [-0.85966563  0.7409808  -0.27235842  1.2886996  -5.902449   -8.9882965
  -4.910681   -2.0559483  17.616325    2.3548064 ]]
