Try out the Xie Ermon paper of Continuous relaxation for subset selection

In [1]:
import torch
from torch.utils.data import DataLoader

from torchvision import datasets
import torchvision.transforms as transforms


from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F

import numpy as np

from torchvision.utils import save_image

import matplotlib.pyplot as plt

import math

In [2]:
import os
from os import listdir

In [3]:
BASE_PATH_DATA = '../data/'

In [4]:
n_epochs = 5
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
img_size = 28
channels = 1

log_interval = 100


z_size = 20

n = 28 * 28

# from running
# EPSILON = np.finfo(tf.float32.as_numpy_dtype).tiny
#EPSILON = 1.1754944e-38
EPSILON = 1e-10

In [5]:
cuda = True if torch.cuda.is_available() else False

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

device = torch.device("cuda:0" if cuda else "cpu")
print(cuda)

True


In [6]:
import scipy.io as sio

In [7]:
a = sio.loadmat("../data/zeisel/CITEseq.mat")
data= a['G'].T
N,d=data.shape
#transformation from integer entries 
data=np.log(data+np.ones(data.shape))
for i in range(N):
    data[i,:]=data[i,:]/np.linalg.norm(data[i,:])

#load labels from file
a = sio.loadmat("../data/zeisel/CITEseq-labels.mat")
l_aux = a['labels']
labels = np.array([i for [i] in l_aux])

#load names from file
a = sio.loadmat("../data/zeisel/CITEseq_names.mat")
names=[a['citeseq_names'][i][0][0] for i in range(N)]

In [8]:
slices = np.random.permutation(np.arange(data.shape[0]))
upto = int(.8 * len(data))

train_data = data[slices[:upto]]
test_data = data[slices[upto:]]

train_data = Tensor(train_data).to(device)
test_data = Tensor(test_data).to(device)

In [9]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

In [10]:
def gumbel_keys(w):
    # sample some gumbels
    uniform = (1.0 - EPSILON) * torch.rand_like(w) + EPSILON
    z = torch.log(-torch.log(uniform))
    w = w + z
    return w


def continuous_topk(w, k, t, separate=False):
    softmax = nn.Softmax(dim = -1)
    khot_list = []
    onehot_approx = torch.zeros_like(w, dtype = torch.float32)
    for i in range(k):
        ### conver the following into pytorch
        #khot_mask = tf.maximum(1.0 - onehot_approx, EPSILON)
        max_mask = 1 - onehot_approx < EPSILON
        khot_mask = 1 - onehot_approx
        khot_mask[max_mask] = EPSILON
        
        w += torch.log(khot_mask)
        #onehot_approx = tf.nn.softmax(w / t, axis=-1)
        onehot_approx = softmax(w/t)
        khot_list.append(onehot_approx)
    if separate:
        return torch.stack(khot_list)
    else:
        return torch.sum(torch.stack(khot_list), dim = 0) 


def sample_subset(w, k, t=0.1):
    '''
    Args:
        w (Tensor): Float Tensor of weights for each element. In gumbel mode
            these are interpreted as log probabilities
        k (int): number of elements in the subset sample
        t (float): temperature of the softmax
    '''
    w = gumbel_keys(w)
    return continuous_topk(w, k, t)

In [11]:
# L1 VAE model we are loading
class VAE_Gumbel(nn.Module):
    def __init__(self, input_size, hidden_layer_size, z_size, k, t = 0.1):
        super(VAE_Gumbel, self).__init__()
        
        self.k = k
        self.t = t
        
        self.weight_creator = nn.Sequential(
            nn.Linear(input_size, hidden_layer_size),
            nn.ReLU(),
            nn.Linear(hidden_layer_size, input_size)
        )
        
        self.fc1 = nn.Linear(input_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, input_size)

    def encode(self, x):
        w = self.weight_creator(x)
        subset_indices = sample_subset(w, self.k, self.t)
        x = x * subset_indices
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [12]:
def train(df, model, optimizer, epoch):
    model.train()
    train_loss = 0
    permutations = torch.randperm(df.shape[0])
    for i in range(math.ceil(len(df)/batch_size)):
        batch_ind = permutations[i * batch_size : (i+1) * batch_size]
        batch_data = df[batch_ind, :]
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(batch_data)
        loss = loss_function(recon_batch, batch_data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(batch_data), len(df),
                100. * i / len(df),
                loss.item() / len(batch_data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(df)))

In [13]:
def test(df, model, epoch):
    model.eval()
    test_loss = 0
    inds = np.arange(df.shape[0])
    with torch.no_grad():
        for i in range(math.ceil(len(df)/batch_size)):
            batch_ind = inds[i * batch_size : (i+1) * batch_size]
            batch_data = df[batch_ind, :]
            batch_data = batch_data.to(device)
            recon_batch, mu, logvar = model(batch_data)
            test_loss += loss_function(recon_batch, batch_data, mu, logvar).item()


    test_loss /= len(df)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [14]:
vae_gumbel = VAE_Gumbel(500, 250, 20, k = 50)
vae_gumbel.to(device)
vae_gumbel_optimizer = torch.optim.Adam(vae_gumbel.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [15]:
for epoch in range(1, 50 + 1):
        train(train_data, vae_gumbel, vae_gumbel_optimizer, epoch)
        #with torch.no_grad():
        #    model.diag.data[torch.abs(model.diag) < 0.05] = 0
        test(test_data, vae_gumbel, epoch)

====> Epoch: 1 Average loss: 201.8854
====> Test set loss: 89.2234
====> Epoch: 2 Average loss: 81.1297
====> Test set loss: 76.9310
====> Epoch: 3 Average loss: 74.8473
====> Test set loss: 73.6720
====> Epoch: 4 Average loss: 72.9427
====> Test set loss: 73.0174
====> Epoch: 5 Average loss: 72.2479
====> Test set loss: 72.1608
====> Epoch: 6 Average loss: 71.6219
====> Test set loss: 71.9428
====> Epoch: 7 Average loss: 70.9108
====> Test set loss: 71.2396
====> Epoch: 8 Average loss: 70.3231
====> Test set loss: 70.0839
====> Epoch: 9 Average loss: 69.5585
====> Test set loss: 69.6861
====> Epoch: 10 Average loss: 68.8508
====> Test set loss: 68.7922
====> Epoch: 11 Average loss: 68.0780
====> Test set loss: 68.0090
====> Epoch: 12 Average loss: 67.4405
====> Test set loss: 67.5323
====> Epoch: 13 Average loss: 67.0006
====> Test set loss: 67.1028
====> Epoch: 14 Average loss: 66.6365
====> Test set loss: 66.7566
====> Epoch: 15 Average loss: 66.3018
====> Test set loss: 66.4537
===

In [16]:
torch.save(vae_gumbel.state_dict(), BASE_PATH_DATA + "../data/models/zeisel/gumbel.pt")

## Do a subset for calculation error because of memory

In [17]:
with torch.no_grad():
    print(torch.sum((vae_gumbel(train_data[0:64, :])[0] - train_data[0:64,:])**2) / 64)

tensor(0.2657, device='cuda:0')


In [18]:
with torch.no_grad():
    print(torch.sum((vae_gumbel(test_data[0:64, :])[0] - test_data[0:64,:])**2) / 64)

tensor(0.2445, device='cuda:0')


**Sometimes does better on test!**

Let's look at some of the weights. Are the sparse ones consistent?

In [19]:
with torch.no_grad():
    weights_train = vae_gumbel.weight_creator(train_data[0:64,:])
# same k and t as above
subset_indices = sample_subset(weights_train, k=50, t=0.1)

In [20]:
subset_indices = subset_indices.clone().detach().cpu().numpy()

In [21]:
np.where(np.isclose(subset_indices[0,:], 0))[0]

array([  0,   1,   8,  11,  23,  31,  34,  36,  42,  44,  48,  50,  52,
        53,  55,  56,  57,  61,  63,  64,  66,  67,  68,  69,  71,  72,
        73,  74,  75,  77,  78,  79,  80,  82,  83,  84,  85,  86,  88,
        90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102,
       103, 104, 105, 106, 107, 109, 110, 111, 112, 114, 115, 117, 118,
       119, 121, 122, 123, 124, 126, 127, 129, 130, 131, 132, 133, 134,
       135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
       148, 149, 150, 152, 153, 154, 156, 157, 158, 159, 160, 161, 162,
       163, 164, 165, 166, 167, 168, 169, 170, 172, 173, 174, 175, 176,
       177, 178, 179, 180, 181, 182, 184, 185, 186, 187, 188, 189, 190,
       191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203,
       204, 205, 206, 207, 208, 209, 210, 213, 214, 215, 216, 217, 218,
       219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
       232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 24

In [22]:
np.where(np.isclose(subset_indices[1,:], 0))[0]

array([  6,   7,   9,  11,  15,  17,  23,  30,  31,  34,  36,  39,  40,
        42,  43,  44,  45,  46,  48,  50,  52,  53,  54,  55,  56,  57,
        60,  61,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,
        74,  76,  77,  78,  79,  80,  82,  83,  84,  85,  86,  88,  89,
        90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102,
       103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115,
       117, 118, 119, 120, 121, 122, 123, 124, 126, 127, 128, 129, 130,
       131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
       144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156,
       157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
       170, 171, 172, 173, 174, 175, 177, 178, 179, 180, 181, 182, 183,
       184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196,
       197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,
       210, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 22

In [23]:
np.intersect1d( np.where(np.isclose(subset_indices[1,:], 0))[0], np.where(np.isclose(subset_indices[0,:], 0))[0])

array([ 11,  23,  31,  34,  36,  42,  44,  48,  50,  52,  53,  55,  56,
        57,  61,  63,  64,  66,  67,  68,  69,  71,  72,  73,  74,  77,
        78,  79,  80,  82,  83,  84,  85,  86,  88,  90,  91,  92,  93,
        94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106,
       107, 109, 110, 111, 112, 114, 115, 117, 118, 119, 121, 122, 123,
       124, 126, 127, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138,
       139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 152,
       153, 154, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166,
       167, 168, 169, 170, 172, 173, 174, 175, 177, 178, 179, 180, 181,
       182, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195,
       196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208,
       209, 210, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
       224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236,
       237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 24

In [24]:
np.intersect1d( np.where(np.isclose(subset_indices[12,:], 0))[0], np.where(np.isclose(subset_indices[15,:], 0))[0])

array([ 11,  17,  23,  31,  34,  36,  42,  44,  45,  48,  50,  52,  53,
        54,  55,  56,  57,  58,  60,  61,  63,  66,  67,  68,  69,  71,
        73,  74,  76,  78,  79,  80,  82,  83,  84,  85,  86,  88,  89,
        90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102,
       103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115,
       117, 118, 119, 120, 121, 122, 123, 124, 126, 127, 128, 129, 130,
       131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
       144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156,
       157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
       170, 171, 172, 173, 174, 175, 177, 178, 179, 180, 181, 182, 183,
       184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196,
       197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,
       210, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224,
       225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 23

**Tend to select the same sparse features**

**Let's try modifying our previous methods to have the weights be the output of a neural network.**

In [25]:
# L1 VAE model we are loading
class VAE_L1_Hypernetwork(nn.Module):
    def __init__(self, input_size, hidden_layer_size, z_size):
        super(VAE_L1_Hypernetwork, self).__init__()
        
        self.weight_creator = nn.Sequential(
            nn.Linear(input_size, hidden_layer_size),
            nn.ReLU(),
            nn.Linear(hidden_layer_size, input_size)
        )
        
        self.fc1 = nn.Linear(input_size, hidden_layer_size)
        self.fc21 = nn.Linear(hidden_layer_size, z_size)
        self.fc22 = nn.Linear(hidden_layer_size, z_size)
        self.fc3 = nn.Linear(z_size, hidden_layer_size)
        self.fc4 = nn.Linear(hidden_layer_size, input_size)

    def encode(self, x):
        self.l1_weights = self.weight_creator(x)
        x = x * self.l1_weights
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [26]:
def l1_norm_50_2(weights):
    return torch.sum(torch.stack([(torch.norm(weights[i, :], 1)**2 - 50)**2 for i in range(weights.shape[0])]))

def l1_norm_50(weights):
    return torch.sum(torch.stack([(torch.norm(weights[i, :], 1) - 50)**2 for i in range(weights.shape[0])]))

def mean_l1_norm(weights):
    return torch.sum(torch.stack([torch.norm(weights[i, :], 1) for i in range(weights.shape[0])]))

In [27]:
def train_hypernetwork(df, model, optimizer, epoch):
    model.train()
    train_loss = 0
    permutations = torch.randperm(df.shape[0])
    for i in range(math.ceil(len(df)/batch_size)):
        batch_ind = permutations[i * batch_size : (i+1) * batch_size]
        batch_data = df[batch_ind, :]
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(batch_data)
        loss = loss_function(recon_batch, batch_data, mu, logvar)
        
        loss += 100 * l1_norm_50_2(model.l1_weights)
        loss += 100 * l1_norm_50(model.l1_weights)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(batch_data), len(df),
                100. * i / len(df),
                loss.item() / len(batch_data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(df)))

In [28]:
vae_l1_hypernetwork = VAE_L1_Hypernetwork(500, 250, 20)
vae_l1_hypernetwork.to(device)
vae_l1_hypernetwork_optimizer = torch.optim.Adam(vae_l1_hypernetwork.parameters(), 
                                                lr=lr, 
                                                betas = (b1,b2))

In [29]:
for epoch in range(1, 50 + 1):
        train_hypernetwork(train_data, vae_l1_hypernetwork, vae_l1_hypernetwork_optimizer, epoch)
        #with torch.no_grad():
        #    model.diag.data[torch.abs(model.diag) < 0.05] = 0
        test(test_data, vae_l1_hypernetwork, epoch)

====> Epoch: 1 Average loss: 394124.1187
====> Test set loss: 116.8953
====> Epoch: 2 Average loss: 183976.3396
====> Test set loss: 75.7574
====> Epoch: 3 Average loss: 183682.8196
====> Test set loss: 73.1656
====> Epoch: 4 Average loss: 183623.6807
====> Test set loss: 72.2464
====> Epoch: 5 Average loss: 183590.4244
====> Test set loss: 71.2622
====> Epoch: 6 Average loss: 183563.2187
====> Test set loss: 70.1436
====> Epoch: 7 Average loss: 183547.1288
====> Test set loss: 69.2699
====> Epoch: 8 Average loss: 183535.3562
====> Test set loss: 68.8273
====> Epoch: 9 Average loss: 183527.5683
====> Test set loss: 67.9229
====> Epoch: 10 Average loss: 183518.9251
====> Test set loss: 67.3084
====> Epoch: 11 Average loss: 183514.3807
====> Test set loss: 66.5727
====> Epoch: 12 Average loss: 183508.2074
====> Test set loss: 66.1613
====> Epoch: 13 Average loss: 183505.4259
====> Test set loss: 65.7909
====> Epoch: 14 Average loss: 183506.2770
====> Test set loss: 65.5926
====> Epoch: 1

====> Epoch: 48 Average loss: 183524.9320
====> Test set loss: 63.9373
====> Epoch: 49 Average loss: 183524.2789
====> Test set loss: 63.9189
====> Epoch: 50 Average loss: 183529.8860
====> Test set loss: 63.8926


In [30]:
with torch.no_grad():
    print(torch.sum((vae_l1_hypernetwork(train_data[0:64, :])[0] - train_data[0:64, :])**2) / 64)

tensor(0.2707, device='cuda:0')


In [31]:
with torch.no_grad():
    print(torch.sum((vae_l1_hypernetwork(test_data[0:64, :])[0] - test_data[0:64, :])**2) / 64)

tensor(0.2552, device='cuda:0')


In [32]:
_ = vae_l1_hypernetwork(test_data[0:64, ])

Look at the weights

In [33]:
np.histogram(torch.abs(vae_l1_hypernetwork.l1_weights)[1, :].clone().detach().cpu().numpy())

(array([103, 130, 126,  72,  29,  18,  10,   6,   5,   1]),
 array([6.9841743e-05, 6.0694804e-03, 1.2069119e-02, 1.8068759e-02,
        2.4068397e-02, 3.0068036e-02, 3.6067676e-02, 4.2067315e-02,
        4.8066951e-02, 5.4066591e-02, 6.0066231e-02], dtype=float32))

In [34]:
np.sum(torch.abs(vae_l1_hypernetwork.l1_weights)[1, :].clone().detach().cpu().numpy())

7.255764

In [35]:
np.sum(torch.abs(vae_l1_hypernetwork.l1_weights)[1, :].clone().detach().cpu().numpy() < 1e-3)

18

In [36]:
torch.abs(vae_l1_hypernetwork.l1_weights)[1, :].clone().detach().cpu().numpy()

array([2.69319303e-02, 3.10958177e-03, 2.43317857e-02, 1.94285996e-02,
       1.20058507e-02, 4.48931754e-03, 2.50979159e-02, 1.51752867e-02,
       5.92885353e-03, 7.98936561e-03, 1.91015601e-02, 3.21855247e-02,
       2.04350948e-02, 9.63868387e-03, 3.49931046e-02, 1.67993177e-02,
       3.26597458e-03, 7.01383129e-03, 1.01694260e-02, 1.76095217e-02,
       9.89056379e-03, 1.47589184e-02, 2.23821178e-02, 9.32968594e-03,
       1.00713596e-02, 1.85512789e-02, 7.53364898e-03, 3.59241106e-03,
       3.11556123e-02, 1.80314649e-02, 1.20125711e-04, 5.79204410e-03,
       3.33424881e-02, 6.16349839e-03, 4.64464352e-03, 1.42324269e-02,
       4.22804570e-03, 5.95882162e-03, 3.46041471e-03, 3.89055870e-02,
       1.00746844e-02, 2.13784948e-02, 5.36743551e-03, 2.38343831e-02,
       1.26601188e-02, 9.98461246e-03, 1.27780885e-02, 1.92644857e-02,
       5.58926724e-03, 1.03354119e-02, 2.02848762e-03, 6.94607943e-03,
       2.93867569e-02, 1.17492406e-02, 1.06210560e-02, 5.27538732e-03,
      

Instance wise does not work very well.