In [1]:
import torch


from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F

import numpy as np

import matplotlib.pyplot as plt
#from sklearn.manifold import TSNE

#import math

#import gc

from utils import *

from sklearn.preprocessing import MinMaxScaler

from scipy.stats import pearsonr

import seaborn as sns
import os
import scipy
import scipy.io

In [2]:
cuda = True if torch.cuda.is_available() else False

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

device = torch.device("cuda:0" if cuda else "cpu")
#device = 'cpu'
print("Device")
print(device)

Device
cuda:0


In [3]:
a = scipy.io.loadmat("../data/zeisel/zeisel_data.mat")
data= a['zeisel_data'].T
N,d=data.shape

#load labels (first level of the hierarchy) from file
a = scipy.io.loadmat("../data/zeisel/zeisel_labels1.mat")
l_aux = a['zeisel_labels1']
l_0=[l_aux[i][0] for i in range(l_aux.shape[0])]
#load labels (second level of the hierarchy) from file
a = scipy.io.loadmat("../data/zeisel/zeisel_labels2.mat")
l_aux = a['zeisel_labels2']
l_1=[l_aux[i][0] for i in range(l_aux.shape[0])]
#construct an array with hierarchy labels
labels=np.array([l_0, l_1])

# load names from file 
a = scipy.io.loadmat("../data/zeisel/zeisel_names.mat")
names0=[a['zeisel_names'][i][0][0] for i in range(N)]
names1=[a['zeisel_names'][i][1][0] for i in range(N)]


slices = np.random.permutation(np.arange(data.shape[0]))
upto = int(.8 * len(data))

train_data = data[slices[:upto]]
test_data = data[slices[upto:]]



scaler = MinMaxScaler()
train_data = scaler.fit_transform(train_data)
test_data = scaler.transform(test_data)

train_data = Tensor(train_data).to(device)
test_data = Tensor(test_data).to(device)

In [4]:
N = 10000
z_size = 100

# really good results for vanilla VAE on synthetic data with EPOCHS set to 50, 
# but when running locally set to 10 for reasonable run times
n_epochs = 600
batch_size = 64
lr = 0.0001
b1 = 0.9
b2 = 0.999

global_t = 4
k = 50

In [5]:
def train_model(train_data, model):
    optimizer = torch.optim.Adam(model.parameters(), 
                                 lr=lr, 
                                 betas = (b1,b2))
        
    for epoch in range(1, n_epochs+1):
        train(train_data, 
              model, 
              optimizer, 
              epoch, 
              batch_size)
        model.t = max(0.001, model.t * 0.99)

        
    return model

def save_model(base_path, model):
    # make directory
    if not os.path.exists(os.path.dirname(base_path)):
        try:
            os.makedirs(os.path.dirname(base_path))
        except OSError as exc: # Guard against race condition
            if exc.errno != errno.EEXIST:
                raise Exception("COULD NOT MAKE PATH")
    with open(base_path, 'wb') as PATH:
        torch.save(model.state_dict(), PATH)

In [6]:
def top_logits_gumbel_runningstate_vae(data, model):
    assert isinstance(model, VAE_Gumbel_RunningState)
    with torch.no_grad():
        w = model.logit_enc.clone().view(-1)
        top_k_logits = torch.topk(w, k = model.k, sorted = True)[1]
        enc_top_logits = torch.nn.functional.one_hot(top_k_logits, num_classes = data.shape[1]).sum(dim = 0)
        
        #subsets = sample_subset(w, model.k,model.t,True)
        subsets = sample_subset(w, model.k,model.t)
        #max_idx = torch.argmax(subsets, 1, keepdim=True)
        #one_hot = Tensor(subsets.shape)
        #one_hot.zero_()
        #one_hot.scatter_(1, max_idx, 1)

        
    return enc_top_logits, subsets

def top_logits_gumbel_concrete_vae_nsml(data, model):
    assert isinstance(model, ConcreteVAE_NMSL)
    
    with torch.no_grad():

        w = gumbel_keys(model.logit_enc, EPSILON = torch.finfo(torch.float32).eps)
        w = torch.softmax(w/model.t, dim = -1)
        subset_indices = w.clone().detach()

        #max_idx = torch.argmax(subset_indices, 1, keepdim=True)
        #one_hot = Tensor(subset_indices.shape)
        #one_hot.zero_()
        #one_hot.scatter_(1, max_idx, 1)

        all_subsets = subset_indices.sum(dim = 0)

        inds = torch.argsort(subset_indices.sum(dim = 0), descending = True)[:model.k]
        all_logits = torch.nn.functional.one_hot(inds, num_classes = data.shape[1]).sum(dim = 0)
        
        
        
        
    return all_logits, all_subsets

In [7]:
model = VAE_Gumbel_RunningState(train_data.shape[1], 200, 50, k = k, t = global_t, alpha = 0.9)
model.to(device)
train_model(train_data, model)
model.set_burned_in()

====> Epoch: 1 Average loss: 2719.7863
====> Epoch: 2 Average loss: 2562.4391
====> Epoch: 3 Average loss: 2402.9005
====> Epoch: 4 Average loss: 2254.6724
====> Epoch: 5 Average loss: 2180.2413
====> Epoch: 6 Average loss: 2136.3696
====> Epoch: 7 Average loss: 2119.3101
====> Epoch: 8 Average loss: 2116.3955
====> Epoch: 9 Average loss: 2109.4229
====> Epoch: 10 Average loss: 2104.4947
====> Epoch: 11 Average loss: 2099.3665
====> Epoch: 12 Average loss: 2098.0626
====> Epoch: 13 Average loss: 2093.2165
====> Epoch: 14 Average loss: 2089.3273
====> Epoch: 15 Average loss: 2086.5812
====> Epoch: 16 Average loss: 2082.6283
====> Epoch: 17 Average loss: 2075.4121
====> Epoch: 18 Average loss: 2076.1247
====> Epoch: 19 Average loss: 2072.3223
====> Epoch: 20 Average loss: 2068.0519
====> Epoch: 21 Average loss: 2065.8067
====> Epoch: 22 Average loss: 2061.2673
====> Epoch: 23 Average loss: 2061.3985
====> Epoch: 24 Average loss: 2059.9693
====> Epoch: 25 Average loss: 2058.2188
====> Epo

====> Epoch: 59 Average loss: 2024.0044
====> Epoch: 60 Average loss: 2023.7686
====> Epoch: 61 Average loss: 2022.3812
====> Epoch: 62 Average loss: 2022.2131
====> Epoch: 63 Average loss: 2020.9206
====> Epoch: 64 Average loss: 2018.8638
====> Epoch: 65 Average loss: 2021.4104
====> Epoch: 66 Average loss: 2021.8113
====> Epoch: 67 Average loss: 2019.4620
====> Epoch: 68 Average loss: 2019.7700
====> Epoch: 69 Average loss: 2019.7934
====> Epoch: 70 Average loss: 2017.5489
====> Epoch: 71 Average loss: 2016.7070
====> Epoch: 72 Average loss: 2018.3007
====> Epoch: 73 Average loss: 2017.1173
====> Epoch: 74 Average loss: 2018.7274
====> Epoch: 75 Average loss: 2019.8160
====> Epoch: 76 Average loss: 2019.1308
====> Epoch: 77 Average loss: 2019.9339
====> Epoch: 78 Average loss: 2018.4822
====> Epoch: 79 Average loss: 2016.3845
====> Epoch: 80 Average loss: 2015.3701
====> Epoch: 81 Average loss: 2015.0001
====> Epoch: 82 Average loss: 2016.5938
====> Epoch: 83 Average loss: 2014.2559


====> Epoch: 118 Average loss: 2024.6919
====> Epoch: 119 Average loss: 2030.8778
====> Epoch: 120 Average loss: 2023.0372
====> Epoch: 121 Average loss: 2022.1427
====> Epoch: 122 Average loss: 2028.8640
====> Epoch: 123 Average loss: 2033.1398
====> Epoch: 124 Average loss: 2025.4590
====> Epoch: 125 Average loss: 2026.7152
====> Epoch: 126 Average loss: 2026.5339
====> Epoch: 127 Average loss: 2028.5624
====> Epoch: 128 Average loss: 2028.9531
====> Epoch: 129 Average loss: 2029.7511
====> Epoch: 130 Average loss: 2023.6737
====> Epoch: 131 Average loss: 2022.6096
====> Epoch: 132 Average loss: 2020.3704
====> Epoch: 133 Average loss: 2023.5657
====> Epoch: 134 Average loss: 2019.7366
====> Epoch: 135 Average loss: 2031.9446
====> Epoch: 136 Average loss: 2028.4923
====> Epoch: 137 Average loss: 2024.4576
====> Epoch: 138 Average loss: 2026.4345
====> Epoch: 139 Average loss: 2036.4513
====> Epoch: 140 Average loss: 2033.2672
====> Epoch: 141 Average loss: 2036.7067
====> Epoch: 142

====> Epoch: 175 Average loss: 2023.3863
====> Epoch: 176 Average loss: 2022.9597
====> Epoch: 177 Average loss: 2023.5643
====> Epoch: 178 Average loss: 2022.0133
====> Epoch: 179 Average loss: 2024.3357
====> Epoch: 180 Average loss: 2020.3996
====> Epoch: 181 Average loss: 2021.7470
====> Epoch: 182 Average loss: 2019.9728
====> Epoch: 183 Average loss: 2019.9039
====> Epoch: 184 Average loss: 2021.5862
====> Epoch: 185 Average loss: 2021.6697
====> Epoch: 186 Average loss: 2020.6729
====> Epoch: 187 Average loss: 2020.6787
====> Epoch: 188 Average loss: 2019.9528
====> Epoch: 189 Average loss: 2020.7088
====> Epoch: 190 Average loss: 2022.3504
====> Epoch: 191 Average loss: 2020.5086
====> Epoch: 192 Average loss: 2020.1604
====> Epoch: 193 Average loss: 2018.7598
====> Epoch: 194 Average loss: 2019.4064
====> Epoch: 195 Average loss: 2018.2586
====> Epoch: 196 Average loss: 2019.3989
====> Epoch: 197 Average loss: 2018.9714
====> Epoch: 198 Average loss: 2017.8115
====> Epoch: 199

====> Epoch: 233 Average loss: 1992.3626
====> Epoch: 234 Average loss: 1995.3649
====> Epoch: 235 Average loss: 1996.6368
====> Epoch: 236 Average loss: 1997.7052
====> Epoch: 237 Average loss: 1991.4645
====> Epoch: 238 Average loss: 1987.8746
====> Epoch: 239 Average loss: 1989.0517
====> Epoch: 240 Average loss: 1984.5535
====> Epoch: 241 Average loss: 2010.9621
====> Epoch: 242 Average loss: 1993.8861
====> Epoch: 243 Average loss: 1987.0152
====> Epoch: 244 Average loss: 1992.3891
====> Epoch: 245 Average loss: 1983.4104
====> Epoch: 246 Average loss: 1983.5403
====> Epoch: 247 Average loss: 1997.2021
====> Epoch: 248 Average loss: 1993.0572
====> Epoch: 249 Average loss: 1997.4820
====> Epoch: 250 Average loss: 1996.9337
====> Epoch: 251 Average loss: 1992.9070
====> Epoch: 252 Average loss: 1989.3763
====> Epoch: 253 Average loss: 1989.3293
====> Epoch: 254 Average loss: 1993.2557
====> Epoch: 255 Average loss: 1989.4333
====> Epoch: 256 Average loss: 1993.2791
====> Epoch: 257

====> Epoch: 290 Average loss: 1977.7277
====> Epoch: 291 Average loss: 1991.9609
====> Epoch: 292 Average loss: 2005.1371
====> Epoch: 293 Average loss: 1986.8984
====> Epoch: 294 Average loss: 1983.5934
====> Epoch: 295 Average loss: 1985.8394
====> Epoch: 296 Average loss: 1980.1194
====> Epoch: 297 Average loss: 1974.8710
====> Epoch: 298 Average loss: 1975.8510
====> Epoch: 299 Average loss: 1971.1145
====> Epoch: 300 Average loss: 1968.6953
====> Epoch: 301 Average loss: 1965.7752
====> Epoch: 302 Average loss: 1968.6686
====> Epoch: 303 Average loss: 1978.2070
====> Epoch: 304 Average loss: 1968.4447
====> Epoch: 305 Average loss: 1967.0836
====> Epoch: 306 Average loss: 1975.9365
====> Epoch: 307 Average loss: 1974.6803
====> Epoch: 308 Average loss: 1968.1688
====> Epoch: 309 Average loss: 1974.6180
====> Epoch: 310 Average loss: 1965.6220
====> Epoch: 311 Average loss: 1967.3585
====> Epoch: 312 Average loss: 1972.2200
====> Epoch: 313 Average loss: 1972.8836
====> Epoch: 314

====> Epoch: 348 Average loss: 1942.4646
====> Epoch: 349 Average loss: 1938.7805
====> Epoch: 350 Average loss: 1940.9146
====> Epoch: 351 Average loss: 1938.8266
====> Epoch: 352 Average loss: 1943.1686
====> Epoch: 353 Average loss: 1938.1879
====> Epoch: 354 Average loss: 1942.4045
====> Epoch: 355 Average loss: 1937.0470
====> Epoch: 356 Average loss: 1940.3836
====> Epoch: 357 Average loss: 1932.4209
====> Epoch: 358 Average loss: 1941.1222
====> Epoch: 359 Average loss: 1936.9021
====> Epoch: 360 Average loss: 1936.8637
====> Epoch: 361 Average loss: 1932.0271
====> Epoch: 362 Average loss: 1931.5752
====> Epoch: 363 Average loss: 1934.2504
====> Epoch: 364 Average loss: 1937.5633
====> Epoch: 365 Average loss: 1927.7868
====> Epoch: 366 Average loss: 1930.6416
====> Epoch: 367 Average loss: 1936.0103
====> Epoch: 368 Average loss: 1951.5425
====> Epoch: 369 Average loss: 1937.7041
====> Epoch: 370 Average loss: 1942.4384
====> Epoch: 371 Average loss: 1934.2406
====> Epoch: 372

====> Epoch: 405 Average loss: 1910.5306
====> Epoch: 406 Average loss: 1906.5771
====> Epoch: 407 Average loss: 1908.4568
====> Epoch: 408 Average loss: 1910.9125
====> Epoch: 409 Average loss: 1907.5444
====> Epoch: 410 Average loss: 1910.4051
====> Epoch: 411 Average loss: 1907.8604
====> Epoch: 412 Average loss: 1908.6206
====> Epoch: 413 Average loss: 1905.6071
====> Epoch: 414 Average loss: 1906.3550
====> Epoch: 415 Average loss: 1913.3211
====> Epoch: 416 Average loss: 1910.2592
====> Epoch: 417 Average loss: 1909.1264
====> Epoch: 418 Average loss: 1905.7145
====> Epoch: 419 Average loss: 1904.7619
====> Epoch: 420 Average loss: 1906.8775
====> Epoch: 421 Average loss: 1910.4118
====> Epoch: 422 Average loss: 1910.8417
====> Epoch: 423 Average loss: 1909.9817
====> Epoch: 424 Average loss: 1910.6006
====> Epoch: 425 Average loss: 1909.9075
====> Epoch: 426 Average loss: 1906.7653
====> Epoch: 427 Average loss: 1912.3288
====> Epoch: 428 Average loss: 1908.2351
====> Epoch: 429

====> Epoch: 463 Average loss: 1893.9608
====> Epoch: 464 Average loss: 1895.7253
====> Epoch: 465 Average loss: 1895.2169
====> Epoch: 466 Average loss: 1893.7589
====> Epoch: 467 Average loss: 1895.2960
====> Epoch: 468 Average loss: 1894.5266
====> Epoch: 469 Average loss: 1894.5988
====> Epoch: 470 Average loss: 1893.8565
====> Epoch: 471 Average loss: 1889.9075
====> Epoch: 472 Average loss: 1892.2043
====> Epoch: 473 Average loss: 1892.9188
====> Epoch: 474 Average loss: 1892.0323
====> Epoch: 475 Average loss: 1892.1483
====> Epoch: 476 Average loss: 1889.7404
====> Epoch: 477 Average loss: 1890.3523
====> Epoch: 478 Average loss: 1892.9607
====> Epoch: 479 Average loss: 1892.1804
====> Epoch: 480 Average loss: 1890.7418
====> Epoch: 481 Average loss: 1893.5384
====> Epoch: 482 Average loss: 1888.9050
====> Epoch: 483 Average loss: 1888.9705
====> Epoch: 484 Average loss: 1890.6468
====> Epoch: 485 Average loss: 1888.8957
====> Epoch: 486 Average loss: 1891.6306
====> Epoch: 487

====> Epoch: 520 Average loss: 1886.4293
====> Epoch: 521 Average loss: 1885.6462
====> Epoch: 522 Average loss: 1884.4226
====> Epoch: 523 Average loss: 1886.8171
====> Epoch: 524 Average loss: 1885.1492
====> Epoch: 525 Average loss: 1884.8867
====> Epoch: 526 Average loss: 1883.9219
====> Epoch: 527 Average loss: 1883.8640
====> Epoch: 528 Average loss: 1885.4657
====> Epoch: 529 Average loss: 1884.7080
====> Epoch: 530 Average loss: 1886.6413
====> Epoch: 531 Average loss: 1883.4822
====> Epoch: 532 Average loss: 1883.1099
====> Epoch: 533 Average loss: 1885.3765
====> Epoch: 534 Average loss: 1882.9911
====> Epoch: 535 Average loss: 1883.6280
====> Epoch: 536 Average loss: 1884.0722
====> Epoch: 537 Average loss: 1883.1941
====> Epoch: 538 Average loss: 1882.2811
====> Epoch: 539 Average loss: 1880.3132
====> Epoch: 540 Average loss: 1882.1960
====> Epoch: 541 Average loss: 1881.0973
====> Epoch: 542 Average loss: 1882.0058
====> Epoch: 543 Average loss: 1883.3488
====> Epoch: 544

====> Epoch: 578 Average loss: 1879.1089
====> Epoch: 579 Average loss: 1876.6190
====> Epoch: 580 Average loss: 1878.4924
====> Epoch: 581 Average loss: 1879.7926
====> Epoch: 582 Average loss: 1878.5590
====> Epoch: 583 Average loss: 1878.5802
====> Epoch: 584 Average loss: 1877.2785
====> Epoch: 585 Average loss: 1877.9464
====> Epoch: 586 Average loss: 1876.5743
====> Epoch: 587 Average loss: 1877.8058
====> Epoch: 588 Average loss: 1877.2253
====> Epoch: 589 Average loss: 1878.6132
====> Epoch: 590 Average loss: 1878.4486
====> Epoch: 591 Average loss: 1876.2508
====> Epoch: 592 Average loss: 1876.4000
====> Epoch: 593 Average loss: 1877.6010
====> Epoch: 594 Average loss: 1876.9728
====> Epoch: 595 Average loss: 1877.3129
====> Epoch: 596 Average loss: 1875.7477
====> Epoch: 597 Average loss: 1875.2815
====> Epoch: 598 Average loss: 1875.9209
====> Epoch: 599 Average loss: 1877.5551
====> Epoch: 600 Average loss: 1877.3490


In [8]:
top_logits_running_state = top_logits_gumbel_runningstate_vae(test_data, model)

In [9]:
torch.argsort(top_logits_running_state[0], descending = True)[:k]

tensor([   0,    2,    4,    7,   13,   17,   26,   36,   51,   63,   64,   90,
          92,  104,  106,  133,  138,  146,  150,  154,  173,  213,  245,  317,
         388,  421,  427,  448,  480,  601,  612,  826,  839,  851, 1029, 1079,
        1124, 1563, 1747, 2229, 2392, 2397, 2538, 2746, 2756, 2784, 2855, 2866,
        3119, 3521], device='cuda:0')

In [10]:
inds = torch.argsort(top_logits_running_state[0], descending = True)[:50].cpu().numpy()

In [11]:
len(labels[0])
print("HOW TO GET NAME OF FEATURES?")

HOW TO GET NAME OF FEATURES?


In [12]:
save_model("../data/models/final_run_zeisel/runningstate_vae/k_50/model.pt", model)

In [13]:
model = ConcreteVAE_NMSL(train_data.shape[1], 200, 50, k = k, t = global_t)
model.to(device)
train_model(train_data, model)

====> Epoch: 1 Average loss: 2714.6364
====> Epoch: 2 Average loss: 2549.6524
====> Epoch: 3 Average loss: 2388.7925
====> Epoch: 4 Average loss: 2282.6377
====> Epoch: 5 Average loss: 2222.3463
====> Epoch: 6 Average loss: 2166.5590
====> Epoch: 7 Average loss: 2107.2921
====> Epoch: 8 Average loss: 2059.6841
====> Epoch: 9 Average loss: 2024.6347
====> Epoch: 10 Average loss: 2004.8789
====> Epoch: 11 Average loss: 1994.1026
====> Epoch: 12 Average loss: 1987.4479
====> Epoch: 13 Average loss: 1981.9815
====> Epoch: 14 Average loss: 1975.3556
====> Epoch: 15 Average loss: 1968.1564
====> Epoch: 16 Average loss: 1960.2452
====> Epoch: 17 Average loss: 1955.8200
====> Epoch: 18 Average loss: 1951.3502
====> Epoch: 19 Average loss: 1947.2246
====> Epoch: 20 Average loss: 1944.4444
====> Epoch: 21 Average loss: 1941.8641
====> Epoch: 22 Average loss: 1939.1970
====> Epoch: 23 Average loss: 1936.8234
====> Epoch: 24 Average loss: 1935.0333
====> Epoch: 25 Average loss: 1933.3056
====> Epo

====> Epoch: 60 Average loss: 1921.0144
====> Epoch: 61 Average loss: 1920.9372
====> Epoch: 62 Average loss: 1920.9568
====> Epoch: 63 Average loss: 1920.8284
====> Epoch: 64 Average loss: 1920.7421
====> Epoch: 65 Average loss: 1920.5400
====> Epoch: 66 Average loss: 1920.5512
====> Epoch: 67 Average loss: 1920.5122
====> Epoch: 68 Average loss: 1920.1050
====> Epoch: 69 Average loss: 1920.0133
====> Epoch: 70 Average loss: 1920.1817
====> Epoch: 71 Average loss: 1920.1067
====> Epoch: 72 Average loss: 1919.5454
====> Epoch: 73 Average loss: 1919.7832
====> Epoch: 74 Average loss: 1919.4514
====> Epoch: 75 Average loss: 1919.5024
====> Epoch: 76 Average loss: 1919.2004
====> Epoch: 77 Average loss: 1919.2699
====> Epoch: 78 Average loss: 1918.9586
====> Epoch: 79 Average loss: 1918.8879
====> Epoch: 80 Average loss: 1918.6753
====> Epoch: 81 Average loss: 1918.5988
====> Epoch: 82 Average loss: 1918.5382
====> Epoch: 83 Average loss: 1918.5302
====> Epoch: 84 Average loss: 1918.4410


====> Epoch: 119 Average loss: 1914.1710
====> Epoch: 120 Average loss: 1914.0002
====> Epoch: 121 Average loss: 1913.8724
====> Epoch: 122 Average loss: 1913.6888
====> Epoch: 123 Average loss: 1913.5349
====> Epoch: 124 Average loss: 1913.3521
====> Epoch: 125 Average loss: 1913.6918
====> Epoch: 126 Average loss: 1913.2128
====> Epoch: 127 Average loss: 1912.9095
====> Epoch: 128 Average loss: 1913.0292
====> Epoch: 129 Average loss: 1912.8781
====> Epoch: 130 Average loss: 1912.6843
====> Epoch: 131 Average loss: 1912.5743
====> Epoch: 132 Average loss: 1912.6164
====> Epoch: 133 Average loss: 1912.4997
====> Epoch: 134 Average loss: 1912.3227
====> Epoch: 135 Average loss: 1912.1124
====> Epoch: 136 Average loss: 1911.7905
====> Epoch: 137 Average loss: 1911.4484
====> Epoch: 138 Average loss: 1911.7487
====> Epoch: 139 Average loss: 1911.4617
====> Epoch: 140 Average loss: 1911.4718
====> Epoch: 141 Average loss: 1911.0157
====> Epoch: 142 Average loss: 1910.7755
====> Epoch: 143

====> Epoch: 177 Average loss: 1888.5724
====> Epoch: 178 Average loss: 1887.0717
====> Epoch: 179 Average loss: 1886.4753
====> Epoch: 180 Average loss: 1887.0190
====> Epoch: 181 Average loss: 1885.8686
====> Epoch: 182 Average loss: 1886.1201
====> Epoch: 183 Average loss: 1885.3424
====> Epoch: 184 Average loss: 1884.9834
====> Epoch: 185 Average loss: 1884.6295
====> Epoch: 186 Average loss: 1884.9661
====> Epoch: 187 Average loss: 1884.7420
====> Epoch: 188 Average loss: 1884.4348
====> Epoch: 189 Average loss: 1884.3561
====> Epoch: 190 Average loss: 1883.6886
====> Epoch: 191 Average loss: 1883.4077
====> Epoch: 192 Average loss: 1884.1762
====> Epoch: 193 Average loss: 1883.8996
====> Epoch: 194 Average loss: 1882.9473
====> Epoch: 195 Average loss: 1883.1739
====> Epoch: 196 Average loss: 1882.4847
====> Epoch: 197 Average loss: 1883.0093
====> Epoch: 198 Average loss: 1882.8644
====> Epoch: 199 Average loss: 1882.2652
====> Epoch: 200 Average loss: 1882.0904
====> Epoch: 201

====> Epoch: 235 Average loss: 1881.1621
====> Epoch: 236 Average loss: 1882.1228
====> Epoch: 237 Average loss: 1882.0465
====> Epoch: 238 Average loss: 1882.0293
====> Epoch: 239 Average loss: 1881.7593
====> Epoch: 240 Average loss: 1881.8700
====> Epoch: 241 Average loss: 1882.7124
====> Epoch: 242 Average loss: 1881.8279
====> Epoch: 243 Average loss: 1882.2970
====> Epoch: 244 Average loss: 1882.2607
====> Epoch: 245 Average loss: 1883.6308
====> Epoch: 246 Average loss: 1881.8782
====> Epoch: 247 Average loss: 1883.1702
====> Epoch: 248 Average loss: 1882.9160
====> Epoch: 249 Average loss: 1882.7469
====> Epoch: 250 Average loss: 1882.3638
====> Epoch: 251 Average loss: 1882.3976
====> Epoch: 252 Average loss: 1882.6316
====> Epoch: 253 Average loss: 1882.5255
====> Epoch: 254 Average loss: 1882.7328
====> Epoch: 255 Average loss: 1882.7764
====> Epoch: 256 Average loss: 1883.1831
====> Epoch: 257 Average loss: 1883.5181
====> Epoch: 258 Average loss: 1882.6221
====> Epoch: 259

====> Epoch: 293 Average loss: 1886.4493
====> Epoch: 294 Average loss: 1885.9929
====> Epoch: 295 Average loss: 1886.8064
====> Epoch: 296 Average loss: 1887.4402
====> Epoch: 297 Average loss: 1885.6933
====> Epoch: 298 Average loss: 1886.7734
====> Epoch: 299 Average loss: 1886.4972
====> Epoch: 300 Average loss: 1887.7351
====> Epoch: 301 Average loss: 1886.8946
====> Epoch: 302 Average loss: 1886.9151
====> Epoch: 303 Average loss: 1886.9301
====> Epoch: 304 Average loss: 1886.4590
====> Epoch: 305 Average loss: 1887.1801
====> Epoch: 306 Average loss: 1887.5624
====> Epoch: 307 Average loss: 1887.2737
====> Epoch: 308 Average loss: 1887.4604
====> Epoch: 309 Average loss: 1887.9515
====> Epoch: 310 Average loss: 1887.2081
====> Epoch: 311 Average loss: 1886.3545
====> Epoch: 312 Average loss: 1888.0406
====> Epoch: 313 Average loss: 1887.4949
====> Epoch: 314 Average loss: 1887.0683
====> Epoch: 315 Average loss: 1887.9613
====> Epoch: 316 Average loss: 1887.4852
====> Epoch: 317

====> Epoch: 350 Average loss: 1887.4702
====> Epoch: 351 Average loss: 1889.3213
====> Epoch: 352 Average loss: 1889.4227
====> Epoch: 353 Average loss: 1890.5902
====> Epoch: 354 Average loss: 1889.3134
====> Epoch: 355 Average loss: 1888.0262
====> Epoch: 356 Average loss: 1889.5825
====> Epoch: 357 Average loss: 1887.8062
====> Epoch: 358 Average loss: 1889.0232
====> Epoch: 359 Average loss: 1888.9525
====> Epoch: 360 Average loss: 1889.3091
====> Epoch: 361 Average loss: 1888.6513
====> Epoch: 362 Average loss: 1888.4845
====> Epoch: 363 Average loss: 1891.4788
====> Epoch: 364 Average loss: 1891.2915
====> Epoch: 365 Average loss: 1888.8692
====> Epoch: 366 Average loss: 1890.2269
====> Epoch: 367 Average loss: 1890.2767
====> Epoch: 368 Average loss: 1889.6032
====> Epoch: 369 Average loss: 1890.1352
====> Epoch: 370 Average loss: 1890.1317
====> Epoch: 371 Average loss: 1889.0751
====> Epoch: 372 Average loss: 1888.8442
====> Epoch: 373 Average loss: 1890.9236
====> Epoch: 374

====> Epoch: 408 Average loss: 1889.9648
====> Epoch: 409 Average loss: 1888.4854
====> Epoch: 410 Average loss: 1890.1419
====> Epoch: 411 Average loss: 1890.2389
====> Epoch: 412 Average loss: 1889.6099
====> Epoch: 413 Average loss: 1889.3492
====> Epoch: 414 Average loss: 1890.6933
====> Epoch: 415 Average loss: 1891.4217
====> Epoch: 416 Average loss: 1890.9343
====> Epoch: 417 Average loss: 1889.5356
====> Epoch: 418 Average loss: 1890.8525
====> Epoch: 419 Average loss: 1890.5876
====> Epoch: 420 Average loss: 1889.4447
====> Epoch: 421 Average loss: 1890.6998
====> Epoch: 422 Average loss: 1889.5643
====> Epoch: 423 Average loss: 1890.7089
====> Epoch: 424 Average loss: 1891.5750
====> Epoch: 425 Average loss: 1890.2744
====> Epoch: 426 Average loss: 1891.2411
====> Epoch: 427 Average loss: 1889.8411
====> Epoch: 428 Average loss: 1889.4134
====> Epoch: 429 Average loss: 1890.8174
====> Epoch: 430 Average loss: 1890.5495
====> Epoch: 431 Average loss: 1890.4735
====> Epoch: 432

====> Epoch: 466 Average loss: 1889.2105
====> Epoch: 467 Average loss: 1891.3147
====> Epoch: 468 Average loss: 1890.2687
====> Epoch: 469 Average loss: 1890.2704
====> Epoch: 470 Average loss: 1890.2404
====> Epoch: 471 Average loss: 1888.7441
====> Epoch: 472 Average loss: 1890.3162
====> Epoch: 473 Average loss: 1889.2823
====> Epoch: 474 Average loss: 1892.2157
====> Epoch: 475 Average loss: 1890.0881
====> Epoch: 476 Average loss: 1889.9890
====> Epoch: 477 Average loss: 1890.4596
====> Epoch: 478 Average loss: 1890.0731
====> Epoch: 479 Average loss: 1891.1417
====> Epoch: 480 Average loss: 1889.5395
====> Epoch: 481 Average loss: 1888.9553
====> Epoch: 482 Average loss: 1891.1861
====> Epoch: 483 Average loss: 1889.0171
====> Epoch: 484 Average loss: 1888.9530
====> Epoch: 485 Average loss: 1889.2079
====> Epoch: 486 Average loss: 1889.6632
====> Epoch: 487 Average loss: 1890.6119
====> Epoch: 488 Average loss: 1888.7256
====> Epoch: 489 Average loss: 1889.8198
====> Epoch: 490

====> Epoch: 524 Average loss: 1890.0656
====> Epoch: 525 Average loss: 1887.8578
====> Epoch: 526 Average loss: 1889.2596
====> Epoch: 527 Average loss: 1889.6063
====> Epoch: 528 Average loss: 1890.3909
====> Epoch: 529 Average loss: 1890.7394
====> Epoch: 530 Average loss: 1890.8248
====> Epoch: 531 Average loss: 1890.8084
====> Epoch: 532 Average loss: 1890.5217
====> Epoch: 533 Average loss: 1889.4594
====> Epoch: 534 Average loss: 1890.0006
====> Epoch: 535 Average loss: 1889.7483
====> Epoch: 536 Average loss: 1890.1728
====> Epoch: 537 Average loss: 1889.3981
====> Epoch: 538 Average loss: 1889.9759
====> Epoch: 539 Average loss: 1888.7913
====> Epoch: 540 Average loss: 1888.6244
====> Epoch: 541 Average loss: 1888.9946
====> Epoch: 542 Average loss: 1888.7783
====> Epoch: 543 Average loss: 1889.7300
====> Epoch: 544 Average loss: 1888.7462
====> Epoch: 545 Average loss: 1890.5786
====> Epoch: 546 Average loss: 1890.5773
====> Epoch: 547 Average loss: 1888.9917
====> Epoch: 548

====> Epoch: 581 Average loss: 1890.1977
====> Epoch: 582 Average loss: 1890.7496
====> Epoch: 583 Average loss: 1889.5106
====> Epoch: 584 Average loss: 1890.9907
====> Epoch: 585 Average loss: 1888.4598
====> Epoch: 586 Average loss: 1890.0474
====> Epoch: 587 Average loss: 1888.5036
====> Epoch: 588 Average loss: 1890.3388
====> Epoch: 589 Average loss: 1888.9394
====> Epoch: 590 Average loss: 1887.9638
====> Epoch: 591 Average loss: 1890.4983
====> Epoch: 592 Average loss: 1889.1679
====> Epoch: 593 Average loss: 1889.3311
====> Epoch: 594 Average loss: 1889.2232
====> Epoch: 595 Average loss: 1888.7560
====> Epoch: 596 Average loss: 1889.6751
====> Epoch: 597 Average loss: 1888.1413
====> Epoch: 598 Average loss: 1888.2543
====> Epoch: 599 Average loss: 1887.7388
====> Epoch: 600 Average loss: 1888.4059


ConcreteVAE_NMSL(
  (encoder): Sequential(
    (0): Linear(in_features=50, out_features=200, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
  )
  (enc_mean): Linear(in_features=200, out_features=50, bias=True)
  (enc_logvar): Linear(in_features=200, out_features=50, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=50, out_features=200, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=200, out_features=4000, bias=True)
    (3): Sigmoid()
  )
)

In [14]:
top_logits_concrete = top_logits_gumbel_concrete_vae_nsml(test_data, model)

In [15]:
torch.argsort(top_logits_concrete[0], descending = True)[:k]

tensor([  20,  147,  148,  263,  266,  464,  515,  624,  627,  662,  718,  739,
         805,  855,  917,  945, 1001, 1061, 1155, 1187, 1189, 1205, 1220, 1243,
        1261, 1317, 1418, 1491, 1718, 1790, 1870, 1956, 1961, 2034, 2071, 2098,
        2139, 2187, 2230, 2362, 2547, 2709, 2747, 2805, 2834, 3248, 3333, 3342,
        3586, 3715], device='cuda:0')

In [16]:
save_model("../data/models/final_run_zeisel/concrete_vae/k_50/model.pt", model)