Try out the Xie Ermon paper of Continuous relaxation for subset selection

In [1]:
import torch
from torch.utils.data import DataLoader

from torchvision import datasets
import torchvision.transforms as transforms


from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F

import numpy as np

from torchvision.utils import save_image

import matplotlib.pyplot as plt

import math

In [2]:
import os
from os import listdir

In [3]:
BASE_PATH_DATA = '../data/'

In [4]:
n_epochs = 5
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
img_size = 28
channels = 1

log_interval = 100


z_size = 20

n = 28 * 28

# from running
# EPSILON = np.finfo(tf.float32.as_numpy_dtype).tiny
#EPSILON = 1.1754944e-38
EPSILON = 1e-10

In [5]:
cuda = True if torch.cuda.is_available() else False

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

device = torch.device("cuda:0" if cuda else "cpu")
print(cuda)

False


In [6]:
import scipy.io as sio

In [7]:
a = sio.loadmat("../data/zeisel/CITEseq.mat")
data= a['G'].T
N,d=data.shape
#transformation from integer entries 
data=np.log(data+np.ones(data.shape))
for i in range(N):
    data[i,:]=data[i,:]/np.linalg.norm(data[i,:])

#load labels from file
a = sio.loadmat("../data/zeisel/CITEseq-labels.mat")
l_aux = a['labels']
labels = np.array([i for [i] in l_aux])

#load names from file
a = sio.loadmat("../data/zeisel/CITEseq_names.mat")
names=[a['citeseq_names'][i][0][0] for i in range(N)]

In [8]:
slices = np.random.permutation(np.arange(data.shape[0]))
upto = int(.8 * len(data))

train_data = data[slices[:upto]]
test_data = data[slices[upto:]]

train_data = Tensor(train_data).to(device)
test_data = Tensor(test_data).to(device)

In [9]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [10]:
def gumbel_keys(w):
    # sample some gumbels
    uniform = (1.0 - EPSILON) * torch.rand_like(w) + EPSILON
    z = torch.log(-torch.log(uniform))
    w = w + z
    return w


def continuous_topk(w, k, t, separate=False):
    softmax = nn.Softmax(dim = -1)
    khot_list = []
    onehot_approx = torch.zeros_like(w, dtype = torch.float32)
    for i in range(k):
        ### conver the following into pytorch
        #khot_mask = tf.maximum(1.0 - onehot_approx, EPSILON)
        max_mask = 1 - onehot_approx < EPSILON
        khot_mask = 1 - onehot_approx
        khot_mask[max_mask] = EPSILON
        
        w += torch.log(khot_mask)
        #onehot_approx = tf.nn.softmax(w / t, axis=-1)
        onehot_approx = softmax(w/t)
        khot_list.append(onehot_approx)
    if separate:
        return torch.stack(khot_list)
    else:
        return torch.sum(torch.stack(khot_list), dim = 0) 


def sample_subset(w, k, t=0.1):
    '''
    Args:
        w (Tensor): Float Tensor of weights for each element. In gumbel mode
            these are interpreted as log probabilities
        k (int): number of elements in the subset sample
        t (float): temperature of the softmax
    '''
    w = gumbel_keys(w)
    return continuous_topk(w, k, t)

In [16]:
# L1 VAE model we are loading
class VAE_Gumbel(nn.Module):
    def __init__(self, input_size, hidden_layer_size, z_size, k, t = 0.1):
        super(VAE_Gumbel, self).__init__()
        
        self.k = k
        self.t = t
        
        self.weight_creator = nn.Sequential(
            nn.Linear(input_size, hidden_layer_size),
            nn.ReLU(),
            nn.Linear(hidden_layer_size, input_size)
        )
        
        self.fc1 = nn.Linear(input_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, input_size)

    def encode(self, x):
        w = self.weight_creator(x)
        subset_indices = sample_subset(w, self.k, self.t)
        x = x * subset_indices
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [17]:
def train(df, model, optimizer, epoch):
    model.train()
    train_loss = 0
    permutations = torch.randperm(df.shape[0])
    for i in range(math.ceil(len(df)/batch_size)):
        batch_ind = permutations[i * batch_size : (i+1) * batch_size]
        batch_data = df[batch_ind, :]
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(batch_data)
        loss = loss_function(recon_batch, batch_data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(batch_data), len(df),
                100. * i / len(df),
                loss.item() / len(batch_data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(df)))

In [18]:
def test(df, model, epoch):
    model.eval()
    test_loss = 0
    inds = np.arange(df.shape[0])
    with torch.no_grad():
        for i in range(math.ceil(len(df)/batch_size)):
            batch_ind = inds[i * batch_size : (i+1) * batch_size]
            batch_data = df[batch_ind, :]
            batch_data = batch_data.to(device)
            recon_batch, mu, logvar = model(batch_data)
            test_loss += loss_function(recon_batch, batch_data, mu, logvar).item()


    test_loss /= len(df)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [19]:
vae_gumbel = VAE_Gumbel(500, 250, 20, k = 50)
vae_gumbel.to(device)
vae_gumbel_optimizer = torch.optim.Adam(vae_gumbel.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [48]:
for epoch in range(1, 50 + 1):
        train(train_data, vae_gumbel, vae_gumbel_optimizer, epoch)
        #with torch.no_grad():
        #    model.diag.data[torch.abs(model.diag) < 0.05] = 0
        test(test_data, vae_gumbel, epoch)

====> Epoch: 1 Average loss: 201.9731
====> Test set loss: 88.6152
====> Epoch: 2 Average loss: 80.9494
====> Test set loss: 76.5628
====> Epoch: 3 Average loss: 74.7947
====> Test set loss: 73.7309
====> Epoch: 4 Average loss: 73.2649
====> Test set loss: 72.8139
====> Epoch: 5 Average loss: 72.4091
====> Test set loss: 72.0782
====> Epoch: 6 Average loss: 71.8642
====> Test set loss: 71.6768
====> Epoch: 7 Average loss: 71.2473
====> Test set loss: 70.8740
====> Epoch: 8 Average loss: 70.5291
====> Test set loss: 70.0870
====> Epoch: 9 Average loss: 69.7739
====> Test set loss: 69.4165
====> Epoch: 10 Average loss: 68.9996
====> Test set loss: 68.5028
====> Epoch: 11 Average loss: 68.0274
====> Test set loss: 67.6850
====> Epoch: 12 Average loss: 67.1969
====> Test set loss: 66.9268
====> Epoch: 13 Average loss: 66.5495
====> Test set loss: 66.5136
====> Epoch: 14 Average loss: 66.2013
====> Test set loss: 66.0763
====> Epoch: 15 Average loss: 65.8976
====> Test set loss: 65.9249
===

In [49]:
with torch.no_grad():
    print(torch.sum((vae_gumbel(train_data)[0][1,:] - train_data[1,:])**2))

tensor(0.2363)


In [50]:
with torch.no_grad():
    print(torch.sum((vae_gumbel(test_data)[0][1,:] - test_data[1,:])**2))

tensor(0.2005)


In [51]:
torch.save(vae_gumbel.state_dict(), BASE_PATH_DATA + "../data/models/zeisel/gumbel.pt")

In [52]:
torch.sum((vae_gumbel(train_data)[0] - train_data)**2) / len(train_data)

tensor(0.2706, grad_fn=<DivBackward0>)

In [53]:
torch.sum((vae_gumbel(test_data)[0] - test_data)**2) / len(test_data)

tensor(0.2654, grad_fn=<DivBackward0>)

Let's try modifying our previous methods to have the weights be the output of a neural network.

In [38]:
# L1 VAE model we are loading
class VAE_L1_Hypernetwork(nn.Module):
    def __init__(self, input_size, hidden_layer_size, z_size):
        super(VAE_L1_Hypernetwork, self).__init__()
        
        self.weight_creator = nn.Sequential(
            nn.Linear(input_size, hidden_layer_size),
            nn.ReLU(),
            nn.Linear(hidden_layer_size, input_size)
        )
        
        self.fc1 = nn.Linear(input_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, input_size)

    def encode(self, x):
        self.l1_weights = self.weight_creator(x)
        x = x * self.l1_weights
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [49]:
def mean_l1_norm(weights):
    return torch.sum(torch.stack([torch.norm(weights[i, :], 1) for i in range(weights.shape[0])]))

In [53]:
def train_hypernetwork(df, model, optimizer, epoch):
    model.train()
    train_loss = 0
    permutations = torch.randperm(df.shape[0])
    for i in range(math.ceil(len(df)/batch_size)):
        batch_ind = permutations[i * batch_size : (i+1) * batch_size]
        batch_data = df[batch_ind, :]
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(batch_data)
        loss = loss_function(recon_batch, batch_data, mu, logvar)
        
        loss += mean_l1_norm(model.l1_weights)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(batch_data), len(df),
                100. * i / len(df),
                loss.item() / len(batch_data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(df)))

In [54]:
vae_l1_hypernetwork = VAE_L1_Hypernetwork(500, 250, 20)
vae_l1_hypernetwork.to(device)
vae_l1_hypernetwork_optimizer = torch.optim.Adam(vae_l1_hypernetwork.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [55]:
for epoch in range(1, 50 + 1):
        train_hypernetwork(train_data, vae_l1_hypernetwork, vae_l1_hypernetwork_optimizer, epoch)
        #with torch.no_grad():
        #    model.diag.data[torch.abs(model.diag) < 0.05] = 0
        test(test_data, vae_l1_hypernetwork, epoch)

====> Epoch: 1 Average loss: 231.0197
====> Test set loss: 124.8660
====> Epoch: 2 Average loss: 90.6851
====> Test set loss: 75.4202
====> Epoch: 3 Average loss: 75.1994
====> Test set loss: 72.8025
====> Epoch: 4 Average loss: 73.2216
====> Test set loss: 71.4933
====> Epoch: 5 Average loss: 71.8990
====> Test set loss: 70.6264
====> Epoch: 6 Average loss: 70.9756
====> Test set loss: 69.3226
====> Epoch: 7 Average loss: 69.9467
====> Test set loss: 68.7080
====> Epoch: 8 Average loss: 69.1143
====> Test set loss: 68.1165
====> Epoch: 9 Average loss: 68.3865
====> Test set loss: 67.1594
====> Epoch: 10 Average loss: 67.6970
====> Test set loss: 66.6947
====> Epoch: 11 Average loss: 67.1095
====> Test set loss: 66.0758
====> Epoch: 12 Average loss: 66.5872
====> Test set loss: 65.7160
====> Epoch: 13 Average loss: 66.1843
====> Test set loss: 65.5178
====> Epoch: 14 Average loss: 65.9462
====> Test set loss: 65.2116
====> Epoch: 15 Average loss: 65.6303
====> Test set loss: 65.0563
==

In [56]:
torch.sum((vae_l1_hypernetwork(train_data)[0] - train_data)**2) / len(train_data)

tensor(0.2847, grad_fn=<DivBackward0>)

In [57]:
torch.sum((vae_l1_hypernetwork(test_data)[0] - test_data)**2) / len(test_data)

tensor(0.2894, grad_fn=<DivBackward0>)

In [58]:
_ = vae_l1_hypernetwork(train_data)

Look at the weights

In [69]:
torch.sum(torch.abs(vae_l1_hypernetwork.l1_weights) < 1e-4, dim = -1)

tensor([154, 122, 143,  ..., 109, 140, 184])

In [71]:
np.histogram(torch.abs(vae_l1_hypernetwork.l1_weights)[0, :].clone().detach().numpy())

(array([138, 121,  90,  62,  38,  24,  12,   5,   7,   3]),
 array([7.6973811e-07, 8.8643887e-05, 1.7651804e-04, 2.6439218e-04,
        3.5226633e-04, 4.4014049e-04, 5.2801461e-04, 6.1588880e-04,
        7.0376293e-04, 7.9163711e-04, 8.7951124e-04], dtype=float32))