In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [21]:
# run this cell to create a home directory where all repos will be cloned

%%bash
cd /content/drive/MyDrive/
mkdir -p home

In [None]:
# run this cell to clone and install the invertible_cl repo

%%bash
cd /content/drive/MyDrive/home
if [ ! -d "invertible_cl" ]; then
  git clone https://github.com/mishgon/invertible_cl.git
fi
cd invertible_cl
pip install -e .

In [24]:
# run this cell to update (pull) and reinstall the invertible_cl repo

%%bash
cd /content/drive/MyDrive/home/invertible_cl
git pull
pip install -e .

Already up to date.
Obtaining file:///content/drive/MyDrive/home/invertible_cl
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Installing collected packages: invertible-cl
  Attempting uninstall: invertible-cl
    Found existing installation: invertible-cl 0.0.1
    Uninstalling invertible-cl-0.0.1:
      Successfully uninstalled invertible-cl-0.0.1
  Running setup.py develop for invertible-cl
Successfully installed invertible-cl-0.0.1


In [2]:
# try restarting the notebook if some imports do not work

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import LearningRateMonitor

from invertible_cl.pretrain.data import CIFAR10
from invertible_cl.pretrain.models import VICReg
from invertible_cl.eval.models.probing import OnlineProbing

In [None]:
batch_size = 256
lr = 1e-2 * batch_size / 256  # change lr proportionally to batch size
num_epochs = 1000

datamodule = CIFAR10(
    data_dir='/content/drive/data/cifar10/',
    batch_size=batch_size,
    num_workers=8
)

model = VICReg(
    encoder='resnet18_32x32',
    proj_dim=4096,
    lr=lr
)

callbacks = [
    OnlineProbing(
        embed_dim=model.embed_dim,
        num_classes=datamodule.num_classes
    ),
    LearningRateMonitor()
]

logger = TensorBoardLogger(
    save_dir='/content/drive/experiments/',
    name=f'pretrain/cifar10/vicreg/'
)

trainer = pl.Trainer(
    logger=logger,
    callbacks=callbacks,
    accelerator='gpu',
    max_epochs=num_epochs,
    gradient_clip_val=1.0,
    log_every_n_steps=10
)

trainer.fit(
    model=model,
    datamodule=datamodule
)
