2/24

Try Binarize

Try YYT

Try on Zeisel

In [1]:
import torch
from torch.utils.data import DataLoader

from torchvision import datasets
import torchvision.transforms as transforms


from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F

import numpy as np

from torchvision.utils import save_image

import matplotlib.pyplot as plt

import math

In [2]:
import os
from os import listdir

In [3]:
BASE_PATH_DATA = '../data/'

In [4]:
n_epochs = 5
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
img_size = 28
channels = 1

log_interval = 100


z_size = 20

n = 28 * 28

In [5]:
cuda = True if torch.cuda.is_available() else False

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

device = torch.device("cuda:0" if cuda else "cpu")
print(cuda)

True


In [6]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        BASE_PATH_DATA + '/mnist/train',
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor()]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)


test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        BASE_PATH_DATA + '/mnist/test', 
        train=False, 
        download = True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor()]
        )
    ),
    batch_size=batch_size, shuffle=True
)

In [7]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [8]:
# Vanilla VAE model
class VAE(nn.Module):
    def __init__(self, hidden_layer_size, z_size):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [9]:
# L1 VAE model we are loading
class VAE_l1_diag(nn.Module):
    def __init__(self, hidden_layer_size, z_size):
        super(VAE_l1_diag, self).__init__()
        
        self.diag = nn.Parameter(torch.normal(torch.zeros(784), 
                                 torch.ones(784)).to(device).requires_grad_(True))
        
        self.selection_layer = torch.diag(self.diag)
        self.fc1 = nn.Linear(784, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 784)

    def encode(self, x):
        h0 = torch.mm(x, self.selection_layer)
        h1 = F.relu(self.fc1(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


# Try a binarized model

In [10]:
# the model we are loading
class VAE_binary_diag(nn.Module):
    def __init__(self, hidden_layer_size, z_size, indices_diag):
        super(VAE_binary_diag, self).__init__()
        
        self.diag = nn.Parameter(indices_diag.to(device).requires_grad_(False), requires_grad=False)
        
        self.selection_layer = torch.diag(self.diag).requires_grad_(False)
        self.fc1 = nn.Linear(784, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 784)

    def encode(self, x):
        h0 = torch.mm(x, self.selection_layer)
        h1 = F.relu(self.fc1(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


Load Diag Model where $\lambda = 100$ for l1 norm. Not using pre-trained model.

In [11]:
pretrained_l1_diag_model = VAE_l1_diag(400, 20).to(device)
pretrained_l1_diag_model.load_state_dict(torch.load(BASE_PATH_DATA + 
                                            "../data/models/with_regularization/l1_norm_diag_100_lambda.pt"))
pretrained_l1_diag_model.eval()
pretrained_l1_diag_model.requires_grad_(False)

VAE_l1_diag(
  (fc1): Linear(in_features=784, out_features=400, bias=True)
  (fc21): Linear(in_features=400, out_features=20, bias=True)
  (fc22): Linear(in_features=400, out_features=20, bias=True)
  (fc3): Linear(in_features=20, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=784, bias=True)
)

In [12]:
mask = np.logical_not(np.abs(pretrained_l1_diag_model.diag.data.cpu().numpy()) < 1e-4)

In [13]:
diag_indices = torch.zeros(len(mask)).to(device)
diag_indices.masked_fill_(Tensor(mask).to(torch.bool), 1)

tensor([0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 1.,
        0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1.,
        0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1.,
        1., 0., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1.,
        1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 1., 1., 1.,
        0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 1., 1., 0., 0.,
        1., 0., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0., 1., 0., 1.,
        1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
        1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 0., 1., 1., 1., 0., 0., 0., 0.,
        0., 1., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1.,
        1., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0.,
        1., 1., 0., 0., 0., 0., 1., 0., 

In [14]:
model_binary_diag = VAE_binary_diag(400, 20, diag_indices).to(device)
optimizer_binary_diag = torch.optim.Adam(model_binary_diag.parameters(), lr=lr, betas = (b1,b2))

In [15]:
def train(model, optimizer, epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [16]:
def test(model, epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         '../data/binarized_diag/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [17]:
torch.manual_seed(123)
for epoch in range(1, n_epochs + 1):
        train(model_binary_diag, optimizer_binary_diag, epoch)
        test(model_binary_diag, epoch)
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model_binary_diag.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       '../data/binarized_diag/sample_' + str(epoch) + '.png')

====> Epoch: 1 Average loss: 198.6112
====> Test set loss: 155.5003
====> Epoch: 2 Average loss: 145.2190
====> Test set loss: 135.3320
====> Epoch: 3 Average loss: 131.5038
====> Test set loss: 126.3548
====> Epoch: 4 Average loss: 124.9050
====> Test set loss: 121.5326
====> Epoch: 5 Average loss: 121.0645
====> Test set loss: 118.7398


Use results from L1 Diag model with no attempt to match pre trained latent. $\lambda = 100$

In [18]:
torch.save(model_binary_diag.state_dict(), BASE_PATH_DATA + "../data/models/binarized_diag/base.pt")

Reconstruction looks slightly worse. but not by much

# Try Squeeze Fit like formulationw with $W = YY^T$

$W$ is the selection layer

also going to match pretrained vanilla model.

In [19]:
class VAE_rank_k_selection(nn.Module):
    
    def __init__(self, hidden_layer_size, z_size, k, n = 784):
        super(VAE_rank_k_selection, self).__init__()
        
        Y = torch.normal(mean = 0, std = 1/10*torch.ones(k*n).reshape(n, k)).to(device)
        self.Y = nn.Parameter(Y.detach().clone(), requires_grad=True)
        
        #Ytest = torch.normal(mean = 0, std = 1/10*torch.ones(k*n)).reshape((k,n)).to(device)
        #Ytest.requires_grad_(False)
        
        #self.selection_layer = torch.mm(self.Y, torch.t(self.Y))
        #self.selection_layer = torch.matmul(self.Y, Ytest)
        #self.selection_layer = nn.Parameter(torch.diag(torch.ones(784)))
        #self.selection_layer = nn.Parameter(torch.normal(mean = 0, std = 1/10*torch.ones(n*n)).reshape((n,n)))
        
        self.fc1 = nn.Linear(784, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 784)

    def encode(self, x):
        self.selection_layer = torch.mm(self.Y, torch.t(self.Y))
        h0 = torch.mm(x, self.selection_layer)
        h1 = F.relu(self.fc1(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [20]:
pretrained_vanilla_model = VAE(400, 20).to(device)
pretrained_vanilla_model.load_state_dict(torch.load(BASE_PATH_DATA + "../data/models/first_try/no_norm.pt"))
pretrained_vanilla_model.eval()
pretrained_vanilla_model.requires_grad_(False)

for param in pretrained_vanilla_model.parameters():
    param.requires_grad = False

In [21]:
def train_wyyt_pretrained(model, pretrained_model, optimizer, epoch, reg_lambda_1, reg_lambda_latent, k, n):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss_base = loss_function(recon_batch, data, mu, logvar)
        
        row_wise_norm_constraints = torch.sum(model.Y**2, dim = 1) 
        match = torch.ones(n).to(device) * k / n
        l2_loss_selection = reg_lambda_1 * F.mse_loss(row_wise_norm_constraints, match)
        #loss += l2_loss_selection
        
        h1_pretrained = F.relu(pretrained_model.fc1(data.view(-1, 784)))
        h0 = torch.mm(data.view(-1, 784), model.selection_layer)
        h1_model = F.relu(model.fc1(h0))
        l2_loss = reg_lambda_latent * F.mse_loss(h1_model, h1_pretrained)
        #loss += l2_loss
        loss = loss_base + l2_loss_selection + l2_loss
        
        #loss.backward(retain_graph = True)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \t L2 Loss Selection: {}\t L2 Loss Latent:{}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data), l2_loss_selection.item() / len(data), l2_loss.item() / len(data)))
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [22]:
def test(model, epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         '../data/w_yyt_model_results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [23]:
def try_wyyt_model(pretrained_model, reg_lambda_1, reg_lambda_latent, k, n = 784):
    torch.manual_seed(123)
    
    model = VAE_rank_k_selection(400, 20, k).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas = (b1,b2))
    for epoch in range(1, n_epochs + 1):
        train_wyyt_pretrained(model, pretrained_model, optimizer, epoch, reg_lambda_1, reg_lambda_latent, k, n)
        #with torch.no_grad():
        #    model.diag.data[torch.abs(model.diag) < 0.05] = 0
        print(model.Y)
        print(model.selection_layer)
        test(model, epoch)
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       '../data/w_yyt_model_results/sample_' + str(epoch) + '.png')
    return model

In [24]:
wyyt_model = try_wyyt_model(pretrained_vanilla_model, 1, 1, 40)

====> Epoch: 1 Average loss: 184.3343
Parameter containing:
tensor([[ 0.0227, -0.0165, -0.0335,  ..., -0.1031, -0.0465, -0.1142],
        [ 0.0889,  0.1494,  0.0608,  ...,  0.0072,  0.0432, -0.1444],
        [ 0.0244, -0.0543,  0.0986,  ..., -0.0134, -0.0309, -0.0861],
        ...,
        [ 0.0253, -0.0566,  0.0148,  ..., -0.1599, -0.0051, -0.0221],
        [ 0.0320, -0.0257,  0.0814,  ...,  0.1534,  0.0273,  0.1593],
        [-0.0094, -0.0879, -0.0404,  ...,  0.0090, -0.1960, -0.0853]],
       device='cuda:0', requires_grad=True)
tensor([[ 0.4326,  0.0028,  0.0335,  ..., -0.0115, -0.0565,  0.0408],
        [ 0.0028,  0.3921, -0.0140,  ...,  0.0145,  0.0403, -0.0477],
        [ 0.0335, -0.0140,  0.4308,  ..., -0.1360, -0.0887,  0.0175],
        ...,
        [-0.0115,  0.0145, -0.1360,  ...,  0.4376, -0.0943,  0.0043],
        [-0.0565,  0.0403, -0.0887,  ..., -0.0943,  0.4156,  0.0100],
        [ 0.0408, -0.0477,  0.0175,  ...,  0.0043,  0.0100,  0.2938]],
       device='cuda:0', grad

====> Epoch: 4 Average loss: 121.4545
Parameter containing:
tensor([[ 0.0313, -0.0152, -0.0297,  ..., -0.1075, -0.0438, -0.1189],
        [ 0.0880,  0.1513,  0.0462,  ...,  0.0047,  0.0570, -0.1353],
        [ 0.0050, -0.0775,  0.0979,  ..., -0.0053, -0.0194, -0.0808],
        ...,
        [ 0.0366, -0.0583,  0.0101,  ..., -0.1508,  0.0039, -0.0176],
        [ 0.0269, -0.0101,  0.0698,  ...,  0.1441,  0.0334,  0.1564],
        [-0.0065, -0.0922, -0.0435,  ...,  0.0069, -0.1793, -0.0737]],
       device='cuda:0', requires_grad=True)
tensor([[ 0.4088,  0.0094,  0.0335,  ..., -0.0074, -0.0627,  0.0384],
        [ 0.0094,  0.3633, -0.0079,  ...,  0.0006,  0.0323, -0.0550],
        [ 0.0335, -0.0079,  0.4196,  ..., -0.1350, -0.0958,  0.0154],
        ...,
        [-0.0074,  0.0006, -0.1350,  ...,  0.4120, -0.0861,  0.0056],
        [-0.0627,  0.0323, -0.0958,  ..., -0.0861,  0.3888,  0.0118],
        [ 0.0384, -0.0550,  0.0154,  ...,  0.0056,  0.0118,  0.2840]],
       device='cuda:0', grad

In [25]:
wyyt_model = try_wyyt_model(pretrained_vanilla_model, 10, 1, 20)

====> Epoch: 1 Average loss: 190.0930
Parameter containing:
tensor([[ 0.0322, -0.0085, -0.0310,  ..., -0.1124, -0.0166,  0.0965],
        [ 0.1538,  0.1287,  0.1525,  ..., -0.0918, -0.0395, -0.1108],
        [ 0.0859,  0.1611,  0.0723,  ..., -0.0747, -0.0236,  0.0823],
        ...,
        [ 0.1046,  0.1369,  0.1653,  ...,  0.0954,  0.0312, -0.1186],
        [ 0.1278, -0.0508,  0.1704,  ..., -0.1938, -0.0013, -0.1034],
        [-0.0590,  0.0271,  0.1023,  ...,  0.1311, -0.0253,  0.0442]],
       device='cuda:0', requires_grad=True)
tensor([[ 0.1234,  0.0092,  0.0066,  ...,  0.0626,  0.0228, -0.0016],
        [ 0.0092,  0.3115,  0.0170,  ...,  0.1697,  0.1557,  0.0518],
        [ 0.0066,  0.0170,  0.2254,  ..., -0.0164,  0.0361,  0.0259],
        ...,
        [ 0.0626,  0.1697, -0.0164,  ...,  0.3016,  0.0834,  0.0792],
        [ 0.0228,  0.1557,  0.0361,  ...,  0.0834,  0.2279, -0.0308],
        [-0.0016,  0.0518,  0.0259,  ...,  0.0792, -0.0308,  0.2141]],
       device='cuda:0', grad

====> Epoch: 4 Average loss: 126.2323
Parameter containing:
tensor([[ 0.0341, -0.0066, -0.0383,  ..., -0.1417, -0.0144,  0.0995],
        [ 0.1490,  0.1291,  0.1595,  ..., -0.0877, -0.0153, -0.0987],
        [ 0.0841,  0.1520,  0.0645,  ..., -0.0773, -0.0238,  0.0773],
        ...,
        [ 0.1060,  0.1279,  0.1599,  ...,  0.0882,  0.0322, -0.1262],
        [ 0.1206, -0.0561,  0.1739,  ..., -0.1825,  0.0071, -0.1033],
        [-0.0387,  0.0238,  0.0857,  ...,  0.1188, -0.0153,  0.0426]],
       device='cuda:0', requires_grad=True)
tensor([[ 0.1242,  0.0024,  0.0102,  ...,  0.0544,  0.0237, -0.0049],
        [ 0.0024,  0.2938,  0.0089,  ...,  0.1699,  0.1473,  0.0547],
        [ 0.0102,  0.0089,  0.2139,  ..., -0.0193,  0.0341,  0.0216],
        ...,
        [ 0.0544,  0.1699, -0.0193,  ...,  0.2915,  0.0892,  0.0717],
        [ 0.0237,  0.1473,  0.0341,  ...,  0.0892,  0.2226, -0.0256],
        [-0.0049,  0.0547,  0.0216,  ...,  0.0717, -0.0256,  0.1952]],
       device='cuda:0', grad

In [26]:
torch.sum(wyyt_model.Y**2, 1)

tensor([0.1222, 0.2875, 0.2090, 0.1689, 0.2546, 0.1601, 0.2567, 0.0715, 0.3131,
        0.2944, 0.1292, 0.1225, 0.2799, 0.1964, 0.1563, 0.1255, 0.1785, 0.2490,
        0.1475, 0.2081, 0.2750, 0.1586, 0.1345, 0.2663, 0.1067, 0.1539, 0.2719,
        0.1424, 0.2507, 0.1534, 0.1560, 0.2481, 0.1711, 0.1306, 0.1778, 0.1486,
        0.1932, 0.1269, 0.0964, 0.0652, 0.2008, 0.0692, 0.1528, 0.2497, 0.1951,
        0.1678, 0.1888, 0.1659, 0.3504, 0.3095, 0.1099, 0.1765, 0.2469, 0.1593,
        0.1584, 0.2014, 0.1797, 0.1924, 0.1519, 0.1755, 0.2514, 0.1705, 0.2000,
        0.1455, 0.1217, 0.2141, 0.1469, 0.1915, 0.0734, 0.3153, 0.1232, 0.2525,
        0.1684, 0.1165, 0.1545, 0.0859, 0.1934, 0.1705, 0.2485, 0.0996, 0.2260,
        0.0870, 0.1525, 0.2606, 0.2227, 0.2339, 0.0939, 0.1451, 0.1758, 0.1795,
        0.1909, 0.2582, 0.1546, 0.1334, 0.1241, 0.0888, 0.2084, 0.1335, 0.1215,
        0.1309, 0.1837, 0.1651, 0.1475, 0.2308, 0.1403, 0.1368, 0.1415, 0.1559,
        0.1038, 0.2456, 0.1150, 0.2035, 

In [27]:
np.linalg.matrix_rank(wyyt_model.Y.clone().detach().cpu().numpy())

20

In [28]:
torch.save(wyyt_model.state_dict(), BASE_PATH_DATA + "../data/models/wyyt/base.pt")

# Zeisel Data

In [29]:
import scipy.io as sio

In [30]:
a = sio.loadmat("../data/zeisel/CITEseq.mat")
data= a['G'].T
N,d=data.shape
#transformation from integer entries 
data=np.log(data+np.ones(data.shape))
for i in range(N):
    data[i,:]=data[i,:]/np.linalg.norm(data[i,:])

#load labels from file
a = sio.loadmat("../data/zeisel/CITEseq-labels.mat")
l_aux = a['labels']
labels = np.array([i for [i] in l_aux])

#load names from file
a = sio.loadmat("../data/zeisel/CITEseq_names.mat")
names=[a['citeseq_names'][i][0][0] for i in range(N)]

In [31]:
slices = np.random.permutation(np.arange(data.shape[0]))
upto = int(.8 * len(data))

In [32]:
train_data = data[slices[:upto]]
test_data = data[slices[upto:]]

In [33]:
train_data = Tensor(train_data).to(device)

In [34]:
test_data = Tensor(test_data).to(device)

In [35]:
train_data.shape

torch.Size([6893, 500])

In [36]:
# Vanilla VAE model
class VAE(nn.Module):
    def __init__(self, input_size, hidden_layer_size, z_size):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(input_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, input_size)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [37]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [38]:
vanilla_vae_zeisel = VAE(500, 250, 20)
vanilla_vae_zeisel.to(device)
vanilla_optimizer_zeisel = torch.optim.Adam(vanilla_vae_zeisel.parameters(), 
                                            lr=lr, 
                                            betas = (b1,b2))

In [39]:
torch.randperm(10)

tensor([0, 7, 3, 9, 5, 2, 6, 1, 8, 4])

In [40]:
def train(df, model, optimizer, epoch):
    model.train()
    train_loss = 0
    permutations = torch.randperm(df.shape[0])
    for i in range(math.ceil(len(df)/batch_size)):
        batch_ind = permutations[i * batch_size : (i+1) * batch_size]
        batch_data = df[batch_ind, :]
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(batch_data)
        loss = loss_function(recon_batch, batch_data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(batch_data), len(df),
                100. * i / len(df),
                loss.item() / len(batch_data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(df)))

In [41]:
def test(df, model, epoch):
    model.eval()
    test_loss = 0
    inds = np.arange(df.shape[0])
    with torch.no_grad():
        for i in range(math.ceil(len(df)/batch_size)):
            batch_ind = inds[i * batch_size : (i+1) * batch_size]
            batch_data = df[batch_ind, :]
            batch_data = batch_data.to(device)
            recon_batch, mu, logvar = model(batch_data)
            test_loss += loss_function(recon_batch, batch_data, mu, logvar).item()


    test_loss /= len(df)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [42]:
for epoch in range(1, 50 + 1):
        train(train_data, vanilla_vae_zeisel, vanilla_optimizer_zeisel, epoch)
        #with torch.no_grad():
        #    model.diag.data[torch.abs(model.diag) < 0.05] = 0
        test(test_data, vanilla_vae_zeisel, epoch)

====> Epoch: 1 Average loss: 182.6022
====> Test set loss: 88.3004
====> Epoch: 2 Average loss: 81.2381
====> Test set loss: 77.1690
====> Epoch: 3 Average loss: 75.4734
====> Test set loss: 74.5743
====> Epoch: 4 Average loss: 73.5186
====> Test set loss: 73.2552
====> Epoch: 5 Average loss: 72.6909
====> Test set loss: 72.8146
====> Epoch: 6 Average loss: 72.0207
====> Test set loss: 71.7761
====> Epoch: 7 Average loss: 71.3482
====> Test set loss: 71.2176
====> Epoch: 8 Average loss: 70.5316
====> Test set loss: 70.0781
====> Epoch: 9 Average loss: 69.5983
====> Test set loss: 69.3159
====> Epoch: 10 Average loss: 68.6155
====> Test set loss: 68.4623
====> Epoch: 11 Average loss: 67.7127
====> Test set loss: 67.6483
====> Epoch: 12 Average loss: 67.0463
====> Test set loss: 67.1137
====> Epoch: 13 Average loss: 66.6574
====> Test set loss: 66.6134
====> Epoch: 14 Average loss: 66.3515
====> Test set loss: 66.3929
====> Epoch: 15 Average loss: 66.0892
====> Test set loss: 66.2369
===

In [43]:
len(vanilla_vae_zeisel(test_data))

3

In [44]:
vanilla_vae_zeisel(test_data)[0][0,:]

tensor([0.1304, 0.1165, 0.1040, 0.1138, 0.1340, 0.1026, 0.1105, 0.1015, 0.1029,
        0.1074, 0.1102, 0.0947, 0.1057, 0.0935, 0.1055, 0.1120, 0.0977, 0.0779,
        0.0986, 0.0969, 0.0813, 0.0828, 0.1066, 0.0972, 0.0780, 0.0876, 0.0860,
        0.0779, 0.0906, 0.0704, 0.0860, 0.0855, 0.0898, 0.0825, 0.0847, 0.0765,
        0.0918, 0.0834, 0.0876, 0.0837, 0.0864, 0.0790, 0.0921, 0.0761, 0.0736,
        0.0743, 0.0700, 0.0616, 0.0622, 0.0798, 0.0726, 0.0808, 0.0902, 0.0657,
        0.0813, 0.0728, 0.0739, 0.0682, 0.0650, 0.0667, 0.0670, 0.0682, 0.0546,
        0.0555, 0.0614, 0.0668, 0.0699, 0.0681, 0.0707, 0.0637, 0.0577, 0.0573,
        0.0632, 0.0578, 0.0657, 0.0720, 0.0590, 0.0569, 0.0594, 0.0604, 0.0587,
        0.0519, 0.0575, 0.0506, 0.0548, 0.0482, 0.0576, 0.0472, 0.0534, 0.0528,
        0.0465, 0.0457, 0.0472, 0.0544, 0.0555, 0.0462, 0.0451, 0.0491, 0.0484,
        0.0436, 0.0470, 0.0494, 0.0530, 0.0466, 0.0429, 0.0502, 0.0473, 0.0369,
        0.0411, 0.0384, 0.0429, 0.0426, 

In [45]:
test_data[0,:]

tensor([0.1874, 0.0807, 0.1246, 0.1202, 0.1783, 0.0695, 0.0976, 0.0976, 0.0551,
        0.0695, 0.0695, 0.0898, 0.0807, 0.1202, 0.1502, 0.1421, 0.0976, 0.0898,
        0.1102, 0.0695, 0.0807, 0.0695, 0.1358, 0.0807, 0.0898, 0.0976, 0.0898,
        0.1043, 0.0695, 0.0898, 0.0695, 0.0976, 0.0348, 0.0898, 0.1043, 0.1043,
        0.0551, 0.0348, 0.0898, 0.0551, 0.0898, 0.0695, 0.0695, 0.0807, 0.0695,
        0.0695, 0.0348, 0.0551, 0.0976, 0.0551, 0.0695, 0.0807, 0.0551, 0.1154,
        0.0551, 0.0898, 0.0348, 0.0551, 0.0348, 0.0976, 0.0551, 0.0551, 0.0695,
        0.0551, 0.0695, 0.0551, 0.0551, 0.0695, 0.0695, 0.0807, 0.0000, 0.0000,
        0.0348, 0.0551, 0.0551, 0.0551, 0.0348, 0.0000, 0.0000, 0.0348, 0.0551,
        0.0348, 0.0898, 0.0551, 0.0695, 0.0000, 0.0348, 0.0348, 0.0551, 0.0348,
        0.0898, 0.0348, 0.0898, 0.1102, 0.0000, 0.0348, 0.0000, 0.0000, 0.0807,
        0.0551, 0.0000, 0.0695, 0.0000, 0.0695, 0.0348, 0.0551, 0.0348, 0.0551,
        0.0551, 0.0551, 0.0000, 0.0551, 

In [46]:
vanilla_vae_zeisel(test_data)[0][0,:] - test_data[0,:]

tensor([-0.0614,  0.0316, -0.0163, -0.0083, -0.0457,  0.0285,  0.0144, -0.0027,
         0.0411,  0.0287,  0.0353,  0.0024,  0.0155, -0.0279, -0.0418, -0.0370,
         0.0067, -0.0116, -0.0236,  0.0202, -0.0019,  0.0070, -0.0343,  0.0103,
        -0.0074, -0.0165, -0.0027, -0.0301,  0.0139, -0.0240,  0.0128, -0.0167,
         0.0718, -0.0084, -0.0266, -0.0334,  0.0281,  0.0457, -0.0089,  0.0223,
        -0.0128,  0.0055,  0.0134, -0.0097,  0.0016,  0.0002,  0.0279,  0.0140,
        -0.0334,  0.0172, -0.0043, -0.0080,  0.0226, -0.0469,  0.0144, -0.0111,
         0.0294,  0.0088,  0.0310, -0.0372,  0.0103,  0.0060, -0.0023,  0.0063,
        -0.0163,  0.0024,  0.0065, -0.0075, -0.0093, -0.0206,  0.0530,  0.0575,
         0.0297,  0.0042,  0.0017,  0.0055,  0.0173,  0.0482,  0.0521,  0.0207,
         0.0007,  0.0217, -0.0382, -0.0070, -0.0198,  0.0471,  0.0165,  0.0120,
        -0.0080,  0.0115, -0.0490,  0.0126, -0.0455, -0.0447,  0.0564,  0.0147,
         0.0395,  0.0616, -0.0356, -0.01

In [47]:
torch.sum((vanilla_vae_zeisel(test_data)[0][0,:] - test_data[0,:])**2)

tensor(0.3154, device='cuda:0', grad_fn=<SumBackward0>)

In [48]:
torch.sum((vanilla_vae_zeisel(train_data)[0] - train_data)**2) / len(train_data)

tensor(0.2596, device='cuda:0', grad_fn=<DivBackward0>)

In [49]:
torch.sum((vanilla_vae_zeisel(test_data)[0] - test_data)**2) / len(test_data)

tensor(0.2582, device='cuda:0', grad_fn=<DivBackward0>)

In [50]:
torch.save(vanilla_vae_zeisel.state_dict(), BASE_PATH_DATA + "../data/models/zeisel/vanilla.pt")

In [51]:
#vanilla_vae_zeisel.require_grad_(False)
for param in vanilla_vae_zeisel.parameters():
    param.requires_grad_(False)

## Let's try L1 Diag with L2 Norm on Pretrained

In [52]:
# L1 VAE model we are loading
class VAE_l1_diag(nn.Module):
    def __init__(self, input_size, hidden_layer_size, z_size):
        super(VAE_l1_diag, self).__init__()
        
        self.diag = nn.Parameter(torch.normal(torch.zeros(input_size), 
                                 torch.ones(input_size)).to(device).requires_grad_(True))
        
        self.fc1 = nn.Linear(input_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, input_size)

    def encode(self, x):
        self.selection_layer = torch.diag(self.diag)
        h0 = torch.mm(x, self.selection_layer)
        h1 = F.relu(self.fc1(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [53]:
def train_l1_diag_pretrained(df, model, pretrained_model, optimizer, epoch, reg_lambda_l1, reg_lambda_latent):
    model.train()
    train_loss = 0
    
    permutations = torch.randperm(df.shape[0])
    for i in range(math.ceil(len(df)/batch_size)):
        batch_ind = permutations[i * batch_size : (i+1) * batch_size]
        batch_data = df[batch_ind, :]
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(batch_data)
        loss = loss_function(recon_batch, batch_data, mu, logvar)
        
        l1_norm = reg_lambda_l1 * torch.norm(model.diag, p=1)
        loss += l1_norm
        
        h1_pretrained = F.relu(pretrained_model.fc1(batch_data))
        h0 = torch.mm(batch_data, model.selection_layer)
        h1_model = F.relu(model.fc1(h0))
        l2_loss = reg_lambda_latent * F.mse_loss(h1_model, h1_pretrained)
        loss += l2_loss
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        
        if i % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \t L2 Loss:{}'.format(
                epoch, i * len(batch_data), len(df),
                100. * i / len(df),
                loss.item() / len(batch_data), l2_loss.item() / len(batch_data)))
            
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(df)))

In [54]:
def test(df, model, epoch):
    model.eval()
    test_loss = 0
    inds = np.arange(df.shape[0])
    with torch.no_grad():
        for i in range(math.ceil(len(df)/batch_size)):
            batch_ind = inds[i * batch_size : (i+1) * batch_size]
            batch_data = df[batch_ind, :]
            batch_data = batch_data.to(device)
            recon_batch, mu, logvar = model(batch_data)
            test_loss += loss_function(recon_batch, batch_data, mu, logvar).item()


    test_loss /= len(df)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [55]:
vae_l1_diag_zeisel = VAE_l1_diag(500, 250, 20)
vae_l1_diag_zeisel.to(device)
vae_l1_diag_zeisel_optimizer = torch.optim.Adam(vae_l1_diag_zeisel.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [56]:
for epoch in range(1, 50 + 1):
        train_l1_diag_pretrained(train_data, vae_l1_diag_zeisel, 
                                 vanilla_vae_zeisel, vae_l1_diag_zeisel_optimizer, epoch, 1, 1)
        #with torch.no_grad():
        #    model.diag.data[torch.abs(model.diag) < 0.05] = 0
        test(test_data, vae_l1_diag_zeisel, epoch)

====> Epoch: 1 Average loss: 193.1159
====> Test set loss: 89.1026
====> Epoch: 2 Average loss: 87.7522
====> Test set loss: 77.4059
====> Epoch: 3 Average loss: 81.4604
====> Test set loss: 74.8661
====> Epoch: 4 Average loss: 79.7153
====> Test set loss: 73.8384
====> Epoch: 5 Average loss: 78.6345
====> Test set loss: 72.8051
====> Epoch: 6 Average loss: 78.2005
====> Test set loss: 72.6735
====> Epoch: 7 Average loss: 77.2240
====> Test set loss: 71.7985
====> Epoch: 8 Average loss: 76.6667
====> Test set loss: 71.1682
====> Epoch: 9 Average loss: 75.8993
====> Test set loss: 70.5100
====> Epoch: 10 Average loss: 75.0019
====> Test set loss: 69.5515
====> Epoch: 11 Average loss: 74.0373
====> Test set loss: 68.9356
====> Epoch: 12 Average loss: 73.2010
====> Test set loss: 68.2291
====> Epoch: 13 Average loss: 72.4673
====> Test set loss: 67.4483
====> Epoch: 14 Average loss: 71.6646
====> Test set loss: 66.7649
====> Epoch: 15 Average loss: 70.9610
====> Test set loss: 66.2847
===

====> Epoch: 37 Average loss: 66.1113
====> Test set loss: 63.8553
====> Epoch: 38 Average loss: 65.9971
====> Test set loss: 63.8339
====> Epoch: 39 Average loss: 65.9024
====> Test set loss: 63.7793
====> Epoch: 40 Average loss: 65.7887
====> Test set loss: 63.7679
====> Epoch: 41 Average loss: 65.7027
====> Test set loss: 63.7552
====> Epoch: 42 Average loss: 65.6144
====> Test set loss: 63.7137
====> Epoch: 43 Average loss: 65.5224
====> Test set loss: 63.6913
====> Epoch: 44 Average loss: 65.4296
====> Test set loss: 63.6607
====> Epoch: 45 Average loss: 65.3435
====> Test set loss: 63.6076
====> Epoch: 46 Average loss: 65.2536
====> Test set loss: 63.6110
====> Epoch: 47 Average loss: 65.1858
====> Test set loss: 63.6052
====> Epoch: 48 Average loss: 65.1005
====> Test set loss: 63.6098
====> Epoch: 49 Average loss: 65.0172
====> Test set loss: 63.5974
====> Epoch: 50 Average loss: 64.9584
====> Test set loss: 63.6003


In [57]:
torch.save(vae_l1_diag_zeisel.state_dict(), BASE_PATH_DATA + "../data/models/zeisel/l1_diag_pretrained.pt")

**Let's see the weights of the selection layer**

In [70]:
vae_l1_diag_zeisel.diag

Parameter containing:
tensor([-6.3273e-05, -1.3902e+00, -6.8938e-01,  2.9555e-05,  4.4665e-01,
        -6.4411e-02, -2.0655e-07, -1.2139e-05, -8.0846e-05,  5.4511e-02,
         3.9446e-01,  1.2396e-01, -4.1297e-01,  9.5827e-02, -9.6834e-06,
        -4.2757e-05, -1.2402e+00,  5.1740e-05, -5.8258e-01,  2.9363e-05,
        -7.0815e-05, -6.5134e-01,  1.2553e+00, -8.6336e-05, -7.6553e-05,
        -1.1529e+00,  4.6713e-05,  4.2166e-05,  3.8737e-05, -2.5682e-05,
         7.3785e-01, -7.1304e-05, -5.0949e-05, -9.7768e-01, -4.0172e-05,
        -2.1261e-05,  5.5754e-05,  7.6230e-06,  4.2893e-01, -5.2703e-01,
        -5.5457e-05,  4.0582e-06, -2.9569e-05,  8.0104e-06,  2.7945e-05,
        -2.1696e-05,  1.5579e-01, -3.2284e-05,  7.1349e-05,  1.7207e+00,
        -7.7178e-05, -4.1721e-05, -4.0866e-01,  6.5149e-02,  2.2152e-05,
        -3.6020e-05,  4.2638e-05, -9.4398e-01,  1.8881e-01, -9.2252e-01,
        -2.2585e-05, -1.1457e-04, -4.0208e-01,  5.4353e-01, -9.8283e-05,
         2.0178e-06, -5.8073e

### Let's see what is empty

In [59]:
np.sum(np.abs(vae_l1_diag_zeisel.diag.clone().detach().cpu().numpy()) < 1e-4)

288

In [60]:
np.where(np.abs(vae_l1_diag_zeisel.diag.clone().detach().cpu().numpy()) < 1e-4)[0]

array([  0,   3,   6,   7,   8,  14,  15,  17,  19,  20,  23,  24,  26,
        27,  28,  29,  31,  32,  34,  35,  36,  37,  40,  41,  42,  43,
        44,  45,  47,  48,  50,  51,  54,  55,  56,  60,  64,  65,  67,
        69,  70,  71,  72,  75,  77,  80,  81,  82,  84,  85,  88,  90,
        91,  92,  97,  98, 100, 102, 103, 107, 109, 110, 112, 113, 115,
       116, 118, 119, 121, 122, 123, 124, 128, 129, 131, 133, 134, 135,
       136, 137, 139, 142, 144, 145, 146, 148, 151, 153, 155, 156, 159,
       162, 163, 165, 166, 169, 170, 174, 175, 180, 181, 182, 183, 184,
       188, 191, 192, 193, 194, 196, 197, 198, 200, 201, 202, 204, 205,
       208, 209, 210, 211, 213, 219, 220, 222, 223, 227, 230, 231, 232,
       233, 234, 235, 237, 238, 240, 241, 243, 244, 245, 246, 247, 248,
       249, 250, 253, 254, 255, 256, 259, 260, 261, 262, 267, 268, 270,
       271, 274, 276, 278, 280, 281, 286, 288, 290, 291, 295, 298, 302,
       304, 306, 307, 308, 309, 311, 314, 315, 318, 319, 321, 32

In [63]:
torch.sum((vae_l1_diag_zeisel(train_data)[0] - train_data)**2) / len(train_data)

tensor(0.2634, device='cuda:0', grad_fn=<DivBackward0>)

In [64]:
torch.sum((vae_l1_diag_zeisel(test_data)[0] - test_data)**2) / len(test_data)

tensor(0.2607, device='cuda:0', grad_fn=<DivBackward0>)

In [65]:
vae_l1_diag_zeisel(test_data)[0][0,:] - test_data[0,:]

tensor([-5.9445e-02,  3.9319e-02, -3.1205e-02, -9.7333e-03, -5.0948e-02,
         3.6899e-02,  1.0317e-02,  5.1701e-03,  3.5854e-02,  3.6776e-02,
         4.0770e-02, -5.2774e-04,  2.3183e-02, -2.9636e-02, -5.4838e-02,
        -3.0614e-02, -1.9111e-02, -1.1739e-02, -1.1819e-02,  2.2336e-02,
        -8.1006e-03,  1.8426e-02, -5.0393e-02,  1.0266e-02, -1.7134e-02,
        -2.6794e-02, -7.7010e-03, -3.1495e-02,  1.5575e-02, -2.1924e-02,
         1.1787e-02, -8.9945e-03,  3.3456e-02, -6.1293e-03, -2.1720e-02,
        -3.2058e-02,  3.1439e-02,  4.6353e-02, -6.3351e-03,  2.2218e-02,
        -7.7415e-03,  7.7045e-03,  1.4176e-02,  5.8020e-04,  6.8509e-03,
        -9.3299e-03,  3.3766e-02,  6.2676e-03, -5.0022e-02,  2.8517e-02,
        -1.2883e-02, -9.3973e-03,  2.7565e-02, -5.0535e-02,  2.4865e-02,
        -2.0956e-02,  3.2009e-02,  1.9721e-02,  3.0522e-02, -3.4791e-02,
         6.9582e-03,  1.1278e-02, -3.2429e-02, -8.5957e-03, -7.4868e-03,
         3.0959e-03,  3.9729e-03,  4.1524e-03, -2.1

In [66]:
test_data[0,:]

tensor([0.1874, 0.0807, 0.1246, 0.1202, 0.1783, 0.0695, 0.0976, 0.0976, 0.0551,
        0.0695, 0.0695, 0.0898, 0.0807, 0.1202, 0.1502, 0.1421, 0.0976, 0.0898,
        0.1102, 0.0695, 0.0807, 0.0695, 0.1358, 0.0807, 0.0898, 0.0976, 0.0898,
        0.1043, 0.0695, 0.0898, 0.0695, 0.0976, 0.0348, 0.0898, 0.1043, 0.1043,
        0.0551, 0.0348, 0.0898, 0.0551, 0.0898, 0.0695, 0.0695, 0.0807, 0.0695,
        0.0695, 0.0348, 0.0551, 0.0976, 0.0551, 0.0695, 0.0807, 0.0551, 0.1154,
        0.0551, 0.0898, 0.0348, 0.0551, 0.0348, 0.0976, 0.0551, 0.0551, 0.0695,
        0.0551, 0.0695, 0.0551, 0.0551, 0.0695, 0.0695, 0.0807, 0.0000, 0.0000,
        0.0348, 0.0551, 0.0551, 0.0551, 0.0348, 0.0000, 0.0000, 0.0348, 0.0551,
        0.0348, 0.0898, 0.0551, 0.0695, 0.0000, 0.0348, 0.0348, 0.0551, 0.0348,
        0.0898, 0.0348, 0.0898, 0.1102, 0.0000, 0.0348, 0.0000, 0.0000, 0.0807,
        0.0551, 0.0000, 0.0695, 0.0000, 0.0695, 0.0348, 0.0551, 0.0348, 0.0551,
        0.0551, 0.0551, 0.0000, 0.0551, 

In [67]:
vae_l1_diag_zeisel(test_data)[0][0,:]

tensor([0.1302, 0.1091, 0.1011, 0.1097, 0.1312, 0.1021, 0.1034, 0.1067, 0.0969,
        0.1008, 0.1022, 0.0906, 0.1040, 0.0866, 0.1003, 0.1133, 0.0854, 0.0772,
        0.0957, 0.0907, 0.0805, 0.0869, 0.0933, 0.0947, 0.0782, 0.0763, 0.0865,
        0.0753, 0.0781, 0.0652, 0.0862, 0.0809, 0.0837, 0.0832, 0.0833, 0.0764,
        0.0937, 0.0799, 0.0815, 0.0810, 0.0921, 0.0746, 0.0959, 0.0840, 0.0715,
        0.0690, 0.0729, 0.0538, 0.0580, 0.0732, 0.0593, 0.0800, 0.0794, 0.0673,
        0.0752, 0.0678, 0.0707, 0.0687, 0.0670, 0.0540, 0.0629, 0.0640, 0.0469,
        0.0507, 0.0565, 0.0543, 0.0633, 0.0698, 0.0672, 0.0588, 0.0603, 0.0478,
        0.0607, 0.0485, 0.0549, 0.0683, 0.0593, 0.0555, 0.0507, 0.0506, 0.0532,
        0.0432, 0.0513, 0.0458, 0.0535, 0.0435, 0.0572, 0.0440, 0.0524, 0.0536,
        0.0445, 0.0454, 0.0430, 0.0506, 0.0542, 0.0354, 0.0385, 0.0387, 0.0442,
        0.0395, 0.0433, 0.0380, 0.0471, 0.0367, 0.0376, 0.0508, 0.0418, 0.0375,
        0.0372, 0.0365, 0.0366, 0.0403, 