In [None]:
%load_ext autoreload
%autoreload 2

### LOSS

In [None]:
import torch
from model import D
size = (1, 32, 64)
p = torch.rand(size)
z = torch.rand(size)
D(p, z)

### PROJECTION MLP

In [None]:
import torch
from model import projection_MLP
#
d_i = 256
d_h = 1024
d_o = 1024
n_hidden = 10
b = 4
#
# in = [b, d_i, 1, 1]
f = projection_MLP(d_i, d_h, d_o, n_hidden)
#
x = torch.rand(b, d_i)
y = f(x)
#
assert y.shape == torch.Size((b, d_o))

### PREDICTION MLP

In [None]:
import torch
from model import prediction_MLP
#
d_i = 2048
d_h = 512
d_o = 2048
b = 8
#
f = prediction_MLP(d_i, d_h, d_o)
#
x = torch.rand(b, d_i)
#
y = f(x)
#
assert y.shape == torch.Size((b, d_o))

### SimSiam 

In [None]:
import torch
from model import SimSiam
from torchvision.models import resnet50
#
b = 8
channels = 3
img_size = 224
projector_args = {
    "hidden_dim": 2048,
    "out_dim": 2048,
    "n_hidden_layers": 1
}
predictor_args = {
    "hidden_dim": 512,
    "in_dim":  projector_args["out_dim"],
    "out_dim": projector_args["out_dim"]
}
#
x1 = torch.rand(b, channels, img_size, img_size)
x2 = torch.rand(x1.size())
#
backbone = resnet50()
backbone.out_dim = backbone.fc.in_features
backbone.fc = torch.nn.Identity()
#
model = SimSiam(backbone, projector_args, predictor_args)
L = model(x1, x2)

## Configs

In [None]:
import pprint
from configs import *
#
pp = pprint.PrettyPrinter(indent=2)
#
config = simsiam_default(debug=False)
config = add_paths(config)
pp.pprint(config)
#
config = simsiam_cifar10()

In [None]:
config = simsiam_cifar10()
pp.pprint(config)

# Test Augementations

In [None]:
import PIL
import numpy as np
import matplotlib.pyplot as plt
from augmentations import SimSiamAugmentations, LinearProbAugmentations

In [None]:
imagenet_mean_std = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]]
img_size = 60
x = PIL.Image.open("data/flower.jpg")

### SimSiam

In [None]:
aug_do_norm = SimSiamAugmentations(img_size, imagenet_mean_std)
aug_no_norm = SimSiamAugmentations(img_size)

In [None]:
x1, x2 = aug_do_norm(x)
#
x1 = x1.transpose(0, 2).numpy()
x2 = x2.transpose(0, 2).numpy()
#
fig, axs = plt.subplots(1, 3, figsize=(10, 5))
axs[0].imshow(np.array(x))
axs[1].imshow(x1)
axs[2].imshow(x2)

In [None]:
x1, x2 = aug_no_norm(x)
#
x1 = x1.transpose(0, 2).numpy()
x2 = x2.transpose(0, 2).numpy()
#
fig, axs = plt.subplots(1, 3, figsize=(10, 5))
axs[0].imshow(np.array(x))
axs[1].imshow(x1)
axs[2].imshow(x2)

In [None]:
aug_train = LinearProbAugmentations(img_size, train=True)
aug_valid = LinearProbAugmentations(img_size, train=False)
#
x_aug_train = aug_train(x).transpose(0, 2).numpy()
x_aug_valid = aug_valid(x).transpose(0, 2).numpy()
#
fig, axs = plt.subplots(1, 3, figsize=(10, 5))
axs[0].imshow(np.array(x))
axs[1].imshow(x_aug_train)
axs[2].imshow(x_aug_valid)

In [None]:
aug_train = LinearProbAugmentations(img_size, train=True, means_std=imagenet_mean_std)
aug_valid = LinearProbAugmentations(img_size, train=False, means_std=imagenet_mean_std)
#
x_aug_train = aug_train(x).transpose(0, 2).numpy()
x_aug_valid = aug_valid(x).transpose(0, 2).numpy()
#
fig, axs = plt.subplots(1, 3, figsize=(10, 5))
axs[0].imshow(np.array(x))
axs[1].imshow(x_aug_train)
axs[2].imshow(x_aug_valid)

In [None]:
from augmentations import get_aug

In [None]:
print(get_aug(64, train=True, train_classifier=False))
print(get_aug(64, train=True, train_classifier=True))
print(get_aug(64, train=False, train_classifier=True))

### Backbones

In [None]:
from utils import get_backbone

In [None]:
model = get_backbone("resnet50")
model = get_backbone("resnet18")

In [None]:
model

### Optimizers

In [None]:
from utils import get_optimizer
import torchvision.models as models

In [None]:
model = models.mobilenet_v2()
optimizer_args = {
     "lr": 0.03,
     "weight_decay": 0.0005,
}

In [None]:
optim = get_optimizer("sgd", model, optimizer_args)
print(optim)
optim = get_optimizer("adam", model, optimizer_args)
print(optim)

In [None]:
from utils import get_scheduler

In [None]:
scheduler_name = "cosine_decay"
scheduler_args = {"T_max": 100, "eta_min":0}

scheduler = get_scheduler(scheduler_name, optim, scheduler_args)

### Datasets

In [None]:
from utils import get_dataset
import matplotlib.pyplot as plt
import numpy as np

In [None]:
img_size = 96
p_data = "/mnt/data/pytorch"
dataset = "cifar10"
ds_train = get_dataset(
                 dataset=dataset,
                 data_dir=p_data,
                 transform=None,
                 train=True,
                 download=False)

ds_test = get_dataset(
                 dataset=dataset,
                 data_dir=p_data,
                 transform=None,
                 train=False,
                 download=False)
#
print(len(ds_train))
print(len(ds_test))

In [None]:
#
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(np.array(ds_train[0][0]))
ax[1].imshow(np.array(ds_test[0][0]))