In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

In [None]:
# Python
from pathlib import Path
import os
import warnings
import math
import datetime
import time
warnings.filterwarnings('ignore')

# Extern
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.models.resnet import resnet18
from dotted_dict import DottedDict
import pickle
from tqdm import tqdm
import pprint
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression

# Local
from BTwins.barlow import *
from BTwins.transform_utils import *
from BTwins.utils import *
from csprites.datasets import ClassificationDataset
import utils
from backbone import get_backbone
from optimizer import get_optimizer

# Notes
- Without geometric stuff transform!

# Paper Stuff
### Lrearning Rates
Batch Size	Learning Rate
- 128  0.7
- 256  0.4
- 512  0.3
- 1024 0.25
- 2048 0.2
- 4096 0.2

In [None]:
config = {
    'device': 'cuda',
    'cuda_visible_devices': '0',
    'p_data': '/mnt/data/csprites/single_csprites_64x64_n7_c128_a32_p10_s3_bg_inf_random_function_100000',
    'target_variable': 'shape',
    'batch_size': 1024,
    'num_workers': 6,
    'num_epochs': 200,
    'freqs': {
        'ckpt': 50,         # epochs
        'linprob': 5,       # epochs
    },
    'num_vis': 64,
    'backbone': "ResNet-18",
    'optimizer': 'adam',
    'optimizer_args': {
        'lr': 0.001,
        'weight_decay': 1e-6
    },
    'projector': [4096, 4096, 4096],
    'scale_factor': 1,
    'p_ckpts': "ckpts",
    'p_model': "model_{}.ckpt",
    'p_stats': "stats.pkl",
    'p_config': 'config.pkl',
    'p_R_train': 'R_train.npy',
    'p_R_valid': 'R_valid.npy',
    'w_stl': 0.5,
    'w_geo': 0.5,
    'w_l1': 0.05,
    'p_Y_valid': 'Y_valid.npy',
    'p_Y_train': 'Y_train.npy',
    'linprob': {
        'optimizer': 'adam',
        'optimizer_args': {
            'lr': 0.001,
            'weight_decay': 1e-6
        },
        'n_hid': 0,
        'd_hid': 1024,
        'num_epochs': 1
    }
    
}
p_base = Path("/mnt/experiments/csprites") / Path(config["p_data"]).name
#
ts = time.time()
st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')
#
config["p_experiment"] = str(p_base / "BTwins_bb_[{}]_target_[{}]_{}".format(config["backbone"],
                                                           config["target_variable"],
                                                                            st))
config['lambd'] = calc_lambda(config["projector"][-1] // 2)
config = DottedDict(config)
pprint.pprint(config)

In [None]:
# TORCH SETTINGS
torch.backends.cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_visible_devices
device = torch.device(config.device)

# Dataset

In [None]:
p_ds_config = Path(config.p_data) / "config.pkl"

with open(p_ds_config, "rb") as file:
    ds_config = pickle.load(file)

target_variable = config.target_variable
target_idx = [idx for idx, target in enumerate(ds_config["classes"]) if target == target_variable][0]
n_classes = ds_config["n_classes"][target_variable]

In [None]:
norm_transform = utils.normalize_transform(
    ds_config["means"],
    ds_config["stds"])
inverse_norm_transform = utils.inverse_normalize_transform(
    ds_config["means"],
    ds_config["stds"]
)
target_transform = lambda x: x[target_idx]
#
stl_transform = transforms.Compose([
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=0.5),
                Solarization(p=0.2)
])
fin_transform = transforms.Compose([
                transforms.ToTensor(),
                norm_transform
            ])

train_transform = CSpritesTransform(
    img_size=ds_config["img_size"],
    scale=(0.6, 1.0),
    ratio=(1, 1),
    p_hflip=(0.5),
    p_vflip=(0.5),
    stl_transform=stl_transform,
    fin_transform=fin_transform)

transform_linprob = transforms.Compose([
                transforms.Resize(ds_config["img_size"]),
                transforms.ToTensor(),
                norm_transform
            ])

In [None]:
# TRAIN
ds_train = ClassificationDataset(
    p_data = config.p_data,
    transform=train_transform,
    target_transform=target_transform,
    split="train"
)
dl_train = DataLoader(
    ds_train,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=False,
    drop_last=True
)
# LINPROB
ds_linprob = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=target_transform,
    split="valid"
)
dl_linprob = DataLoader(
    ds_linprob,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers = config.num_workers,
    pin_memory=False
)

In [None]:
ds_train.config

# Visualize Data

In [None]:
n_vis = 4 #config.num_vis

In [None]:
# dl_train
(x11, x12, x21, x22),_ = next(iter(dl_train))

In [None]:
x11 = inverse_norm_transform(x11[:n_vis])
x12 = inverse_norm_transform(x12[:n_vis])
x21 = inverse_norm_transform(x21[:n_vis])
x22 = inverse_norm_transform(x22[:n_vis])

In [None]:
grid_img = torchvision.utils.make_grid(x11, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
grid_img = torchvision.utils.make_grid(x12, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
grid_img = torchvision.utils.make_grid(x21, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
grid_img = torchvision.utils.make_grid(x22, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
# dl_linprob_train
x,y = next(iter(dl_linprob))
x = x[:n_vis]
y = y[:n_vis]
#
x = inverse_norm_transform(x)
#
grid_img = torchvision.utils.make_grid(x, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
#
y = [ds_config["class_maps"]["shape"][idx.item()] for idx in y]
print(y)

# Model

In [None]:
class CspritesBarlowTwinsL1(nn.Module):
    '''
    Adapted from https://github.com/facebookresearch/barlowtwins for arbitrary backbones, and arbitrary choice of which
    latent representation to use. Designed for models which can fit on a single GPU (though training can be parallelized
    across multiple as with any other model). Support for larger models can be done easily for individual use cases by
    by following PyTorch's model parallelism best practices.
    '''

    def __init__(self, backbone, projection_sizes, lambd, w_stl=0.5, w_geo=0.5, scale_factor=1, w_l1=0.1):
        '''

        :param backbone: Model backbone
        :param latent_id: name (or index) of the layer to be fed to the projection MLP
        :param projection_sizes: size of the hidden layers in the projection MLP
        :param lambd: tradeoff function
        :param scale_factor: Factor to scale loss by, default is 1
        '''
        super().__init__()
        self.backbone = backbone
        self.lambd = lambd
        self.w_l1 = w_l1
        self.w_stl = w_stl
        self.w_geo = w_geo
        self.scale_factor = scale_factor
        # projector
        sizes = [backbone.dim_out] + projection_sizes
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        # normalization layer for the representations z1 and z2
        assert projection_sizes[-1] % 2 == 0

        self.d_stl = projection_sizes[-1] // 2
        self.d_geo = projection_sizes[-1] - self.d_stl
        self.bn_stl = nn.BatchNorm1d(self.d_stl, affine=False)
        self.bn_geo = nn.BatchNorm1d(self.d_geo, affine=False)

    def get_representation(self, x):
        return self.backbone(x)

    def barlow_stl_loss(self, z1, z2):
        # empirical cross-correlation matrix
        c = self.bn_stl(z1).T @ self.bn_stl(z2)

        # sum the cross-correlation matrix between all gpus
        c.div_(z1.shape[0])

        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = self.scale_factor * (on_diag + self.lambd * off_diag)
        return loss

    def barlow_geo_loss(self, z1, z2):
        # empirical cross-correlation matrix
        c = self.bn_geo(z1).T @ self.bn_geo(z2)

        # sum the cross-correlation matrix between all gpus
        c.div_(z1.shape[0])

        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = self.scale_factor * (on_diag + self.lambd * off_diag)
        return loss

    def forward(self, y11, y12, y21, y22):
        """
        same geo: (y11, y12)(y21, y22)
        same stl: (y11, y21), (y12, y22)
        """
        z11 = self.backbone(y11)
        z12 = self.backbone(y12)
        z21 = self.backbone(y21)
        z22 = self.backbone(y22)
        #
        l1_loss = torch.abs(torch.cat([z11, z12, z21, z21], dim=0)).sum(axis=1).mean()
        #
        z11 = self.projector(z11)
        z12 = self.projector(z12)
        z21 = self.projector(z21)
        z22 = self.projector(z22)
        #
        z11_stl = z11[:, :self.d_stl]
        z11_geo = z11[:, self.d_stl:]
        #
        z12_stl = z12[:, :self.d_stl]
        z12_geo = z12[:, self.d_stl:]
        #
        z21_stl = z21[:, :self.d_stl]
        z21_geo = z21[:, self.d_stl:]
        #
        z22_stl = z22[:, :self.d_stl]
        z22_geo = z22[:, self.d_stl:]
        #
        # GEO LOSS
        geo_1112_loss = self.barlow_geo_loss(z11_geo, z12_geo) * self.w_geo
        geo_2122_loss = self.barlow_geo_loss(z21_geo, z22_geo) * self.w_geo

        # STL LOSS
        stl_1121_loss = self.barlow_stl_loss(z11_stl, z21_stl) * self.w_stl
        stl_1222_loss = self.barlow_stl_loss(z12_stl, z22_stl) * self.w_stl

        barlow_loss = 0.25 * (geo_1112_loss + geo_2122_loss +
                              stl_1121_loss + stl_1222_loss)
        loss = barlow_loss + self.w_l1 * l1_loss
        return loss, barlow_loss, geo_1112_loss, geo_2122_loss, stl_1121_loss, stl_1222_loss, l1_loss


# Test Pipeline

In [None]:
model = CspritesBarlowTwinsL1(get_backbone(config.backbone, pretrained=False, zero_init_residual=True),
                    config.projector,
                    config.lambd,
                    w_stl = config.w_stl,
                    w_geo = config.w_geo,
                    scale_factor = config.scale_factor,
                    w_l1 = config.w_l1
)

In [None]:
(x11, x12, x21, x22), y = next(iter(dl_train))

In [None]:
loss, barlow_loss, geo_1112_loss, geo_2122_loss, stl_1121_loss, stl_1222_loss, l1_loss = model(x11, x12, x21, x22)

# Prepare train run

In [None]:
model = CspritesBarlowTwinsL1(get_backbone(config.backbone, pretrained=False, zero_init_residual=True),
                    config.projector,
                    config.lambd,
                    w_stl = config.w_stl,
                    w_geo = config.w_geo,
                    scale_factor = config.scale_factor,
                    w_l1 = config.w_l1
)

if torch.cuda.device_count() > 1 and device != "cpu":
    print("Using {} gpus!".format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)
    model.backbone = model.module.backbone
elif device != "cpu":
    print("Using 1 GPU!")
else:
    print("Using CPU!")
model = model.to(device)

In [None]:
model

In [None]:
optimizer = get_optimizer(config.optimizer, model.parameters(), config.optimizer_args)

In [None]:
stats = {
    'train': {
        'loss': [],
        'epoch': [],
    },
    'linprob': {
        'linacc': [],
        'knnacc': [],
        'epoch': [],
    }
}
stats = DottedDict(stats)
#
p_experiment = Path(config.p_experiment)
p_experiment.mkdir(exist_ok=True, parents=True)
p_ckpts = p_experiment / config.p_ckpts
p_ckpts.mkdir(exist_ok=True)

In [None]:
config.p_experiment

In [None]:
global_step = 0
for epoch_idx in range(1, config.num_epochs + 1, 1):
    ################2
    # TRAIN
    ################
    model.train()
    epoch_step = 0
    epoch_loss = 0
   
    desc = "Ep[{:3}/{:3}]{}".format(epoch_idx, config.num_epochs, 'train')
    pbar = tqdm(dl_train, bar_format= desc + '{bar:10}{r_bar}{bar:-10b}')
    #
    for (x11, x12, x21, x22), _ in pbar:
        x11 = x11.to(device)
        x12 = x12.to(device)
        x21 = x21.to(device)
        x22 = x22.to(device)
        #
        for param in model.parameters():
            param.grad = None
        loss, barlow_loss, geo_1112_loss, geo_2122_loss, stl_1121_loss, stl_1222_loss, l1_loss = model(x11, x12, x21, x22)
        #loss, geo_1112_loss, geo_2122_loss, stl_1121_loss, stl_1222_loss = model(x11, x12, x21, x22)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_step += 1
        global_step += 1
        #
        pbar.set_postfix({'l': loss.item(),
                          'b': barlow_loss.item(),
                          "g1": geo_1112_loss.item(),
                          "g2": geo_2122_loss.item(),
                          "s1": stl_1121_loss.item(),
                          "s2": stl_1222_loss.item(),
                          "l1": l1_loss.item()
                         })

    stats.train.loss.append(epoch_loss / epoch_step)
    stats.train.epoch.append(epoch_idx)

    ################
    # Linprob
    ################
    if epoch_idx % config.freqs.linprob == 0 or epoch_idx == config.num_epochs:
        model.eval()
        R = []
        Y = []
        with torch.no_grad():
            for x, y in dl_linprob:
                x = x.to(device)
                r = model.backbone(x)
                R.append(r.detach().cpu().numpy())
                Y.append(y.cpu().numpy())
        R = np.concatenate(R)
        Y = np.concatenate(Y)
        #
        knn = KNeighborsClassifier(n_neighbors=5)
        knn.fit(R, Y)
        knnacc = knn.score(R, Y)
        #
        clf = LogisticRegression(random_state=0, tol=0.001, max_iter=200).fit(R, Y)
        linacc = clf.score(R, Y)
        print("    Linprob Eval @LR: {:.2f} @KNN: {:.2f}".format(linacc, knnacc))
        stats.linprob.epoch.append(epoch_idx)
        stats.linprob.knnacc.append(knnacc)
        stats.linprob.linacc.append(linacc)
        model.train()
    # Checkpoint
    if epoch_idx % config.freqs.ckpt == 0 or epoch_idx == config.num_epochs:
        print("save model!")
        if torch.cuda.device_count() > 1 and device != "cpu":
            torch.save(model.module.state_dict(), p_ckpts / config.p_model.format(epoch_idx))
        else:
            torch.save(model.state_dict(), p_ckpts / config.p_model.format(epoch_idx))

### 4L
- EP05:  50, 53
- EP10:  58, 56
- EP15:  64, 60
- EP20:  70, 62
- EP30:  80, 67
- EP40:  86, 71
- EP50:  91, 76
- EP70:  96, 80
### 1L
- EP05:  56, 56
- EP10:  59, 60
- EP15:  63, 62
- EP20:  66, 64
- EP30:  71, 69
- EP40:  76, 74
- EP50:  78, 77
- EP70:  82, 82
- EP80:  82, 83
- EP100: 83, 86
- EP120: 84, 87
- EP150: 

# Plot 

In [None]:
# plot losses
plt.plot(stats.train.epoch, stats.train.loss, label="train")
plt.legend()
plt.savefig(p_experiment / "barlow_loss.png")
plt.show()


# plot linprob loss
#plt.plot(stats.linprob.epoch, stats.linprob.loss, label="train")
#plt.legend()
#plt.savefig(p_experiment / "linprob_loss.png")
#plt.show()

# plot linprob acc
plt.plot(stats.linprob.epoch, stats.linprob.knnacc, label="knn")
plt.plot(stats.linprob.epoch, stats.linprob.linacc, label="lin")
plt.yscale("log")
plt.legend()
plt.savefig(p_experiment / "linprob_acc.png")
plt.show()

# Save stats and config

In [None]:
with open(p_experiment / config.p_config, "wb") as file:
    pickle.dump(config, file)
with open(p_experiment / config.p_stats, "wb") as file:
    pickle.dump(stats, file)

# Get Representations 

In [None]:
p_R_train = p_experiment / config["p_R_train"]
p_Y_train = p_experiment / config["p_Y_train"]
p_R_valid = p_experiment / config["p_R_valid"]
p_Y_valid = p_experiment / config["p_Y_valid"]

In [None]:
# TRAIN
ds_train = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=None,
    split="train"
)
dl_train = DataLoader(
    ds_train,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=False,
    drop_last=True
)
# LINPROB
ds_valid = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=None,
    split="valid"
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers = config.num_workers,
    pin_memory=False
)

In [None]:
R_train = []
R_valid = []
Y_train = []
Y_valid = []
#
model.eval()
for x, y in tqdm(dl_train):
    x = x.to(device)
    with torch.no_grad():
        r = model.backbone(x).detach().cpu().numpy()
    R_train.append(r)
    Y_train.append(y.numpy())
#
for x, y in tqdm(dl_valid):
    x = x.to(device)
    with torch.no_grad():
        r = model.backbone(x).detach().cpu().numpy()
    R_valid.append(r)
    Y_valid.append(y.numpy())

R_train = np.concatenate(R_train)
R_valid = np.concatenate(R_valid)
Y_train = np.concatenate(Y_train)
Y_valid = np.concatenate(Y_valid)

In [None]:
np.save(p_R_train, R_train)
np.save(p_Y_train, Y_train)
np.save(p_R_valid, R_valid)
np.save(p_Y_valid, Y_valid)

In [None]:
p_R_train

In [None]:
Y_train.shape

In [None]:
Y_valid.shape