In [None]:
from pathlib import Path
import datetime
#
from dotted_dict import DottedDict
import torch
#
import numpy as np
import pprint
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
#
import torch.nn.functional as F

In [None]:
from models.backbones import *
from models.projectors import *
from models.barlow_twins import BarlowTwins
from optimizers import *
from augmentations import SimSiamAugmentation, Augmentation
from datasets import get_dataset
from utils import show, show_batch
from config_utils import get_dataloaders_from_config
from train_utils import down_knn, down_train_linear, down_valid_linear

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
pp = pprint.PrettyPrinter(indent=4)

In [None]:
def get_config_template():
    config = DottedDict()
    return config

def add_paths_to_confg(config):
    # run directory name
    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    fs_run = "run_{}_{}_{}".format(config.dataset, config.backbone, timestamp)
    
    # checkpoint
    config.fs_ckpt = "model_{}_epoch_{:0>6}.ckpt"
    
    # train dir
    if config.debug:
        config.p_train = Path(config.p_base) / "tmp" / fs_run
    else:
        config.p_train = Path(config.p_base) / fs_run
    config.p_ckpts = config.p_train / "ckpts"
    config.p_logs = config.p_train / "logs"

In [None]:
config = get_config_template()

#################
# DVICE
#################
config.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#config.device = 'cpu'

#################
# frequencies
#################
config.freqs = {
    "ckpt": 2,
    "lin_eval": 2,
    "knn_eval": 2,
    "std_eval": 2,
    "plot_cov": 2,
}
#################
# data
#################
config.p_data = "/mnt/data/pytorch"
config.dataset = "cifar10"
config.img_size = 32
config.n_classes = 10
config.train_split = 'train'
config.valid_split = "valid"
config.augmentations_train = [
    ("RandomResizedCrop", {'size': config.img_size, "scale": (0.2, 1.0)}),
    ("RandomHorizontalFlip", {'p': 0.5}),
    ("RandomApply", {
        "transforms": [
            ("ColorJitter", {"brightness": 0.4,
                             "contrast": 0.4,
                             "saturation": 0.2,
                             'hue': 0.1})
        ],
        "p": 0.8,
    }),
    ("RandomGrayscale", {"p": 0.1}),
    ("ToTensor", {}),
    ('Normalize', {'mean': [0.485, 0.456, 0.406],
                   'std':[0.229, 0.224, 0.225]}),
]
#
config.augmentations_valid = [
    ("Resize", {'size': (config.img_size, config.img_size)}),
    ("ToTensor", {}),
    ('Normalize', {'mean': [0.485, 0.456, 0.406],
                   'std':[0.229, 0.224, 0.225]}),
]
#################
# train model
#################
config.backbone =  "ResNet-18"
config.projector_args = {
    'd_out': 512,
    'd_hidden': 512,
    'n_hidden': 2,
    'normalize': True,
    'dropout_rate': None,
    'activation_last': False,
    'normalize_last': False,
    'dropout_rate_last': None,
}
#################
# training
#################
config.batch_size = 512
config.num_epochs = 10
config.num_workers = 8

#################
# optimizer
#################
config.optimizer = "sgd"
config.optimizer_args = {
        "lr": 0.3,
        "weight_decay": 5e-4,
        "momentum": 0.9
    }
config.scheduler = "cosine_decay"
config.scheduler_args = {
        "T_max": config.num_epochs,
        "eta_min": 0,
}
#################
# down train
#################
config.down_batch_size = 128
config.down_num_epochs = 2
config.down_num_workers = 8

#################
# down optimizer
#################
config.down_optimizer = "sgd"
config.down_optimizer_args = {
        "lr": 0.03 * config.down_batch_size / 256,
        "weight_decay": 5e-4,  # used always
        "momentum": 0.9
    }
config.down_scheduler = "cosine_decay"
config.down_scheduler_args = {
        "T_max": config.down_num_epochs,
        "eta_min": 0,
}

config.loss = {
    'scale': 0.024,
    'lmbda': 0.0051
}
config.debug = False
config.p_base = "/mnt/experiments/barlow"
add_paths_to_confg(config)
config = DottedDict(config)

In [None]:
pp.pprint(config)

In [None]:
# create model
backbone = get_backbone(config.backbone, zero_init_residual=True, pretrained=False)
projector = get_projector(d_in=backbone.dim_out, **config.projector_args)
model = BarlowTwins(backbone, projector, config.loss["scale"], config.loss["lmbda"])

In [None]:
model

In [None]:
optimizer = get_optimizer(config.optimizer, model, config.optimizer_args)
scheduler = get_scheduler(config.scheduler, optimizer, config.scheduler_args)

In [None]:
dl_train, dl_down_train, dl_down_valid = get_dataloaders_from_config(config)

In [None]:
def train_step(model, optimizer, device, x1, x2):
    model.train()
    x1, x2 = x1.to(device), x2.to(device)
    optimizer.zero_grad()
    loss = model(x1, x2)
    loss.backward()
    optimizer.step()
    return loss

In [None]:
global_step = 0
epoch = 0
model = model.to(config.device)
for epoch in range(epoch, config.num_epochs, 1):
    # KNN EVAL
    if epoch % config.freqs.knn_eval == 0:
        acc = down_knn(dl_down_valid, model.backbone, config.device, n_neighbors=5)
    # LINEAR EVAL
    if epoch % config.freqs.lin_eval == 0:
        classifier = torch.nn.Linear(model.backbone.dim_out, config.n_classes).to(config.device)
        classifier.weight.data.normal_(mean=0.0, std=0.01)
        classifier.bias.data.zero_()
        #
        criterion = torch.nn.CrossEntropyLoss().to(config.device)
        #

        optimizer_down = get_optimizer(config.down_optimizer, classifier, config.down_optimizer_args)
        scheduler_down = get_scheduler(config.down_scheduler, optimizer_down, config.down_scheduler_args)
        #
        _, _ = down_train_linear(model.backbone, classifier, dl_down_train,
                              optimizer_down, config.device, config.down_num_epochs)
            
        acc = down_valid_linear(
                model.backbone,
                classifier,
                dl_down_valid,
                config.device)
    
    # TRAIN STEP
    losses, step = 0., 0.
    p_bar = tqdm(dl_train, desc=f'Pretrain {epoch}')
    for (x1, x2), target in p_bar:
        if global_step % config.freqs.std_eval == 0:
            pass
        if global_step % config.freqs.plot_cov == 0:
            pass
        if global_step % config.freqs.ckpt == 0:
            pass
        # loss
        loss = train_step(model, optimizer, config.device, x1, x2)
        losses += loss.item()
        global_step += 1
        step += 1
        p_bar.set_postfix({'loss': losses / step})
    
    # checkpointing
    if epoch % config.freqs.ckpt == 0:
            pass