In [None]:
import sys
sys.path.append("..")
from plot_utils import *
from BTwins.utils import calc_lambda

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import numpy as np
import time
import torch
import torch.nn as nn
from torch.distributions import Beta
from torch.distributions.dirichlet import Dirichlet
from tqdm import tqdm
from torchvision import transforms
from sklearn.linear_model import LogisticRegression
from dotted_dict import DottedDict
import pprint
#
from utils import *
from functions import *

In [None]:
class GMDist():
    def __init__(self, mus, stds, alphas):
        assert mus.shape == stds.shape
        assert alphas.sum() == 1
        if len(mus.shape) == 1:
            mus = mus.reshape((1, -1))
            stds = sts.reshape((1, -1))
        assert mus.shape[0] == len(alphas)
        self.mus = mus
        self.stds = stds
        self.alphas = alphas
        #
        self.dist_cat = torch.distributions.Categorical(alphas)
    
    def sample(self, n, shuffle=True):
        y = self.dist_cat.sample((n,))
        n_subsamples = torch.bincount(y)
        y = y.sort()[0]
        x = []
        for idx, n_sub in enumerate(n_subsamples):
            x.append(torch.distributions.Normal(self.mus[idx], self.stds[idx]).sample((n_sub,)))
        x = torch.cat(x)
        if shuffle:
            idcs = torch.randperm(x.size()[0]) 
            x = x[idcs]
            y = y[idcs]
        return x, y
    
    def sample_barlow(self, n, shuffle=True):
        y = self.dist_cat.sample((n,))
        n_subsamples = torch.bincount(y)
        y = y.sort()[0]
        x1 = []
        x2 = []
        for idx, n_sub in enumerate(n_subsamples):
            x1.append(torch.distributions.Normal(self.mus[idx], self.stds[idx]).sample((n_sub,)))
            x2.append(torch.distributions.Normal(self.mus[idx], self.stds[idx]).sample((n_sub,)))
        x1 = torch.cat(x1)
        x2 = torch.cat(x2)
        if shuffle:
            idcs = torch.randperm(x1.size()[0]) 
            x1 = x1[idcs]
            x2 = x2[idcs]
            y = y[idcs]
        return (x1, x2), y

def get_normalizer(x):
    norm_mean = x.mean(axis=0)
    norm_std = (x - norm_mean).std(axis=0)
    #
    def normalize(samples):
        return (samples - norm_mean) / norm_std
    return normalize

def normalize_z(z, eps):
    return (z - z.mean(axis=0)) / (torch.sqrt(z.var(axis=0)) + eps)

In [None]:
for d in [2, 4, 8, 16, 32, 64, 128, 512]:
    lmbda = calc_lambda(d)
    g = d / ((d**2 - d) * lmbda)
    h = ((d**2 - d) * lmbda) / d
    print("{:>4d}: {:8.4f} on/off={:.4f} off/on={:8.4f}".format(d,lmbda, g, h))

# GMM

In [None]:
cmap = "turbo"

In [None]:
mus = torch.Tensor([
    [2, 2],
    [-2,1],
    [0, -1]
])
stds = torch.Tensor([
    [0.1, 0.1],
    [0.1, 0.4],
    [0.1, 0.1]
])
alphas = torch.Tensor([1/3, 1/3, 1/3])
dist_in = GMDist(mus, stds, alphas)
#
n_samples = 1000
x,y = dist_in.sample(n_samples)
#
normalizer = get_normalizer(x)
x_normalized = normalizer(x)
#
plt.scatter(x[:,0], x[:,1], c=y, cmap=cmap)
plt.scatter(x_normalized[:,0], x_normalized[:, 1],c=y, cmap=cmap)
#
print(x.mean(axis=0), x.var(axis=0))
print(x_normalized.mean(axis=0), x_normalized.var(axis=0))

In [None]:
(x1, x2), y = dist_in.sample_barlow(n_samples)
plt.scatter(x1[:,0], x1[:,1], c=y)
plt.show()
#
plt.scatter(x2[:,0], x2[:,1], c=y, cmap=cmap)

# Beta-Barlows with Projector

In [None]:
def get_activation(activation):
    if activation == "ReLU":
        return nn.ReLU(inplace=True)
    elif activation == "Sigmoid":
        return nn.Sigmoid()
    else:
        raise NotImplementedError(activation)

class BasicBlock(nn.Module):
    def __init__(self, d_in, d_out, batch_norm=True, activation=None, bias=False):
        super().__init__()
        layers = [nn.Linear(d_in, d_out, bias=bias)]
        if batch_norm:
            layers.append(nn.BatchNorm1d(d_out))
        if activation is not None:
            layers.append(get_activation(activation))
        
        self.ff = nn.Sequential(*layers)
    def forward(self, x):
        return self.ff(x)

class Net(nn.Module):
    def __init__(self, d_in, d_hid, n_hid, d_out,
                 batch_norm=False, activation_last="ReLU", batch_norm_last=True, bias=True):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.n_hid = n_hid
        self.d_hid = d_hid
        
        self.dims = [d_in] + [d_hid] * n_hid
        layers = []
        for idx in range(len(self.dims) - 1):
            layers.append(BasicBlock(self.dims[idx], self.dims[idx + 1], batch_norm=batch_norm, activation="ReLU", bias=bias))
        layers.append(BasicBlock(self.dims[-1], d_out, batch_norm_last, activation_last, bias=bias))
        self.net = nn.Sequential(*layers)
        self.dim_out = d_out
        
    def forward(self, x):
        return self.net(x)

class BetaBarlow(nn.Module):
    def __init__(self, backbone, beta_proj, barlow_proj):
        super().__init__()
        self.backbone = backbone
        self.beta_proj = beta_proj
        self.barlow_proj = barlow_proj
        #
        self.bn = nn.BatchNorm1d(barlow_proj.dim_out, affine=False)
    
    def forward(self, x):
        return self.beta_proj(self.backbone(x))

In [None]:
# INPUT DATA
mus = torch.Tensor([
    [2, 2],
    [-2,1],
    [0, -1],
    [3, 5],
])
stds = torch.Tensor([
    [0.4, 0.2],
    [0.2, 0.7],
    [0.3, 0.3],
    [0.1, 0.2],
])
alphas = torch.Tensor([0.25, 0.25, 0.25, 0.25])
dist_in = GMDist(mus, stds, alphas)
#
x,y = dist_in.sample(10000)
normalizer = get_normalizer(x)

x,y = dist_in.sample(10000)
xt = normalizer(x)
#
fig, axes = plt.subplots(1, 2)
axes[0].scatter(xt[:,0], xt[:,1], c=y, cmap="turbo")
axes[0].set_title("Unnormalized")
axes[1].scatter(x[:,0], x[:,1], c=y, cmap="turbo")
axes[1].set_title("Normalized")
plt.show()
#
print(x.mean(axis=0), x.var(axis=0))
print(xt.mean(axis=0), xt.var(axis=0))

In [None]:
# global config
DIM_IN  = mus.shape[1]
DIM_BB  = 64
DIM_OUT = 8
N_BACKBONE = 5
N_BETAPROJ = 1
N_BARLOWPROJ = 3
#
config = {
    "train":
    {
        "batch_size": 512,
        "num_steps": 500,
        "num_epochs": 20,
        "plot_freq": 5,
        "lr": 0.001,
    },
    "loss":
    {
        "barlow": {
            "w_off": 41,                     # used with mean
            "lmbda": calc_lambda(DIM_OUT),   # used with sum
            "eps": 1e-8,
        },
        "beta":
        {
            "a_true": 0.1,
            "b_true": 0.5,
        },
        "w_barlow": 2,
        "w_beta": 1
    },
    "backbone":
    {
        "d_in": DIM_IN,
        "d_out": DIM_BB,
        "d_hid": DIM_BB,
        "n_hid": N_BACKBONE - 1,
        "batch_norm": True,
        "bias": True,
        "activation_last": "ReLU",
        "batch_norm_last": True,
    },
    "beta_proj":{
        "bias": False,
        "d_in": DIM_BB,
        "n_hid": N_BETAPROJ - 1,
        "d_hid": DIM_BB,
        "d_out": DIM_OUT,
        "batch_norm": True,
        "activation_last": "Sigmoid",
        "batch_norm_last": False,
    },
    "barlow_proj":
    {
        "bias": False,
        "d_in": DIM_OUT,
        "d_out": DIM_OUT,
        "n_hid": N_BARLOWPROJ - 1,
        "d_hid": DIM_OUT,
        "batch_norm": True,
        "batch_norm_last": False,
        "activation_last": None
    }
}
config = DottedDict(config)
pprint.pprint(config)

In [None]:
# BETA LOSS
a_true, b_true = torch.Tensor([config.loss.beta.a_true, config.loss.beta.b_true])
dist_true = Beta(a_true, b_true)
plot_beta_pdf(dist_true, "True")

In [None]:
model = BetaBarlow(
    backbone = Net(**config.backbone),
    beta_proj = Net(**config.beta_proj),
    barlow_proj = Net(**config.barlow_proj)
    
)
print(model.barlow_proj)
print(model.beta_proj)
print(model.backbone)


In [None]:
x_train, y_train = dist_in.sample(1000)
x_valid, y_valid = dist_in.sample(1000)
#
x_train = normalizer(x_train)
x_valid = normalizer(x_valid)
#
clf = LogisticRegression(random_state=0).fit(x_train, y_train)
print("LR(X)",clf.score(x_valid, y_valid))
#
model = BetaBarlow(
    backbone = Net(**config.backbone),
    beta_proj = Net(**config.beta_proj),
    barlow_proj = Net(**config.barlow_proj)
    
)
model.eval()
with torch.no_grad():
    z_train = model(x_train)
    z_valid = model(x_valid)

clf = LogisticRegression(random_state=0).fit(z_train, y_train)
print("LR(Z)",clf.score(z_valid, y_valid))
#
a_z, b_z = beta_params(z_train)
for idx in range(z_train.shape[1]):
    title = "Z_{}, alpha={:.3f}, beta={:.3f}".format(idx, a_z[idx].item(), b_z[idx].item())
    simplex_plot(z_train[:,idx].detach().numpy(), title=title, c=y_train, cmap=cmap)

In [None]:
model = BetaBarlow(
    backbone = Net(**config.backbone),
    beta_proj = Net(**config.beta_proj),
    barlow_proj = Net(**config.barlow_proj)
    
)
optimizer = torch.optim.Adam(model.parameters(), lr=config.train.lr)
#
all_loss = []
all_a = []
all_b = []
for epoch_idx in range(1, config.train.num_epochs + 1, 1):
    # ##########
    # TRAIN
    # ##########
    model.train()
    desc = "[{:3}/{:3}]".format(epoch_idx, config.train.num_epochs)
    pbar = tqdm(range(config.train.num_steps), bar_format= desc + '{bar:10}{n_fmt}/{total_fmt}{postfix}')
    epoch_loss = 0
    for step in pbar:
        (x1, x2), _ = dist_in.sample_barlow(config.train.batch_size)
        x1 = normalizer(x1)
        x2 = normalizer(x2)
        for param in model.parameters():
            param.grad = None
        z1 = model.beta_proj(model.backbone(x1))
        z2 = model.beta_proj(model.backbone(x2))
        
        # BETA LOSS
        a_z, b_z = beta_params(torch.cat([z1, z2], axis=0))
        loss_beta = kl_beta_beta((a_z,b_z),(a_true,b_true),forward=True).sum()
        
        # BARLOW LOSS
        z1 = model.barlow_proj(z1)
        z2 = model.barlow_proj(z2)
        #
        z1_norm = model.bn(z1)
        z2_norm = model.bn(z2)
        #
        #z1_norm = beta_normalize(z1, dist_true.mean, dist_true.stddev)
        #z2_norm = beta_normalize(z2, dist_true.mean, dist_true.stddev)
        #
        c = z1_norm.T @ z2_norm
        c.div_(z1.shape[0])
        
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss_barlow = on_diag + config.loss.barlow.lmbda * off_diag
        
        # LOSS
        loss = config.loss.w_barlow * loss_barlow + config.loss.w_beta * loss_beta
        loss.backward()
        optimizer.step()
        #
        epoch_loss += loss.item()
        pbar.set_postfix(
              {'L': loss.item(),
               'on': on_diag.item(),
               'off': off_diag.item(),
               'dkl': loss_beta.item(),
               'a_min': a_z.min().item(),
               'a_max': a_z.max().item(),
               'b_min': b_z.min().item(),
               'b_max': b_z.max().item()
               }
          )

    all_loss.append(epoch_loss / config.train.num_steps)
    ############
    # EVAL
    ############
    if epoch_idx % config.train.plot_freq == 0:
        model.eval()
        with torch.no_grad():
            x, y = dist_in.sample(config.train.batch_size)
            x = normalizer(x)
            z = model(x)
            a_z, b_z = beta_params(z)
            #
            for idx in range(z.shape[1]):
                title = "Z_{}, alpha={:.3f}, beta={:.3f}".format(idx, a_z[idx].item(), b_z[idx].item())
                simplex_plot(z[:,idx].detach().numpy(), title=title, c=y, cmap=cmap)
            
            x_train,y_train = dist_in.sample(1000)
            x_valid, y_valid = dist_in.sample(100)
            #
            x_train = normalizer(x_train)
            x_valid = normalizer(x_valid)
            #
            z_train = model(x_train)
            z_valid = model(x_valid)
            #
            clf = LogisticRegression(random_state=0).fit(z_train, y_train)
            print("   LINPROB: {:.3f}".format(clf.score(z_valid, y_valid)))

In [None]:
cmap

In [None]:
plt.cm.get_cmap(cmap)(y_show)

In [None]:
plt.bar(range(z.shape[1]), z.mean(axis=0), width=1.0)
plt.show()
plt.bar(range(z.shape[0]), z.mean(axis=1), width=1.0)

In [None]:
for clz in range(mus.shape[0]):
    print("*"*100)
    print("CLASS: {}".format(clz))
    print("*"*100)
    for idx in range(z.shape[1]):
        show_idcs = (y == clz)
        z_show = z[show_idcs]
        y_show = y[show_idcs]
        title = "Z_{}, alpha={:.3f}, beta={:.3f}".format(idx, a_z[idx].item(), b_z[idx].item())
        simplex_plot(z_show[:,idx].detach().numpy(), title=title, c=plt.cm.get_cmap(cmap)(y_show))
        