So we know from Notebook 10 that trying to make it learn one set of logits per batch is good and gets features right.

How can we make the logits consistent over batches?

Will collect the gradients and gumbel selected values after 5 epochs in each mode. The two modes will be the behavior before burn-in and the behavior after mode-in.

Will follow Notebook 10 with slight modifications (described right before). Remember in Notebook 10, we explored behaviors when all the features were real vs when half the features were noise.

In [1]:
import torch


from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F

import numpy as np

import matplotlib.pyplot as plt
#from sklearn.manifold import TSNE

#import math

#import gc

from utils import *

from sklearn.preprocessing import MinMaxScaler

from scipy.stats import pearsonr

import seaborn as sns

In [2]:
torch.manual_seed(0)
np.random.seed(0)

In [3]:
# really good results for vanilla VAE on synthetic data with EPOCHS set to 50, 
# but when running locally set to 10 for reasonable run times
n_epochs = 600
batch_size = 64
lr = 0.0001
b1 = 0.9
b2 = 0.999

In [4]:
cuda = True if torch.cuda.is_available() else False

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

device = torch.device("cuda:0" if cuda else "cpu")
#device = 'cpu'
print("Device")
print(device)

Device
cuda:0


In [5]:
D = 30
N = 10000
z_size = 8

In [6]:
torch.manual_seed(0)
np.random.seed(0)

latent_data = np.random.normal(loc=0.0, scale=1.0, size=N*z_size).reshape(N, z_size)

data_mapper = nn.Sequential(
    nn.Linear(z_size, 2 * z_size, bias=False),
    nn.Tanh(),
    nn.Linear(2 * z_size, D, bias = True),
    nn.ReLU()
).to(device)

data_mapper.requires_grad_(False)

latent_data = Tensor(latent_data)
latent_data.requires_grad_(False)

actual_data = data_mapper(latent_data)


for i in range(5):
    print(torch.sum(actual_data[i,:] != 0))

tensor(14, device='cuda:0')
tensor(18, device='cuda:0')
tensor(14, device='cuda:0')
tensor(17, device='cuda:0')
tensor(16, device='cuda:0')


add noiise

In [7]:
noise_features = torch.empty(N * D).normal_(mean=0,std=0.01).reshape(N, D).to(device)
noise_features.requires_grad_(False)

tensor([[ 0.0013,  0.0135,  0.0054,  ..., -0.0047,  0.0033, -0.0097],
        [ 0.0080, -0.0057,  0.0010,  ...,  0.0009, -0.0134,  0.0105],
        [-0.0103, -0.0029,  0.0185,  ..., -0.0133, -0.0037,  0.0134],
        ...,
        [ 0.0073, -0.0149, -0.0108,  ..., -0.0047, -0.0137,  0.0070],
        [ 0.0006, -0.0141, -0.0124,  ..., -0.0085,  0.0069, -0.0110],
        [-0.0159,  0.0177, -0.0087,  ..., -0.0076, -0.0009,  0.0078]],
       device='cuda:0')

In [8]:
actual_data = torch.cat([actual_data, noise_features], dim = 1)

In [9]:
actual_data.shape

torch.Size([10000, 60])

In [10]:
actual_data = actual_data.cpu().numpy()
scaler = MinMaxScaler()
actual_data = scaler.fit_transform(actual_data)

actual_data = Tensor(actual_data)

slices = np.random.permutation(np.arange(actual_data.shape[0]))
upto = int(.8 * len(actual_data))

train_data = actual_data[slices[:upto]]
test_data = actual_data[slices[upto:]]

In [11]:
global_t = 4

Example of what worked before
Vanilla Gumbel

In [12]:
vae_gumbel_truncated = VAE_Gumbel(2*D, 100, 20, k = 3*z_size, t = global_t)
vae_gumbel_truncated.to(device)
vae_gumbel_trunc_optimizer = torch.optim.Adam(vae_gumbel_truncated.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [13]:
gradients_before_burnin = torch.zeros(train_data.shape[1]).to(device)
gradient_post_burn_in = torch.zeros(train_data.shape[1]).to(device)
subset_indices_before_burnin = torch.zeros(train_data.shape[1]).to(device)
subset_indices_post_burnin = torch.zeros(train_data.shape[1]).to(device)

for epoch in range(1, n_epochs+1):
    grads=train_truncated_with_gradients(train_data, vae_gumbel_truncated, 
                                                      vae_gumbel_trunc_optimizer, 
                                                      epoch, 
                                                      batch_size, 
                                                      Dim = 2*D)
    
    vae_gumbel_truncated.t = max(0.001, vae_gumbel_truncated.t * 0.99)
    if epoch <=(n_epochs//5*4):
        gradients_before_burnin += grads
        with torch.no_grad():
            subset_indices_before_burnin += vae_gumbel_truncated.subset_indices.sum(dim = 0)
    if epoch > (n_epochs//5*4):
        gradient_post_burn_in += grads
        with torch.no_grad():
            subset_indices_post_burnin += vae_gumbel_truncated.subset_indices.sum(dim = 0)

====> Epoch: 1 Average loss: 40.2651
====> Epoch: 2 Average loss: 36.4149
====> Epoch: 3 Average loss: 33.8501
====> Epoch: 4 Average loss: 33.2810
====> Epoch: 5 Average loss: 33.0638
====> Epoch: 6 Average loss: 32.9366
====> Epoch: 7 Average loss: 32.8513
====> Epoch: 8 Average loss: 32.7571
====> Epoch: 9 Average loss: 32.6861
====> Epoch: 10 Average loss: 32.6218
====> Epoch: 11 Average loss: 32.5453
====> Epoch: 12 Average loss: 32.4585
====> Epoch: 13 Average loss: 32.3123
====> Epoch: 14 Average loss: 32.1066
====> Epoch: 15 Average loss: 31.9131
====> Epoch: 16 Average loss: 31.7394
====> Epoch: 17 Average loss: 31.5845
====> Epoch: 18 Average loss: 31.4851
====> Epoch: 19 Average loss: 31.4033
====> Epoch: 20 Average loss: 31.3518
====> Epoch: 21 Average loss: 31.3093


====> Epoch: 22 Average loss: 31.2534
====> Epoch: 23 Average loss: 31.1998
====> Epoch: 24 Average loss: 31.1411
====> Epoch: 25 Average loss: 31.0923
====> Epoch: 26 Average loss: 31.0160
====> Epoch: 27 Average loss: 30.9540
====> Epoch: 28 Average loss: 30.8884
====> Epoch: 29 Average loss: 30.8450
====> Epoch: 30 Average loss: 30.7908
====> Epoch: 31 Average loss: 30.7684
====> Epoch: 32 Average loss: 30.7387
====> Epoch: 33 Average loss: 30.7075
====> Epoch: 34 Average loss: 30.6920
====> Epoch: 35 Average loss: 30.6579
====> Epoch: 36 Average loss: 30.6577
====> Epoch: 37 Average loss: 30.6316
====> Epoch: 38 Average loss: 30.6118
====> Epoch: 39 Average loss: 30.5835
====> Epoch: 40 Average loss: 30.5747
====> Epoch: 41 Average loss: 30.5585
====> Epoch: 42 Average loss: 30.5314
====> Epoch: 43 Average loss: 30.5203


====> Epoch: 44 Average loss: 30.5022
====> Epoch: 45 Average loss: 30.4841
====> Epoch: 46 Average loss: 30.4663
====> Epoch: 47 Average loss: 30.4575
====> Epoch: 48 Average loss: 30.4404
====> Epoch: 49 Average loss: 30.4250
====> Epoch: 50 Average loss: 30.4042
====> Epoch: 51 Average loss: 30.3971
====> Epoch: 52 Average loss: 30.3863
====> Epoch: 53 Average loss: 30.3664
====> Epoch: 54 Average loss: 30.3530
====> Epoch: 55 Average loss: 30.3444
====> Epoch: 56 Average loss: 30.3375
====> Epoch: 57 Average loss: 30.3284
====> Epoch: 58 Average loss: 30.3196
====> Epoch: 59 Average loss: 30.3147
====> Epoch: 60 Average loss: 30.3014
====> Epoch: 61 Average loss: 30.2949
====> Epoch: 62 Average loss: 30.2913
====> Epoch: 63 Average loss: 30.2834
====> Epoch: 64 Average loss: 30.2757


====> Epoch: 65 Average loss: 30.2654
====> Epoch: 66 Average loss: 30.2637
====> Epoch: 67 Average loss: 30.2571
====> Epoch: 68 Average loss: 30.2518
====> Epoch: 69 Average loss: 30.2418
====> Epoch: 70 Average loss: 30.2352
====> Epoch: 71 Average loss: 30.2283
====> Epoch: 72 Average loss: 30.2183
====> Epoch: 73 Average loss: 30.2116
====> Epoch: 74 Average loss: 30.1944
====> Epoch: 75 Average loss: 30.1899
====> Epoch: 76 Average loss: 30.1843
====> Epoch: 77 Average loss: 30.1684
====> Epoch: 78 Average loss: 30.1633
====> Epoch: 79 Average loss: 30.1601
====> Epoch: 80 Average loss: 30.1569
====> Epoch: 81 Average loss: 30.1455
====> Epoch: 82 Average loss: 30.1371
====> Epoch: 83 Average loss: 30.1439
====> Epoch: 84 Average loss: 30.1341
====> Epoch: 85 Average loss: 30.1279
====> Epoch: 86 Average loss: 30.1200


====> Epoch: 87 Average loss: 30.1181
====> Epoch: 88 Average loss: 30.1153
====> Epoch: 89 Average loss: 30.1075
====> Epoch: 90 Average loss: 30.1034
====> Epoch: 91 Average loss: 30.0930
====> Epoch: 92 Average loss: 30.0952
====> Epoch: 93 Average loss: 30.0980
====> Epoch: 94 Average loss: 30.0796
====> Epoch: 95 Average loss: 30.0802
====> Epoch: 96 Average loss: 30.0688
====> Epoch: 97 Average loss: 30.0652
====> Epoch: 98 Average loss: 30.0658
====> Epoch: 99 Average loss: 30.0579
====> Epoch: 100 Average loss: 30.0537
====> Epoch: 101 Average loss: 30.0427
====> Epoch: 102 Average loss: 30.0443
====> Epoch: 103 Average loss: 30.0378
====> Epoch: 104 Average loss: 30.0301
====> Epoch: 105 Average loss: 30.0257
====> Epoch: 106 Average loss: 30.0231
====> Epoch: 107 Average loss: 30.0177


====> Epoch: 108 Average loss: 30.0199
====> Epoch: 109 Average loss: 30.0074
====> Epoch: 110 Average loss: 30.0055
====> Epoch: 111 Average loss: 30.0071
====> Epoch: 112 Average loss: 29.9978
====> Epoch: 113 Average loss: 30.0045
====> Epoch: 114 Average loss: 30.0021
====> Epoch: 115 Average loss: 29.9932
====> Epoch: 116 Average loss: 29.9942
====> Epoch: 117 Average loss: 29.9924
====> Epoch: 118 Average loss: 29.9892
====> Epoch: 119 Average loss: 29.9854
====> Epoch: 120 Average loss: 29.9870
====> Epoch: 121 Average loss: 29.9905
====> Epoch: 122 Average loss: 29.9824
====> Epoch: 123 Average loss: 29.9842
====> Epoch: 124 Average loss: 29.9793
====> Epoch: 125 Average loss: 29.9784
====> Epoch: 126 Average loss: 29.9763
====> Epoch: 127 Average loss: 29.9796
====> Epoch: 128 Average loss: 29.9794


====> Epoch: 129 Average loss: 29.9723
====> Epoch: 130 Average loss: 29.9691
====> Epoch: 131 Average loss: 29.9707
====> Epoch: 132 Average loss: 29.9651
====> Epoch: 133 Average loss: 29.9616
====> Epoch: 134 Average loss: 29.9645
====> Epoch: 135 Average loss: 29.9590
====> Epoch: 136 Average loss: 29.9606
====> Epoch: 137 Average loss: 29.9561
====> Epoch: 138 Average loss: 29.9548
====> Epoch: 139 Average loss: 29.9564
====> Epoch: 140 Average loss: 29.9572
====> Epoch: 141 Average loss: 29.9559
====> Epoch: 142 Average loss: 29.9542
====> Epoch: 143 Average loss: 29.9503
====> Epoch: 144 Average loss: 29.9454
====> Epoch: 145 Average loss: 29.9502
====> Epoch: 146 Average loss: 29.9500
====> Epoch: 147 Average loss: 29.9389
====> Epoch: 148 Average loss: 29.9412
====> Epoch: 149 Average loss: 29.9349


====> Epoch: 150 Average loss: 29.9383
====> Epoch: 151 Average loss: 29.9426
====> Epoch: 152 Average loss: 29.9381
====> Epoch: 153 Average loss: 29.9365
====> Epoch: 154 Average loss: 29.9476
====> Epoch: 155 Average loss: 29.9429
====> Epoch: 156 Average loss: 29.9336
====> Epoch: 157 Average loss: 29.9321
====> Epoch: 158 Average loss: 29.9348
====> Epoch: 159 Average loss: 29.9296
====> Epoch: 160 Average loss: 29.9322
====> Epoch: 161 Average loss: 29.9336
====> Epoch: 162 Average loss: 29.9313
====> Epoch: 163 Average loss: 29.9239
====> Epoch: 164 Average loss: 29.9307
====> Epoch: 165 Average loss: 29.9240
====> Epoch: 166 Average loss: 29.9253
====> Epoch: 167 Average loss: 29.9242
====> Epoch: 168 Average loss: 29.9268
====> Epoch: 169 Average loss: 29.9207
====> Epoch: 170 Average loss: 29.9188


====> Epoch: 171 Average loss: 29.9267
====> Epoch: 172 Average loss: 29.9230
====> Epoch: 173 Average loss: 29.9220
====> Epoch: 174 Average loss: 29.9230
====> Epoch: 175 Average loss: 29.9173
====> Epoch: 176 Average loss: 29.9172
====> Epoch: 177 Average loss: 29.9208
====> Epoch: 178 Average loss: 29.9151
====> Epoch: 179 Average loss: 29.9158
====> Epoch: 180 Average loss: 29.9137
====> Epoch: 181 Average loss: 29.9146
====> Epoch: 182 Average loss: 29.9150
====> Epoch: 183 Average loss: 29.9101
====> Epoch: 184 Average loss: 29.9106
====> Epoch: 185 Average loss: 29.9094
====> Epoch: 186 Average loss: 29.9090
====> Epoch: 187 Average loss: 29.9178
====> Epoch: 188 Average loss: 29.9133
====> Epoch: 189 Average loss: 29.9083
====> Epoch: 190 Average loss: 29.9103
====> Epoch: 191 Average loss: 29.9082


====> Epoch: 192 Average loss: 29.9089
====> Epoch: 193 Average loss: 29.9097
====> Epoch: 194 Average loss: 29.9047
====> Epoch: 195 Average loss: 29.9094
====> Epoch: 196 Average loss: 29.9056
====> Epoch: 197 Average loss: 29.9052
====> Epoch: 198 Average loss: 29.9051
====> Epoch: 199 Average loss: 29.8995
====> Epoch: 200 Average loss: 29.9051
====> Epoch: 201 Average loss: 29.9035
====> Epoch: 202 Average loss: 29.9012
====> Epoch: 203 Average loss: 29.9051
====> Epoch: 204 Average loss: 29.8985
====> Epoch: 205 Average loss: 29.9003
====> Epoch: 206 Average loss: 29.9015
====> Epoch: 207 Average loss: 29.9029
====> Epoch: 208 Average loss: 29.9030
====> Epoch: 209 Average loss: 29.8939
====> Epoch: 210 Average loss: 29.9005
====> Epoch: 211 Average loss: 29.8990
====> Epoch: 212 Average loss: 29.8978


====> Epoch: 213 Average loss: 29.8991
====> Epoch: 214 Average loss: 29.9021
====> Epoch: 215 Average loss: 29.8905
====> Epoch: 216 Average loss: 29.8975
====> Epoch: 217 Average loss: 29.9000
====> Epoch: 218 Average loss: 29.8912
====> Epoch: 219 Average loss: 29.8905
====> Epoch: 220 Average loss: 29.8936
====> Epoch: 221 Average loss: 29.8924
====> Epoch: 222 Average loss: 29.8933
====> Epoch: 223 Average loss: 29.8942
====> Epoch: 224 Average loss: 29.8936
====> Epoch: 225 Average loss: 29.8911
====> Epoch: 226 Average loss: 29.8963
====> Epoch: 227 Average loss: 29.8905
====> Epoch: 228 Average loss: 29.8873
====> Epoch: 229 Average loss: 29.8904
====> Epoch: 230 Average loss: 29.8891
====> Epoch: 231 Average loss: 29.8897
====> Epoch: 232 Average loss: 29.8890
====> Epoch: 233 Average loss: 29.8892


====> Epoch: 234 Average loss: 29.8886
====> Epoch: 235 Average loss: 29.8867
====> Epoch: 236 Average loss: 29.8956
====> Epoch: 237 Average loss: 29.8872
====> Epoch: 238 Average loss: 29.8886


KeyboardInterrupt: 

In [None]:
print(gradients_before_burnin[:D].sum())
print(gradients_before_burnin[D:].sum())
sns.heatmap(gradients_before_burnin.clone().detach().cpu().numpy()[:, np.newaxis])


In [None]:
print(gradient_post_burn_in[:D].sum())
print(gradient_post_burn_in[D:].sum())
sns.heatmap(gradient_post_burn_in.clone().detach().cpu().numpy()[:, np.newaxis])

In [None]:
print(subset_indices_before_burnin[:D].sum())
print(subset_indices_before_burnin[D:].sum())
sns.heatmap(subset_indices_before_burnin.clone().detach().cpu().numpy()[:, np.newaxis])

In [None]:
print(subset_indices_post_burnin[:D].sum())
print(subset_indices_post_burnin[(D):].sum())
sns.heatmap(subset_indices_post_burnin.clone().detach().cpu().numpy()[:, np.newaxis])

In [None]:
subset_indices_post_burnin[:(D)].sum() - subset_indices_post_burnin[(D):].sum()

**VAE_Gumbel_NInsta test here**

In [None]:
vae_gumbel_truncated = VAE_Gumbel_NInsta(2*D, 100, 20, k = 3*z_size, t = global_t)
vae_gumbel_truncated.to(device)
vae_gumbel_trunc_optimizer = torch.optim.Adam(vae_gumbel_truncated.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))



In [None]:
gradients_before_burnin = torch.zeros(train_data.shape[1]).to(device)
gradient_post_burn_in = torch.zeros(train_data.shape[1]).to(device)
subset_indices_before_burnin = torch.zeros(train_data.shape[1]).to(device)
subset_indices_post_burnin = torch.zeros(train_data.shape[1]).to(device)

for epoch in range(1, n_epochs+1):
    grads=train_truncated_with_gradients(train_data, vae_gumbel_truncated, 
                                                      vae_gumbel_trunc_optimizer, 
                                                      epoch, 
                                                      batch_size, 
                                                      Dim = 2*D)
    
    vae_gumbel_truncated.t = max(0.001, vae_gumbel_truncated.t * 0.99)
    if epoch <=(n_epochs//5*4):
        gradients_before_burnin += grads
        with torch.no_grad():
            subset_indices_before_burnin += vae_gumbel_truncated.subset_indices.sum(dim = 0)
    if epoch > (n_epochs//5*4):
        gradient_post_burn_in += grads
        with torch.no_grad():
            subset_indices_post_burnin += vae_gumbel_truncated.subset_indices.sum(dim = 0)

In [None]:
print(gradients_before_burnin[:D].sum())
print(gradients_before_burnin[D:].sum())
sns.heatmap(gradients_before_burnin.clone().detach().cpu().numpy()[:, np.newaxis])


In [None]:
print(gradient_post_burn_in[:D].sum())
print(gradient_post_burn_in[D:].sum())
sns.heatmap(gradient_post_burn_in.clone().detach().cpu().numpy()[:, np.newaxis])

In [None]:
print(subset_indices_before_burnin[:D].sum())
print(subset_indices_before_burnin[D:].sum())
sns.heatmap(subset_indices_before_burnin.clone().detach().cpu().numpy()[:, np.newaxis])

In [None]:
print(subset_indices_post_burnin[:D].sum())
print(subset_indices_post_burnin[(D):].sum())
sns.heatmap(subset_indices_post_burnin.clone().detach().cpu().numpy()[:, np.newaxis])

In [None]:
subset_indices_post_burnin[:(D)].sum() - subset_indices_post_burnin[(D):].sum()

The new model


In [None]:
vae_gumbel_truncated = VAE_Gumbel_NInstaState(2*D, 100, 20, k = 3*z_size, t = global_t)
vae_gumbel_truncated.to(device)
vae_gumbel_trunc_optimizer = torch.optim.Adam(vae_gumbel_truncated.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))



In [None]:
gradients_before_burnin = torch.zeros(train_data.shape[1]).to(device)
gradient_post_burn_in = torch.zeros(train_data.shape[1]).to(device)
subset_indices_before_burnin = torch.zeros(train_data.shape[1]).to(device)
subset_indices_post_burnin = torch.zeros(train_data.shape[1]).to(device)

for epoch in range(1, n_epochs+1):
    grads=train_truncated_with_gradients_gumbel_state(train_data, vae_gumbel_truncated, 
                                                      vae_gumbel_trunc_optimizer, 
                                                      epoch, 
                                                      batch_size, 
                                                      Dim = 2*D)
    
    vae_gumbel_truncated.t = max(0.001, vae_gumbel_truncated.t * 0.99)
    if epoch <=(n_epochs//5*4):
        gradients_before_burnin += grads
        with torch.no_grad():
            subset_indices_before_burnin += sample_subset(vae_gumbel_truncated.logit_enc, 
                                                          vae_gumbel_truncated.k, 
                                                          vae_gumbel_truncated.t).view(-1)
    if epoch == (n_epochs//5*4):
        print("BURN IN DEBUG")
        vae_gumbel_truncated.set_burned_in()
        #vae_gumbel_truncated.t /= 10
        print("Going post burn in")
    if epoch > (n_epochs//5*4):
        gradient_post_burn_in += grads
        with torch.no_grad():
            subset_indices_post_burnin += sample_subset(vae_gumbel_truncated.logit_enc, 
                                                        vae_gumbel_truncated.k, 
                                                        vae_gumbel_truncated.t).view(-1)

In [None]:
print(gradients_before_burnin[:(D)].mean())
print(gradients_before_burnin[(D):].mean())
sns.heatmap(gradients_before_burnin.clone().detach().cpu().numpy()[:, np.newaxis])


In [None]:
print(gradient_post_burn_in[:D].mean())
print(gradient_post_burn_in[D:].mean())
sns.heatmap(gradient_post_burn_in.clone().detach().cpu().numpy()[:, np.newaxis])

In [None]:
print(subset_indices_before_burnin[:D].sum())
print(subset_indices_before_burnin[D:].sum())
sns.heatmap(subset_indices_before_burnin.clone().detach().cpu().numpy()[:, np.newaxis])

In [None]:
print(subset_indices_post_burnin[:D].sum())
print(subset_indices_post_burnin[D:].sum())
sns.heatmap(subset_indices_post_burnin.clone().detach().cpu().numpy()[:, np.newaxis])

In [None]:
(subset_indices_post_burnin[:(D)] > 0).sum()

In [None]:
(subset_indices_post_burnin[(D):] > 0).sum()

In [None]:
subset_indices_post_burnin[:(D)]

In [None]:
subset_indices_post_burnin[(D):]

In [None]:
subset_indices_post_burnin[:(D)].sum() - subset_indices_post_burnin[(D):].sum()

In [None]:
top_ind = torch.argsort(sample_subset(vae_gumbel_truncated.logit_enc, 
                                                        vae_gumbel_truncated.k, 
                                                        vae_gumbel_truncated.t).view(-1), 
                        descending = True)[:vae_gumbel_truncated.k]

In [None]:
sum(top_ind < 30)

In [None]:
sum(top_ind >= 30)