In [1]:
import torch


from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F

import numpy as np

import matplotlib.pyplot as plt
#from sklearn.manifold import TSNE

#import math

#import gc

from utils import *

from sklearn.preprocessing import MinMaxScaler

from scipy.stats import pearsonr

import seaborn as sns
import os

In [2]:
torch.manual_seed(0)
np.random.seed(0)

In [3]:
cuda = True if torch.cuda.is_available() else False

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

device = torch.device("cuda:0" if cuda else "cpu")
#device = 'cpu'
print("Device")
print(device)

Device
cuda:0


In [4]:
D = 30
N = 10000
z_size = 8

# really good results for vanilla VAE on synthetic data with EPOCHS set to 50, 
# but when running locally set to 10 for reasonable run times
n_epochs = 600
batch_size = 64
lr = 0.0001
b1 = 0.9
b2 = 0.999

global_t = 0.099
k_lab = [D//10, D//6, D//3, D//2, D]
trial_num = 5

In [5]:
train_data, test_data = generate_synthetic_data_with_noise(N, z_size, D)

In [6]:
BASE_PATH = "../data/models/final_run/"
# BASE_PATH = '/scratch/ns3429/sparse-subset/data/models/final_run/'

In [7]:
def load_trial_model(trial_path, k):
    if ('vanilla_vae_gumbel' in trial_path):
        model = VAE_Gumbel(2*D, 100, 20, k = k, t = global_t) 
    if ('batching_gumbel_vae' in trial_path):
        model = VAE_Gumbel_NInsta(2*D, 100, 20, k = k, t = global_t)
    if ('globalgate_vae' in trial_path):
        model = VAE_Gumbel_GlobalGate(2*D, 100, 20, k = k, t = global_t)
    if ('runningstate_vae' in trial_path):
        model = VAE_Gumbel_RunningState(2*D, 100, 20, k = k, t = global_t, alpha = 0.9)
    if ('concrete_vae_nmsl' in trial_path):
        model = ConcreteVAE_NMSL(2*D, 100, 20, k = k, t = global_t)
    
    model.load_state_dict(torch.load(trial_path))
    model.to(device)
    model.eval()
    
    return model

In [8]:
def obtain_model_metrics(model):
    pass

In [9]:
# test model
path = '../data/models/final_run/vanilla_vae_gumbel/k_15/model_trial_3.pt'
model = VAE_Gumbel(2*D, 100, 20, k = 15, t = global_t) 
model.load_state_dict(torch.load(path))
model.to(device)

VAE_Gumbel(
  (encoder): Sequential(
    (0): Linear(in_features=60, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
  )
  (enc_mean): Linear(in_features=100, out_features=20, bias=True)
  (enc_logvar): Linear(in_features=100, out_features=20, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=20, out_features=100, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=100, out_features=60, bias=True)
    (3): Sigmoid()
  )
  (weight_creator): Sequential(
    (0): Linear(in_features=60, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=60, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
  )
)

BCE

In [10]:
test_df = test_data[0:64, :]
with torch.no_grad():
    pred_df = model(test_df)[0]

In [11]:
test_df.mean(dim = 0)

tensor([0.0464, 0.1195, 0.0845, 0.0104, 0.2252, 0.1259, 0.1879, 0.2976, 0.1553,
        0.1303, 0.3214, 0.1200, 0.0780, 0.2892, 0.1563, 0.0764, 0.1087, 0.0170,
        0.3132, 0.1194, 0.0414, 0.1190, 0.1890, 0.1809, 0.1659, 0.1510, 0.1903,
        0.1344, 0.0871, 0.0294, 0.4809, 0.4847, 0.4540, 0.4661, 0.4507, 0.5386,
        0.5251, 0.4553, 0.4645, 0.5043, 0.5350, 0.4467, 0.4940, 0.4733, 0.5062,
        0.4916, 0.4410, 0.4825, 0.4733, 0.4785, 0.5255, 0.4801, 0.5315, 0.5335,
        0.5443, 0.4986, 0.5396, 0.4663, 0.4663, 0.4903], device='cuda:0')

In [12]:
pred_df.mean(dim=0)

tensor([0.0605, 0.1226, 0.1053, 0.0253, 0.2312, 0.1258, 0.2084, 0.3063, 0.1771,
        0.1592, 0.3263, 0.1122, 0.0983, 0.2907, 0.1670, 0.0844, 0.1208, 0.0189,
        0.3198, 0.1345, 0.0441, 0.1405, 0.2189, 0.1823, 0.1709, 0.1352, 0.1966,
        0.1574, 0.0859, 0.0503, 0.4701, 0.5047, 0.4484, 0.4648, 0.4346, 0.5511,
        0.5150, 0.4591, 0.4863, 0.5043, 0.5335, 0.4692, 0.5127, 0.4659, 0.5057,
        0.5039, 0.4515, 0.5015, 0.4748, 0.4784, 0.5431, 0.4873, 0.5155, 0.5153,
        0.5320, 0.5046, 0.5208, 0.4845, 0.4921, 0.5073], device='cuda:0')

In [13]:
F.binary_cross_entropy(pred_df, test_df, reduction = 'sum')

tensor(1876.1711, device='cuda:0')

In [14]:
def bce_model(data, model):
    test_loss = 0
    inds = np.arange(test_data.shape[0])
    with torch.no_grad():
        for i in range(math.ceil(len(data)/batch_size)):
            batch_ind = inds[i * batch_size : (i+1) * batch_size]
            batch_data = data[batch_ind, :]

            test_pred = model(batch_data)[0]


            test_loss += F.binary_cross_entropy(test_pred, batch_data, reduction='sum')

            del batch_data
    return test_loss / data.shape[0]

In [15]:
bce_model(test_data, model)

tensor(29.3722, device='cuda:0')

In [16]:
def top_logits_gumbel_vanilla_vae_gumbel(data, model):
    all_logits = torch.zeros(data.shape[1], dtype = torch.float32)
    all_subsets = torch.zeros(data.shape[1], dtype = torch.float32)
    torch.tensor
    with torch.no_grad():
        for i in range(math.ceil(len(data)/batch_size)):
            batch_ind = inds[i * batch_size : (i+1) * batch_size]
            batch_data = data[batch_ind, :]

            test_pred = model(batch_data)[0]
            
            logits = model.weight_creator(batch_data)
            subsets = sample_subset(logits, model.k, model.t, separate = True)
            
            topk_logits = torch.topk(logits, k = model.k, dim = 1, sorted = True)[1]
            one_hotted_top_k = torch.nn.functional.one_hot(top_k, num_classes = data.shape[1]).sum(dim = 1)
            
            max_idx = torch.argmax(subsets, 2, keepdim=True)
            one_hot = Tensor(subsets.shape)
            one_hot.zero_()
            one_hot.scatter_(2, max_idx, 1)
    
    
            all_logits += one_hotted_top_k.sum(dim = 0)
            all_subsets += one_hot.sum(dim = (0, 1))
            
            
    
        all_logits /= train_data.shape[0]
        all_subsets /= train_data.shape[0]
    
    return all_logits, all_subsets

In [None]:
def top_logits_gumbel_batching_vae_gumbel(data, model):
    