So we know that Gumbel selects things relatively well. Its effects of Zeisel though are a bit muddled because of reconstruction. Let's do a simple synethetic dataset. Half the features are real. Half the features at noise.

In [1]:
import torch


from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F

import numpy as np

import matplotlib.pyplot as plt
#from sklearn.manifold import TSNE

#import math

#import gc

from utils import *

from sklearn.preprocessing import MinMaxScaler

In [2]:
torch.manual_seed(0)
np.random.seed(0)

In [3]:
#BASE_PATH_DATA = '../data/'
BASE_PATH_DATA = '/scratch/ns3429/sparse-subset/data/'

In [4]:
# really good results for vanilla VAE on synthetic data with EPOCHS set to 50, 
# but when running locally set to 10 for reasonable run times
#n_epochs = 50
n_epochs = 20
batch_size = 64
lr = 0.0001
b1 = 0.9
b2 = 0.999


# from running
# EPSILON = np.finfo(tf.float32.as_numpy_dtype).tiny
#EPSILON = 1.1754944e-38
EPSILON = 1e-10

In [5]:
cuda = True if torch.cuda.is_available() else False

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

device = torch.device("cuda:0" if cuda else "cpu")
#device = 'cpu'

In [6]:
print("Device")
print(device)

Device
cuda:0


Create the data

In [7]:
D = 30
N = 5000
z_size = 5

In [8]:
latent_data = np.random.normal(loc=0.0, scale=1.0, size=N*z_size).reshape(N, z_size)

In [9]:
data_mapper = nn.Sequential(
    nn.Linear(z_size, 2 * z_size, bias=False),
    nn.Tanh(),
    nn.Linear(2 * z_size, D, bias = True),
    nn.ReLU()
).to(device)

data_mapper.requires_grad_(False)

Sequential(
  (0): Linear(in_features=5, out_features=10, bias=False)
  (1): Tanh()
  (2): Linear(in_features=10, out_features=30, bias=True)
  (3): ReLU()
)

In [10]:
latent_data = Tensor(latent_data)
latent_data.requires_grad_(False)

tensor([[ 1.7641,  0.4002,  0.9787,  2.2409,  1.8676],
        [-0.9773,  0.9501, -0.1514, -0.1032,  0.4106],
        [ 0.1440,  1.4543,  0.7610,  0.1217,  0.4439],
        ...,
        [ 0.2501, -1.0168,  0.0459,  0.5006,  1.2243],
        [-0.5595,  1.5234, -0.5857,  0.8466, -0.1063],
        [ 0.7700,  0.7508, -0.5606, -1.7603,  0.4371]], device='cuda:0')

In [11]:
actual_data = data_mapper(latent_data)

In [12]:
for i in range(5):
    print(torch.sum(actual_data[i,:] != 0))

tensor(19, device='cuda:0')
tensor(12, device='cuda:0')
tensor(18, device='cuda:0')
tensor(14, device='cuda:0')
tensor(14, device='cuda:0')


For each sample, half the data is non zero, whereas in zeisel, about 25% if non zero. Easier than Zeisel good.

In [13]:
noise_features = torch.empty(N * D).normal_(mean=0,std=0.01).reshape(N, D).to(device)
noise_features.requires_grad_(False)

tensor([[-1.0921e-02, -6.1085e-04, -1.4928e-02,  ..., -1.4309e-02,
          1.6859e-02, -1.2177e-02],
        [ 7.6496e-03,  1.1971e-02, -2.2414e-02,  ...,  1.0256e-02,
         -5.5957e-03,  4.3434e-03],
        [ 2.7566e-03,  1.0969e-03,  3.5942e-03,  ...,  6.0039e-03,
          8.7524e-04,  7.0365e-03],
        ...,
        [ 1.8449e-02,  8.3797e-04, -8.9499e-03,  ...,  8.9735e-04,
         -1.6982e-03,  7.8153e-03],
        [-1.0649e-02, -9.6204e-03, -8.1562e-03,  ..., -2.2612e-04,
         -1.4104e-02, -8.2127e-03],
        [ 2.1183e-02, -1.1416e-02,  1.8769e-03,  ..., -1.3100e-02,
         -6.2333e-03, -4.3646e-05]], device='cuda:0')

In [14]:
actual_data = torch.cat([actual_data, noise_features], dim = 1)

In [15]:
actual_data.shape

torch.Size([5000, 60])

In [16]:
actual_data = actual_data.cpu().numpy()
scaler = MinMaxScaler()
actual_data = scaler.fit_transform(actual_data)

actual_data = Tensor(actual_data)

In [17]:
actual_data.std(dim = 0)

tensor([0.1866, 0.2313, 0.2054, 0.2209, 0.2323, 0.1899, 0.1801, 0.1969, 0.1133,
        0.2353, 0.0925, 0.1310, 0.1725, 0.1902, 0.2294, 0.2275, 0.2082, 0.0530,
        0.0980, 0.1738, 0.1728, 0.2156, 0.0460, 0.0932, 0.0255, 0.1816, 0.1587,
        0.2263, 0.2125, 0.2393, 0.1326, 0.1439, 0.1281, 0.1421, 0.1297, 0.1413,
        0.1492, 0.1272, 0.1233, 0.1420, 0.1422, 0.1363, 0.1256, 0.1288, 0.1377,
        0.1438, 0.1337, 0.1331, 0.1258, 0.1346, 0.1507, 0.1223, 0.1429, 0.1343,
        0.1348, 0.1361, 0.1388, 0.1391, 0.1426, 0.1383], device='cuda:0')

Standard deviatiosn are comparable.

In [18]:
actual_data.max(dim = 0)

torch.return_types.max(
values=tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000], device='cuda:0'),
indices=tensor([ 738, 4262, 1553, 1484, 2220, 1316, 3892, 1316, 3121, 3883, 1838,  623,
        1004, 4856,  689, 2033, 2038, 1316, 4515, 4562, 4668,  616,  894, 4515,
        1885, 3892, 4615,  819, 4397, 4293,  713, 2220, 3813, 4659, 4389, 3659,
         309, 1804,  495, 4790, 3110, 4671,   36,    0, 1215,  148, 4008, 1317,
        2503, 1402, 1580, 2684, 4078, 3334, 1376, 2499, 1301, 3114, 4203, 3183],
     

In [19]:
actual_data.min(dim = 0)

torch.return_types.min(
values=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0'),
indices=tensor([   1,    8,    0,    1,    0,    5,    1,    1,    1,    0,    1,    1,
           0,   13,    3,    3,    1,    0,    0,    1,    1,    8,    0,    0,
           0,    1,  254,    4,    0,    0,  454, 3677, 1909, 3750, 3638, 4476,
         553, 4105,  289, 1150, 2707, 1846, 3579, 3101,  299, 1324, 3277, 4318,
        3023,  967, 2932, 3588,  919, 1190,  271, 2937, 3428, 3955, 2719,  198],
       device='cuda:0'))

In [20]:
for i in range(5):
    print(torch.sum(actual_data[i,:] != 0))

tensor(49, device='cuda:0')
tensor(42, device='cuda:0')
tensor(48, device='cuda:0')
tensor(44, device='cuda:0')
tensor(44, device='cuda:0')


In [21]:
slices = np.random.permutation(np.arange(actual_data.shape[0]))
upto = int(.8 * len(actual_data))

train_data = actual_data[slices[:upto]]
test_data = actual_data[slices[upto:]]

In [22]:
train_data.shape

torch.Size([4000, 60])

In [23]:
test_data.shape

torch.Size([1000, 60])

Data is all ready. Now time to feed into into a pretraining-matching Gumbel and joint training Gumbel.

Pre train VAE First

In [24]:
pretrain_vae = VAE(2*D, 100, 20)

pretrain_vae.to(device)
pretrain_vae_optimizer = torch.optim.Adam(pretrain_vae.parameters(), 
                                            lr=lr,
                                            betas = (b1,b2))

#pretrain_vae_optimizer = torch.optim.SGD(pretrain_vae.parameters(), 
#                                            lr=lr, momentum = 0.9)

In [25]:
for epoch in range(1, n_epochs + 1):
        train(train_data, pretrain_vae, pretrain_vae_optimizer, epoch, batch_size)
        test(test_data, pretrain_vae, epoch, batch_size)

====> Epoch: 1 Average loss: 40.4244
====> Test set loss: 38.9524
====> Epoch: 2 Average loss: 37.3049
====> Test set loss: 35.4656
====> Epoch: 3 Average loss: 34.6850
====> Test set loss: 34.1042
====> Epoch: 4 Average loss: 33.8812
====> Test set loss: 33.6500
====> Epoch: 5 Average loss: 33.5430
====> Test set loss: 33.3353
====> Epoch: 6 Average loss: 33.3164
====> Test set loss: 33.1817
====> Epoch: 7 Average loss: 33.1319
====> Test set loss: 33.0265
====> Epoch: 8 Average loss: 33.0185
====> Test set loss: 32.9396
====> Epoch: 9 Average loss: 32.9161
====> Test set loss: 32.8299
====> Epoch: 10 Average loss: 32.8035
====> Test set loss: 32.6680
====> Epoch: 11 Average loss: 32.5824
====> Test set loss: 32.3997
====> Epoch: 12 Average loss: 32.3290
====> Test set loss: 32.1813
====> Epoch: 13 Average loss: 32.1558
====> Test set loss: 32.0547
====> Epoch: 14 Average loss: 32.0321
====> Test set loss: 31.9535
====> Epoch: 15 Average loss: 31.9635
====> Test set loss: 31.8698
====

In [26]:
with torch.no_grad():
    print("Test Loss")
    print(F.binary_cross_entropy(pretrain_vae(test_data)[0], test_data))

Test Loss
tensor(0.5172, device='cuda:0')


Actually pretty good! %35 percent off when wrong

Get 0.49 when nepochs is 50.
Get 0.54 when nepochs is 10.

As a note, if the final layer of the data mapper is not ReLU, this reconstruction is usually on point. When some of the features can be sparse, then this becomes troublesome.

Compare means

In [27]:
train_data.mean(dim = 0)[:D]

tensor([0.0931, 0.3051, 0.1384, 0.2096, 0.2092, 0.2112, 0.1042, 0.2325, 0.0347,
        0.2014, 0.0244, 0.0502, 0.0919, 0.2735, 0.2254, 0.3294, 0.2786, 0.0084,
        0.0338, 0.1001, 0.0849, 0.2512, 0.0048, 0.0257, 0.0015, 0.0874, 0.4764,
        0.3270, 0.2107, 0.2899], device='cuda:0')

In [28]:
test_data.mean(dim = 0)[:D]

tensor([0.0978, 0.3031, 0.1394, 0.2099, 0.2063, 0.2098, 0.1076, 0.2316, 0.0298,
        0.1968, 0.0261, 0.0561, 0.0848, 0.2727, 0.2320, 0.3355, 0.2739, 0.0056,
        0.0301, 0.0953, 0.0889, 0.2483, 0.0063, 0.0231, 0.0009, 0.0924, 0.4782,
        0.3251, 0.2092, 0.2881], device='cuda:0')

In [29]:
pretrain_vae(test_data)[0].mean(dim = 0)[:D]

tensor([0.0974, 0.3121, 0.1461, 0.2162, 0.2220, 0.2156, 0.1141, 0.2397, 0.0427,
        0.2098, 0.0319, 0.0692, 0.1002, 0.2707, 0.2366, 0.3450, 0.2797, 0.0256,
        0.0445, 0.0980, 0.1043, 0.2610, 0.0198, 0.0355, 0.0178, 0.1067, 0.4767,
        0.3174, 0.2263, 0.2916], device='cuda:0', grad_fn=<SliceBackward>)

Compare standard deviations

In [30]:
test_data.std(dim = 0)[:D]

tensor([0.1913, 0.2338, 0.2076, 0.2199, 0.2306, 0.1890, 0.1805, 0.1945, 0.1051,
        0.2321, 0.0976, 0.1389, 0.1652, 0.1922, 0.2341, 0.2330, 0.2048, 0.0390,
        0.0929, 0.1692, 0.1755, 0.2236, 0.0499, 0.0923, 0.0152, 0.1863, 0.1593,
        0.2244, 0.2096, 0.2350], device='cuda:0')

In [75]:
pretrain_vae(test_data)[0].std(dim = 0)[:D]

tensor([0.1150, 0.0651, 0.1637, 0.1249, 0.0813, 0.0640, 0.1193, 0.1396, 0.0550,
        0.1004, 0.0389, 0.0728, 0.0831, 0.1564, 0.2018, 0.1958, 0.0672, 0.0400,
        0.0595, 0.1027, 0.1094, 0.1252, 0.0288, 0.0476, 0.0292, 0.1376, 0.0661,
        0.1911, 0.0611, 0.1617], device='cuda:0')

In [74]:
pretrain_vae(test_data)[0].std(dim = 0)[D:2*D]

tensor([0.0227, 0.0189, 0.0205, 0.0207, 0.0211, 0.0221, 0.0240, 0.0220, 0.0226,
        0.0221, 0.0231, 0.0189, 0.0223, 0.0205, 0.0187, 0.0195, 0.0201, 0.0195,
        0.0187, 0.0216, 0.0235, 0.0188, 0.0225, 0.0241, 0.0205, 0.0223, 0.0199,
        0.0221, 0.0198, 0.0199], device='cuda:0')

In [32]:
average_std = pretrain_vae(test_data)[0].std(dim = 0)[:D] / test_data.std(dim = 0)[:D]

In [33]:
print(average_std)
print(average_std.mean().item())

tensor([0.5937, 0.2845, 0.7928, 0.5789, 0.3512, 0.3359, 0.6825, 0.7307, 0.5134,
        0.4415, 0.3738, 0.5347, 0.4935, 0.8127, 0.8684, 0.8394, 0.3234, 1.0021,
        0.5954, 0.5871, 0.6287, 0.5625, 0.5662, 0.4930, 1.7515, 0.7608, 0.4121,
        0.8517, 0.2969, 0.6839], device='cuda:0', grad_fn=<DivBackward0>)
0.6247512102127075


Get .8 as the mean when nepoch is 50. Get 0.43 as the mean when nepochs is 10.

Compare values

In [34]:
samp = 45

In [35]:
test_data[samp,:D]

tensor([0.0000, 0.5286, 0.0000, 0.0231, 0.6963, 0.3966, 0.0103, 0.2939, 0.0000,
        0.3152, 0.0000, 0.0000, 0.0000, 0.0617, 0.4150, 0.6262, 0.2110, 0.0000,
        0.0000, 0.0000, 0.0000, 0.6829, 0.0000, 0.0000, 0.0000, 0.0000, 0.3516,
        0.0000, 0.6491, 0.4560], device='cuda:0')

In [36]:
pretrain_vae(test_data)[0][samp, :D]

tensor([0.0260, 0.4043, 0.0297, 0.2576, 0.3167, 0.2183, 0.1417, 0.3313, 0.0390,
        0.1503, 0.0259, 0.0783, 0.0397, 0.1553, 0.3596, 0.5376, 0.1995, 0.0185,
        0.0231, 0.0407, 0.1605, 0.3890, 0.0124, 0.0281, 0.0093, 0.1379, 0.4134,
        0.1563, 0.2425, 0.1784], device='cuda:0', grad_fn=<SliceBackward>)

In [37]:
torch.abs(test_data[samp,:D] - pretrain_vae(test_data)[0][samp, :D]).mean()

tensor(0.1230, device='cuda:0', grad_fn=<MeanBackward0>)

In [38]:
pretrain_vae(test_data)[1][:, :D].mean()

tensor(-0.0455, device='cuda:0', grad_fn=<MeanBackward0>)

In [39]:
torch.exp(pretrain_vae(test_data)[2][:, :D]).mean()

tensor(0.7776, device='cuda:0', grad_fn=<MeanBackward0>)

In [40]:
pretrain_vae.requires_grad_(False)

VAE(
  (encoder): Sequential(
    (0): Linear(in_features=60, out_features=200, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=200, out_features=100, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=100, out_features=100, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Linear(in_features=100, out_features=100, bias=True)
    (7): LeakyReLU(negative_slope=0.01)
  )
  (enc_mean): Linear(in_features=100, out_features=20, bias=True)
  (enc_logvar): Linear(in_features=100, out_features=20, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=20, out_features=200, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=200, out_features=60, bias=True)
    (3): Sigmoid()
  )
)

This looks pretty good.

**Gumbel matching pretrained VAE next**

In [41]:
# let's see how it does here
vae_gumbel_with_pre = VAE_Gumbel(2*D, 100, 20, k = 3*z_size)
vae_gumbel_with_pre.to(device)
vae_gumbel_with_pre_optimizer = torch.optim.Adam(vae_gumbel_with_pre.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [42]:
for epoch in range(1, n_epochs + 1):
    train_pre_trained(train_data, vae_gumbel_with_pre, vae_gumbel_with_pre_optimizer, 
                      epoch, pretrain_vae, batch_size)
    test(test_data, vae_gumbel_with_pre, epoch, batch_size)

====> Epoch: 1 Average loss: 126.1919
====> Test set loss: 39.1105
====> Epoch: 2 Average loss: 109.5175
====> Test set loss: 36.7800
====> Epoch: 3 Average loss: 94.6618
====> Test set loss: 34.9128
====> Epoch: 4 Average loss: 90.6348
====> Test set loss: 34.0960
====> Epoch: 5 Average loss: 85.0320
====> Test set loss: 33.6748
====> Epoch: 6 Average loss: 70.3975
====> Test set loss: 33.3098
====> Epoch: 7 Average loss: 54.9485
====> Test set loss: 32.9794
====> Epoch: 8 Average loss: 50.1609
====> Test set loss: 32.6813
====> Epoch: 9 Average loss: 47.6164
====> Test set loss: 32.4655
====> Epoch: 10 Average loss: 45.8215
====> Test set loss: 32.2971
====> Epoch: 11 Average loss: 44.7431
====> Test set loss: 32.1836
====> Epoch: 12 Average loss: 43.9962
====> Test set loss: 32.0745
====> Epoch: 13 Average loss: 43.3767
====> Test set loss: 31.9876
====> Epoch: 14 Average loss: 42.6998
====> Test set loss: 31.9468
====> Epoch: 15 Average loss: 42.4021
====> Test set loss: 31.8453
==

In [43]:
with torch.no_grad():
    print("Test Loss")
    print(F.binary_cross_entropy(vae_gumbel_with_pre(test_data)[0], test_data))

Test Loss
tensor(0.5209, device='cuda:0')


Means

In [44]:
test_data.mean(dim = 0)[:D]

tensor([0.0978, 0.3031, 0.1394, 0.2099, 0.2063, 0.2098, 0.1076, 0.2316, 0.0298,
        0.1968, 0.0261, 0.0561, 0.0848, 0.2727, 0.2320, 0.3355, 0.2739, 0.0056,
        0.0301, 0.0953, 0.0889, 0.2483, 0.0063, 0.0231, 0.0009, 0.0924, 0.4782,
        0.3251, 0.2092, 0.2881], device='cuda:0')

In [45]:
vae_gumbel_with_pre(test_data)[0].mean(dim = 0)[:D]

tensor([0.1026, 0.3113, 0.1555, 0.2103, 0.2154, 0.2139, 0.1029, 0.2296, 0.0346,
        0.2128, 0.0285, 0.0556, 0.1053, 0.2863, 0.2217, 0.3214, 0.2817, 0.0200,
        0.0359, 0.0998, 0.0931, 0.2494, 0.0175, 0.0319, 0.0153, 0.0941, 0.4700,
        0.3410, 0.2185, 0.3018], device='cuda:0', grad_fn=<SliceBackward>)

Deviations

In [77]:
test_data.std(dim = 0)[:D]

tensor([0.1913, 0.2338, 0.2076, 0.2199, 0.2306, 0.1890, 0.1805, 0.1945, 0.1051,
        0.2321, 0.0976, 0.1389, 0.1652, 0.1922, 0.2341, 0.2330, 0.2048, 0.0390,
        0.0929, 0.1692, 0.1755, 0.2236, 0.0499, 0.0923, 0.0152, 0.1863, 0.1593,
        0.2244, 0.2096, 0.2350], device='cuda:0')

In [79]:
vae_gumbel_with_pre(test_data)[0].std(dim = 0)[:D]

tensor([0.1153, 0.0518, 0.1646, 0.1163, 0.0842, 0.0573, 0.1136, 0.1312, 0.0339,
        0.0973, 0.0276, 0.0546, 0.0667, 0.1537, 0.1882, 0.1878, 0.0449, 0.0208,
        0.0360, 0.0927, 0.0970, 0.1194, 0.0200, 0.0306, 0.0193, 0.1121, 0.0552,
        0.1921, 0.0659, 0.1561], device='cuda:0', grad_fn=<SliceBackward>)

Values

In [46]:
test_data[samp,:D]

tensor([0.0000, 0.5286, 0.0000, 0.0231, 0.6963, 0.3966, 0.0103, 0.2939, 0.0000,
        0.3152, 0.0000, 0.0000, 0.0000, 0.0617, 0.4150, 0.6262, 0.2110, 0.0000,
        0.0000, 0.0000, 0.0000, 0.6829, 0.0000, 0.0000, 0.0000, 0.0000, 0.3516,
        0.0000, 0.6491, 0.4560], device='cuda:0')

In [47]:
vae_gumbel_with_pre(test_data)[0][samp, :D]

tensor([0.0356, 0.3265, 0.0356, 0.2972, 0.3427, 0.2824, 0.1551, 0.3338, 0.0507,
        0.1601, 0.0465, 0.1108, 0.0659, 0.1482, 0.4148, 0.5075, 0.3063, 0.0286,
        0.0437, 0.0535, 0.1857, 0.3813, 0.0241, 0.0340, 0.0231, 0.2043, 0.4821,
        0.1731, 0.2410, 0.1717], device='cuda:0', grad_fn=<SliceBackward>)

In [48]:
with torch.no_grad():
    w = vae_gumbel_with_pre.weight_creator(test_data[0:2, :])
    subset_indices = sample_subset(w, k=3*z_size, t=0.1).cpu()

In [49]:
# as long as feature index is lesss than 30, then it isn't picking noise
torch.argsort(subset_indices, dim = 1, descending = True)[:, :3 * z_size]

(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 array([ 3, 15, 19, 21, 35, 47, 49, 50, 51, 56, 59,  1,  5,  6,  8, 14, 15,
        19, 30, 38, 49]))

**Joint Train**

In [50]:
joint_vanilla_vae = VAE(2*D, 100, 20)
joint_vanilla_vae.to(device)

joint_vae_gumbel = VAE_Gumbel(2*D, 100, 20, k = 3*z_size)
joint_vae_gumbel.to(device)

joint_optimizer = torch.optim.Adam(list(joint_vanilla_vae.parameters()) + list(joint_vae_gumbel.parameters()), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [51]:
for epoch in range(1, n_epochs + 1):
    train_joint(train_data, joint_vanilla_vae, joint_vae_gumbel, joint_optimizer, epoch, batch_size)
    test_joint(test_data, joint_vanilla_vae, joint_vae_gumbel, epoch, batch_size)

====> Epoch: 1 Average loss: 80.8498
====> Test set loss: 78.6082
====> Epoch: 2 Average loss: 75.8394
====> Test set loss: 75.3927
====> Epoch: 3 Average loss: 71.5442
====> Test set loss: 74.9460
====> Epoch: 4 Average loss: 68.7228
====> Test set loss: 74.6494
====> Epoch: 5 Average loss: 67.4864
====> Test set loss: 75.0293
====> Epoch: 6 Average loss: 66.8799
====> Test set loss: 72.3640
====> Epoch: 7 Average loss: 66.5198
====> Test set loss: 70.8693
====> Epoch: 8 Average loss: 66.3276
====> Test set loss: 70.2999
====> Epoch: 9 Average loss: 66.1376
====> Test set loss: 69.8311
====> Epoch: 10 Average loss: 66.0242
====> Test set loss: 69.7916
====> Epoch: 11 Average loss: 65.9360
====> Test set loss: 69.5051
====> Epoch: 12 Average loss: 65.8537
====> Test set loss: 69.1673
====> Epoch: 13 Average loss: 65.7782
====> Test set loss: 69.1441
====> Epoch: 14 Average loss: 65.6924
====> Test set loss: 69.1444
====> Epoch: 15 Average loss: 65.6371
====> Test set loss: 69.8364
====

In [52]:
with torch.no_grad():
    print("Test Loss")
    print(F.binary_cross_entropy(joint_vae_gumbel(test_data)[0], test_data))

Test Loss
tensor(0.5416, device='cuda:0')


Means

In [53]:
test_data.mean(dim = 0)[:D]

tensor([0.0978, 0.3031, 0.1394, 0.2099, 0.2063, 0.2098, 0.1076, 0.2316, 0.0298,
        0.1968, 0.0261, 0.0561, 0.0848, 0.2727, 0.2320, 0.3355, 0.2739, 0.0056,
        0.0301, 0.0953, 0.0889, 0.2483, 0.0063, 0.0231, 0.0009, 0.0924, 0.4782,
        0.3251, 0.2092, 0.2881], device='cuda:0')

In [54]:
joint_vae_gumbel(test_data)[0].mean(dim = 0)[:D]

tensor([0.1045, 0.3075, 0.1511, 0.2136, 0.2214, 0.2187, 0.1068, 0.2332, 0.0491,
        0.2182, 0.0362, 0.0662, 0.1080, 0.2849, 0.2242, 0.3268, 0.2868, 0.0284,
        0.0435, 0.1078, 0.0941, 0.2522, 0.0245, 0.0395, 0.0227, 0.0989, 0.4723,
        0.3333, 0.2250, 0.3063], device='cuda:0', grad_fn=<SliceBackward>)

Deviations

In [82]:
test_data.std(dim = 0)[:D]

tensor([0.1913, 0.2338, 0.2076, 0.2199, 0.2306, 0.1890, 0.1805, 0.1945, 0.1051,
        0.2321, 0.0976, 0.1389, 0.1652, 0.1922, 0.2341, 0.2330, 0.2048, 0.0390,
        0.0929, 0.1692, 0.1755, 0.2236, 0.0499, 0.0923, 0.0152, 0.1863, 0.1593,
        0.2244, 0.2096, 0.2350], device='cuda:0')

In [81]:
joint_vae_gumbel(test_data)[0].std(dim = 0)[:D]

tensor([0.0422, 0.0386, 0.0399, 0.0387, 0.0384, 0.0383, 0.0428, 0.0385, 0.0335,
        0.0405, 0.0257, 0.0377, 0.0396, 0.0373, 0.0401, 0.0324, 0.0356, 0.0294,
        0.0332, 0.0454, 0.0443, 0.0425, 0.0228, 0.0294, 0.0249, 0.0474, 0.0228,
        0.0286, 0.0389, 0.0340], device='cuda:0', grad_fn=<SliceBackward>)

Values

In [55]:
test_data[samp,:D]

tensor([0.0000, 0.5286, 0.0000, 0.0231, 0.6963, 0.3966, 0.0103, 0.2939, 0.0000,
        0.3152, 0.0000, 0.0000, 0.0000, 0.0617, 0.4150, 0.6262, 0.2110, 0.0000,
        0.0000, 0.0000, 0.0000, 0.6829, 0.0000, 0.0000, 0.0000, 0.0000, 0.3516,
        0.0000, 0.6491, 0.4560], device='cuda:0')

In [56]:
joint_vae_gumbel(test_data)[0][samp, :D]

tensor([0.0954, 0.3105, 0.1620, 0.2101, 0.2351, 0.2171, 0.0862, 0.2378, 0.0265,
        0.2463, 0.0175, 0.0578, 0.0978, 0.3219, 0.2149, 0.3163, 0.3104, 0.0110,
        0.0319, 0.0950, 0.0762, 0.2342, 0.0099, 0.0270, 0.0086, 0.0724, 0.4506,
        0.3724, 0.2340, 0.3332], device='cuda:0', grad_fn=<SliceBackward>)

In [57]:
with torch.no_grad():
    w = joint_vae_gumbel.weight_creator(test_data[0:2, :])
    subset_indices = sample_subset(w, k=3*z_size, t=0.1).cpu()

In [58]:
torch.argsort(subset_indices, dim = 1, descending = True)[:, :3 * z_size]

tensor([[26, 39, 55, 50, 56, 13, 11, 37, 31, 30],
        [ 8, 25, 20, 52, 43, 48, 16, 15, 26, 50]])

Joint Training while selecting exactly z_size. Why does it pick the noise variables?

In [59]:
joint_vanilla_vae = VAE(2*D, 100, 20)
joint_vanilla_vae.to(device)

joint_vae_gumbel = VAE_Gumbel(2*D, 100, 20, k = z_size)
joint_vae_gumbel.to(device)

joint_optimizer = torch.optim.Adam(list(joint_vanilla_vae.parameters()) + list(joint_vae_gumbel.parameters()), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [60]:
for epoch in range(1, n_epochs + 1):
    train_joint(train_data, joint_vanilla_vae, joint_vae_gumbel, joint_optimizer, epoch, batch_size)
    test_joint(test_data, joint_vanilla_vae, joint_vae_gumbel, epoch, batch_size)

====> Epoch: 1 Average loss: 81.5972
====> Test set loss: 83.5908
====> Epoch: 2 Average loss: 76.2796
====> Test set loss: 75.9106
====> Epoch: 3 Average loss: 72.1026
====> Test set loss: 72.9315
====> Epoch: 4 Average loss: 68.9180
====> Test set loss: 72.0876
====> Epoch: 5 Average loss: 67.5563
====> Test set loss: 71.8436
====> Epoch: 6 Average loss: 66.9143
====> Test set loss: 70.9081
====> Epoch: 7 Average loss: 66.5269
====> Test set loss: 70.0472
====> Epoch: 8 Average loss: 66.3147
====> Test set loss: 70.5378
====> Epoch: 9 Average loss: 66.1673
====> Test set loss: 69.6297
====> Epoch: 10 Average loss: 66.0504
====> Test set loss: 69.0215
====> Epoch: 11 Average loss: 65.9466
====> Test set loss: 68.6673
====> Epoch: 12 Average loss: 65.8798
====> Test set loss: 68.7583
====> Epoch: 13 Average loss: 65.8044
====> Test set loss: 68.4594
====> Epoch: 14 Average loss: 65.7616
====> Test set loss: 68.4906
====> Epoch: 15 Average loss: 65.6961
====> Test set loss: 68.7406
====

In [61]:
with torch.no_grad():
    print("Test Loss")
    print(F.binary_cross_entropy(joint_vae_gumbel(test_data)[0], test_data))

Test Loss
tensor(0.5430, device='cuda:0')


Means

In [62]:
test_data.mean(dim = 0)[:D]

tensor([0.0978, 0.3031, 0.1394, 0.2099, 0.2063, 0.2098, 0.1076, 0.2316, 0.0298,
        0.1968, 0.0261, 0.0561, 0.0848, 0.2727, 0.2320, 0.3355, 0.2739, 0.0056,
        0.0301, 0.0953, 0.0889, 0.2483, 0.0063, 0.0231, 0.0009, 0.0924, 0.4782,
        0.3251, 0.2092, 0.2881], device='cuda:0')

In [69]:
joint_vae_gumbel(test_data)[0].mean(dim = 0)[:D]

tensor([0.1043, 0.3059, 0.1469, 0.2165, 0.2183, 0.2168, 0.1126, 0.2374, 0.0493,
        0.2107, 0.0356, 0.0629, 0.1037, 0.2789, 0.2302, 0.3307, 0.2845, 0.0304,
        0.0463, 0.1113, 0.0959, 0.2545, 0.0247, 0.0398, 0.0259, 0.0993, 0.4844,
        0.3294, 0.2177, 0.2941], device='cuda:0', grad_fn=<SliceBackward>)

Deviations

In [83]:
test_data.std(dim = 0)[:D]

tensor([0.1913, 0.2338, 0.2076, 0.2199, 0.2306, 0.1890, 0.1805, 0.1945, 0.1051,
        0.2321, 0.0976, 0.1389, 0.1652, 0.1922, 0.2341, 0.2330, 0.2048, 0.0390,
        0.0929, 0.1692, 0.1755, 0.2236, 0.0499, 0.0923, 0.0152, 0.1863, 0.1593,
        0.2244, 0.2096, 0.2350], device='cuda:0')

In [84]:
joint_vae_gumbel(test_data)[0].std(dim = 0)[:D]

tensor([0.0421, 0.0380, 0.0389, 0.0380, 0.0371, 0.0375, 0.0419, 0.0377, 0.0326,
        0.0397, 0.0250, 0.0369, 0.0385, 0.0363, 0.0385, 0.0311, 0.0349, 0.0291,
        0.0321, 0.0444, 0.0433, 0.0411, 0.0229, 0.0289, 0.0246, 0.0463, 0.0226,
        0.0292, 0.0382, 0.0326], device='cuda:0', grad_fn=<SliceBackward>)

Values

In [64]:
test_data[samp,:D]

tensor([0.0000, 0.5286, 0.0000, 0.0231, 0.6963, 0.3966, 0.0103, 0.2939, 0.0000,
        0.3152, 0.0000, 0.0000, 0.0000, 0.0617, 0.4150, 0.6262, 0.2110, 0.0000,
        0.0000, 0.0000, 0.0000, 0.6829, 0.0000, 0.0000, 0.0000, 0.0000, 0.3516,
        0.0000, 0.6491, 0.4560], device='cuda:0')

In [65]:
joint_vae_gumbel(test_data)[0][samp, :D]

tensor([0.0551, 0.2485, 0.1082, 0.1835, 0.1943, 0.1613, 0.0786, 0.1836, 0.0109,
        0.1750, 0.0067, 0.0196, 0.0446, 0.2755, 0.2051, 0.3167, 0.2483, 0.0037,
        0.0090, 0.0602, 0.0459, 0.2150, 0.0034, 0.0155, 0.0029, 0.0388, 0.4901,
        0.3103, 0.1910, 0.2466], device='cuda:0', grad_fn=<SliceBackward>)

In [66]:
with torch.no_grad():
    w = joint_vae_gumbel.weight_creator(test_data[0:10, :])
    subset_indices = sample_subset(w, k=z_size, t=0.1).cpu()

In [67]:
torch.argsort(subset_indices, dim = 1, descending = True)[:, :z_size]

tensor([[52, 41, 36, 58, 10],
        [29, 35, 12,  4, 28]])

Matching pre trained is actually better here than joint training.
The gumbel trick greatly reduces the ability to make predictions. 
Notice that the standard deviations are not as high as in the original data. Not being able to use a model looking at the full data as an anchor definitely hurts.