scratch notebook for testing things

In [1]:
import torch
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import v2

from implementations.simclr import run_simclr_train
from implementations.byol import run_byol_train
from implementations.utils import make_simclr_augment_fn

In [2]:
def byol_tau_cosine_schedule(state, config):
    k = state["monitors"]["epoch"][-1]
    K = config["max_epochs"]
    state["tau"] = 1 - (1-config["tau_base"])*(np.cos(np.pi*k/K)+1)/2

def get_default_byol_config():
    config = {
        "device": "cuda",
        "max_epochs": 1000,
        "verbose_epoch": True,
        
        "data_kwargs": {
            "batch_size": 4096,
            "drop_last": True,
        },
        "transform": lambda x: x[0],

        "base_class": "resnet50",
        "base_kwargs": {
            "n_channels": 1,
        },
        "projection_head_class": "byol_default",
        "projection_head_kwargs": {
            "in_features": 2048,
            "out_features": 256,
            "hidden_dim": 4096
        },
        "predictor_class": "byol_default",
        "predictor_kwargs": {
            "in_features": 256,
            "out_features": 256,
            "hidden_dim": 4096
        },

        "tau_base": 0.996,
        "optim_class": "lars",
        "optim_kwargs": {
            "lr": 0.2,
            "lr_scaling": "batch_linear",
            "weight_decay": 1.5e-6,
            "trust_coefficient": 0.001
        },
        "scheduler_class": "cosine",
        "scheduler_kwargs": {},

        "callbacks": [byol_tau_cosine_schedule],
        "monitor_names": []
    }
    return config

In [3]:
config = get_default_byol_config()
config["augment_fn"] = make_simclr_augment_fn(
    image_size=(28, 28),
    do_color_distort=False,
    crop_scale=(0.2, 1)
)

dataset = MNIST(
    root="./data",
    train=True,
    transform=v2.Compose([
        v2.PILToTensor(),
        v2.ToDtype(torch.float32, scale=True)
    ]),
    download=True
)

In [None]:
result = run_byol_train(dataset, config)

Training:   0%|          | 5/1000 [05:09<17:05:52, 61.86s/epoch, epoch=5, train_loss=1.64]