In [None]:
%load_ext autoreload
%autoreload 2

## Setup data

In [None]:
from torchvision.datasets import STL10
stl10_unlabeled = STL10("./data", split="unlabeled", download=True)

In [None]:
from core.dataset import AugmentedDataset
import torchvision.transforms.functional as tvf
aug_ds = AugmentedDataset(stl10_unlabeled, (100,100))

## Setup PL Embeddor

In [None]:
import torchvision, torch
import pytorch_lightning as pl

base_model = torchvision.models.efficientnet_b0(from_pretrained=True)

class Identity(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x
    
base_model.classifier = Identity()
base_out_size = base_model(aug_ds[0][0][0].unsqueeze(0)).shape[1]
print(base_out_size)

In [None]:
from core.train import SimCLREmbeddor

embed_size = 50
temperature = .1
batch_size = 100
train_count = 10000
val_count = 1000
lr = .003

embeddor = SimCLREmbeddor(aug_ds, base_model, base_out_size, embed_size = embed_size, temperature = temperature, 
                      batch_size = batch_size, train_count = train_count, val_count = val_count, lr = lr)

In [None]:
trainer = pl.Trainer(log_every_n_steps = 50, max_epochs = 100)

#### Get good initial lr

In [None]:
lr_find_emb = trainer.tuner.lr_find(embeddor)

In [None]:
lr_find_emb.plot(show=False, suggest=True)
emb_lr = lr_find_emb.suggestion()

In [None]:
emb_opt = embeddor.optimizers()
emb_opt.param_groups[0]['lr'] = emb_lr

In [None]:
trainer.fit(embeddor)

#### Save checkpoint

In [None]:
emb_checkpoint_file = "checkpoints/efficientnet-b0-stl10-embeddor.ckpt"
trainer.save_checkpoint(emb_checkpoint_file)

## Setup PL Classifier

In [None]:
from core.dataset import get_normalized_dataset

norm_ds_class = get_normalized_dataset(STL10)

In [None]:
norm_ds = norm_ds_class("./data", "train")

In [None]:
from core.train import SimCLRClassifier

bqatch_size = 160
n_classes = 10
freeze_base = True
epochs = 100
lr = .03

embeddor = SimCLREmbeddor.load_from_checkpoint(emb_checkpoint_file, dataset=ds, base_model=base_model).model
classifier = SimCLRClassifier(norm_ds_class, embeddor, n_classes=n_classes, freeze_base=freeze_base,
                              batch_size=batch_size, epochs=epochs, lr=lr)

In [None]:
trainer = pl.Trainer(log_every_n_steps = 50, max_epochs = 10)

#### Get good initial lr

In [None]:
lr_find_cls = trainer.tuner.lr_find(classifier)

In [None]:
lr_find_cls.plot(show=False, suggest=True)
cls_lr = lr_find_cls.suggestion()

In [None]:
cls_opt = classifier.optimizers()
cls_opt.param_groups[0]['lr'] = cls_lr

In [None]:
trainer.fit(classifier)

## View logs

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./lightning_logs/