In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append("../..")
sys.path.append("..")

In [None]:
# Python
from PIL import Image
from pathlib import Path
import os
import warnings
import math
import datetime
import time
warnings.filterwarnings('ignore')

# TORCH
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.models.resnet import resnet18

# MISC
from tqdm import tqdm
import pprint
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from dotted_dict import DottedDict
import pickle

from BTwins.utils import calc_lambda
from BTwins.barlow import *
from BTwins.transform_utils import *
from Beta.models import *
from csprites.datasets import ClassificationDataset
import utils
from backbone import get_backbone
from optimizer import get_optimizer
from functions import *
from plot_utils import *

In [None]:
config = {
    # CUDA SETTINGS
    'device': 'cuda',
    'cuda_visible_devices': '1',
    
    # DATA
    #'p_data': "/mnt/data/csprites/single_csprites_64x64_n7_c8_a8_p12_s1_bg_1_constant_color_64512",
    #'p_data': "/mnt/data/csprites/single_csprites_64x64_n7_c8_a16_p4_s3_bg_inf_random_function_43008",
    'p_data': '/mnt/data/csprites/single_csprites_64x64_n7_c32_a16_p38_s1_bg_1_constant_color_70000',
    'target_variable': 'shape',
    
    # TRAIN
    'batch_size': 512,
    'num_workers': 24,
    'num_epochs': 20,
    'freqs': {
        'ckpt': 100,                 # epochs
        'linprob': 5,                # epochs
        'plot_features': np.inf,     # epochs
        'plot_classes': np.inf,      # epochs
    },
    'num_vis': 64,
    
    # backbone
    'backbone': "FCN8i223o32",
    'backbone_args': {
        'ch_last': 32,
        'dim_in': 3,
    },
    # projectors
    'beta_projector': [128,128],
    'barlow_projector': [],
    'optimizer': 'adam',
    'optimizer_args': {
        'lr': 0.001,
        'weight_decay': 1e-6
    },
    # LOSS
    'r_stl': 0.3,
    'r_geo': 0.3,
    'w_beta': 1,
    'w_barlow': 1,
    'a_true': 0.1,
    'b_true': 0.9,
    'w_off_stl': None, # SET ADAPTIVE
    'w_off_geo': None,
    'w_on': 1,
    
    # PATH STUFF
    'p_ckpts': "ckpts",
    'p_model': "model_{}.ckpt",
    'p_stats': "stats.pkl",
    'p_config': 'config.pkl',
    'p_R_train': 'R_train.npy',
    'p_R_valid': 'R_valid.npy',
    'p_Y_valid': 'Y_valid.npy',
    'p_Y_train': 'Y_train.npy',
    'p_R_train_bb': 'R_train_bb.npy',
    'p_R_valid_bb': 'R_valid_bb.npy',
    'p_Y_valid_bb': 'Y_valid_bb.npy',
    'p_Y_train_bb': 'Y_train_bb.npy',
}

p_base = Path("/mnt/experiments/csprites") / Path(config["p_data"]).name / "tmp"

# PATHS
ts = time.time()
st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')
#2
config["p_experiment"] = str(p_base / "Beta_[{}_d]_target_[{}]_{}".format(
    config["backbone"],
    config["backbone_args"]["ch_last"],
    config["target_variable"],
    st))
config = DottedDict(config)
pprint.pprint(config)

In [None]:
# TORCH SETTINGS
torch.backends.cudnn.benchmark = True
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_visible_devices
device = torch.device(config.device)

# Data

In [None]:
p_ds_config = Path(config.p_data) / "config.pkl"

with open(p_ds_config, "rb") as file:
    ds_config = pickle.load(file)

target_variable = config.target_variable
target_idx = [idx for idx, target in enumerate(ds_config["classes"]) if target == target_variable][0]
n_classes = ds_config["n_classes"][target_variable]

In [None]:
norm_transform = utils.normalize_transform(
    ds_config["means"],
    ds_config["stds"])
inverse_norm_transform = utils.inverse_normalize_transform(
    ds_config["means"],
    ds_config["stds"]
)
target_transform = lambda x: x[target_idx]
#
init_transform = lambda x: x
stl_transform = transforms.Compose([
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=(0.2, 2),
                                            contrast=(0.2, 2),
                                            saturation=(0.2, 2),
                                            hue=(-0.45, 0.45))],
                    p=0.9
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=0.5),
                Solarization(p=0.2)
])

geo_transform = transforms.Compose([
    transforms.RandomResizedCrop(ds_config["img_size"],
                                 scale=(0.6, 1.9),
                                 ratio=(1, 1),
                                 interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5)
])

fin_transform = transforms.Compose([
                transforms.ToTensor(),
                norm_transform
            ])

train_transform = CSpritesTripleTransform(
    init_transform = init_transform,
    stl_transform=stl_transform,
    geo_transform=geo_transform,
    fin_transform=fin_transform
)

transform_linprob = transforms.Compose([
                transforms.Resize(ds_config["img_size"]),
                transforms.ToTensor(),
                norm_transform
            ])

In [None]:
# TRAIN
ds_train = ClassificationDataset(
    p_data = config.p_data,
    transform=train_transform,
    target_transform=target_transform,
    split="train"
)
dl_train = DataLoader(
    ds_train,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=False,
    drop_last=True
)
# LINPROB
ds_linprob_train = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=target_transform,
    split="train"
)
dl_linprob_train = DataLoader(
    ds_linprob_train,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers = config.num_workers,
    pin_memory=False
)
ds_linprob_valid = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=target_transform,
    split="valid"
)
dl_linprob_valid = DataLoader(
    ds_linprob_valid,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers = config.num_workers,
    pin_memory=False
)
print(len(dl_train))
print(len(dl_linprob_train))
print(len(dl_linprob_valid))

# Visualize Data

In [None]:
n_vis = 49

In [None]:
# dl_train
(x_ori, x_stl, x_geo),_ = next(iter(dl_train))
#
x_ori = inverse_norm_transform(x_ori[:n_vis])
x_stl = inverse_norm_transform(x_stl[:n_vis])
x_geo = inverse_norm_transform(x_geo[:n_vis])

In [None]:
grid_img = torchvision.utils.make_grid(x_geo, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

grid_img = torchvision.utils.make_grid(x_ori, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()
#
grid_img = torchvision.utils.make_grid(x_stl, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
plt.show()

In [None]:
# dl_linprob_train
x,y = next(iter(dl_linprob_train))
x = x[:n_vis]
y = y[:n_vis]
#
x = inverse_norm_transform(x)
#
grid_img = torchvision.utils.make_grid(x, nrow=int(np.sqrt(n_vis)))
plt.imshow(grid_img.permute(1, 2, 0))
#
y = [ds_config["class_maps"]["shape"][idx.item()] for idx in y]
print(y)

# Model

In [None]:
class BetaBarlowTwins(nn.Module):
    def __init__(self, backbone, beta_projector, barlow_projector, dim_stl, dim_geo, dim_cnt):
        super().__init__()
        self.backbone = backbone
        self.beta_projector = beta_projector
        self.barlow_projector = barlow_projector
        self.dim_stl = dim_stl
        self.dim_geo = dim_geo
        self.dim_cnt = dim_cnt

        self.bn_stl = nn.BatchNorm1d(self.dim_stl + self.dim_cnt, affine=False)
        self.bn_geo = nn.BatchNorm1d(self.dim_geo + self.dim_cnt, affine=False)

    def backbone_proj(self, x):
        return self.backbone(x)

    def beta_proj(self, x):
        return self.beta_projector(self.backbone(x))

    def barlow_proj(self, x):
        return self.barlow_projector(self.beta_proj(x))

    def forward(self, x):
        return self.beta_proj(x)

In [None]:
# backbone
backbone = get_backbone(config.backbone, **config.backbone_args)

# beta projector
beta_projector = get_projector(planes_in=backbone.dim_out, sizes=config.beta_projector, activation_last="Sigmoid")

# barlow projector
barlow_projector = get_projector(planes_in=beta_projector.dim_out, sizes=config.barlow_projector)

# stl vs geo
dim_stl = int(config.r_stl * barlow_projector.dim_out)
dim_geo = int(config.r_geo * barlow_projector.dim_out)
dim_cnt = barlow_projector.dim_out - dim_stl - dim_geo
#
if config["w_off_stl"] is None:
    config["w_off_stl"] = calc_lambda(dim_stl + dim_cnt)
if config["w_off_geo"] is None:
    config["w_off_geo"] = calc_lambda(dim_geo + dim_cnt)
#

model = BetaBarlowTwins(backbone, beta_projector, barlow_projector, dim_stl, dim_geo, dim_cnt)
print("#params", utils.count_parameters(model))
#
if torch.cuda.device_count() > 1 and device != "cpu":
    print("Using {} gpus!".format(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model)
    model.backbone = model.module.backbone
elif device != "cpu":
    print("Using 1 GPU!")
else:
    print("Using CPU!")
model = model.to(device)
model

In [None]:
optimizer = get_optimizer(config.optimizer, model.parameters(), config.optimizer_args)

In [None]:
stats = {
    'train': {
        'loss': [],
        'loss_beta': [],
        'loss_barlow': [],
        'loss_on': [],
        'loss_off': [],
        'a_min': [],
        'a_mean': [],
        'a_max': [],
        'b_min': [],
        'b_mean': [],
        'b_max': [],
        'epoch': [],
    },
    'linprob': {
        'linacc': [],
        'knnacc': [],
        'epoch': [],
    }
}
stats = DottedDict(stats)
#
p_experiment = Path(config.p_experiment)
p_experiment.mkdir(exist_ok=True, parents=True)
p_ckpts = p_experiment / config.p_ckpts
p_ckpts.mkdir(exist_ok=True)

In [None]:
config.p_experiment

In [None]:
def cc_loss(z1, z2):
    c = z1.T @ z2
    c.div_(z1.shape[0])
        
    on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
    off_diag = off_diagonal(c).pow_(2).sum()
    return on_diag, off_diag

def cc_zero_loss(z1, z2):
    c = z1.T @ z2
    c.div_(z1.shape[0])
    #
    return c.pow_(2).sum()

def feature_split(z, d_stl, d_cnt, d_geo):
    z_stl = z[:, :d_stl + d_cnt]
    z_geo = z[:, d_stl:]
    return z_stl, z_geo

In [None]:
z = torch.rand((10, 128))
z_stl, z_geo = feature_split(z, dim_stl, dim_cnt, dim_geo)
print(z_stl.shape)
print(z_geo.shape)

In [None]:
a_true, b_true = torch.Tensor([config.a_true, config.b_true])
dist_true = Beta(a_true, b_true)
plot_beta_pdf(dist_true, "True")
#
global_step = 0
for epoch_idx in range(1, config.num_epochs + 1, 1):
    ################
    # TRAIN
    ################
    model.train()
    
    # STATS
    epoch_step = 0
    epoch_loss = 0
    epoch_loss_beta = 0
    epoch_loss_barlow = 0
    epoch_loss_on = 0
    epoch_loss_off = 0
    epoch_a_min = 0
    epoch_a_mean = 0
    epoch_a_max = 0
    epoch_b_min = 0
    epoch_b_mean = 0
    epoch_b_max = 0
    #
    desc = "[{:3}/{:3}]".format(epoch_idx, config.num_epochs)
    pbar = tqdm(dl_train, bar_format= desc + '{bar:10}{n_fmt}/{total_fmt}{postfix}')
    #
    for (x_ori, x_stl, x_geo), _ in pbar:
        x_ori = x_ori.to(device)
        x_stl = x_stl.to(device)
        x_geo = x_geo.to(device)

        for param in model.parameters():
            param.grad = None
        
        ######################
        # BETA LOSS
        ######################

        z_ori = model.beta_proj(x_ori)
        z_stl = model.beta_proj(x_stl)
        z_geo = model.beta_proj(x_geo)

        if config.w_beta == 0:
            with torch.no_grad():
                a_z, b_z = beta_params(torch.cat([z_ori, z_stl, z_geo], axis=0))
                loss_beta = kl_beta_beta((a_z,b_z),(a_true,b_true),forward=True).sum()
        else:
            a_z, b_z = beta_params(torch.cat([z_ori, z_stl, z_geo], axis=0))
            loss_beta = kl_beta_beta((a_z,b_z),(a_true,b_true),forward=True).sum()
        
        ######################
        # BARLOW  LOSS
        ######################
        
        z_ori = model.barlow_projector(z_ori)
        z_stl = model.barlow_projector(z_stl)
        z_geo = model.barlow_projector(z_geo)
        
        z_ori_stl, z_ori_geo = feature_split(z_ori, dim_stl, dim_cnt, dim_geo)
        z_stl_stl, z_stl_geo = feature_split(z_stl, dim_stl, dim_cnt, dim_geo)
        z_geo_stl, z_geo_geo = feature_split(z_geo, dim_stl, dim_cnt, dim_geo)
        
        # --------
        # Normalize
        # --------
        z_ori_stl = model.bn_stl(z_ori_stl)
        z_stl_stl = model.bn_stl(z_stl_stl)
        z_geo_stl = model.bn_stl(z_geo_stl)
        #
        z_ori_geo = model.bn_geo(z_ori_geo)
        z_stl_geo = model.bn_geo(z_stl_geo)
        z_geo_geo = model.bn_geo(z_geo_geo)
        #
        
        # this kind of works
        on_diag_stl, off_diag_stl = cc_loss(z_ori_stl, z_stl_stl)
        on_diag_geo, off_diag_geo = cc_loss(z_ori_geo, z_geo_geo)
        #
        on_diag = on_diag_stl + on_diag_geo
        off_diag = off_diag_stl * config["w_off_stl"] + off_diag_geo * config["w_off_geo"]
        #
        
        #on_diag_stl, _ = cc_loss(z_ori_stl, z_stl_stl)
        #on_diag_geo, _ = cc_loss(z_ori_geo, z_geo_geo)
        #on_diag = on_diag_stl + on_diag_geo
        #off_diag = cc_zero_loss(z_stl_geo, z_geo_stl)
        
        loss_barlow = config["w_on"] * on_diag + off_diag
        
        #loss_zero = cc_zero_loss(z_stl_geo, z_geo_stl)
        loss_zero = torch.Tensor([0]).to(device)
        
        ######################
        # Total  LOSS
        ######################
        loss = config.w_barlow * loss_barlow + config.w_beta * loss_beta + loss_zero
        loss.backward()
        optimizer.step()

        ######################
        # Tack  Stats
        ######################
        epoch_loss += loss.item()
        epoch_step += 1 
        epoch_loss_beta += loss_beta.item()
        epoch_loss_barlow += loss_barlow.item()
        epoch_loss_on += on_diag.item()
        epoch_loss_off += off_diag.item()
        epoch_a_min += a_z.min().item()
        epoch_a_mean += a_z.mean().item()
        epoch_a_max += a_z.max().item()
        epoch_b_min += b_z.min().item()
        epoch_b_mean += b_z.mean().item()
        epoch_b_max += b_z.max().item()
        #
        global_step += 1
        #
        pbar.set_postfix(
              {'L': loss.item(),
               'CC': loss_barlow.item(),
               'on': on_diag.item(),
               'off': off_diag.item(),
               'beta': loss_beta.item(),
               'zero': loss_zero.item(),
               'al': a_z.min().item(),
               'ah': a_z.max().item(),
               'bl': b_z.min().item(),
               'bh': b_z.max().item(),
               'on_stl': on_diag_stl.item(),
               'on_geo': on_diag_geo.item(),
               }
          )

    stats.train.loss.append(epoch_loss / epoch_step)
    stats.train.loss_beta.append(epoch_loss_beta / epoch_step)
    stats.train.loss_barlow.append(epoch_loss_barlow / epoch_step)
    stats.train.loss_on.append(epoch_loss_on / epoch_step)
    stats.train.loss_off.append(epoch_loss_off / epoch_step)
    stats.train.a_min.append(epoch_a_min / epoch_step)
    stats.train.a_mean.append(epoch_a_mean / epoch_step)
    stats.train.a_max.append(epoch_a_max / epoch_step)
    stats.train.b_min.append(epoch_b_min / epoch_step)
    stats.train.b_mean.append(epoch_b_mean / epoch_step)
    stats.train.b_max.append(epoch_b_max / epoch_step)
    stats.train.epoch.append(epoch_idx)

    ################
    # Linprob
    ################
    if epoch_idx % config.freqs.linprob == 0 or epoch_idx == config.num_epochs:
        model.eval()
        linacc, knnacc = utils.linprob_model(
            model.beta_proj,
            dl_linprob_valid,
            device
        )
        print("    Linprob Beta @LR: {:.2f} @KNN: {:.2f}".format(linacc, knnacc))
        #
        linacc, knnacc = utils.linprob_model(
        model.backbone,
        dl_linprob_valid,
        device
        )
        print("    Linprob Back @LR: {:.2f} @KNN: {:.2f}".format(linacc, knnacc))
        stats.linprob.epoch.append(epoch_idx)
        stats.linprob.knnacc.append(knnacc)
        stats.linprob.linacc.append(linacc)
        model.train()

    ################
    # PLOT FEATURES
    ################
    if epoch_idx % config.freqs.plot_features == 0:
        model.eval()
        with torch.no_grad():
            R, Y = utils.get_representations(model.beta_proj, dl_eval_valid, device)
        n_samples = 1000
        idcs = np.random.choice(R.shape[0], size=n_samples, replace=False)
        R = R[idcs]
        Y = Y[idcs]
        #
        target_idx = 0
        cmap = "turbo"
        #
        a_est, b_est = beta_params(R)
        #
        for idx in range(R.shape[1]):
            title = "a={:.3f} b={:.3f}".format(a_est[idx], b_est[idx])
            simplex_plot(R[:,idx], title=title, c=Y[:, target_idx], cmap=cmap)
        
    ################
    # Checkpoint
    ################
    if epoch_idx % config.freqs.ckpt == 0 or epoch_idx == config.num_epochs:
        print("save model!")
        if torch.cuda.device_count() > 1 and device != "cpu":
            torch.save(model.module.state_dict(), p_ckpts / config.p_model.format(epoch_idx))
        else:
            torch.save(model.state_dict(), p_ckpts / config.p_model.format(epoch_idx))
            

# Get Representations

In [None]:
# EVAL with all Features
ds_eval_train = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=None,
    split="train"
)
dl_eval_train = DataLoader(
    ds_eval_train,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers = config.num_workers,
    pin_memory=False
)
ds_eval_valid = ClassificationDataset(
    p_data = config.p_data,
    transform=transform_linprob,
    target_transform=None,
    split="valid"
)
dl_eval_valid = DataLoader(
    ds_eval_valid,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers = config.num_workers,
    pin_memory=False
)

In [None]:
R_train, Y_train = utils.get_representations(model.beta_proj, dl_eval_train, device, imgs=False)
print(R_train.shape, Y_train.shape)

In [None]:
R_valid, Y_valid, X_valid = utils.get_representations(model.beta_proj, dl_eval_valid, device, imgs=True, inverse_norm_transform=inverse_norm_transform)
print(R_valid.shape, Y_valid.shape, X_valid.shape)

In [None]:
print(R_train.min(), R_train.max())

# DISTS

In [None]:
R = R_valid
plt.bar(range(R.shape[1]), R.mean(axis=0), width=1)
plt.title("Feature Mean")
plt.savefig(p_experiment / "feature_dist_valid.png")
plt.show()

plt.bar(range(R.shape[0]), R.mean(axis=1), width=1)
plt.title("Sample Mean")
plt.savefig(p_experiment / "sample_dist_valid.png")
plt.show()

# Class Distributions on features

In [None]:
n_plot = 100
idcs = np.random.choice(R_valid.shape[0], size=n_plot, replace=False)
#
R_plot = R_valid[idcs]
Y_plot = Y_valid[idcs]
#
dim_featuers = R_plot.shape[1]
num_targets = Y_plot.shape[1]
scale = 4
figsize = (num_targets * scale, dim_featuers)
fig, axes = plt.subplots(1, num_targets, figsize=figsize)
for col_idx in range(num_targets):
    ax = axes[col_idx]
    ax.set_title(ds_config["classes"][col_idx])
    for row_idx in range(dim_featuers):
        # reps
        r = R_plot[:, row_idx]
        #r = (r - r.min()) / (r - r.min()).max()
        # targets
        y = Y_plot[:,col_idx]
        xx = np.ones(len(r)) * row_idx
        #
        ax.scatter(r, xx, c=y, cmap="turbo")
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        #ax.set_ylim([0.95, 1.05])
plt.savefig(p_experiment / "class_distribution.png")
plt.tight_layout()
plt.show()

# Predict Classes from Featuers

In [None]:
results = []
for target_idx in range(Y_valid.shape[1]):
    target = ds_config["classes"][target_idx]
    if len(set(Y_train[:, target_idx])) == 1:
        print("{:>15}: acc = NA".format(target))
        results.append(np.inf)
        continue
    clf = LogisticRegression(random_state=0).fit(R_train, Y_train[:, target_idx])
    score = clf.score(R_valid, Y_valid[:, target_idx])
    target = ds_config["classes"][target_idx]
    print("{:>15}: acc = {:.2f}".format(target, score))
    results.append(score)

fig, ax = plt.subplots(1, 1)
ax.bar(range(len(results)), results, width=1)
ax.set_ylim([0, 1])
ax.set_xticks(np.arange(len(ds_config["classes"])))
ax.set_xticklabels(ds_config["classes"])
plt.title("Prediction Accurace LR on valid")
plt.savefig(p_experiment / "score_lr.png")

# Visualize Latent Dimensions

In [None]:
R = R_valid
X = X_valid
Y = Y_valid
#
n_imgs = 20
topic_idcs = []
for dim_idx in range(R.shape[1]):
    r = R[:, dim_idx]
    idcs = np.argsort(r)[-n_imgs:]
    topic_idcs.append(idcs)
topic_idcs = np.array(topic_idcs)

In [None]:
h, w = np.array(topic_idcs.shape) * 64
img = np.zeros((h, w, 3))
print(img.shape)
n_rows, n_cols = topic_idcs.shape
for row_idx in range(n_rows):
    for col_idx in range(n_cols):
        img_idx = topic_idcs[row_idx][col_idx]
        img[row_idx * 64: row_idx * 64 + 64, col_idx * 64:col_idx * 64 + 64,:] = X[img_idx]

In [None]:
plt.figure(figsize=topic_idcs.shape)
plt.imshow(img)
Image.fromarray(np.uint8(img * 255)).save(p_experiment / "feature_dims_highest.png")
plt.show()

In [None]:
p_experiment

# OLD

# Eval

In [None]:
with torch.no_grad():
    R_train, Y_train = utils.get_representations(model.beta_proj, dl_eval_train, device)
    R_valid, Y_valid = utils.get_representations(model.beta_proj, dl_eval_valid, device)

In [None]:
#
R = R_valid
Y = Y_valid
n_samples = 500
idcs = np.random.choice(R.shape[0], size=n_samples, replace=False)
R = R[idcs]
Y = Y[idcs]
#
target_idx = 2

cmap = "turbo"
#
a_est, b_est = beta_params(R)
#
for idx in range(R.shape[1]):
    title = "{:.3f} - {:.3f}".format(a_est[idx], b_est[idx])
    simplex_plot(R[:,idx], title=title, c=Y[:, target_idx], cmap=cmap)

# ROW PLOTS

In [None]:
R = []
Y = []
X = []
model.eval()
with torch.no_grad():
    for x, y in dl_linprob_valid:
        x = x.to(device)
        r = model.beta_proj(x)
        X.append(inverse_norm_transform(x).detach().cpu().numpy())
        R.append(r.detach().cpu().numpy())
        Y.append(y.cpu().numpy())
R = np.concatenate(R)
Y = np.concatenate(Y)
X = np.concatenate(X)
X = np.transpose(X, axes=(0,2,3,1))

In [None]:
R.shape

In [None]:
n_imgs = 30
topic_idcs = []
for dim_idx in range(R.shape[1]):
    r = R[:, dim_idx]
    idcs = np.argsort(r)[-n_imgs:]
    topic_idcs.append(idcs)
topic_idcs = np.array(topic_idcs)

In [None]:
h, w = np.array(topic_idcs.shape) * 64
img = np.zeros((h, w, 3))
print(img.shape)
n_rows, n_cols = topic_idcs.shape
for row_idx in range(n_rows):
    for col_idx in range(n_cols):
        img_idx = topic_idcs[row_idx][col_idx]
        img[row_idx * 64: row_idx * 64 + 64, col_idx * 64:col_idx * 64 + 64,:] = X[img_idx]

In [None]:
R.shape

In [None]:
plt.figure(figsize=topic_idcs.shape)
plt.imshow(img)

# TODO
- linprob backbone features + beta features
- plot featuers
- show crosscorrelation matrix
- add all losses to pbar! what to do with weights?
- show class distributions for all classes!
- more workers