Have tried L1 normalization on the weight layer that selects features. Have trained when trying to match a pretrained VAE. Can we train a vanilla model and a L1 model together? How does it compare when each done solo?

In [1]:
import torch
from torch.utils.data import DataLoader

from torchvision import datasets
import torchvision.transforms as transforms


from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F

import numpy as np

from torchvision.utils import save_image

import matplotlib.pyplot as plt
import math

In [2]:
import os
from os import listdir

In [3]:
BASE_PATH_DATA = '../data/'

In [4]:
n_epochs = 5
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
img_size = 28
channels = 1

log_interval = 100


z_size = 40

n = 28 * 28

# from running
# EPSILON = np.finfo(tf.float32.as_numpy_dtype).tiny
#EPSILON = 1.1754944e-38
EPSILON = 1e-10

In [5]:
cuda = True if torch.cuda.is_available() else False

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

device = torch.device("cuda:0" if cuda else "cpu")

In [6]:
print(device)

cuda:0


In [7]:
class VAE(nn.Module):
    def __init__(self, hidden_layer_size, z_size):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(500, hidden_layer_size)
        #self.fcextra = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, 500)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        #h1 = F.leaky_relu(self.fcextra(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.relu(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [8]:
# L1 VAE model we are loading
class VAE_l1_diag(nn.Module):
    def __init__(self, input_size, hidden_layer_size, z_size):
        super(VAE_l1_diag, self).__init__()
        
        self.diag = nn.Parameter(torch.normal(torch.zeros(input_size), 
                                 torch.ones(input_size)).to(device).requires_grad_(True))
        
        self.fc1 = nn.Linear(input_size, hidden_layer_size)
        self.fcextra = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, input_size)

    def encode(self, x):
        self.selection_layer = torch.diag(self.diag)
        h0 = torch.mm(x, self.selection_layer)
        h1 = F.leaky_relu(self.fc1(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.leaky_relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [9]:
def gumbel_keys(w):
    # sample some gumbels
    uniform = (1.0 - EPSILON) * torch.rand_like(w) + EPSILON
    z = torch.log(-torch.log(uniform))
    w = w + z
    return w


#equations 3 and 4 and 5
def continuous_topk(w, k, t, separate=False):
    softmax = nn.Softmax(dim = -1)
    khot_list = []
    onehot_approx = torch.zeros_like(w, dtype = torch.float32)
    for i in range(k):
        ### conver the following into pytorch
        #khot_mask = tf.maximum(1.0 - onehot_approx, EPSILON)
        max_mask = 1 - onehot_approx < EPSILON
        khot_mask = 1 - onehot_approx
        khot_mask[max_mask] = EPSILON
        
        w += torch.log(khot_mask)
        #onehot_approx = tf.nn.softmax(w / t, axis=-1)
        onehot_approx = softmax(w/t)
        khot_list.append(onehot_approx)
    if separate:
        return torch.stack(khot_list)
    else:
        return torch.sum(torch.stack(khot_list), dim = 0) 


def sample_subset(w, k, t=0.1):
    '''
    Args:
        w (Tensor): Float Tensor of weights for each element. In gumbel mode
            these are interpreted as log probabilities
        k (int): number of elements in the subset sample
        t (float): temperature of the softmax
    '''
    w = gumbel_keys(w)
    return continuous_topk(w, k, t)

In [23]:
# L1 VAE model we are loading
class VAE_Gumbel(nn.Module):
    def __init__(self, input_size, hidden_layer_size, z_size, k, t = 0.1):
        super(VAE_Gumbel, self).__init__()
        
        self.k = k
        self.t = t
        
        self.weight_creator = nn.Sequential(
            nn.Linear(input_size, hidden_layer_size),
            nn.ReLU(),
            nn.Linear(hidden_layer_size, input_size)
        )
        
        self.fc1 = nn.Linear(input_size, hidden_layer_size)
        #self.fcextra = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, input_size)

    def encode(self, x):
        w = self.weight_creator(x)
        subset_indices = sample_subset(w, self.k, self.t)
        x = x * subset_indices
        h1 = F.relu(self.fc1(x))
        #h1 = F.leaky_relu(self.fcextra(h0))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.relu(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [24]:
### KLD between the two variational blocks

# KLD of D(P_1||P_2) where P_i are Gaussians, assuming diagonal
def kld_joint_autoencoders(mu_1, mu_2, logvar_1, logvar_2):
    # equation 6 of Tutorial on Variational Autoencoders by Carl Doersch
    # https://arxiv.org/pdf/1606.05908.pdf
    mu_12 = mu_1 - mu_2
    kld = 0.5 * (-1 - (logvar_1 - logvar_2) + mu_12.pow(2) / logvar_2.exp() + torch.exp(logvar_1 - logvar_2))
    #print(kld.shape)
    kld = torch.sum(kld, dim = 1)
    
    return kld.sum()

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function_per_autoencoder(recon_x, x, mu, logvar):
    #BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    MSE = F.mse_loss(recon_x, x, reduction='sum')
    base_loss = MSE
    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return base_loss + KLD


def loss_function(x, ae_1, ae_2):
    # assuming that both autoencoders return recon_x, mu, and logvar
    # try to make ae_1 the vanilla vae
    # ae_2 should be the L1 penalty VAE
    recon_x1, mu_1, logvar_1 = ae_1(x)
    recon_x2, mu_2, logvar_2 = ae_2(x)
    
    loss_vae_1 = loss_function_per_autoencoder(recon_x1, x, mu_1, logvar_1)
    loss_vae_2 = loss_function_per_autoencoder(recon_x2, x, mu_2, logvar_2)
    joint_kld_loss = kld_joint_autoencoders(mu_1, mu_2, logvar_1, logvar_2)
    #print("Losses")
    #print(loss_vae_1)
    #print(loss_vae_2)
    #print(joint_kld_loss)
    return loss_vae_1, loss_vae_2, joint_kld_loss

In [47]:
def train_joint(df, model1, model2, optimizer, epoch):
    model1.train()
    model2.train()
    train_loss = 0
    permutations = torch.randperm(df.shape[0])
    for i in range(math.ceil(len(df)/batch_size)):
        batch_ind = permutations[i * batch_size : (i+1) * batch_size]
        batch_data = df[batch_ind, :]
        
        optimizer.zero_grad()
        
        loss_vae_1, loss_vae_2, joint_kld_loss = loss_function(batch_data, model1, model2)
        
        loss = loss_vae_1 + loss_vae_2 + joint_kld_loss
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(batch_data), len(df),
                100. * i / len(df),
                loss.item() / len(batch_data)))
            
            print('Loss vae 1: {}\tLoss vae 2: {}\tJoint KLD Loss {}'.format(
            loss_vae_1/len(batch_data),
            loss_vae_2/len(batch_data),
            joint_kld_loss/len(batch_data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(df)))

In [48]:
def test_joint(df, model1, model2, epoch):
    model1.eval()
    model2.eval()
    test_loss = 0
    inds = np.arange(df.shape[0])
    with torch.no_grad():
        for i in range(math.ceil(len(df)/batch_size)):
            batch_ind = inds[i * batch_size : (i+1) * batch_size]
            batch_data = df[batch_ind, :]
            batch_data = batch_data.to(device)
            loss_vae_1, loss_vae_2, joint_kld_loss = loss_function(batch_data, model1, model2)
        
            test_loss += loss_vae_1 + loss_vae_2 + joint_kld_loss


    test_loss /= len(df)
    print('====> Test set loss: {:.4f}'.format(test_loss))

### set up the data

In [49]:
import scipy.io as sio

In [50]:
a = sio.loadmat("../data/zeisel/CITEseq.mat")
data= a['G'].T
N,d=data.shape
#transformation from integer entries 
data=np.log(data+np.ones(data.shape))
#for i in range(N):
for i in range(d):
    #data[i,:]=data[i,:]/np.linalg.norm(data[i,:])
    data[:,i]= (data[:,i] - np.min(data[:,i])) /  (np.max(data[:,i]) - np.min(data[:, i]))

#load labels from file
a = sio.loadmat("../data/zeisel/CITEseq-labels.mat")
l_aux = a['labels']
labels = np.array([i for [i] in l_aux])

#load names from file
a = sio.loadmat("../data/zeisel/CITEseq_names.mat")
names=[a['citeseq_names'][i][0][0] for i in range(N)]

In [51]:
slices = np.random.permutation(np.arange(data.shape[0]))
upto = int(.8 * len(data))

In [52]:
train_data = data[slices[:upto]]
test_data = data[slices[upto:]]

train_data = Tensor(train_data).to(device)
test_data = Tensor(test_data).to(device)

In [53]:
vanilla_vae_zeisel = VAE(250, 200)
vanilla_vae_zeisel.to(device)

VAE(
  (fc1): Linear(in_features=500, out_features=250, bias=True)
  (fc21): Linear(in_features=250, out_features=200, bias=True)
  (fc22): Linear(in_features=250, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=250, bias=True)
  (fc4): Linear(in_features=250, out_features=500, bias=True)
)

In [54]:
vanilla_vae_zeisel

VAE(
  (fc1): Linear(in_features=500, out_features=250, bias=True)
  (fc21): Linear(in_features=250, out_features=200, bias=True)
  (fc22): Linear(in_features=250, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=250, bias=True)
  (fc4): Linear(in_features=250, out_features=500, bias=True)
)

In [55]:
vae_gumbel = VAE_Gumbel(500, 250, 200, k = 200)
vae_gumbel.to(device)

VAE_Gumbel(
  (weight_creator): Sequential(
    (0): Linear(in_features=500, out_features=250, bias=True)
    (1): ReLU()
    (2): Linear(in_features=250, out_features=500, bias=True)
  )
  (fc1): Linear(in_features=500, out_features=250, bias=True)
  (fc21): Linear(in_features=250, out_features=200, bias=True)
  (fc22): Linear(in_features=250, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=250, bias=True)
  (fc4): Linear(in_features=250, out_features=500, bias=True)
)

In [56]:
joint_optimizer = torch.optim.Adam(list(vanilla_vae_zeisel.parameters()) + list(vae_gumbel.parameters()), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [57]:
for epoch in range(1, 50 + 1):
    train_joint(train_data, vanilla_vae_zeisel, vae_gumbel, joint_optimizer, epoch)
    test_joint(test_data, vanilla_vae_zeisel, vae_gumbel, epoch)

Loss vae 1: 38.75425338745117	Loss vae 2: 38.25804901123047	Joint KLD Loss 1.3607683181762695
Loss vae 1: 26.20536231994629	Loss vae 2: 26.407848358154297	Joint KLD Loss 0.2125486582517624
====> Epoch: 1 Average loss: 59.7581
====> Test set loss: 51.7693
Loss vae 1: 24.932126998901367	Loss vae 2: 25.524137496948242	Joint KLD Loss 0.26394617557525635
Loss vae 1: 21.23501205444336	Loss vae 2: 20.887840270996094	Joint KLD Loss 0.2497882843017578
====> Epoch: 2 Average loss: 48.3270
====> Test set loss: 45.1713
Loss vae 1: 22.338294982910156	Loss vae 2: 22.102371215820312	Joint KLD Loss 0.3088669776916504
Loss vae 1: 19.6541805267334	Loss vae 2: 20.398244857788086	Joint KLD Loss 0.2005358636379242
====> Epoch: 3 Average loss: 42.8524
====> Test set loss: 40.6811
Loss vae 1: 21.339506149291992	Loss vae 2: 21.048254013061523	Joint KLD Loss 0.21843723952770233
Loss vae 1: 18.05624771118164	Loss vae 2: 18.214248657226562	Joint KLD Loss 0.17674225568771362
====> Epoch: 4 Average loss: 39.1408
=

Loss vae 1: 14.997343063354492	Loss vae 2: 14.743370056152344	Joint KLD Loss 0.050267454236745834
====> Epoch: 24 Average loss: 31.0699
====> Test set loss: 31.0906
Loss vae 1: 15.18156909942627	Loss vae 2: 15.287999153137207	Joint KLD Loss 0.07508301734924316
Loss vae 1: 16.07442283630371	Loss vae 2: 15.987509727478027	Joint KLD Loss 0.0756925567984581
====> Epoch: 25 Average loss: 31.0014
====> Test set loss: 30.9272
Loss vae 1: 14.69028377532959	Loss vae 2: 14.180795669555664	Joint KLD Loss 0.059787970036268234
Loss vae 1: 16.044801712036133	Loss vae 2: 16.172056198120117	Joint KLD Loss 0.06199256703257561
====> Epoch: 26 Average loss: 31.0274
====> Test set loss: 30.9478
Loss vae 1: 16.265708923339844	Loss vae 2: 16.29326629638672	Joint KLD Loss 0.07221371680498123
Loss vae 1: 14.635459899902344	Loss vae 2: 15.20244312286377	Joint KLD Loss 0.062713623046875
====> Epoch: 27 Average loss: 30.9603
====> Test set loss: 30.8924
Loss vae 1: 14.717414855957031	Loss vae 2: 14.2828559875488

====> Epoch: 47 Average loss: 30.4621
====> Test set loss: 30.3816
Loss vae 1: 15.341231346130371	Loss vae 2: 15.212363243103027	Joint KLD Loss 0.05455673113465309
Loss vae 1: 15.74246883392334	Loss vae 2: 15.321748733520508	Joint KLD Loss 0.04794357344508171
====> Epoch: 48 Average loss: 30.3907
====> Test set loss: 30.4078
Loss vae 1: 15.447562217712402	Loss vae 2: 16.094022750854492	Joint KLD Loss 0.053147003054618835
Loss vae 1: 14.497469902038574	Loss vae 2: 15.361926078796387	Joint KLD Loss 0.048253707587718964
====> Epoch: 49 Average loss: 30.3841
====> Test set loss: 30.3752
Loss vae 1: 15.511667251586914	Loss vae 2: 16.03175926208496	Joint KLD Loss 0.05240229517221451
Loss vae 1: 15.115747451782227	Loss vae 2: 15.207712173461914	Joint KLD Loss 0.05242437869310379
====> Epoch: 50 Average loss: 30.4151
====> Test set loss: 30.3500


In [58]:
with torch.no_grad():
    print(torch.sum((vanilla_vae_zeisel(test_data[0:64, :])[0] - test_data[0:64, :])**2) / 64)

tensor(13.6654, device='cuda:0')


In [59]:
with torch.no_grad():
    print(torch.sum((vae_gumbel(test_data[0:64, :])[0] - test_data[0:64, :])**2) / 64)

tensor(13.6343, device='cuda:0')


final_losses = []
for k in [10, 25, 50, 250]:
    vanilla_vae_zeisel = VAE(250, 20)
    vanilla_vae_zeisel.to(device)
    vanilla_optimizer_zeisel = torch.optim.Adam(vanilla_vae_zeisel.parameters(), 
                                            lr=lr, 
                                            betas = (b1,b2))
    
    vae_gumbel = VAE_Gumbel(500, 250, 20, k = k)
    vae_gumbel.to(device)
    vae_gumbel_optimizer = torch.optim.Adam(vae_gumbel.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))
    for epoch in range(1, 50 + 1):
        train_joint(train_data, vanilla_vae_zeisel, vae_gumbel, vanilla_optimizer_zeisel, vae_gumbel_optimizer, epoch)
    print("Gumbel Reconstruction Loss with Joint Training at k {}".format(k))
    with torch.no_grad():
        final_losses.append(torch.sum((vae_gumbel(test_data[0:64, :])[0] - test_data[0:64, :])**2) / 64)

final_losses

In [60]:
vae_gumbel(test_data[0, :])[0]

tensor([0.5349, 0.5274, 0.4967, 0.4788, 0.4552, 0.4752, 0.4918, 0.4644, 0.4842,
        0.4379, 0.4471, 0.4018, 0.4543, 0.3787, 0.3776, 0.4840, 0.3472, 0.3995,
        0.4212, 0.4618, 0.4138, 0.3804, 0.4153, 0.4940, 0.3286, 0.4405, 0.4333,
        0.4224, 0.4106, 0.3891, 0.4117, 0.4167, 0.3601, 0.4228, 0.4414, 0.4511,
        0.3945, 0.4099, 0.3780, 0.4205, 0.4538, 0.3754, 0.4761, 0.4042, 0.3593,
        0.3754, 0.3514, 0.4047, 0.3747, 0.3668, 0.2910, 0.4076, 0.4145, 0.2974,
        0.4268, 0.4075, 0.3685, 0.3524, 0.3864, 0.4225, 0.3351, 0.3585, 0.2713,
        0.3217, 0.3263, 0.3203, 0.3258, 0.3880, 0.3618, 0.3389, 0.3422, 0.3264,
        0.3431, 0.2994, 0.3474, 0.3597, 0.3084, 0.3165, 0.3374, 0.2574, 0.3257,
        0.2464, 0.3373, 0.2775, 0.3463, 0.2811, 0.3196, 0.2957, 0.3210, 0.2780,
        0.3342, 0.2280, 0.2287, 0.3220, 0.3781, 0.2444, 0.2456, 0.2455, 0.3262,
        0.2350, 0.2349, 0.1963, 0.2622, 0.2510, 0.2336, 0.2822, 0.2149, 0.3032,
        0.1881, 0.2813, 0.1917, 0.2142, 

In [61]:
test_data[0, :]

tensor([0.6068, 0.6551, 0.5780, 0.5294, 0.4702, 0.6572, 0.6578, 0.5857, 0.6613,
        0.5666, 0.5739, 0.3545, 0.6372, 0.3434, 0.2615, 0.6320, 0.2276, 0.3662,
        0.5136, 0.4590, 0.6113, 0.4520, 0.4931, 0.6468, 0.2969, 0.5324, 0.5899,
        0.5625, 0.5376, 0.5772, 0.5386, 0.6193, 0.2588, 0.5929, 0.5571, 0.6239,
        0.6625, 0.5682, 0.5634, 0.5998, 0.5809, 0.5121, 0.6033, 0.4546, 0.3926,
        0.5438, 0.5582, 0.4037, 0.4857, 0.4359, 0.3070, 0.5646, 0.6615, 0.2881,
        0.7467, 0.4565, 0.5441, 0.3444, 0.5953, 0.4732, 0.5364, 0.4981, 0.2493,
        0.1810, 0.5172, 0.2361, 0.5444, 0.5303, 0.4621, 0.1596, 0.4457, 0.3931,
        0.1700, 0.1679, 0.2717, 0.5608, 0.4428, 0.3694, 0.4735, 0.1499, 0.4308,
        0.0000, 0.4117, 0.5303, 0.4604, 0.4424, 0.5058, 0.4513, 0.4016, 0.4232,
        0.4628, 0.3828, 0.2537, 0.2999, 0.4363, 0.3020, 0.3709, 0.1700, 0.2000,
        0.3182, 0.4113, 0.1407, 0.5723, 0.5377, 0.0000, 0.4232, 0.3623, 0.0000,
        0.2759, 0.0000, 0.2275, 0.2652, 

In [62]:
vanilla_vae_zeisel(test_data[0,:])[0]

tensor([0.5567, 0.5163, 0.4704, 0.4493, 0.4056, 0.4780, 0.5040, 0.4574, 0.4642,
        0.4493, 0.3840, 0.3351, 0.4668, 0.3232, 0.3194, 0.4631, 0.2301, 0.3734,
        0.4104, 0.4340, 0.4029, 0.3576, 0.3383, 0.4662, 0.2737, 0.4077, 0.3662,
        0.4180, 0.4007, 0.3476, 0.3970, 0.4113, 0.2377, 0.3731, 0.4503, 0.4273,
        0.3733, 0.3936, 0.3678, 0.4015, 0.4214, 0.3478, 0.4196, 0.3530, 0.3519,
        0.3371, 0.3481, 0.3700, 0.2923, 0.4254, 0.2635, 0.3925, 0.3928, 0.2604,
        0.4173, 0.3662, 0.3461, 0.3125, 0.3394, 0.4337, 0.3063, 0.3426, 0.2847,
        0.2596, 0.2968, 0.2763, 0.3311, 0.3793, 0.3387, 0.3180, 0.3449, 0.2736,
        0.2344, 0.2327, 0.3148, 0.3415, 0.2950, 0.2932, 0.2958, 0.2145, 0.3119,
        0.1669, 0.3295, 0.2516, 0.3234, 0.2192, 0.3227, 0.3188, 0.2986, 0.2863,
        0.3453, 0.1990, 0.2382, 0.2161, 0.3761, 0.2013, 0.2577, 0.1789, 0.3311,
        0.2012, 0.2155, 0.1727, 0.2801, 0.2502, 0.1598, 0.3068, 0.2291, 0.2302,
        0.1302, 0.2056, 0.0892, 0.1937, 

In [63]:
np.where(vanilla_vae_zeisel(test_data[0,:])[0].clone().detach().cpu().numpy() == 0)[0]

array([176, 199, 212, 215, 225, 262, 263, 265, 267, 279, 283, 292, 311,
       331, 376, 391, 407, 409, 445, 447, 454, 476, 488])

In [64]:
print(torch.sum(vae_gumbel(test_data[0, :])[0] != 0))

tensor(499, device='cuda:0')


In [65]:
print(torch.sum(test_data[0, :] != 0))

tensor(263, device='cuda:0')


In [66]:
print(torch.sum(vanilla_vae_zeisel(test_data[0, :])[0] != 0))

tensor(488, device='cuda:0')


In [67]:
w = vae_gumbel.weight_creator(test_data[0:64, :])
subset_indices = sample_subset(w, k = 200, t = 0.1)

In [68]:
subset_indices.sum(dim = 1)

tensor([200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000,
        200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000,
        200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000,
        200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000,
        200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000,
        200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000,
        200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000,
        200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000,
        200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000, 200.0000,
        200.0000], device='cuda:0', grad_fn=<SumBackward1>)

In [69]:
def len_unique(arr):
    return len(np.unique(arr))

In [70]:
np.apply_along_axis(arr=test_data.clone().detach().cpu().numpy(), axis = 1, func1d=len_unique)

array([182, 252, 149, ..., 166, 220, 193])

In [71]:
np.sum(test_data[:, 499].clone().detach().cpu().numpy() > 0) /len(test_data)

0.3677494199535963

In [72]:
np.sum(test_data[:, 60].clone().detach().cpu().numpy() > 0) /len(test_data)

0.9408352668213457