In [None]:
import os
os.chdir('../')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as tF
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from loss import MaximalCodingRateReduction, label_to_membership
import functional as F

In [None]:
class MCR2(nn.Module):
    def __init__(self, eps=0.1):
        super(MCR2, self).__init__()
        self.eps = eps

    def compute_discrimn_loss(self, Z):
        """Theoretical Discriminative Loss."""
        d, n = Z.shape
        I = torch.eye(d).to(Z.device)
        scalar = d / (n * self.eps)
        logdet = torch.logdet(I + scalar * Z @ Z.T)
        return logdet / 2.

    def compute_compress_loss(self, Z, Pi):
        """Theoretical Compressive Loss."""
        d, n = Z.shape
        I = torch.eye(d).to(Z.device)
        compress_loss = 0.
        for j in range(Pi.shape[1]):
            trPi = Pi[:, j].sum()
            scalar = d / (trPi * self.eps)
            log_det = torch.logdet(I + scalar * Z @ Pi[:, j].diag() @ Z.T)
            compress_loss += trPi / (2 * n) * log_det
        return compress_loss

    def forward(self, Z, Pi):
        discrimn_loss = self.compute_discrimn_loss(Z.T)
        compress_loss = self.compute_compress_loss(Z.T, Pi)
        total_loss = discrimn_loss - compress_loss
        return -total_loss, discrimn_loss.item(), compress_loss.item()

In [None]:
class MCR2Variational(nn.Module):
    """Equation 9 in writeup. """
    def __init__(self, eps, mu):
        super(MCR2Variational, self).__init__()
        self.eps = eps
        self.mu = mu
        
    def loss_discrimn(self, Z):
        d, n = Z.shape
        I = torch.eye(d).to(Z.device)
        return 0.5 * torch.logdet(I + d / (n * self.eps) * Z @ Z.T)

    def loss_compress(self, Z, Pi, Us):
        d, n = Z.shape
        I = torch.eye(d).to(Z.device)
        compress_loss = 0.
        for j in range(Pi.shape[1]):
            trPi_j = Pi[:, j].sum()
            scalar_j = trPi_j / (2 * n)
            norms = torch.linalg.norm(Us[j], axis=0, keepdims=True, ord=2) ** 2
            compress_loss += scalar_j * torch.log(1 + d / (trPi_j * self.eps) * norms).sum()
        return compress_loss

    def reg_U(self, Z, Pi, Us):
        loss_reg = 0.
        for j in range(Pi.shape[1]):
            loss_reg += torch.linalg.norm((Z @ Pi[:, j].diag() @ Z.T) - (Us[j] @ Us[j].T), ord='fro') ** 2
        return 0.5 * loss_reg
    
    def forward(self, Z, Pi, Us):
        loss_R = self.loss_discrimn(Z.T)
        loss_Rc = self.loss_compress(Z.T, Pi, Us)
        loss_reg_U = self.mu * self.reg_U(Z.T, Pi, Us)
        loss_obj = loss_R - loss_Rc - loss_reg_U
        return -loss_obj, loss_R.item(), loss_Rc.item(), loss_reg_U.item()
    
class Simple(nn.Module):
    def __init__(self):
        super(Simple, self).__init__()
        self.linear1 = nn.Linear(3, 3)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(3, 3)
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return F.normalize(x)

In [None]:
def plot_loss(a, b, c):
    fig, ax = plt.subplots()
    ax.plot(np.arange(len(a)), a, label=r'$\Delta R$')
    ax.plot(np.arange(len(b)), b, label=r'$R$')
    ax.plot(np.arange(len(c)), c, label=r'$R_c$')
    ax.legend()
    plt.show()
    
def plot_params(params):
    fig, ax = plt.subplots(ncols=params.shape[0])
    for j in range(params.shape[0]):
        im = ax[j].imshow(params[j])
        ax[j].set_title(f'U{j}')
    fig.colorbar(im, pad=0.02, drawedges=0)
    plt.show()

In [None]:
# data
n = 100
X1 = torch.distributions.multivariate_normal.MultivariateNormal(torch.tensor([1., 0., 0.]), 0.01*torch.eye(3)).sample([n])
X2 = torch.distributions.multivariate_normal.MultivariateNormal(torch.tensor([0., 1., 0.]), 0.01*torch.eye(3)).sample([n])
X = torch.cat([X1, X2])
y = torch.cat([torch.tensor(0).repeat(n), torch.tensor(1).repeat(n)])
Pi = label_to_membership(y)

## Original MCR2 objective

In [None]:
net = Simple()
optimizer = optim.SGD(net.parameters(), lr=0.01)
criterion = MCR2(eps=0.1)


all_loss_true, all_loss_discrimn, all_loss_compress = [], [], []
for epoch in range(5000):
    optimizer.zero_grad()
    Z = net(X)
    loss_true, loss_discrimn, loss_compress = criterion(Z, Pi)
    all_loss_true.append(-loss_true.item())
    all_loss_discrimn.append(loss_discrimn)
    all_loss_compress.append(loss_compress)
    print('{} | {:.5f} {:.5f} {:.5f}'.format(epoch, -loss_true.item(), loss_discrimn, loss_compress))
    loss_true.backward()
    optimizer.step()
    
Z = Z.detach()

In [None]:
plt.imshow(Z @ Z.T, cmap='Blues')

In [None]:
plot_loss(all_loss_true, all_loss_discrimn, all_loss_compress)

## Variational Form - Equation 9

In [None]:
mu = 5.
net_lr = 0.001
param_lr = 0.001
net = Simple()

init_Us = []
with torch.no_grad():
    Z = net(X)
    for j in range(2):
        U, S, _ = torch.linalg.svd(Z.T @ Pi[:, j].diag() @ Z)
        init_Us.append(U @ (S**0.5).diag())
#         init_Us.append(U)
init_Us = torch.stack(init_Us)
Us = nn.Parameter(
    init_Us, 
    requires_grad=True
    )

criterion_mcr2var = MCR2Variational(0.1, mu)
optimizer_net = optim.SGD(net.parameters(), lr=net_lr)
optimizer_Us = optim.SGD([Us], lr=param_lr)

In [None]:
all_loss_true, all_loss_discrimn, all_loss_compress, all_loss_reg = [], [], [], []
for epoch in range(20000):
    optimizer_net.zero_grad()
    optimizer_Us.zero_grad()
    Z = net(X)
    loss_true, loss_discrimn, loss_compress, loss_reg = criterion_mcr2var(Z, Pi, Us)
    all_loss_true.append(loss_discrimn - loss_compress)
    all_loss_discrimn.append(loss_discrimn)
    all_loss_compress.append(loss_compress)
    all_loss_reg.append(loss_reg)
#     print('{} | {:.8f} {:.8f} {:.8f} {:.8f}'.format(epoch, -loss_true.item(), loss_discrimn, loss_compress, loss_reg))
    loss_true.backward()
    
    if epoch % 20 == 0:
        optimizer_net.step()
    if epoch % 1 == 0:
        optimizer_Us.step()
        
    if epoch % 100 == 0:
        plt.imshow(Z.detach() @ Z.detach().T, cmap='Blues')
        plt.title(f'step{epoch}')
        plt.show()
        
    
    
Z = Z.detach()
Us = Us.detach()

In [None]:
plot_loss(all_loss_true, all_loss_discrimn, all_loss_compress)
plot_params(Us)
plt.imshow(Z @ Z.T)