In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

In [None]:
# Python
from pathlib import Path
import os
import warnings
import math
import datetime
import time
warnings.filterwarnings('ignore')

# Extern
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.models.resnet import resnet18
from dotted_dict import DottedDict
import pickle
from tqdm import tqdm
import pprint
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression

# Local
from BTwins.barlow import *
from BTwins.transform_utils import *
from BTwins.utils import *
from csprites.datasets import ClassificationDataset
import utils
from backbone import get_backbone
from optimizer import get_optimizer
from Beta.models import get_projector
import plot_utils
import eval_utils

In [None]:
config = {
    'device': 'cuda',
    'cuda_visible_devices': '1',
    'p_data': '/mnt/data/csprites/single_csprites_64x64_n7_c32_a32_p30_s3_bg_inf_random_function_70000',
    'target_variable': 'shape',
    'batch_size': 512,
    'num_workers': 20,
    'num_epochs': 5,
    'freqs': {
        'ckpt': 50,         # epochs
        'linprob': 5,       # epochs
    },
    'num_vis': 64,
    'backbone': "FCN16i223o64",
    'backbone_args': {
        'ch_last': 128,
        'dim_in': 3,
    },
    'dim_out': 64,
    'optimizer': 'adam',
    'optimizer_args': {
        'lr': 0.001,
        'weight_decay': 1e-6
    },
    'projector': [512,
                  512,
                  512],
    'cnt_overlap': 0.0,
    'ratio_stl_geo': 0.5,
    'p_ckpts': "ckpts",
    'p_model': "model_{}.ckpt",
    'p_stats': "stats.pkl",
    'p_config': 'config.pkl',
    'p_R_train': 'R_train.npy',
    'p_R_valid': 'R_valid.npy',
    'p_Y_valid': 'Y_valid.npy',
    'p_Y_train': 'Y_train.npy',
}
p_base = Path("/mnt/experiments/csprites") / Path(config["p_data"]).name / "tmp"
#
ts = time.time()
st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')
#
config["p_experiment"] = str(p_base / "BT_[{}_d{}]_[S4]_{}".format(
    config["backbone"],
    config["backbone_args"]["ch_last"],
    st
    )
                            )
config = DottedDict(config)
pprint.pprint(config)

# CUDA SETTINGS

In [None]:
# TORCH SETTINGS
torch.backends.cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_visible_devices
device = torch.device(config.device)

# Dataset

In [None]:
p_ds_config = Path(config.p_data) / "config.pkl"

with open(p_ds_config, "rb") as file:
    ds_config = pickle.load(file)

target_variable = config.target_variable
target_idx = [idx for idx, target in enumerate(ds_config["classes"]) if target == target_variable][0]
n_classes = ds_config["n_classes"][target_variable]

In [None]:
norm_transform = utils.normalize_transform(
    ds_config["means"],
    ds_config["stds"])
inverse_norm_transform = utils.inverse_normalize_transform(
    ds_config["means"],
    ds_config["stds"]
)
target_transform = lambda x: x[target_idx]
#
stl_transform = transforms.Compose([
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.4, hue=0.4)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=0.5),
                Solarization(p=0.2)
])
fin_transform = transforms.Compose([
                transforms.ToTensor(),
                norm_transform
            ])

train_transform = CSpritesTransform(
    img_size=ds_config["img_size"],
    scale=(0.6, 1.0),
    ratio=(1, 1),
    p_hflip=(0.5),
    p_vflip=(0.5),
    stl_transform=stl_transform,
    fin_transform=fin_transform)

transform_linprob = transforms.Compose([
                transforms.Resize(ds_config["img_size"]),
                transforms.ToTensor(),
                norm_transform
            ])

In [None]:
# TRAIN
ds_train = ClassificationDataset(
    p_data = config.p_data,
    transform=train_transform,
    target_transform=target_transform,
    split="train"
)
dl_train = DataLoader(
    ds_train,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=False,
    drop_last=True
)
# LINPROB VALID
ds_linprob = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=target_transform,
    split="valid"
)
dl_linprob = DataLoader(
    ds_linprob,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers = config.num_workers,
    pin_memory=False
)

# Visualize Data

In [None]:
n_vis = 8 #config.num_vis

In [None]:
# dl_train
(x11, x12, x21, x22),_ = next(iter(dl_train))

In [None]:
x11 = inverse_norm_transform(x11[:n_vis])
x12 = inverse_norm_transform(x12[:n_vis])
x21 = inverse_norm_transform(x21[:n_vis])
x22 = inverse_norm_transform(x22[:n_vis])

In [None]:
grid_img = torchvision.utils.make_grid(x11, nrow=n_vis)
plt.figure(figsize=(n_vis * 2, 2))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
#
grid_img = torchvision.utils.make_grid(x12,  nrow=n_vis)
plt.figure(figsize=(n_vis * 2, 2))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
#
grid_img = torchvision.utils.make_grid(x21,  nrow=n_vis)
plt.figure(figsize=(n_vis * 2, 2))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
#
grid_img = torchvision.utils.make_grid(x22,  nrow=n_vis)
plt.figure(figsize=(n_vis * 2, 2))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
# dl_linprob_train
x,y = next(iter(dl_linprob))
x = x[:n_vis]
y = y[:n_vis]
#
x = inverse_norm_transform(x)
#
grid_img = torchvision.utils.make_grid(x, nrow=n_vis)
plt.figure(figsize=(n_vis * 2, 2))
plt.imshow(grid_img.permute(1, 2, 0))
#
y = [ds_config["class_maps"]["shape"][idx.item()] for idx in y]
print(y)

# Model

In [None]:
class BarlowTwins(nn.Module):
    def __init__(self, backbone, projector, dim_stl, dim_geo, dim_cnt):
        super().__init__()
        self.backbone = backbone
        self.projector = projector
        self.dim_cnt = dim_cnt
        self.dim_stl = dim_stl
        self.dim_geo = dim_geo
        
        self.bn_geo = nn.BatchNorm1d(dim_geo + dim_cnt, affine=False)
        self.bn_stl = nn.BatchNorm1d(dim_stl + dim_cnt, affine=False)

    def get_representation(self, x):
        return self.backbone(x)

    def forward(self, x):
        return self.projector(self.backbone(x))

    def barlow_stl_loss(self, z1, z2):
        # empirical cross-correlation matrix
        c = self.bn_stl(z1).T @ self.bn_stl(z2)

        # sum the cross-correlation matrix between all gpus
        c.div_(z1.shape[0])

        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        return on_diag, off_diag

    def barlow_geo_loss(self, z1, z2):
        # empirical cross-correlation matrix
        c = self.bn_geo(z1).T @ self.bn_geo(z2)

        # sum the cross-correlation matrix between all gpus
        c.div_(z1.shape[0])

        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        return on_diag, off_diag

def feature_split(z, dim, overlap=None):
    if overlap:
        z1 = z[:, : dim + overlap]
        z2 = z[:, dim:]
    else:
        z1 = z[:, :dim]
        z2 = z[:, dim:]
    return z1, z2

In [None]:
# backbone
backbone = get_backbone(config.backbone, **config.backbone_args)

# barlow projector
barlow_projector = get_projector(planes_in=backbone.dim_out, sizes=config.projector)

overlap_cnt = config["cnt_overlap"]
ratio_stl_geo = config["ratio_stl_geo"]
#
dim_cnt = int(barlow_projector.dim_out * overlap_cnt)
dim_stl_geo = barlow_projector.dim_out - dim_cnt
dim_stl = int(ratio_stl_geo * dim_stl_geo)
dim_geo = dim_stl_geo - dim_stl
#
w_off_stl = calc_lambda(dim_stl + dim_cnt)
w_off_geo = calc_lambda(dim_geo + dim_cnt)
#
model = BarlowTwins(backbone, barlow_projector, dim_stl=dim_stl, dim_geo=dim_geo, dim_cnt=dim_cnt)
print(dim_stl, dim_geo, dim_cnt)
print(w_off_stl, w_off_geo)

if torch.cuda.device_count() > 1 and device != "cpu":
    print("Using {} gpus!".format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)
    model.backbone = model.module.backbone
elif device != "cpu":
    print("Using 1 GPU!")
else:
    print("Using CPU!")
print(device)
model = model.to(device)
model

In [None]:
optimizer = get_optimizer(config.optimizer, model.parameters(), config.optimizer_args)

In [None]:
stats = {
    'train': {
        'loss': [],
        'epoch': [],
    },
    'linprob': {
        'linacc': [],
        'knnacc': [],
        'epoch': [],
    }
}
stats = DottedDict(stats)
#
p_experiment = Path(config.p_experiment)
p_experiment.mkdir(exist_ok=True, parents=True)
p_ckpts = p_experiment / config.p_ckpts
p_ckpts.mkdir(exist_ok=True)

In [None]:
global_step = 0
for epoch_idx in range(1, config.num_epochs + 1, 1):
    ################2
    # TRAIN
    ################
    model.train()
    epoch_step = 0
    epoch_loss = 0
   
    desc = "[{:3}/{:3}]".format(epoch_idx, config.num_epochs)
    pbar = tqdm(dl_train, bar_format= desc + '{bar:10}{n_fmt}/{total_fmt}{postfix}')
    #
    for (x11, x12, x21, x22), _ in pbar:
        x11 = x11.to(device)
        x12 = x12.to(device)
        x21 = x21.to(device)
        x22 = x22.to(device)
        #
        for param in model.parameters():
            param.grad = None
            
        # PROJECT
        z11 = model(x11)
        z12 = model(x12)
        z21 = model(x21)
        z22 = model(x22)
        
        # SPLIT 
        z11_stl, z11_geo = feature_split(z11, dim_stl, dim_cnt)
        z12_stl, z12_geo = feature_split(z12, dim_stl, dim_cnt)
        z21_stl, z21_geo = feature_split(z21, dim_stl, dim_cnt)
        z22_stl, z22_geo = feature_split(z22, dim_stl, dim_cnt)
        #
        # GEO LOSS
        geo_1112_on, geo_1112_off = model.barlow_geo_loss(z11_geo, z12_geo)
        geo_2122_on, geo_2122_off = model.barlow_geo_loss(z21_geo, z22_geo)

        # STL LOSS
        stl_1121_on, stl_1121_off = model.barlow_stl_loss(z11_stl, z21_stl)
        stl_1222_on, stl_1222_off = model.barlow_stl_loss(z12_stl, z22_stl)
        
        loss_on = geo_1112_on + geo_2122_on + stl_1121_on + stl_1222_on
        loss_on = loss_on * 1/4
        loss_off = (geo_1112_off * w_off_geo) + (geo_2122_off * w_off_geo) + (stl_1121_off * w_off_stl) + (stl_1222_off * w_off_stl)
        loff_off = loss_off * 1/4
        
        loss = loss_on + loss_off
        
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_step += 1
        global_step += 1
        #
        pbar.set_postfix({'L': loss.item(),
                          'on': loss_on.item(),
                          'od': loss_off.item(),
                          'g_1_on': geo_1112_on.item(),
                          'g_2_on': geo_2122_on.item(),
                          's_1_on': stl_1121_on.item(),
                          's_2_on': stl_1222_on.item(),
                          'g_1_od': geo_1112_off.item(),
                          'g_2_od': geo_2122_off.item(),
                          's_1_od': stl_1121_off.item(),
                          's_2_od': stl_1222_off.item()
                         })

    stats.train.loss.append(epoch_loss / epoch_step)
    stats.train.epoch.append(epoch_idx)

    ################
    # Linprob
    ################
    if epoch_idx % config.freqs.linprob == 0 or epoch_idx == config.num_epochs:
        model.eval()
        linacc, knnacc = utils.linprob_model(model.backbone, dl_linprob, device)
        print("    Linprob Eval @LR: {:.2f} @KNN: {:.2f}".format(linacc, knnacc))
        stats.linprob.epoch.append(epoch_idx)
        stats.linprob.knnacc.append(knnacc)
        stats.linprob.linacc.append(linacc)
        model.train()
    ##################
    # Checkpoint
    ##################
    if epoch_idx % config.freqs.ckpt == 0 or epoch_idx == config.num_epochs:
        print("save model!")# LOSSback
        if torch.cuda.device_count() > 1 and device != "cpu":
            torch.save(model.module.state_dict(), p_ckpts / config.p_model.format(epoch_idx))
        else:
            torch.save(model.state_dict(), p_ckpts / config.p_model.format(epoch_idx))

# SAVE RESULTS 

In [None]:
# plot losses
plt.plot(stats.train.epoch, stats.train.loss, label="train")
plt.legend()
plt.savefig(p_experiment / "barlow_loss.png")
plt.show()

# plot linprob acc
plt.plot(stats.linprob.epoch, stats.linprob.knnacc, label="knn")
plt.plot(stats.linprob.epoch, stats.linprob.linacc, label="lin")
plt.yticks([0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 1])
plt.legend()
plt.savefig(p_experiment / "linprob_acc.png")
plt.show()

In [None]:
with open(p_experiment / config.p_config, "wb") as file:
    pickle.dump(config, file)
with open(p_experiment / config.p_stats, "wb") as file:
    pickle.dump(stats, file)

# GET REPRESENTATIONS

In [None]:
dl_train, dl_valid = utils.get_raw_csprites_dataloader(
    p_data=config.p_data,
    img_size = ds_config["img_size"],
    batch_size = config.batch_size,
    norm_transform=norm_transform,
    num_workers = config["num_workers"]
)
p_R_train = p_experiment / config["p_R_train"]
p_Y_train = p_experiment / config["p_Y_train"]
p_R_valid = p_experiment / config["p_R_valid"]
p_Y_valid = p_experiment / config["p_Y_valid"]
#
model.eval()
R_train, Y_train = utils.get_representations(model.backbone, dl_train, device, imgs=False)
R_valid, Y_valid, X_valid = utils.get_representations(model.backbone, dl_valid, device, imgs=True, inverse_norm_transform=inverse_norm_transform)
#
np.save(p_R_train, R_train)
np.save(p_Y_train, Y_train)
np.save(p_R_valid, R_valid)
np.save(p_Y_valid, Y_valid)

#
print("TRAIN (R, Y)", R_train.shape, Y_train.shape)
print("VALID (R, Y)", R_valid.shape, Y_valid.shape)

# EVAL REPRESENTATIONS

In [None]:
eval_utils.eval_representations(
    R_train=R_train,
    R_valid=R_valid,
    Y_train=Y_train,
    Y_valid=Y_valid,
    X_valid=X_valid,
    p_experiment=p_experiment,
    class_names = ds_config["classes"],
    show=False
)

In [None]:
p_experiment

# EVAL SEPARATELY

In [None]:
# plt means
plot_utils.plot_mean_dists(
    R=R_valid,
    p_dir=p_experiment,
    show=True)

In [None]:
# plot class dist
plot_utils.plot_class_dist(
    R=R_valid,
    Y=Y_valid,
    n_plot=100,
    p_plot=p_experiment / "class_distribution.png",
    show=True,
    titles=ds_config["classes"])

In [None]:
# predict classes from features
utils.predict_all(
    R_train=R_train,
    Y_train=Y_train,
    R_valid=R_valid,
    Y_valid=Y_valid,
    target_names=ds_config["classes"],
    show=True,
    p_plot=p_experiment / "score_lr.png")

In [None]:
# show latent
utils.plot_latent_by_imgs(
    R=R_valid,
    X=X_valid,
    Y=Y_valid,
    n_imgs=50,
    show=True,
    p_plot=p_experiment / "feature_dims_highest.png")