2/24

Try Binarize

Try YYT

Try on Zeisel

In [1]:
import torch
from torch.utils.data import DataLoader

from torchvision import datasets
import torchvision.transforms as transforms


from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F

import numpy as np

from torchvision.utils import save_image

import matplotlib.pyplot as plt

import math

In [2]:
import os
from os import listdir

In [3]:
BASE_PATH_DATA = '../data/'

In [4]:
n_epochs = 5
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
img_size = 28
channels = 1

log_interval = 100


z_size = 20

n = 28 * 28

In [5]:
cuda = True if torch.cuda.is_available() else False

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

device = torch.device("cuda:0" if cuda else "cpu")
print(cuda)

True


In [6]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        BASE_PATH_DATA + '/mnist/train',
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor()]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)


test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        BASE_PATH_DATA + '/mnist/test', 
        train=False, 
        download = True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor()]
        )
    ),
    batch_size=batch_size, shuffle=True
)

In [7]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [8]:
# Vanilla VAE model
class VAE(nn.Module):
    def __init__(self, hidden_layer_size, z_size):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [9]:
# L1 VAE model we are loading
class VAE_l1_diag(nn.Module):
    def __init__(self, hidden_layer_size, z_size):
        super(VAE_l1_diag, self).__init__()
        
        self.diag = nn.Parameter(torch.normal(torch.zeros(784), 
                                 torch.ones(784)).to(device).requires_grad_(True))
        
        self.selection_layer = torch.diag(self.diag)
        self.fc1 = nn.Linear(784, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 784)

    def encode(self, x):
        h0 = torch.mm(x, self.selection_layer)
        h1 = F.relu(self.fc1(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Try a binarized model

In [10]:
# the model we are loading
class VAE_binary_diag(nn.Module):
    def __init__(self, hidden_layer_size, z_size, indices_diag):
        super(VAE_binary_diag, self).__init__()
        
        self.diag = nn.Parameter(indices_diag.to(device).requires_grad_(False), requires_grad=False)
        
        self.selection_layer = torch.diag(self.diag).requires_grad_(False)
        self.fc1 = nn.Linear(784, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 784)

    def encode(self, x):
        h0 = torch.mm(x, self.selection_layer)
        h1 = F.relu(self.fc1(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


Load Diag Model where $\lambda = 100$ for l1 norm. Not using pre-trained model.

In [11]:
pretrained_l1_diag_model = VAE_l1_diag(400, 20).to(device)
pretrained_l1_diag_model.load_state_dict(torch.load(BASE_PATH_DATA + 
                                            "../data/models/with_regularization/l1_norm_diag_100_lambda.pt"))
pretrained_l1_diag_model.eval()
pretrained_l1_diag_model.requires_grad_(False)

VAE_l1_diag(
  (fc1): Linear(in_features=784, out_features=400, bias=True)
  (fc21): Linear(in_features=400, out_features=20, bias=True)
  (fc22): Linear(in_features=400, out_features=20, bias=True)
  (fc3): Linear(in_features=20, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=784, bias=True)
)

In [12]:
mask = np.logical_not(np.abs(pretrained_l1_diag_model.diag.data.cpu().numpy()) < 1e-4)

In [13]:
diag_indices = torch.zeros(len(mask)).to(device)
diag_indices.masked_fill_(Tensor(mask).to(torch.bool), 1)

tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 1.,
        0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1.,
        0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1.,
        1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1.,
        1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 1.,
        0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0.,
        1., 0., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1.,
        1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
        1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0.,
        0., 1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1.,
        1., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0.,
        1., 1., 0., 0., 0., 0., 1., 0., 

In [14]:
model_binary_diag = VAE_binary_diag(400, 20, diag_indices).to(device)
optimizer_binary_diag = torch.optim.Adam(model_binary_diag.parameters(), lr=lr, betas = (b1,b2))

In [15]:
def train(model, optimizer, epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [16]:
def test(model, epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         '../data/binarized_diag/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [17]:
torch.manual_seed(123)
for epoch in range(1, n_epochs + 1):
        train(model_binary_diag, optimizer_binary_diag, epoch)
        test(model_binary_diag, epoch)
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model_binary_diag.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       '../data/binarized_diag/sample_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 198.6609
====> Test set loss: 156.1388
====> Epoch: 2 Average loss: 145.3550
====> Test set loss: 135.2520
====> Epoch: 3 Average loss: 131.3534
====> Test set loss: 126.2985
====> Epoch: 4 Average loss: 124.8538
====> Test set loss: 121.4120
====> Epoch: 5 Average loss: 120.7668
====> Test set loss: 118.3144


Use results from L1 Diag model with no attempt to match pre trained latent. $\lambda = 100$

In [18]:
torch.save(model_binary_diag.state_dict(), BASE_PATH_DATA + "../data/models/binarized_diag/base.pt")

Reconstruction looks slightly worse. but not by much

# Try Squeeze Fit like formulationw with $W = YY^T$

$W$ is the selection layer

also going to match pretrained vanilla model.

In [19]:
class VAE_rank_k_selection(nn.Module):
    
    def __init__(self, hidden_layer_size, z_size, k, n = 784):
        super(VAE_rank_k_selection, self).__init__()
        
        Y = torch.normal(mean = 0, std = 1/10*torch.ones(k*n).reshape(n, k)).to(device)
        self.Y = nn.Parameter(Y.detach().clone(), requires_grad=True)
        
        #Ytest = torch.normal(mean = 0, std = 1/10*torch.ones(k*n)).reshape((k,n)).to(device)
        #Ytest.requires_grad_(False)
        
        #self.selection_layer = torch.mm(self.Y, torch.t(self.Y))
        #self.selection_layer = torch.matmul(self.Y, Ytest)
        #self.selection_layer = nn.Parameter(torch.diag(torch.ones(784)))
        #self.selection_layer = nn.Parameter(torch.normal(mean = 0, std = 1/10*torch.ones(n*n)).reshape((n,n)))
        
        self.fc1 = nn.Linear(784, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 784)

    def encode(self, x):
        self.selection_layer = torch.mm(self.Y, torch.t(self.Y))
        h0 = torch.mm(x, self.selection_layer)
        h1 = F.relu(self.fc1(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [20]:
pretrained_vanilla_model = VAE(400, 20).to(device)
pretrained_vanilla_model.load_state_dict(torch.load(BASE_PATH_DATA + "../data/models/first_try/no_norm.pt"))
pretrained_vanilla_model.eval()
pretrained_vanilla_model.requires_grad_(False)

for param in pretrained_vanilla_model.parameters():
    param.requires_grad = False

In [21]:
def train_wyyt_pretrained(model, pretrained_model, optimizer, epoch, reg_lambda_1, reg_lambda_latent, k, n):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss_base = loss_function(recon_batch, data, mu, logvar)
        
        row_wise_norm_constraints = torch.sum(model.Y**2, dim = 1) 
        match = torch.ones(n).to(device) * k / n
        l2_loss_selection = reg_lambda_1 * F.mse_loss(row_wise_norm_constraints, match)
        #loss += l2_loss_selection
        
        h1_pretrained = F.relu(pretrained_model.fc1(data.view(-1, 784)))
        h0 = torch.mm(data.view(-1, 784), model.selection_layer)
        h1_model = F.relu(model.fc1(h0))
        l2_loss = reg_lambda_latent * F.mse_loss(h1_model, h1_pretrained)
        #loss += l2_loss
        loss = loss_base + l2_loss_selection + l2_loss
        
        #loss.backward(retain_graph = True)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \t L2 Loss Selection: {}\t L2 Loss Latent:{}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data), l2_loss_selection.item() / len(data), l2_loss.item() / len(data)))
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [22]:
def test(model, epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         '../data/w_yyt_model_results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [23]:
def try_wyyt_model(pretrained_model, reg_lambda_1, reg_lambda_latent, k, n = 784):
    torch.manual_seed(123)
    
    model = VAE_rank_k_selection(400, 20, k).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas = (b1,b2))
    for epoch in range(1, n_epochs + 1):
        train_wyyt_pretrained(model, pretrained_model, optimizer, epoch, reg_lambda_1, reg_lambda_latent, k, n)
        #with torch.no_grad():
        #    model.diag.data[torch.abs(model.diag) < 0.05] = 0
        print(model.Y)
        print(model.selection_layer)
        test(model, epoch)
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       '../data/w_yyt_model_results/sample_' + str(epoch) + '.png')
    return model

In [24]:
wyyt_model = try_wyyt_model(pretrained_vanilla_model, 1, 1, 40)

====> Epoch: 1 Average loss: 184.3343
Parameter containing:
tensor([[ 0.0227, -0.0165, -0.0335,  ..., -0.1031, -0.0465, -0.1142],
        [ 0.0889,  0.1494,  0.0608,  ...,  0.0072,  0.0432, -0.1444],
        [ 0.0244, -0.0543,  0.0986,  ..., -0.0134, -0.0309, -0.0861],
        ...,
        [ 0.0253, -0.0566,  0.0148,  ..., -0.1599, -0.0051, -0.0221],
        [ 0.0320, -0.0257,  0.0814,  ...,  0.1534,  0.0273,  0.1593],
        [-0.0094, -0.0879, -0.0404,  ...,  0.0090, -0.1960, -0.0853]],
       device='cuda:0', requires_grad=True)
tensor([[ 0.4326,  0.0028,  0.0335,  ..., -0.0115, -0.0565,  0.0408],
        [ 0.0028,  0.3921, -0.0140,  ...,  0.0145,  0.0403, -0.0477],
        [ 0.0335, -0.0140,  0.4308,  ..., -0.1360, -0.0887,  0.0175],
        ...,
        [-0.0115,  0.0145, -0.1360,  ...,  0.4376, -0.0943,  0.0043],
        [-0.0565,  0.0403, -0.0887,  ..., -0.0943,  0.4156,  0.0100],
        [ 0.0408, -0.0477,  0.0175,  ...,  0.0043,  0.0100,  0.2938]],
       device='cuda:0', grad

====> Epoch: 4 Average loss: 121.4545
Parameter containing:
tensor([[ 0.0313, -0.0152, -0.0297,  ..., -0.1075, -0.0438, -0.1189],
        [ 0.0880,  0.1513,  0.0462,  ...,  0.0047,  0.0570, -0.1353],
        [ 0.0050, -0.0775,  0.0979,  ..., -0.0053, -0.0194, -0.0808],
        ...,
        [ 0.0366, -0.0583,  0.0101,  ..., -0.1508,  0.0039, -0.0176],
        [ 0.0269, -0.0101,  0.0698,  ...,  0.1441,  0.0334,  0.1564],
        [-0.0065, -0.0922, -0.0435,  ...,  0.0069, -0.1793, -0.0737]],
       device='cuda:0', requires_grad=True)
tensor([[ 0.4088,  0.0094,  0.0335,  ..., -0.0074, -0.0627,  0.0384],
        [ 0.0094,  0.3633, -0.0079,  ...,  0.0006,  0.0323, -0.0550],
        [ 0.0335, -0.0079,  0.4196,  ..., -0.1350, -0.0958,  0.0154],
        ...,
        [-0.0074,  0.0006, -0.1350,  ...,  0.4120, -0.0861,  0.0056],
        [-0.0627,  0.0323, -0.0958,  ..., -0.0861,  0.3888,  0.0118],
        [ 0.0384, -0.0550,  0.0154,  ...,  0.0056,  0.0118,  0.2840]],
       device='cuda:0', grad

In [25]:
wyyt_model = try_wyyt_model(pretrained_vanilla_model, 10, 1, 20)

====> Epoch: 1 Average loss: 190.0930
Parameter containing:
tensor([[ 0.0322, -0.0085, -0.0310,  ..., -0.1124, -0.0166,  0.0965],
        [ 0.1538,  0.1287,  0.1525,  ..., -0.0918, -0.0395, -0.1108],
        [ 0.0859,  0.1611,  0.0723,  ..., -0.0747, -0.0236,  0.0823],
        ...,
        [ 0.1046,  0.1369,  0.1653,  ...,  0.0954,  0.0312, -0.1186],
        [ 0.1278, -0.0508,  0.1704,  ..., -0.1938, -0.0013, -0.1034],
        [-0.0590,  0.0271,  0.1023,  ...,  0.1311, -0.0253,  0.0442]],
       device='cuda:0', requires_grad=True)
tensor([[ 0.1234,  0.0092,  0.0066,  ...,  0.0626,  0.0228, -0.0016],
        [ 0.0092,  0.3115,  0.0170,  ...,  0.1697,  0.1557,  0.0518],
        [ 0.0066,  0.0170,  0.2254,  ..., -0.0164,  0.0361,  0.0259],
        ...,
        [ 0.0626,  0.1697, -0.0164,  ...,  0.3016,  0.0834,  0.0792],
        [ 0.0228,  0.1557,  0.0361,  ...,  0.0834,  0.2279, -0.0308],
        [-0.0016,  0.0518,  0.0259,  ...,  0.0792, -0.0308,  0.2141]],
       device='cuda:0', grad

====> Epoch: 4 Average loss: 126.2323
Parameter containing:
tensor([[ 0.0341, -0.0066, -0.0383,  ..., -0.1417, -0.0144,  0.0995],
        [ 0.1490,  0.1291,  0.1595,  ..., -0.0877, -0.0153, -0.0987],
        [ 0.0841,  0.1520,  0.0645,  ..., -0.0773, -0.0238,  0.0773],
        ...,
        [ 0.1060,  0.1279,  0.1599,  ...,  0.0882,  0.0322, -0.1262],
        [ 0.1206, -0.0561,  0.1739,  ..., -0.1825,  0.0071, -0.1033],
        [-0.0387,  0.0238,  0.0857,  ...,  0.1188, -0.0153,  0.0426]],
       device='cuda:0', requires_grad=True)
tensor([[ 0.1242,  0.0024,  0.0102,  ...,  0.0544,  0.0237, -0.0049],
        [ 0.0024,  0.2938,  0.0089,  ...,  0.1699,  0.1473,  0.0547],
        [ 0.0102,  0.0089,  0.2139,  ..., -0.0193,  0.0341,  0.0216],
        ...,
        [ 0.0544,  0.1699, -0.0193,  ...,  0.2915,  0.0892,  0.0717],
        [ 0.0237,  0.1473,  0.0341,  ...,  0.0892,  0.2226, -0.0256],
        [-0.0049,  0.0547,  0.0216,  ...,  0.0717, -0.0256,  0.1952]],
       device='cuda:0', grad

In [26]:
torch.sum(wyyt_model.Y**2, 1)

tensor([0.1222, 0.2875, 0.2090, 0.1689, 0.2546, 0.1601, 0.2567, 0.0715, 0.3131,
        0.2944, 0.1292, 0.1225, 0.2799, 0.1964, 0.1563, 0.1255, 0.1785, 0.2490,
        0.1475, 0.2081, 0.2750, 0.1586, 0.1345, 0.2663, 0.1067, 0.1539, 0.2719,
        0.1424, 0.2507, 0.1534, 0.1560, 0.2481, 0.1711, 0.1306, 0.1778, 0.1486,
        0.1932, 0.1269, 0.0964, 0.0652, 0.2008, 0.0692, 0.1528, 0.2497, 0.1951,
        0.1678, 0.1888, 0.1659, 0.3504, 0.3095, 0.1099, 0.1765, 0.2469, 0.1593,
        0.1584, 0.2014, 0.1797, 0.1924, 0.1519, 0.1755, 0.2514, 0.1705, 0.2000,
        0.1455, 0.1217, 0.2141, 0.1469, 0.1915, 0.0734, 0.3153, 0.1232, 0.2525,
        0.1684, 0.1165, 0.1545, 0.0859, 0.1934, 0.1705, 0.2485, 0.0996, 0.2260,
        0.0870, 0.1525, 0.2606, 0.2227, 0.2339, 0.0939, 0.1451, 0.1758, 0.1795,
        0.1909, 0.2582, 0.1546, 0.1334, 0.1241, 0.0888, 0.2084, 0.1335, 0.1215,
        0.1309, 0.1837, 0.1651, 0.1475, 0.2308, 0.1403, 0.1368, 0.1415, 0.1559,
        0.1038, 0.2456, 0.1150, 0.2035, 

In [27]:
np.linalg.matrix_rank(wyyt_model.Y.clone().detach().cpu().numpy())

20

In [28]:
torch.save(wyyt_model.state_dict(), BASE_PATH_DATA + "../data/models/wyyt/base.pt")

# Zeisel Data

In [29]:
import scipy.io as sio

In [30]:
a = sio.loadmat("../data/zeisel/CITEseq.mat")
data= a['G'].T
N,d=data.shape
#transformation from integer entries 
data=np.log(data+np.ones(data.shape))
for i in range(N):
    data[i,:]=data[i,:]/np.linalg.norm(data[i,:])

#load labels from file
a = sio.loadmat("../data/zeisel/CITEseq-labels.mat")
l_aux = a['labels']
labels = np.array([i for [i] in l_aux])

#load names from file
a = sio.loadmat("../data/zeisel/CITEseq_names.mat")
names=[a['citeseq_names'][i][0][0] for i in range(N)]

In [31]:
slices = np.random.permutation(np.arange(data.shape[0]))
upto = int(.8 * len(data))

In [32]:
train_data = data[slices[:upto]]
test_data = data[slices[upto:]]

In [33]:
train_data = Tensor(train_data).to(device)

In [34]:
test_data = Tensor(test_data).to(device)

In [35]:
train_data.shape

torch.Size([6893, 500])

In [36]:
# Vanilla VAE model
class VAE(nn.Module):
    def __init__(self, input_size, hidden_layer_size, z_size):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(input_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, input_size)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [37]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [38]:
vanilla_vae_zeisel = VAE(500, 250, 20)
vanilla_vae_zeisel.to(device)
vanilla_optimizer_zeisel = torch.optim.Adam(vanilla_vae_zeisel.parameters(), 
                                            lr=lr, 
                                            betas = (b1,b2))

In [39]:
torch.randperm(10)

tensor([0, 7, 3, 9, 5, 2, 6, 1, 8, 4])

In [40]:
def train(df, model, optimizer, epoch):
    model.train()
    train_loss = 0
    permutations = torch.randperm(df.shape[0])
    for i in range(math.ceil(len(df)/batch_size)):
        batch_ind = permutations[i * batch_size : (i+1) * batch_size]
        batch_data = df[batch_ind, :]
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(batch_data)
        loss = loss_function(recon_batch, batch_data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(batch_data), len(df),
                100. * i / len(df),
                loss.item() / len(batch_data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(df)))

In [41]:
def test(df, model, epoch):
    model.eval()
    test_loss = 0
    inds = np.arange(df.shape[0])
    with torch.no_grad():
        for i in range(math.ceil(len(df)/batch_size)):
            batch_ind = inds[i * batch_size : (i+1) * batch_size]
            batch_data = df[batch_ind, :]
            batch_data = batch_data.to(device)
            recon_batch, mu, logvar = model(batch_data)
            test_loss += loss_function(recon_batch, batch_data, mu, logvar).item()


    test_loss /= len(df)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [42]:
for epoch in range(1, 50 + 1):
        train(train_data, vanilla_vae_zeisel, vanilla_optimizer_zeisel, epoch)
        #with torch.no_grad():
        #    model.diag.data[torch.abs(model.diag) < 0.05] = 0
        test(test_data, vanilla_vae_zeisel, epoch)

====> Epoch: 1 Average loss: 182.3264
====> Test set loss: 87.9929
====> Epoch: 2 Average loss: 81.2936
====> Test set loss: 76.4991
====> Epoch: 3 Average loss: 75.5499
====> Test set loss: 74.1908
====> Epoch: 4 Average loss: 73.6617
====> Test set loss: 72.8495
====> Epoch: 5 Average loss: 72.8050
====> Test set loss: 72.3773
====> Epoch: 6 Average loss: 72.0510
====> Test set loss: 71.3801
====> Epoch: 7 Average loss: 71.4381
====> Test set loss: 70.7115
====> Epoch: 8 Average loss: 70.5748
====> Test set loss: 69.6569
====> Epoch: 9 Average loss: 69.6501
====> Test set loss: 68.8323
====> Epoch: 10 Average loss: 68.6314
====> Test set loss: 67.9397
====> Epoch: 11 Average loss: 67.8105
====> Test set loss: 67.2086
====> Epoch: 12 Average loss: 67.1232
====> Test set loss: 66.6755
====> Epoch: 13 Average loss: 66.7141
====> Test set loss: 66.2104
====> Epoch: 14 Average loss: 66.4174
====> Test set loss: 65.9457
====> Epoch: 15 Average loss: 66.1391
====> Test set loss: 65.7967
===

In [43]:
len(vanilla_vae_zeisel(test_data))

3

In [44]:
vanilla_vae_zeisel(test_data)[0][0,:]

tensor([0.1300, 0.1169, 0.1031, 0.1111, 0.1288, 0.1041, 0.1105, 0.1013, 0.1040,
        0.1114, 0.1073, 0.0932, 0.1085, 0.0911, 0.1012, 0.1120, 0.0869, 0.0796,
        0.1004, 0.0977, 0.0833, 0.0863, 0.1030, 0.0971, 0.0713, 0.0845, 0.0859,
        0.0788, 0.0921, 0.0676, 0.0866, 0.0861, 0.0811, 0.0820, 0.0868, 0.0779,
        0.0941, 0.0833, 0.0899, 0.0837, 0.0884, 0.0789, 0.0905, 0.0748, 0.0742,
        0.0719, 0.0725, 0.0600, 0.0576, 0.0845, 0.0739, 0.0827, 0.0920, 0.0640,
        0.0841, 0.0720, 0.0754, 0.0693, 0.0641, 0.0664, 0.0664, 0.0684, 0.0542,
        0.0502, 0.0626, 0.0655, 0.0713, 0.0687, 0.0729, 0.0648, 0.0593, 0.0539,
        0.0571, 0.0558, 0.0655, 0.0750, 0.0617, 0.0596, 0.0562, 0.0592, 0.0609,
        0.0457, 0.0585, 0.0497, 0.0555, 0.0483, 0.0587, 0.0495, 0.0546, 0.0554,
        0.0479, 0.0445, 0.0469, 0.0496, 0.0550, 0.0442, 0.0465, 0.0441, 0.0484,
        0.0442, 0.0488, 0.0470, 0.0556, 0.0459, 0.0405, 0.0523, 0.0502, 0.0352,
        0.0390, 0.0396, 0.0415, 0.0428, 

In [45]:
test_data[0,:]

tensor([0.1385, 0.1385, 0.1177, 0.1045, 0.1357, 0.1259, 0.1327, 0.1311, 0.0847,
        0.1259, 0.1075, 0.1129, 0.1177, 0.0938, 0.0730, 0.1372, 0.0847, 0.0938,
        0.1154, 0.1129, 0.0977, 0.1012, 0.1012, 0.1199, 0.0793, 0.1075, 0.1129,
        0.0847, 0.0938, 0.0282, 0.1045, 0.0938, 0.0282, 0.0977, 0.0977, 0.0793,
        0.1240, 0.0977, 0.0977, 0.0793, 0.0847, 0.0847, 0.1045, 0.0793, 0.0847,
        0.0793, 0.0656, 0.0656, 0.0565, 0.1103, 0.0847, 0.0895, 0.0895, 0.0565,
        0.0895, 0.0282, 0.1012, 0.0656, 0.0730, 0.0793, 0.1154, 0.0793, 0.0000,
        0.0282, 0.0730, 0.0847, 0.0938, 0.0895, 0.1129, 0.0565, 0.0793, 0.0282,
        0.0565, 0.0000, 0.0730, 0.1012, 0.0000, 0.0656, 0.0847, 0.0282, 0.0847,
        0.0282, 0.0282, 0.0793, 0.0565, 0.0282, 0.0847, 0.0656, 0.0730, 0.0656,
        0.0282, 0.0000, 0.0565, 0.0000, 0.0656, 0.0282, 0.0282, 0.0656, 0.0565,
        0.0565, 0.0565, 0.0000, 0.0656, 0.0000, 0.0448, 0.0565, 0.0730, 0.0448,
        0.0282, 0.0282, 0.0000, 0.0448, 

In [46]:
vanilla_vae_zeisel(test_data)[0][0,:] - test_data[0,:]

tensor([-0.0107, -0.0248, -0.0097,  0.0067, -0.0063, -0.0252, -0.0202, -0.0344,
         0.0115, -0.0232, -0.0051, -0.0195, -0.0180, -0.0030,  0.0317, -0.0294,
         0.0099, -0.0133, -0.0272, -0.0228, -0.0165, -0.0209, -0.0006, -0.0278,
        -0.0016, -0.0288, -0.0261, -0.0083, -0.0093,  0.0366, -0.0214, -0.0115,
         0.0689, -0.0151, -0.0166, -0.0049, -0.0379, -0.0147, -0.0143, -0.0020,
        -0.0057, -0.0079, -0.0210, -0.0085, -0.0118, -0.0112,  0.0012,  0.0044,
         0.0045, -0.0335, -0.0182, -0.0146, -0.0097,  0.0123, -0.0166,  0.0521,
        -0.0356,  0.0013, -0.0058, -0.0183, -0.0495, -0.0165,  0.0675,  0.0287,
        -0.0182, -0.0270, -0.0314, -0.0262, -0.0502,  0.0059, -0.0245,  0.0268,
         0.0031,  0.0594, -0.0153, -0.0368,  0.0545, -0.0141, -0.0335,  0.0261,
        -0.0244,  0.0228,  0.0253, -0.0314, -0.0056,  0.0187, -0.0311, -0.0167,
        -0.0238, -0.0177,  0.0156,  0.0455, -0.0111,  0.0586, -0.0066,  0.0195,
         0.0127, -0.0107, -0.0105, -0.01

In [47]:
torch.sum((vanilla_vae_zeisel(test_data)[0][0,:] - test_data[0,:])**2)

tensor(0.1969, device='cuda:0', grad_fn=<SumBackward0>)

In [48]:
torch.sum((vanilla_vae_zeisel(train_data)[0] - train_data)**2) / len(train_data)

tensor(0.2585, device='cuda:0', grad_fn=<DivBackward0>)

In [49]:
torch.sum((vanilla_vae_zeisel(test_data)[0] - test_data)**2) / len(test_data)

tensor(0.2595, device='cuda:0', grad_fn=<DivBackward0>)

In [50]:
torch.save(vanilla_vae_zeisel.state_dict(), BASE_PATH_DATA + "../data/models/zeisel/vanilla.pt")

In [52]:
#vanilla_vae_zeisel.require_grad_(False)
for param in vanilla_vae_zeisel.parameters():
    param.requires_grad_(False)

## Let's try L1 Diag with L2 Norm on Pretrained

In [53]:
# L1 VAE model we are loading
class VAE_l1_diag(nn.Module):
    def __init__(self, input_size, hidden_layer_size, z_size):
        super(VAE_l1_diag, self).__init__()
        
        self.diag = nn.Parameter(torch.normal(torch.zeros(input_size), 
                                 torch.ones(input_size)).to(device).requires_grad_(True))
        
        self.selection_layer = torch.diag(self.diag)
        self.fc1 = nn.Linear(input_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, input_size)

    def encode(self, x):
        h0 = torch.mm(x, self.selection_layer)
        h1 = F.relu(self.fc1(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [54]:
def train_l1_diag_pretrained(df, model, pretrained_model, optimizer, epoch, reg_lambda_l1, reg_lambda_latent):
    model.train()
    train_loss = 0
    
    permutations = torch.randperm(df.shape[0])
    for i in range(math.ceil(len(df)/batch_size)):
        batch_ind = permutations[i * batch_size : (i+1) * batch_size]
        batch_data = df[batch_ind, :]
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(batch_data)
        loss = loss_function(recon_batch, batch_data, mu, logvar)
        
        l1_norm = reg_lambda_l1 * torch.norm(model.diag, p=1)
        loss += l1_norm
        
        h1_pretrained = F.relu(pretrained_model.fc1(batch_data))
        h0 = torch.mm(batch_data, model.selection_layer)
        h1_model = F.relu(model.fc1(h0))
        l2_loss = reg_lambda_latent * F.mse_loss(h1_model, h1_pretrained)
        loss += l2_loss
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        
        if i % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \t L2 Loss:{}'.format(
                epoch, i * len(batch_data), len(df),
                100. * i / len(df),
                loss.item() / len(batch_data), l2_loss.item() / len(batch_data)))
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(df)))

In [55]:
def test(df, model, epoch):
    model.eval()
    test_loss = 0
    inds = np.arange(df.shape[0])
    with torch.no_grad():
        for i in range(math.ceil(len(df)/batch_size)):
            batch_ind = inds[i * batch_size : (i+1) * batch_size]
            batch_data = df[batch_ind, :]
            batch_data = batch_data.to(device)
            recon_batch, mu, logvar = model(batch_data)
            test_loss += loss_function(recon_batch, batch_data, mu, logvar).item()


    test_loss /= len(df)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [56]:
vae_l1_diag_zeisel = VAE_l1_diag(500, 250, 20)
vae_l1_diag_zeisel.to(device)
vae_l1_diag_zeisel_optimizer = torch.optim.Adam(vae_l1_diag_zeisel.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [57]:
for epoch in range(1, 50 + 1):
        train_l1_diag_pretrained(train_data, vae_l1_diag_zeisel, 
                                 vanilla_vae_zeisel, vae_l1_diag_zeisel_optimizer, epoch, 1, 1)
        #with torch.no_grad():
        #    model.diag.data[torch.abs(model.diag) < 0.05] = 0
        test(test_data, vae_l1_diag_zeisel, epoch)

====> Epoch: 1 Average loss: 193.4225
====> Test set loss: 88.7991
====> Epoch: 2 Average loss: 87.9029
====> Test set loss: 76.7371
====> Epoch: 3 Average loss: 81.5598
====> Test set loss: 74.3858
====> Epoch: 4 Average loss: 79.7597
====> Test set loss: 73.2635
====> Epoch: 5 Average loss: 78.6583
====> Test set loss: 72.2839
====> Epoch: 6 Average loss: 78.2305
====> Test set loss: 72.2003
====> Epoch: 7 Average loss: 77.2982
====> Test set loss: 71.3937
====> Epoch: 8 Average loss: 76.7368
====> Test set loss: 70.7223
====> Epoch: 9 Average loss: 75.9334
====> Test set loss: 70.0074
====> Epoch: 10 Average loss: 75.0348
====> Test set loss: 69.0656
====> Epoch: 11 Average loss: 74.0992
====> Test set loss: 68.4087
====> Epoch: 12 Average loss: 73.2727
====> Test set loss: 67.8040
====> Epoch: 13 Average loss: 72.5573
====> Test set loss: 67.0638
====> Epoch: 14 Average loss: 71.8036
====> Test set loss: 66.4230
====> Epoch: 15 Average loss: 71.1446
====> Test set loss: 65.9010
===

====> Epoch: 38 Average loss: 66.0989
====> Test set loss: 63.4080
====> Epoch: 39 Average loss: 65.9970
====> Test set loss: 63.3508
====> Epoch: 40 Average loss: 65.8867
====> Test set loss: 63.3442
====> Epoch: 41 Average loss: 65.7838
====> Test set loss: 63.3209
====> Epoch: 42 Average loss: 65.7153
====> Test set loss: 63.2676
====> Epoch: 43 Average loss: 65.6176
====> Test set loss: 63.2789
====> Epoch: 44 Average loss: 65.5152
====> Test set loss: 63.2705
====> Epoch: 45 Average loss: 65.4288
====> Test set loss: 63.2181
====> Epoch: 46 Average loss: 65.3522
====> Test set loss: 63.2289
====> Epoch: 47 Average loss: 65.2812
====> Test set loss: 63.2318
====> Epoch: 48 Average loss: 65.1974
====> Test set loss: 63.1955
====> Epoch: 49 Average loss: 65.1203
====> Test set loss: 63.1944
====> Epoch: 50 Average loss: 65.0550
====> Test set loss: 63.1597


In [58]:
torch.save(vae_l1_diag_zeisel.state_dict(), BASE_PATH_DATA + "../data/models/zeisel/l1_diag_pretrained.pt")

### Let's see what is empty

In [59]:
np.sum(np.abs(vae_l1_diag_zeisel.diag.clone().detach().cpu().numpy()) < 1e-4)

294

In [60]:
np.where(np.abs(vae_l1_diag_zeisel.diag.clone().detach().cpu().numpy()) < 1e-4)[0]

array([  0,   6,   7,   8,  14,  15,  17,  19,  20,  23,  24,  27,  28,
        29,  31,  32,  34,  35,  36,  37,  41,  42,  44,  45,  47,  48,
        50,  51,  54,  55,  56,  60,  61,  64,  65,  67,  69,  70,  71,
        72,  73,  75,  77,  80,  81,  82,  85,  88,  90,  91,  92,  97,
        98, 100, 102, 103, 109, 110, 112, 113, 115, 116, 119, 120, 121,
       122, 123, 124, 126, 128, 129, 131, 133, 134, 135, 136, 137, 139,
       142, 144, 145, 146, 148, 151, 153, 155, 156, 159, 162, 163, 164,
       165, 166, 169, 170, 171, 174, 175, 178, 179, 180, 181, 182, 184,
       188, 191, 192, 193, 194, 196, 197, 198, 200, 202, 204, 205, 206,
       208, 209, 210, 211, 218, 219, 220, 222, 223, 227, 230, 231, 233,
       234, 235, 237, 238, 240, 241, 243, 244, 245, 246, 248, 249, 250,
       253, 254, 255, 256, 259, 260, 261, 262, 266, 267, 268, 271, 274,
       275, 277, 278, 279, 280, 281, 284, 286, 288, 290, 295, 298, 302,
       304, 307, 308, 309, 312, 314, 317, 318, 319, 321, 322, 32

In [61]:
with torch.no_grad():
    print(torch.sum((vae_l1_diag_zeisel(train_data)[0][1,:] - train_data[1,:])**2))

tensor(0.3108, device='cuda:0')


In [62]:
with torch.no_grad():
    print(torch.sum((vae_l1_diag_zeisel(test_data)[0][1,:] - test_data[1,:])**2))

tensor(0.2359, device='cuda:0')


In [63]:
torch.sum((vae_l1_diag_zeisel(train_data)[0] - train_data)**2) / len(train_data)

tensor(0.2598, device='cuda:0', grad_fn=<DivBackward0>)

In [64]:
torch.sum((vae_l1_diag_zeisel(test_data)[0] - test_data)**2) / len(test_data)

tensor(0.2618, device='cuda:0', grad_fn=<DivBackward0>)

In [65]:
vae_l1_diag_zeisel(test_data)[0][0,:] - test_data[0,:]

tensor([-0.0087, -0.0197, -0.0243,  0.0025, -0.0151, -0.0197, -0.0204, -0.0264,
         0.0060, -0.0169, -0.0036, -0.0310, -0.0103, -0.0082,  0.0123, -0.0252,
        -0.0212, -0.0169, -0.0147, -0.0240, -0.0229, -0.0153, -0.0207, -0.0296,
        -0.0117, -0.0420, -0.0319, -0.0153, -0.0104,  0.0368, -0.0238, -0.0040,
         0.0273, -0.0151, -0.0126, -0.0090, -0.0385, -0.0167, -0.0130, -0.0050,
        -0.0026, -0.0063, -0.0243, -0.0010, -0.0075, -0.0208,  0.0030, -0.0096,
        -0.0161, -0.0247, -0.0305, -0.0171, -0.0051,  0.0037, -0.0110,  0.0363,
        -0.0349,  0.0102, -0.0117, -0.0174, -0.0555, -0.0111,  0.0344,  0.0143,
        -0.0162, -0.0299, -0.0345, -0.0158, -0.0458,  0.0134, -0.0290,  0.0135,
        -0.0100,  0.0473, -0.0182, -0.0309,  0.0593, -0.0152, -0.0326,  0.0047,
        -0.0278,  0.0012,  0.0216, -0.0439, -0.0059,  0.0015, -0.0265, -0.0170,
        -0.0174, -0.0144,  0.0199,  0.0338, -0.0176,  0.0311, -0.0201,  0.0049,
         0.0098, -0.0409, -0.0130, -0.02

In [66]:
test_data[0,:]

tensor([0.1385, 0.1385, 0.1177, 0.1045, 0.1357, 0.1259, 0.1327, 0.1311, 0.0847,
        0.1259, 0.1075, 0.1129, 0.1177, 0.0938, 0.0730, 0.1372, 0.0847, 0.0938,
        0.1154, 0.1129, 0.0977, 0.1012, 0.1012, 0.1199, 0.0793, 0.1075, 0.1129,
        0.0847, 0.0938, 0.0282, 0.1045, 0.0938, 0.0282, 0.0977, 0.0977, 0.0793,
        0.1240, 0.0977, 0.0977, 0.0793, 0.0847, 0.0847, 0.1045, 0.0793, 0.0847,
        0.0793, 0.0656, 0.0656, 0.0565, 0.1103, 0.0847, 0.0895, 0.0895, 0.0565,
        0.0895, 0.0282, 0.1012, 0.0656, 0.0730, 0.0793, 0.1154, 0.0793, 0.0000,
        0.0282, 0.0730, 0.0847, 0.0938, 0.0895, 0.1129, 0.0565, 0.0793, 0.0282,
        0.0565, 0.0000, 0.0730, 0.1012, 0.0000, 0.0656, 0.0847, 0.0282, 0.0847,
        0.0282, 0.0282, 0.0793, 0.0565, 0.0282, 0.0847, 0.0656, 0.0730, 0.0656,
        0.0282, 0.0000, 0.0565, 0.0000, 0.0656, 0.0282, 0.0282, 0.0656, 0.0565,
        0.0565, 0.0565, 0.0000, 0.0656, 0.0000, 0.0448, 0.0565, 0.0730, 0.0448,
        0.0282, 0.0282, 0.0000, 0.0448, 

In [67]:
vae_l1_diag_zeisel(test_data)[0][0,:]

tensor([0.1328, 0.1095, 0.1011, 0.1082, 0.1289, 0.1018, 0.1080, 0.1060, 0.0968,
        0.1029, 0.0988, 0.0857, 0.1108, 0.0834, 0.0940, 0.1175, 0.0713, 0.0743,
        0.1012, 0.0915, 0.0811, 0.0864, 0.0897, 0.0983, 0.0750, 0.0731, 0.0852,
        0.0768, 0.0766, 0.0651, 0.0875, 0.0818, 0.0739, 0.0837, 0.0877, 0.0777,
        0.0951, 0.0808, 0.0833, 0.0821, 0.0962, 0.0761, 0.0925, 0.0834, 0.0726,
        0.0672, 0.0778, 0.0509, 0.0516, 0.0754, 0.0567, 0.0833, 0.0823, 0.0660,
        0.0755, 0.0665, 0.0732, 0.0705, 0.0645, 0.0564, 0.0645, 0.0658, 0.0435,
        0.0477, 0.0540, 0.0535, 0.0669, 0.0733, 0.0700, 0.0596, 0.0594, 0.0479,
        0.0593, 0.0460, 0.0571, 0.0703, 0.0596, 0.0562, 0.0519, 0.0480, 0.0543,
        0.0367, 0.0524, 0.0452, 0.0566, 0.0409, 0.0601, 0.0437, 0.0554, 0.0562,
        0.0484, 0.0436, 0.0455, 0.0461, 0.0551, 0.0342, 0.0414, 0.0342, 0.0443,
        0.0377, 0.0431, 0.0375, 0.0500, 0.0386, 0.0366, 0.0548, 0.0426, 0.0333,
        0.0395, 0.0367, 0.0346, 0.0408, 