In [None]:
from pathlib import Path
import datetime
#
from dotted_dict import DottedDict
import torch
import torch.nn as nn
#
import numpy as np
import pprint
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
#
import torch.nn.functional as F
import matplotlib.pyplot as plt
#
from torch.utils.tensorboard import SummaryWriter

In [None]:
from models.backbones import *
from models.projectors import *
from models.barlow_twins import BarlowTwins
from optimizers import *
from augmentations import SimSiamAugmentation, Augmentation
from datasets import get_dataset
from utils import show, show_batch, save_checkpoint
from config_utils import get_dataloaders_from_config, get_config_template, add_paths_to_confg
from train_utils import down_knn, down_train_linear, down_valid_linear, std_cov_valid

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
pp = pprint.PrettyPrinter(indent=4)

In [None]:
config = get_config_template()

#################
# DVICE
#################
config.device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

#################
# frequencies
#################
config.freqs = {
    "ckpt": 20,
    "lin_eval": 1,
    "knn_eval": 1,
    "std_eval": 1,
}
#################
# data
#################
config.p_data = "/mnt/data/pytorch"
config.dataset = "cifar10"
config.img_size = 32
config.n_classes = 10
config.train_split = 'train'
config.down_train_split = 'train'
config.down_valid_split = "valid"
config.augmentations_train = [
    ("RandomResizedCrop", {'size': config.img_size, "scale": (0.2, 1.0)}),
    ("RandomHorizontalFlip", {'p': 0.5}),
    ("RandomApply", {
        "transforms": [
            ("ColorJitter", {"brightness": 0.3,
                             "contrast": 0.3,
                             "saturation": 0.1,
                             'hue': 0.1})
        ],
        "p": 0.5,
    }),
    ("RandomGrayscale", {"p": 0.1}),
    ("ToTensor", {}),
    ('Normalize', {'mean': [0.485, 0.456, 0.406],
                   'std':[0.229, 0.224, 0.225]}),
]
#
config.augmentations_valid = [
    ("Resize", {'size': (config.img_size, config.img_size)}),
    ("ToTensor", {}),
    ('Normalize', {'mean': [0.485, 0.456, 0.406],
                   'std':[0.229, 0.224, 0.225]}),
]
#################
# train model
#################
config.backbone =  "MobileNet-v3-Small"
config.projector_args = {
    'd_out': 512,
    'd_hidden': 512,
    'n_hidden': 0,
    'normalize': False,
    'dropout_rate': 0.05,
    'activation_last': False,
    'normalize_last': False,
    'dropout_rate_last': None,
}
#################
# training
#################
config.batch_size = 1024
config.num_epochs = 400
config.num_workers = 8

#################
# optimizer
#################
config.optimizer = "sgd"
config.optimizer_args = {
        "lr": 1e-2,
        "weight_decay": 1e-6,
        "momentum": 0.9
    }
config.scheduler = "cosine_decay"
config.scheduler_args = {
        "T_max": config.num_epochs,
        "eta_min": 0,
}
#################
# down train
#################
config.down_batch_size = 512
config.down_num_epochs = 1
config.down_num_workers = 8

#################
# down optimizer
#################
config.down_optimizer = "sgd"
config.down_optimizer_args = {
        "lr": 0.03 * config.down_batch_size / 256,
        "weight_decay": 5e-4,  # used always
        "momentum": 0.9
    }
config.down_scheduler = "cosine_decay"
config.down_scheduler_args = {
        "T_max": config.down_num_epochs,
        "eta_min": 0,
}

config.loss = {
    'lmda_rec': 10,
    'lmda_kld': 1
}
config.debug = False
config.p_base = "/mnt/experiments/siamesevae"
add_paths_to_confg(config)
config = DottedDict(config)

In [None]:
# META VARS
P_CKPT = None
CONTINUE = True

In [None]:
if P_CKPT is not None:
    print("LOADING CHECKPOINT {}".format(P_CKPT))
    ckpt = torch.load(P_CKPT)
    
    if CONTINUE:
        print("USING CKPT Config")
        config = ckpt["config"]

In [None]:
class SiameseVAE(torch.nn.Module):
    def __init__(self, backbone, net_means, net_logvars):
        super(SiameseVAE, self).__init__()

        self.backbone = backbone
        self.net_means = net_means
        self.net_logvars = net_logvars
        
        self.dim_out = net_means[-1].out_features

    def encode(self, x):
        inter = self.backbone(x)
        z_mu = self.net_means(inter)

        # we predict log_var = log(std**2)
        # -> std = exp(0.5 * log_var)
        # -> alternative is to directly predict std ;)
        z_logvar = self.net_logvars(inter)

        return z_mu, z_logvar
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return z

    def reparametrize(self, mu, logvar):
        #
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(logvar)
        #
        return eps * std + mu
        
    def loss_kld(self, mu, logvar):
        loss = 1 + logvar - mu ** 2 - logvar.exp()
        loss = -0.5 * torch.sum(loss, dim = 1)
        return loss

    def loss_rec(self, x1, x2):
        #loss = F.mse_loss(x1, x2)
        loss = ((x1 - x2)**2).mean(axis=1)
        #loss = - F.cosine_similarity(x1, x2)
        #p = F.normalize(x1, p=2, dim=1)
        #z = F.normalize(x2, p=2, dim=1)
        #loss = -(p * z).sum(dim=1)
        return loss

In [None]:
# load data
dl_train, dl_down_train, dl_down_valid = get_dataloaders_from_config(config)
#
# create model
backbone = get_backbone(config.backbone, pretrained=False)

# projectors
projector_means = get_projector(d_in=backbone.dim_out, **config.projector_args)
projector_logvars = get_projector(d_in=backbone.dim_out, **config.projector_args)

#
model = SiameseVAE(backbone, projector_means, projector_logvars)

# optimizer
optimizer = get_optimizer(config.optimizer, model, config.optimizer_args)
scheduler = get_scheduler(config.scheduler, optimizer, config.scheduler_args)

In [None]:
global_step = 0
epoch = 0
#
if P_CKPT is not None:
    r = model.load_state_dict(ckpt['model_state_dict'])
    print("Load model state dict", r)
    if CONTINUE:
        print("LOAD optimizer")
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        #
        print("LOAD scheduler")
        scheduler.load_state_dict(ckpt['lr_scheduler_state_dict'])
        #
        global_step = ckpt['global_step']
        epoch = ckpt['global_epoch']
        print("Continue epoch {}, step {}".format(epoch, global_step))

In [None]:
# tensorboard
writer = SummaryWriter(config.p_logs)

# create train dir
config.p_logs.mkdir(exist_ok=True, parents=True)
config.p_ckpts.mkdir(exist_ok=True, parents=True)
#
print("tensorboard --logdir={}".format(config.p_logs))

In [None]:
model

## DEBUGGGING

In [None]:
model = model.to(config.device)

In [None]:
def dl_generator():
    for (x1, x2), target in dl_train:
        yield x1, x2
        
generator = dl_generator()

In [None]:

x1, x2 = next(generator)

model.train()
x1, x2 = x1.to(config.device), x2.to(config.device)
optimizer.zero_grad()
    
mu1, logvar1 = model.encode(x1)
mu2, logvar2 = model.encode(x2)
    
z1 = model.reparametrize(mu1, logvar1)
z2 = model.reparametrize(mu1, logvar2)
    
# rec loss
loss_rec = model.loss_rec(z1, z2).mean()
    
# kld loss
loss_kld1 = model.loss_kld(mu1, logvar1).mean()
loss_kld2 = model.loss_kld(mu2, logvar2).mean()
    
loss_kld = (loss_kld1 + loss_kld2) / 2
    
loss = (config.loss.lmda_kld * loss_kld) + (config.loss.lmda_rec * loss_rec)
    
loss.backward()
optimizer.step()
        
print("rec {} kld {} kld_1 {} kld_2 {}".format(loss_rec.item(), loss_kld.item(), loss_kld1.item(), loss_kld2.item()))

## END DEBUGGING

In [None]:
model = model.to(config.device)
for epoch in range(epoch, config.num_epochs, 1):
    # STD EVAL
    if epoch % config.freqs.std_eval == 0:
        std, cov = std_cov_valid(dl_down_valid, model, config.device)
        plt.matshow(cov)
        plt.colorbar()
        print("min {:.3f} max: {:.3f}".format(cov.min(), cov.max()))
        plt.show()
        #
        writer.add_scalar('std', std, global_step)

    # STD EVAL BACKBONE
    if epoch % config.freqs.std_eval == 0:
        std, cov = std_cov_valid(dl_down_valid, model.backbone, config.device)
        plt.matshow(cov)
        plt.colorbar()
        print("backbone min {:.3f} max: {:.3f}".format(cov.min(), cov.max()))
        plt.show()
        #
        writer.add_scalar('std_backbone', std, global_step)
        
    # KNN EVAL
    if epoch % config.freqs.knn_eval == 0:
        acc = down_knn(dl_down_valid, model, config.device, n_neighbors=5)
        #
        writer.add_scalar('acc_knn', acc, global_step)
    
        # KNN EVAL
    if epoch % config.freqs.knn_eval == 0:
        acc = down_knn(dl_down_valid, model.backbone, config.device, n_neighbors=5)
        #
        writer.add_scalar('acc_knn_back', acc, global_step)
    
    # LINEAR EVAL
    if epoch % config.freqs.lin_eval == 0:
        classifier = torch.nn.Linear(model.dim_out, config.n_classes).to(config.device)
        classifier.weight.data.normal_(mean=0.0, std=0.01)
        classifier.bias.data.zero_()
        #
        criterion = torch.nn.CrossEntropyLoss().to(config.device)
        #

        optimizer_down = get_optimizer(config.down_optimizer, classifier, config.down_optimizer_args)
        scheduler_down = get_scheduler(config.down_scheduler, optimizer_down, config.down_scheduler_args)
        #
        _, _ = down_train_linear(model, classifier, dl_down_train,
                              optimizer_down, config.device, config.down_num_epochs)
            
        acc = down_valid_linear(
                model,
                classifier,
                dl_down_valid,
                config.device)
        writer.add_scalar('acc_linear', acc, global_step)
    
    # TRAIN STEP
    losses, step = 0., 0.
    losses_kld = 0
    losses_rec = 0
    p_bar = tqdm(dl_train, desc=f'Pretrain {epoch}')
    for (x1, x2), target in p_bar:
        model.train()
        x1, x2 = x1.to(config.device), x2.to(config.device)
        optimizer.zero_grad()
    
        mu1, logvar1 = model.encode(x1)
        z1 = model.reparametrize(mu1, logvar1)
        
        mu2, logvar2 = model.encode(x2)
        z2 = model.reparametrize(mu2, logvar2)
    
        # rec loss
        loss_rec = model.loss_rec(z1, z2).mean()
        loss = loss_rec
    
        ## kld loss
        #loss_kld1 = model.loss_kld(mu1, logvar1).mean()
        #loss_kld2 = model.loss_kld(mu2, logvar2).mean()
        loss_kld1 = torch.Tensor([0])
        loss_kld2 = torch.Tensor([0])
    
        loss_kld = (loss_kld1 + loss_kld2) / 2
    
        #loss = (config.loss.lmda_kld * loss_kld) + (config.loss.lmda_rec * loss_rec)
        
        loss.backward()
        optimizer.step()

        
        losses += loss.item()
        losses_kld += loss_kld.item()
        losses_rec += loss_rec.item()
        global_step += 1
        step += 1
        p_bar.set_postfix({'loss': losses / step, 'rec': losses_rec / step, 'kld': losses_kld / step})
        #
        writer.add_scalar('loss', loss.item(), global_step)
        writer.add_scalar('rec loss', loss_rec.item(), global_step)
        writer.add_scalar('kld loss', loss_kld.item(), global_step)
        writer.add_scalar('kld1 loss', loss_kld1.item(), global_step)
        writer.add_scalar('kld2 loss', loss_kld2.item(), global_step)

        
    writer.add_scalar('epoch loss', losses / step, global_step)
    
    # CHECKPOINTING
    if epoch % config.freqs.ckpt == 0 and epoch != 0:
        p_ckpt = config.p_ckpts / config.fs_ckpt.format(config.dataset, epoch)
        config.p_ckpts.mkdir(exist_ok=True, parents=True)
        #
        save_checkpoint(model, optimizer, scheduler, config, epoch, global_step, p_ckpt)
        print('\nSave model for epoch {} at {}'.format(epoch, p_ckpt))
    writer.add_scalar('epoch', epoch, global_step)