In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import math
import pickle
import matplotlib.pyplot as plt
from torchvision.utils import make_grid, save_image

In [2]:
batch_size = 100
num_threads = 8

inp_dim = 784
hid_dim = 1000
out_dim = 10

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [3]:
def data_loader():
    train_data = torchvision.datasets.MNIST(root='./data', 
                                            train=True, 
                                            transform=transforms.ToTensor(), 
                                            download=True)
    train_loader = torch.utils.data.DataLoader(train_data, 
                                               batch_size=batch_size, 
                                               shuffle=True, 
                                               num_workers=num_threads)
    test_data = torchvision.datasets.MNIST(root='./data', 
                                           train=False, 
                                           transform=transforms.ToTensor(), 
                                           download=True)
    test_loader = torch.utils.data.DataLoader(test_data, 
                                              batch_size=batch_size, 
                                              shuffle=False, 
                                              num_workers=num_threads)
    return train_loader, test_loader

In [4]:
def output_plot(outputs, nrow=10):
    outputs = outputs.view(-1, 1, 28, 28)
    save_image(outputs, filename='plots/outputs_%d.png' % (int(k)), nrow=nrow, padding=1, scale_each=True)

def filter_plot(model, nrow=30):
    weights = model.W.detach().clone()
    weights = torch.t(weights)
    weights = weights.view(hid_dim, 1, 28, 28)
    weights = weights[0:120]
    save_image(weights, filename='plots/filters_%d.png' % (k),
               nrow=nrow, padding=1, normalize=True, scale_each=True)
    
def plot_activations(hid):
    activations = hid.to(torch.device("cpu")).numpy()
    plt.hist(activations, histtype='barstacked')
    plt.yscale('log')
    plt.yticks([10**1, 10**2, 10**3, 10**4, 10**5, 10**6, 10**7],
                [1, 2, 3, 4, 5, 6, 7])
    plt.xlim(0, 3)
    plt.grid(True)
    plt.savefig('plots/activations_%d.png' % (k))
    plt.close()
    
def plot_loss_curve(k, train_list, test_list, ylim):   
    plt.plot(train_list, 'r-', label='train loss')
    plt.plot(test_list, 'r--', label='test loss')
    plt.legend()
    plt.ylim(ylim)
    plt.savefig('plots/loss_curves_%d.png' % (k))
    plt.close()

In [5]:
class KsparseAE(nn.Module):
    def __init__(self, supervised=False, inp_dim=inp_dim, hid_dim=hid_dim, out_dim=out_dim):
        super(KsparseAE, self).__init__()
        self.supervised = supervised
        self.inp_dim = inp_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim
        
        self.b = torch.nn.Parameter(torch.empty((self.hid_dim), requires_grad=True))
        self.c = torch.nn.Parameter(torch.empty((self.inp_dim), requires_grad=True))
        self.W = torch.nn.Parameter(torch.empty((self.inp_dim, self.hid_dim), requires_grad=True))
        torch.nn.init.normal_(self.b, std=0.01)
        torch.nn.init.normal_(self.c, std=0.01)
        torch.nn.init.normal_(self.W, std=0.01)
        self.layer = nn.Linear(self.hid_dim, self.out_dim)

    def forward(self, x):
        hid = self.b + torch.matmul(x, self.W)
        values, indices = torch.kthvalue(-hid, k, keepdim=True)
        hid[hid < -values] = 0
        self.hid = hid
        if self.supervised:
            y = self.layer(self.hid)
        elif not self.supervised:    
            y = self.c + torch.matmul(self.hid, torch.t(self.W))
        return y

In [6]:
def train(model, supervised=False):
    train_loader, test_loader = data_loader()
        
    #for name, param in model.named_parameters():
    #    if param.requires_grad:
    #        print(name)
    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr)

    train_list, test_list = [], []
    hid = torch.tensor([0.0]).to(device)
    for epoch in range(n_epoch):
        model.train()
        train_loss = 0.0
        cnt = 0
        for x, y in train_loader:
            cnt += 1
            inputs, labels = x, y
            new_batch_size = inputs.size()[0]
            inputs = inputs.view(new_batch_size, -1).to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            hid = torch.cat((hid, model.hid.detach()[:,0]))
            if supervised:
                loss = criterion(outputs, labels)
            elif not supervised:
                loss = criterion(outputs, inputs)
            train_loss += loss * new_batch_size / batch_size
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        train_loss = train_loss / cnt
        train_list.append(train_loss)
        
        with torch.no_grad():
            model.eval()
            test_loss = 0.0
            cnt = 0
            length = 0
            accuracy = 0
            for x, y in test_loader:
                cnt += 1
                length += len(x)
                inputs, labels = x, y
                new_batch_size = x.size()[0]
                inputs = inputs.view(new_batch_size, -1).to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                if supervised:
                    loss = criterion(outputs, labels)
                    accuracy += sum(torch.argmax(outputs, dim=1) == labels)
                elif not supervised:
                    loss = criterion(outputs, inputs)
                test_loss += loss * new_batch_size / batch_size
            test_loss = test_loss / cnt
            test_list.append(test_loss)
            
        torch.save(model.state_dict(), 'models/KsparseAE_%d.pt' % (k))
            
        if (epoch+1) % 1 == 0:
            print('[Epoch %d] train_loss: %.3f, test_loss: %.3f' % (epoch+1, train_loss, test_loss))
    
    filter_plot(model)  #1st experiment in the report : drawing a filter plot for the first weight matrix.
    
    if not supervised:
        #output_plot(outputs)
        hid = hid[1:]
        print(hid.size())
        plot_activations(hid) #2nd experiment in the report : drawing the log-histogram of the hidden unit activities
    
    if supervised:
        accuracy = accuracy.item() / length * 100
        print('accuracy is', accuracy,'%')

        with open('loss/train_ELBO_%d.txt' % (k), 'wb') as f:
            pickle.dump(train_list, f)
        with open('loss/test_ELBO_%d.txt' % (k), 'wb') as f:
            pickle.dump(test_list, f)

    #plot_loss_curve(k, train_list, test_list, (0, 0.3))

In [18]:
if __name__ == '__main__':
    print(device)

    supervised=False
    
    k = 70
    lr=1e-3
    n_epoch = 100
    
    model = KsparseAE(supervised=supervised).to(device)
    model.load_state_dict(torch.load('models/KsparseAE_%d.pt' % (k)))
    train(model=model, supervised=supervised)
    

cuda:1
[Epoch 1] train_loss: 0.000, test_loss: 0.000
[Epoch 2] train_loss: 0.000, test_loss: 0.000
[Epoch 3] train_loss: 0.000, test_loss: 0.000
[Epoch 4] train_loss: 0.000, test_loss: 0.000
[Epoch 5] train_loss: 0.000, test_loss: 0.000
[Epoch 6] train_loss: 0.000, test_loss: 0.000
[Epoch 7] train_loss: 0.000, test_loss: 0.000
[Epoch 8] train_loss: 0.000, test_loss: 0.000
[Epoch 9] train_loss: 0.000, test_loss: 0.000
[Epoch 10] train_loss: 0.000, test_loss: 0.000
[Epoch 11] train_loss: 0.000, test_loss: 0.000
[Epoch 12] train_loss: 0.000, test_loss: 0.000
[Epoch 13] train_loss: 0.000, test_loss: 0.000
[Epoch 14] train_loss: 0.000, test_loss: 0.000
[Epoch 15] train_loss: 0.000, test_loss: 0.000
[Epoch 16] train_loss: 0.002, test_loss: 0.000
[Epoch 17] train_loss: 0.000, test_loss: 0.000
[Epoch 18] train_loss: 0.000, test_loss: 0.000
[Epoch 19] train_loss: 0.000, test_loss: 0.000
[Epoch 20] train_loss: 0.000, test_loss: 0.000
[Epoch 21] train_loss: 0.000, test_loss: 0.000
[Epoch 22] trai