In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("..")

In [None]:
# Python
from pathlib import Path
import os
import warnings
import math
import datetime
import time
warnings.filterwarnings('ignore')

# TORCH
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.models.resnet import resnet18

# MISC
from tqdm import tqdm
import pprint
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from dotted_dict import DottedDict
import pickle

from BTwins.utils import calc_lambda
from BTwins.barlow import *
from BTwins.transform_utils import *
from csprites.datasets import ClassificationDataset
import utils
from backbone import get_backbone
from optimizer import get_optimizer

In [None]:
config = {
    'device': 'cuda',
    'cuda_visible_devices': '1',
    'p_data': '/mnt/data/csprites/single_csprites_64x64_n7_c32_a16_p38_s1_bg_1_constant_color_70000',
    'target_variable': 'shape',
    'batch_size': 512,
    'num_workers': 20,
    'num_epochs': 30,
    'freqs': {
        'ckpt': 50,         # epochs
        'linprob': 5,       # epochs
    },
    'num_vis': 64,
    'backbone': "FCN8i223o32",
    'backbone_args': {
        'ch_last': 32,
        'dim_in': 3,
    },
    'dim_out': 32,
    'backbone_args': {
        'ch_last': 32,
        'dim_in': 3,
    },
    'optimizer': 'adam',
    'optimizer_args': {
        'lr': 0.001,
        'weight_decay': 1e-6
    },
    'projector': [64, 64, 64],
    'w_on': None,
    'w_off': None,
    'w_stl': None,
    'w_cnt': None,
    'w_geo': None,
    'p_ckpts': "ckpts",
    'p_model': "model_{}.ckpt",
    'p_stats': "stats.pkl",
    'p_config': 'config.pkl',
    'p_R_train': 'R_train.npy',
    'p_R_valid': 'R_valid.npy',
    'p_Y_valid': 'Y_valid.npy',
    'p_Y_train': 'Y_train.npy',
}
p_base = Path("/mnt/experiments/csprites") / Path(config["p_data"]).name / "tmp"
#
ts = time.time()
st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')
#
config["p_experiment"] = str(p_base / "BTwins_bb_[{}_d]_target_[{}]_{}".format(config["backbone"],
                                                                               config["dim_out"],
                                                           config["target_variable"],
                                                                            st))
config['lambd'] = calc_lambda(config["projector"][-1])
config = DottedDict(config)
pprint.pprint(config)

In [None]:
# TORCH SETTINGS
torch.backends.cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_visible_devices
device = torch.device(config.device)

# Data

In [None]:
p_ds_config = Path(config.p_data) / "config.pkl"

with open(p_ds_config, "rb") as file:
    ds_config = pickle.load(file)

target_variable = config.target_variable
target_idx = [idx for idx, target in enumerate(ds_config["classes"]) if target == target_variable][0]
n_classes = ds_config["n_classes"][target_variable]

In [None]:
norm_transform = utils.normalize_transform(
    ds_config["means"],
    ds_config["stds"])
inverse_norm_transform = utils.inverse_normalize_transform(
    ds_config["means"],
    ds_config["stds"]
)
target_transform = lambda x: x[target_idx]
#
init_transform = lambda x: x
stl_transform = transforms.Compose([
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=0.5),
                Solarization(p=0.2)
])

geo_transform = transforms.Compose([
    transforms.RandomResizedCrop(ds_config["img_size"],
                                 scale=(0.6, 1.9),
                                 ratio=(1, 1),
                                 interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5)
])

fin_transform = transforms.Compose([
                transforms.ToTensor(),
                norm_transform
            ])

train_transform = CSpritesTripleTransform(
    init_transform = init_transform,
    stl_transform=stl_transform,
    geo_transform=geo_transform,
    fin_transform=fin_transform
)

transform_linprob = transforms.Compose([
                transforms.Resize(ds_config["img_size"]),
                transforms.ToTensor(),
                norm_transform
            ])

In [None]:
# TRAIN
ds_train = ClassificationDataset(
    p_data = config.p_data,
    transform=train_transform,
    target_transform=target_transform,
    split="train"
)
dl_train = DataLoader(
    ds_train,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=False,
    drop_last=True
)
# LINPROB
ds_linprob_train = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=target_transform,
    split="train"
)
dl_linprob_train = DataLoader(
    ds_linprob_train,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers = config.num_workers,
    pin_memory=False
)
ds_linprob_valid = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=target_transform,
    split="valid"
)
dl_linprob_valid = DataLoader(
    ds_linprob_valid,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers = config.num_workers,
    pin_memory=False
)
print(len(dl_train))
print(len(dl_linprob_train))
print(len(dl_linprob_valid))

# Visualize Data

In [None]:
n_vis = 36

In [None]:
# dl_train
(x_ori, x_stl, x_geo),_ = next(iter(dl_train))
#
x_ori = inverse_norm_transform(x_ori[:n_vis])
x_stl = inverse_norm_transform(x_stl[:n_vis])
x_geo = inverse_norm_transform(x_geo[:n_vis])

In [None]:
grid_img = torchvision.utils.make_grid(x_geo, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

grid_img = torchvision.utils.make_grid(x_ori, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
#
grid_img = torchvision.utils.make_grid(x_stl, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
# dl_linprob_train
x,y = next(iter(dl_linprob_train))
x = x[:n_vis]
y = y[:n_vis]
#
x = inverse_norm_transform(x)
#
grid_img = torchvision.utils.make_grid(x, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
#
y = [ds_config["class_maps"]["shape"][idx.item()] for idx in y]
print(y)

# Model

In [None]:
class BarlowTwins(nn.Module):
    def __init__(self, backbone, barlow_projector):
        super().__init__()
        self.backbone = backbone
        self.barlow_projector = barlow_projector
        
        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)

    def get_representation(self, x):
        return self.backbone(x)

    def forward(self, y1, y2, return_all=False):
        z1 = self.backbone(y1)
        z2 = self.backbone(y2)
        z1 = self.projector(z1)
        z2 = self.projector(z2)

        # empirical cross-correlation matrix
        c = self.bn(z1).T @ self.bn(z2)

        # sum the cross-correlation matrix between all gpus
        c.div_(z1.shape[0])

        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = self.scale_factor * (on_diag + self.lambd * off_diag)
        if return_all:
            return loss, on_diag, off_diag * self.lambd
        else:
            return loss

In [None]:
model = BarlowTwins(get_backbone(config.backbone, **config.backbone_args),
                    config.projector,
                    config.lambd,
                    config.scale_factor)
#
if torch.cuda.device_count() > 1 and device != "cpu":
    print("Using {} gpus!".format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)
    model.backbone = model.module.backbone
elif device != "cpu":
    print("Using 1 GPU!")
else:
    print("Using CPU!")
model = model.to(device)

In [None]:
model

In [None]:
optimizer = get_optimizer(config.optimizer, model.parameters(), config.optimizer_args)

In [None]:
stats = {
    'train': {
        'loss': [],
        'epoch': [],
    },
    'linprob': {
        'linacc': [],
        'knnacc': [],
        'epoch': [],
    }
}
stats = DottedDict(stats)
#
p_experiment = Path(config.p_experiment)
p_experiment.mkdir(exist_ok=True, parents=True)
p_ckpts = p_experiment / config.p_ckpts
p_ckpts.mkdir(exist_ok=True)

In [None]:
config.p_experiment

In [None]:
global_step = 0
for epoch_idx in range(1, config.num_epochs + 1, 1):
    ################
    # TRAIN
    ################
    model.train()
    epoch_step = 0
    epoch_loss = 0
   
    desc = "Epoch [{:3}/{:3}] {}:".format(epoch_idx, config.num_epochs, 'train')
    pbar = tqdm(dl_train, bar_format= desc + '{bar:10}{r_bar}{bar:-10b}')
    #
    for (x1, x2), _ in pbar:
        x1 = x1.to(device)
        x2 = x2.to(device)
        for param in model.parameters():
            param.grad = None
        loss, on_diag, off_diag = model.forward(x1, x2, return_all=True)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_step += 1
        global_step += 1
        #
        pbar.set_postfix({'loss': loss.item(), "on_diag": on_diag.item(), "off_diag": off_diag.item()})

    stats.train.loss.append(epoch_loss / epoch_step)
    stats.train.epoch.append(epoch_idx)

    ################
    # Linprob
    ################
    if epoch_idx % config.freqs.linprob == 0 or epoch_idx == config.num_epochs:
        model.eval()
        linacc, knnacc = utils.linprob_model(model.backbone, dl_linprob, device)
        print("    Linprob Eval @LR: {:.2f} @KNN: {:.2f}".format(linacc, knnacc))
        stats.linprob.epoch.append(epoch_idx)
        stats.linprob.knnacc.append(knnacc)
        stats.linprob.linacc.append(linacc)
        model.train()
    # Checkpoint
    if epoch_idx % config.freqs.ckpt == 0 or epoch_idx == config.num_epochs:
        print("save model!")
        if torch.cuda.device_count() > 1 and device != "cpu":
            torch.save(model.module.state_dict(), p_ckpts / config.p_model.format(epoch_idx))
        else:
            torch.save(model.state_dict(), p_ckpts / config.p_model.format(epoch_idx))

# Plot 

In [None]:
# plot losses
plt.plot(stats.train.epoch, stats.train.loss, label="train")
plt.legend()
plt.savefig(p_experiment / "barlow_loss.png")
plt.show()


# plot linprob loss
#plt.plot(stats.linprob.epoch, stats.linprob.loss, label="train")
#plt.legend()
#plt.savefig(p_experiment / "linprob_loss.png")
#plt.show()

# plot linprob acc
plt.plot(stats.linprob.epoch, stats.linprob.knnacc, label="knn")
plt.plot(stats.linprob.epoch, stats.linprob.linacc, label="lin")
plt.yticks([0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 1])
plt.legend()
plt.savefig(p_experiment / "linprob_acc.png")
plt.show()

# Save stats and config

In [None]:
with open(p_experiment / config.p_config, "wb") as file:
    pickle.dump(config, file)
with open(p_experiment / config.p_stats, "wb") as file:
    pickle.dump(stats, file)

# Get Representations 

In [None]:
p_R_train = p_experiment / config["p_R_train"]
p_Y_train = p_experiment / config["p_Y_train"]
p_R_valid = p_experiment / config["p_R_valid"]
p_Y_valid = p_experiment / config["p_Y_valid"]

In [None]:
# TRAIN
ds_train = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=None,
    split="train"
)
dl_train = DataLoader(
    ds_train,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=False,
    drop_last=True
)
# LINPROB
ds_valid = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=None,
    split="valid"
)
dl_valid = DataLoader(
    ds_valid,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers = config.num_workers,
    pin_memory=False
)

In [None]:
model.eval()
R_train, Y_train = utils.get_representations(model.backbone, dl_train, device)
R_valid, Y_valid = utils.get_representations(model.backbone, dl_train, device)
#
np.save(p_R_train, R_train)
np.save(p_Y_train, Y_train)
np.save(p_R_valid, R_valid)
np.save(p_Y_valid, Y_valid)

# Get Representations

In [None]:
# EVAL with all Features
ds_eval_train = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=None,
    split="train"
)
dl_eval_train = DataLoader(
    ds_eval_train,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers = config.num_workers,
    pin_memory=False
)
ds_eval_valid = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=None,
    split="valid"
)
dl_eval_valid = DataLoader(
    ds_eval_valid,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers = config.num_workers,
    pin_memory=False
)

In [None]:
R_train, Y_train = utils.get_representations(model.backbone, dl_eval_train, device, imgs=False)
print(R_train.shape, Y_train.shape)

In [None]:
R_valid, Y_valid, X_valid = utils.get_representations(model.backbone, dl_eval_valid, device, imgs=True, inverse_norm_transform=inverse_norm_transform)
print(R_valid.shape, Y_valid.shape, X_valid.shape)

# DISTS

In [None]:
R = R_valid
plt.bar(range(R.shape[1]), R.mean(axis=0), width=1)
plt.title("Feature Mean")
plt.savefig(p_experiment / "feature_dist_valid.png")
plt.show()

plt.bar(range(R.shape[0]), R.mean(axis=1), width=1)
plt.title("Sample Mean")
plt.savefig(p_experiment / "sample_dist_valid.png")
plt.show()

# Class Distributions on features

In [None]:
n_plot = 100
idcs = np.random.choice(R_valid.shape[0], size=n_plot, replace=False)
#
R_plot = R_valid[idcs]
Y_plot = Y_valid[idcs]
#
dim_featuers = R_plot.shape[1]
num_targets = Y_plot.shape[1]
scale = 4
figsize = (num_targets * scale, dim_featuers)
#
fig, axes = plt.subplots(dim_featuers, num_targets, figsize=figsize)
for row_idx in range(dim_featuers):
    for col_idx in range(num_targets):
        ax = axes[row_idx][col_idx]
        # reps
        r = R_plot[:, row_idx]
        r = (r - r.min()) / (r - r.min()).max()
        # targets
        y = Y_plot[:,col_idx]
        xx = np.ones(len(r))
        #
        ax.scatter(r, xx, c=y, cmap="turbo")
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        ax.set_ylim([0.95, 1.05])
plt.savefig(p_experiment / "class_distribution.png")
plt.show()

# Predict Classes from Featuers

In [None]:
results = []
for target_idx in range(Y_valid.shape[1]):
    target = ds_config["classes"][target_idx]
    if len(set(Y_train[:, target_idx])) == 1:
        print("{:>15}: acc = NA".format(target))
        results.append(np.inf)
        continue
    clf = LogisticRegression(random_state=0).fit(R_train, Y_train[:, target_idx])
    score = clf.score(R_valid, Y_valid[:, target_idx])
    target = ds_config["classes"][target_idx]
    print("{:>15}: acc = {:.2f}".format(target, score))
    results.append(score)

fig, ax = plt.subplots(1, 1)
ax.bar(range(len(results)), results, width=1)
ax.set_ylim([0, 1])
ax.set_xticks(np.arange(len(ds_config["classes"])))
ax.set_xticklabels(ds_config["classes"])
plt.title("Prediction Accurace LR on valid")
plt.savefig(p_experiment / "score_lr.png")

# Visualize Latent Dimensions

In [None]:
R = R_valid
X = X_valid
Y = Y_valid
#
n_imgs = 50
topic_idcs = []
for dim_idx in range(R.shape[1]):
    r = R[:, dim_idx]
    idcs = np.argsort(r)[-n_imgs:]
    topic_idcs.append(idcs)
topic_idcs = np.array(topic_idcs)

In [None]:
h, w = np.array(topic_idcs.shape) * 64
img = np.zeros((h, w, 3))
print(img.shape)
n_rows, n_cols = topic_idcs.shape
for row_idx in range(n_rows):
    for col_idx in range(n_cols):
        img_idx = topic_idcs[row_idx][col_idx]
        img[row_idx * 64: row_idx * 64 + 64, col_idx * 64:col_idx * 64 + 64,:] = X[img_idx]

In [None]:
X.shape

In [None]:
plt.figure(figsize=topic_idcs.shape)
plt.imshow(img)
plt.savefig(p_experiment / "feature_dims_sorted.png")