In [None]:
from pathlib import Path
import datetime
#
from dotted_dict import DottedDict
import torch
#
import numpy as np
import pprint
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
pp = pprint.PrettyPrinter(indent=4)

In [None]:
def get_config_template():
    config = DottedDict()
    return config

def add_paths_to_confg(config):
    # run directory name
    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    fs_run = "run_{}_{}_{}".format(config.dataset, config.backbone, timestamp)
    
    # checkpoint
    config.fs_ckpt = "model_{}_epoch_{:0>6}.ckpt"
    
    # train dir
    if config.debug:
        config.p_train = Path(config.p_base) / "tmp" / fs_run
    else:
        config.p_train = Path(config.p_base) / fs_run
    config.p_ckpts = config.p_train / "ckpts"
    config.p_logs = config.p_train / "logs"

In [None]:
config = get_config_template()

#################
# DVICE
#################
config.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#config.device = 'cpu'

#################
# frequencies
#################
config.freqs = {
    "ckpt": 5,
    "lin_eval": 5,
    "knn_eval": 5,
    "std_eval": 5,
    "plot_cov": 5,
}
#################
# data
#################
config.p_data = "/mnt/data/pytorch"
config.dataset = "cifar10"
config.img_size = 32
config.n_classes = 10
config.train_split = 'train'
config.valid_split = "valid"
config.augmentations_train = [
    ("RandomResizedCrop", {'size': config.img_size, "scale": (0.2, 1.0)}),
    ("RandomHorizontalFlip", {'p': 0.5}),
    ("RandomApply", {
        "transforms": [
            ("ColorJitter", {"brightness": 0.4,
                             "contrast": 0.4,
                             "saturation": 0.2,
                             'hue': 0.1})
        ],
        "p": 0.8,
    }),
    ("RandomGrayscale", {"p": 0.1}),
    ("ToTensor", {}),
    ('Normalize', {'mean': [0.485, 0.456, 0.406],
                   'std':[0.229, 0.224, 0.225]}),
]
#
config.augmentations_valid = [
    ("Resize", {'size': (config.img_size, config.img_size)}),
    ("ToTensor", {}),
    ('Normalize', {'mean': [0.485, 0.456, 0.406],
                   'std':[0.229, 0.224, 0.225]}),
]
#################
# train model
#################
config.backbone =  "ResNet-18"
config.projector_args = {
    'd_out': 512,
    'd_hidden': 512,
    'n_hidden': 3,
    'normalize': True,
    'dropout_rate': None,
    'activation_last': False,
    'normalize_last': False,
    'dropout_rate_last': None,
}
#################
# training
#################
config.batch_size = 512
config.num_epochs = 20
config.num_workers = 8

#################
# optimizer
#################
config.optimizer = "sgd"
config.optimizer_args = {
        "lr": 0.3,
        "weight_decay": 5e-4,  # used always
        "momentum": 0.9
    }
config.scheduler = "cosine_decay"
config.scheduler_args = {
        "T_max": config.num_epochs,
        "eta_min": 0,
}
#################
# down train
#################
config.down_batch_size = 128
config.down_num_epochs = 100
config.down_num_workers = 8

#################
# down optimizer
#################
config.down_optimizer = "sgd"
config.down_optimizer_args = {
        "lr": 0.03 * config.down_batch_size / 256,
        "weight_decay": 5e-4,  # used always
        "momentum": 0.9
    }
config.down_scheduler = "cosine_decay"
config.down_scheduler_args = {
        "T_max": config.down_num_epochs,
        "eta_min": 0,
}

config.loss = {
    'scale': 0.024,
    'lmbda': 0.0051
}
config.debug = False
config.p_base = "/mnt/experiments/barlow"
add_paths_to_confg(config)

In [None]:
pp.pprint(config)

In [None]:
from models.backbones import *
from models.projectors import *
from models.barlow_twins import BarlowTwins
from optimizers import *
from augmentations import SimSiamAugmentation, Augmentation
from datasets import get_dataset
from utils import show, show_batch

In [None]:
# create model
backbone = get_backbone(config.backbone, zero_init_residual=True)
projector = get_projector(d_in=backbone.dim_out, **config.projector_args)
model = BarlowTwins(backbone, projector, config.loss["scale"], config.loss["lmbda"])

In [None]:
model

In [None]:
optimizer = get_optimizer(config.optimizer, model, config.optimizer_args)
scheduler = get_scheduler(config.scheduler, optimizer, config.scheduler_args)

In [None]:
# Augmentations
trans_train = SimSiamAugmentation(config.augmentations_train, downstream=False)
trans_down_train = SimSiamAugmentation(config.augmentations_train, downstream=True)
trans_down_valid = SimSiamAugmentation(config.augmentations_valid, downstream=True)

# Datasets
ds_train = get_dataset(
    dataset=config.dataset,
    p_data=config.p_data,
    transform=trans_train,
    target_transform=None,
    split='train'
)
ds_down_train = get_dataset(
    dataset=config.dataset,
    p_data=config.p_data,
    transform=trans_down_train,
    target_transform=None,
    split='train'
)
ds_down_valid = get_dataset(
    dataset=config.dataset,
    p_data=config.p_data,
    transform=trans_down_train,
    target_transform=None,
    split='valid'
)

In [None]:
# DataLoader
dl_train = DataLoader(
    ds_train,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    drop_last=False,
    pin_memory=True
)
dl_down_train = DataLoader(
    ds_down_train,
    batch_size=config.down_batch_size,
    shuffle=True,
    num_workers=config.down_num_workers,
    drop_last=False,
    pin_memory=True
)
dl_down_valid = DataLoader(
    ds_down_valid,
    batch_size=config.down_batch_size,
    shuffle=True,
    num_workers=config.down_num_workers,
    drop_last=False,
    pin_memory=True
)

In [None]:
def train_epoch(epoch, data_loader, model, optimizer, device, debug=False):
    model.train()

    losses, step = 0., 0.
    p_bar = tqdm(data_loader, desc=f'Pretrain {epoch}')
    for (x1, x2), target in p_bar:
        x1, x2 = x1.to(device), x2.to(device)
        optimizer.zero_grad()
        loss = model(x1, x2)
        loss.backward()
        optimizer.step()
        losses += loss.item()
        step += 1

        p_bar.set_postfix({'loss': losses / step})
        
        if debug is True and step == 10:
            break

    loss_avg = losses / step
    return loss_avg, step

def train_step(model, optimizer, device, x1, x2):
    model.train()
    x1, x2 = x1.to(device), x2.to(device)
    optimizer.zero_grad()
    loss = model(x1, x2)
    loss.backward()
    optimizer.step()
    return loss


def knn_eval(epoch, data_loader, model, device, n_neighbors=5):
    model.eval()
    #
    outs = []
    targets = []
    #
    # p_bar = tqdm(data_loader, desc=f'Valid KNN {epoch}')
    with tqdm(total=len(data_loader), desc=f'Valid KNN {epoch}') as p_bar:
        with torch.no_grad():
            for data, target in data_loader:
                out = model.backbone(data.to(device)).squeeze()
                out = model.projector(out)
                outs.append(out.cpu().numpy())
                targets.append(target.cpu().numpy())

                p_bar.update()

        x = np.concatenate(outs)
        y = np.concatenate(targets)

        neigh = KNeighborsClassifier(n_neighbors=n_neighbors,
                                     algorithm='brute', n_jobs=8)
        neigh.fit(x, y)
        score = neigh.score(x, y)

        p_bar.set_postfix({"acc": score})
    return score

In [None]:
global_step = 0
epoch = 0
model = model.to(config.device)
for epoch in range(epoch, config.num_epochs, 1):
    score = knn_eval(epoch, dl_down_valid, model, config.device)
    print(score)
    
    losses, step = 0., 0.
    p_bar = tqdm(dl_train, desc=f'Pretrain {epoch}')
    for (x1, x2), target in p_bar:
        loss = train_step(model, optimizer, config.device, x1, x2)
        losses += loss.item()
        step += 1
        global_step += 1
        p_bar.set_postfix({'loss': losses / step}) 
    # knn eval
    
    # std eval

    # linear train
    # linear eval