In [1]:
import torch 

from la.utils.io_utils import load_data 
from la.utils.io_utils import preprocess_dataset

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List
from la.data.my_dataset_dict import MyDatasetDict

initialize(version_base=None, config_path=str("../conf"), job_name="relative_supervision")

hydra.initialize()

In [3]:
import logging


pylogger = logging.getLogger(__name__)

In [4]:
from nn_core.common import PROJECT_ROOT

# Instantiate torchvision dataset
cfg = compose(config_name="relative_supervision", overrides=[])

In [5]:
dataset = load_data(cfg)

dataset = preprocess_dataset(dataset, cfg)

Converting to RGB: 100%|██████████| 50000/50000 [00:25<00:00, 1952.57 examples/s]
Converting to RGB: 100%|██████████| 10000/10000 [00:05<00:00, 1909.50 examples/s]


In [6]:
img_size = dataset['train'][0]['x'].shape[0]
pylogger.info(f"Image size: {img_size}")
pylogger.info(f"Train shape: {dataset['train'][0]['x'].shape}")

In [7]:
import pytorch_lightning as pl

teacher: pl.LightningModule = hydra.utils.instantiate(
    cfg.teacher,
    _recursive_=False,
    num_classes=cfg.dataset.num_classes,
    model=cfg.teacher.model,
    input_dim=img_size,
)

In [8]:
import numpy as np 
subsample = False
if subsample:
    modes = ['train', 'test']
    for mode in modes:
        dataset[mode] = dataset[mode].select(np.arange(100))

In [9]:
map_params = {
            "function": lambda x: {"x": teacher.transform_func(x["x"])},
            "writer_batch_size": 100,
            "num_proc": 1,
        }

for mode in modes:
    dataset[mode] = dataset[mode].map(
        desc=f"Transforming {mode} samples", **map_params
    )

Transforming train samples: 100%|██████████| 100/100 [00:11<00:00,  8.43 examples/s]
Transforming test samples: 100%|██████████| 100/100 [00:11<00:00,  8.44 examples/s]


In [10]:
dataset.set_format(type="torch", columns=['x', 'y'])

# Teacher

In [11]:
from functools import partial 
from pytorch_lightning import Trainer

loader_func = partial(
            torch.utils.data.DataLoader,
            batch_size=512,
            num_workers=8,
)

# TODO: switch back num epochs
trainer_func = partial(Trainer, gpus=1, max_epochs=30, logger=False, enable_progress_bar=True)

In [12]:
train_loader = loader_func(dataset['train'], shuffle=True)
test_loader = loader_func(dataset['test'], shuffle=False)

In [13]:
trainer = trainer_func()
trainer.fit(teacher, train_loader)

  rank_zero_deprecation(


  rank_zero_warn(
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Epoch 29: 100%|██████████| 1/1 [00:01<00:00,  1.82s/it, loss=2.03, loss/train=0.953]

Epoch 29: 100%|██████████| 1/1 [00:01<00:00,  1.89s/it, loss=2.03, loss/train=0.953]


In [14]:
trainer.test(teacher, test_loader)

Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00,  3.18it/s]


[{'loss/test': 5.043737888336182, 'acc/test': 0.20999999344348907}]

# Student

In [15]:
img_size = dataset['train'][0]['x'].shape[-1]
img_size

224

In [16]:
student: pl.LightningModule = hydra.utils.instantiate(
    cfg.student,
    _recursive_=False,
    num_classes=cfg.dataset.num_classes,
    model=cfg.student.model,
    input_dim=img_size,
)

In [17]:
train_loader = loader_func(dataset['train'], shuffle=True)
test_loader = loader_func(dataset['test'], shuffle=False)

In [18]:
trainer = trainer_func()
trainer.fit(student, train_loader)

Epoch 29: 100%|██████████| 1/1 [00:02<00:00,  2.04s/it, loss=3.78, loss/train=3.420]

Epoch 29: 100%|██████████| 1/1 [00:02<00:00,  2.05s/it, loss=3.78, loss/train=3.420]


In [19]:
trainer.test(student, test_loader)

Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 40.05it/s]


[{'loss/test': 4.9328932762146, 'acc/test': 0.029999999329447746}]

# Standard distillation

In [20]:
from tqdm import tqdm


def embed_samples(dataset, model) -> Dict:

    model = model.to("cuda")
    dataloader = loader_func(dataset, shuffle=False, batch_size=8)

    all_embeddings = []
    all_logits = []

    for batch in tqdm(dataloader, desc=f"Embedding {mode} samples"):
        x = batch["x"].to("cuda")
        model_out = model(x)
        embeds, logits = model_out["embeds"].detach().cpu(), model_out["logits"].detach().cpu()
        all_embeddings.extend(embeds)
        all_logits.extend(logits)

    all_embeddings = torch.stack(all_embeddings)
    all_logits = torch.stack(all_logits)

    map_params = {
        "with_indices": True,
        "batched": True,
        "batch_size": 128,
        "num_proc": 1,
        "writer_batch_size": 10,
    }

    dataset = dataset.map(
        function=lambda x, ind: {
            "embedding": all_embeddings[ind],
        },
        desc=f"Storing embedded samples",
        **map_params,
    )
    dataset = dataset.map(
        function=lambda x, ind: {
            "logits": all_logits[ind],
        },
        desc=f"Storing logits",
        **map_params,
    )

    return dataset

In [21]:
dataset['train'] = embed_samples(dataset['train'], teacher)
dataset['test'] = embed_samples(dataset['test'], teacher)

Embedding test samples: 100%|██████████| 13/13 [00:01<00:00, 10.79it/s]
Storing embedded samples: 100%|██████████| 100/100 [00:01<00:00, 89.11 examples/s]
Storing logits: 100%|██████████| 100/100 [00:01<00:00, 92.79 examples/s]
Embedding test samples: 100%|██████████| 13/13 [00:01<00:00, 11.69it/s]
Storing embedded samples: 100%|██████████| 100/100 [00:01<00:00, 84.04 examples/s]
Storing logits: 100%|██████████| 100/100 [00:01<00:00, 79.64 examples/s]


In [22]:
dataset['train'].set_format(type="torch", columns=['x', 'y', 'embedding', 'logits'])
dataset['test'].set_format(type="torch", columns=['x', 'y', 'embedding', 'logits'])

# Stan

In [23]:
num_anchors = 99
anchor_idxs = np.random.choice(len(dataset['train']), size=num_anchors, replace=False)

anchors = dataset['train'].select(anchor_idxs)

In [24]:
relative_student = hydra.utils.instantiate(
    cfg.student_relative,
    _recursive_=False,
    num_classes=cfg.dataset.num_classes,
    model=cfg.student_relative.model,
    input_dim=img_size,
    anchors=anchors
)

In [25]:
train_loader = loader_func(dataset['train'], shuffle=True)
test_loader = loader_func(dataset['test'], shuffle=False)

In [26]:
trainer = trainer_func()
trainer.fit(relative_student, train_loader)

Epoch 29: 100%|██████████| 1/1 [00:02<00:00,  2.03s/it, loss=4.76, loss/train=4.690]

Epoch 29: 100%|██████████| 1/1 [00:02<00:00,  2.20s/it, loss=4.76, loss/train=4.690]


In [27]:
trainer.test(relative_student, test_loader)

Testing: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
