In [None]:
from pathlib import Path
import datetime
#
from dotted_dict import DottedDict
import torch
#
import numpy as np
import pprint
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.neighbors import KNeighborsClassifier
#
import torch.nn.functional as F
import matplotlib.pyplot as plt
#
from torch.utils.tensorboard import SummaryWriter

In [None]:
from models.backbones import *
from models.projectors import *
from models.barlow_twins import BarlowTwins
from optimizers import *
from augmentations import SimSiamAugmentation, Augmentation
from datasets import get_dataset
from utils import show, show_batch, save_checkpoint
from config_utils import get_dataloaders_from_config, get_config_template, add_paths_to_confg
from train_utils import down_knn, down_train_linear, down_valid_linear, std_cov_valid

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
pp = pprint.PrettyPrinter(indent=4)

In [None]:
config = get_config_template()

#################
# DVICE
#################
config.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#config.device = 'cpu'

#################
# frequencies
#################
config.freqs = {
    "ckpt": 10,
    "lin_eval": 5,
    "knn_eval": 5,
    "std_eval": 5,
}
#################
# data
#################
config.p_data = "/mnt/data/pytorch"
config.dataset = "cifar10"
config.img_size = 64
config.n_classes = 10
config.train_split = 'train'
config.down_train_split = 'train'
config.down_valid_split = "valid"
config.augmentations_train = [
    ("RandomResizedCrop", {'size': config.img_size, "scale": (0.2, 1.0)}),
    ("RandomHorizontalFlip", {'p': 0.5}),
    ("RandomApply", {
        "transforms": [
            ("ColorJitter", {"brightness": 0.4,
                             "contrast": 0.4,
                             "saturation": 0.2,
                             'hue': 0.1})
        ],
        "p": 0.8,
    }),
    ("RandomApply", {
        "transforms": [
            ('GaussianBlur', {
             'kernel_size': 128 // 20 * 2 + 1, 'sigma': (0.5, 2.0)})
        ],
        "p": 0.9,
    }),
    ("RandomGrayscale", {"p": 0.1}),
    ("ToTensor", {}),
    ('Normalize', {'mean': [0.485, 0.456, 0.406],
                   'std':[0.229, 0.224, 0.225]}),
]
#
config.augmentations_valid = [
    ("Resize", {'size': (config.img_size, config.img_size)}),
    ("ToTensor", {}),
    ('Normalize', {'mean': [0.485, 0.456, 0.406],
                   'std':[0.229, 0.224, 0.225]}),
]
#################
# train model
#################
config.backbone =  "ResNet-18"
config.projector_args = {
    'd_out': 2048,
    'd_hidden': 2048,
    'n_hidden': 2,
    'normalize': True,
    'dropout_rate': None,
    'activation_last': False,
    'normalize_last': False,
    'dropout_rate_last': None,
}
#################
# training
#################
config.batch_size = 512
config.num_epochs = 1600
config.num_workers = 8

#################
# optimizer
#################
config.optimizer = "sgd"
config.optimizer_args = {
        "lr": 0.6,
        "weight_decay": 1e-6,
        "momentum": 0.90
    }
config.scheduler = "cosine_decay"
config.scheduler_args = {
        "T_max": config.num_epochs,
        "eta_min": 0,
}
#################
# down train
#################
config.down_batch_size = 128
config.down_num_epochs = 2
config.down_num_workers = 8

#################
# down optimizer
#################
config.down_optimizer = "sgd"
config.down_optimizer_args = {
        "lr": 0.03 * config.down_batch_size / 256,
        "weight_decay": 5e-4,  # used always
        "momentum": 0.9
    }
config.down_scheduler = "cosine_decay"
config.down_scheduler_args = {
        "T_max": config.down_num_epochs,
        "eta_min": 0,
}

config.loss = {
    'scale': 0.024,
    'lmbda': 0.0051
}
config.debug = False
config.p_base = "/mnt/experiments/barlow"
add_paths_to_confg(config)
config = DottedDict(config)

In [None]:
# META VARS
P_CKPT = None
CONTINUE = True

In [None]:
if P_CKPT is not None:
    print("LOADING CHECKPOINT {}".format(P_CKPT))
    ckpt = torch.load(P_CKPT)
    
    if CONTINUE:
        print("USING CKPT Config")
        config = ckpt["config"]

In [None]:
# create model
backbone = get_backbone(config.backbone, zero_init_residual=True, pretrained=False)
projector = get_projector(d_in=backbone.dim_out, **config.projector_args)
model = BarlowTwins(backbone, projector, config.loss["scale"], config.loss["lmbda"])
model = model.to(config.device) # important to put model already to device, otherwise optimizer fails! (BUG)

# load data
dl_train, dl_down_train, dl_down_valid = get_dataloaders_from_config(config)

# optimizer
optimizer = get_optimizer(config.optimizer, model, config.optimizer_args)
scheduler = get_scheduler(config.scheduler, optimizer, config.scheduler_args)

In [None]:
global_step = 0
epoch = 0
#
if P_CKPT is not None:
    r = model.load_state_dict(ckpt['model_state_dict'])
    print("Load model state dict", r)
    if CONTINUE:
        print("LOAD optimizer")
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        #
        print("LOAD scheduler")
        scheduler.load_state_dict(ckpt['lr_scheduler_state_dict'])
        #
        global_step = ckpt['global_step']
        epoch = ckpt['global_epoch']
        print("Continue epoch {}, step {}".format(epoch, global_step))

In [None]:
def train_step(model, optimizer, device, x1, x2):
    model.train()
    x1, x2 = x1.to(device), x2.to(device)
    optimizer.zero_grad()
    loss = model(x1, x2)
    loss.backward()
    optimizer.step()
    return loss

In [None]:
# tensorboard
writer = SummaryWriter(config.p_logs)

# create train dir
config.p_logs.mkdir(exist_ok=True, parents=True)
config.p_ckpts.mkdir(exist_ok=True, parents=True)
#
print("tensorboard --logdir={}".format(config.p_logs))

In [None]:
model

In [None]:
model = model.to(config.device)
for epoch in range(epoch, config.num_epochs, 1):
    # STD EVAL
    if epoch % config.freqs.std_eval == 0:
        std, cov = std_cov_valid(dl_down_valid, model.backbone, config.device)
        plt.matshow(cov)
        plt.colorbar()
        print("min {:.3f} max: {:.3f}".format(cov.min(), cov.max()))
        plt.show()
        #
        writer.add_scalar('std', std, global_step)
    
    # KNN EVAL
    if epoch % config.freqs.knn_eval == 0:
        acc = down_knn(dl_down_valid, model.backbone, config.device, n_neighbors=5)
        #
        writer.add_scalar('acc_knn', acc, global_step)
    
    # LINEAR EVAL
    if epoch % config.freqs.lin_eval == 0:
        classifier = torch.nn.Linear(model.backbone.dim_out, config.n_classes).to(config.device)
        classifier.weight.data.normal_(mean=0.0, std=0.01)
        classifier.bias.data.zero_()
        #
        criterion = torch.nn.CrossEntropyLoss().to(config.device)
        #

        optimizer_down = get_optimizer(config.down_optimizer, classifier, config.down_optimizer_args)
        scheduler_down = get_scheduler(config.down_scheduler, optimizer_down, config.down_scheduler_args)
        #
        _, _ = down_train_linear(model.backbone, classifier, dl_down_train,
                              optimizer_down, config.device, config.down_num_epochs)
            
        acc = down_valid_linear(
                model.backbone,
                classifier,
                dl_down_valid,
                config.device)
        writer.add_scalar('acc_linear', acc, global_step)
    
    # TRAIN STEP
    losses, step = 0., 0.
    p_bar = tqdm(dl_train, desc=f'Pretrain {epoch}')
    for (x1, x2), target in p_bar:
        loss = train_step(model, optimizer, config.device, x1, x2)
        losses += loss.item()
        global_step += 1
        step += 1
        p_bar.set_postfix({'loss': losses / step})
        #
        writer.add_scalar('batch loss', loss.item(), global_step)
    
    writer.add_scalar('epoch loss', losses / step, global_step)
    
    # CHECKPOINTING
    if epoch % config.freqs.ckpt == 0 and epoch != 0:
        p_ckpt = config.p_ckpts / config.fs_ckpt.format(config.dataset, epoch)
        config.p_ckpts.mkdir(exist_ok=True, parents=True)
        #
        save_checkpoint(model, optimizer, scheduler, config, epoch, global_step, p_ckpt)
        print('\nSave model for epoch {} at {}'.format(epoch, p_ckpt))
    writer.add_scalar('epoch', epoch, global_step)

## Prepare and export representations

In [None]:
import datasets

In [None]:
# train: 1000, 5000, 10000, 50000, 100000
n_samples = len(dl_down_train) * config.down_batch_size

max_imgs = 20000

In [None]:
p_features_base = config.p_train / "representations"
p_features_base.mkdir(exist_ok=True)
p_imgs = p_features_base / f'X_{n_samples}.npy'
p_features = p_features_base / f"R_{n_samples}.npy"
p_targets = p_features_base / f"Y_{n_samples}.npy"

In [None]:
print(len(dl_down_train) * config.down_batch_size)
print(len(dl_down_valid) *  config.down_batch_size)

In [None]:
trans_final = SimSiamAugmentation(config.augmentations_valid, downstream=True)
ds_final = get_dataset(
            dataset=config.dataset,
            p_data=config.p_data,
            transform=trans_final,
            target_transform=None,
            split='train')

In [None]:
batch_size = 1
dl_final = DataLoader(
        ds_final,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        drop_last=False,
        pin_memory=True
    )

In [None]:
n_samples = len(dl_final)
max_imgs = np.inf

In [None]:
all_features = []
all_imgs = []
all_targets = []

all_samples = 0
model.eval()
for x, y in dl_final:
    with torch.no_grad():
        r = model.backbone(x.to(config.device))
        #
        r = r.detach().cpu().numpy()
        x = x.detach().cpu().numpy()
        y = y.detach().cpu().numpy()
        #
        all_features.append(r)
        if n_samples <= max_imgs:
            all_imgs.append(x)
        all_targets.append(y)
        #
        all_samples += x.shape[0]
    if all_samples % 1000 == 0:
        print(all_samples)
    if all_samples >= n_samples:
        break

In [None]:
R = np.concatenate(all_features)
print(R.shape)

In [None]:
if n_samples <= max_imgs:
    X = np.concatenate(all_imgs)
    print(X.shape)

In [None]:
Y = np.concatenate(all_targets)
print(Y.shape)

In [None]:
if n_samples <= max_imgs:
    np.save(p_imgs, X)

In [None]:
np.save(p_features, R)

In [None]:
np.save(p_targets, Y)

In [None]:
print(p_imgs)