In [1]:
import os
import numpy as np
import tensorflow as tf
from torch.utils.data import Dataset, DataLoader
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.datasets as dset
import matplotlib.pyplot as plt
from sklearn import svm
from sklearn.metrics import accuracy_score
import pickle

In [20]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Hyperparameters

In [5]:
SEED = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(SEED)
alpha = 1                  
lr = 1e-3                  
epochs = 25                 
batch_size = 100          


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using {}'.format(device))

Using cuda


Dataset

In [6]:
class MnistDataSet:
  def __init__(self,n=0,batch_size=1000):
    self.train_data = dset.FashionMNIST(root='./data', download=True,train = True,
                       transform=transforms.Compose([
                           transforms.Resize(28),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,)),
                       ]))
    self.test_data = dset.FashionMNIST(root='./data', download=True,train = False,
                       transform=transforms.Compose([
                           transforms.Resize(28),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5,), (0.5,)),
                       ]))
 
    self.train_loader = self.get_indices(n,self.train_data,batch_size)
    self.test_loader  = self.get_indices(0,self.test_data,batch_size)


  @staticmethod
  def get_indices(n,my_dataset,batch_size):
    if n!=0:
      indices = []
      labels = my_dataset.targets.numpy()
      images_perlabel = int(n/10)
      random.seed(1)
      
      for i in range(10):
        per_label = np.where(labels==i)[0].tolist()
        index_sample = random.sample(per_label,images_perlabel)
        indices.extend(index_sample)
      return DataLoader(torch.utils.data.Subset(my_dataset,indices),batch_size=batch_size,shuffle=True)
    else:
      return DataLoader(my_dataset,batch_size=batch_size,shuffle=True)

Model

In [7]:
class Encoder(nn.Module):
    def __init__(self,channels=64,latent_dim_size=50):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=channels, kernel_size=4, stride=2, padding=1) 
        self.conv2 = nn.Conv2d(in_channels=channels, out_channels=channels*2, kernel_size=4, stride=2, padding=1) 
        self.fc_mu = nn.Linear(in_features=channels*2*7*7, out_features=latent_dim_size)
        self.fc_logvar = nn.Linear(in_features=channels*2*7*7, out_features=latent_dim_size)
        self.relu = nn.ReLU()   
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1) 
        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)
        return x_mu, x_logvar

class Decoder(nn.Module):
    def __init__(self,latent_dim_size=50,channels = 64):
        super(Decoder, self).__init__()
        
        self.fc = nn.Linear(in_features=latent_dim_size, out_features=channels*2*7*7)
        self.conv1 = nn.ConvTranspose2d(in_channels=channels*2, out_channels=channels, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.ConvTranspose2d(in_channels=channels, out_channels=1, kernel_size=4, stride=2, padding=1)
        self.relu = nn.ReLU()     
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 64*2, 7, 7) 
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.sigmoid(x) 
        return x
    
class VariationalAutoencoder(nn.Module):
    def __init__(self):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def forward(self, x):
        latent_mu, latent_logvar = self.encoder(x)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_recon = self.decoder(latent)
        return x_recon,latent_mu, latent_logvar,latent
    
    def latent_sample(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu
    
def vae_loss(recon_x, x, mu, logvar,alpha):
    recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + alpha * kldivergence

def train(vae_model,train_dataloader):
  optimizer = torch.optim.Adam(params=vae_model.parameters(), lr=lr, weight_decay=1e-5)

  vae_model.train()

  train_loss_avg = []

  for epoch in range(epochs):
      train_loss_avg.append(0)
      num_batches = 0
      
      for image_batch, _ in train_dataloader:
          
          image_batch = image_batch.to(device)
          image_batch_recon, latent_mu, latent_logvar, _ = vae_model(image_batch)
          
          loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar,alpha)
          optimizer.zero_grad()
          loss.backward()
          
          optimizer.step()
          
          train_loss_avg[-1] += loss.item()
          num_batches += 1
          
      train_loss_avg[-1] /= num_batches
      print('epoch %d/%d: loss: %.2f' % (epoch+1, epochs, train_loss_avg[-1]))
    

VAE Training

In [8]:
data_loader = MnistDataSet()

vae_model = VariationalAutoencoder()
vae_model = vae_model.to(device)

train_losses = train(vae_model,data_loader.train_loader)
model_weights = 'VAE_weights.p'
torch.save(vae_model.state_dict(), model_weights)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

epoch 1/25: loss: -18025460.30
epoch 2/25: loss: -26556841.33
epoch 3/25: loss: -29215195.57
epoch 4/25: loss: -31391921.40
epoch 5/25: loss: -33167408.60
epoch 6/25: loss: -34382967.33
epoch 7/25: loss: -35476232.00
epoch 8/25: loss: -36243495.40
epoch 9/25: loss: -37005495.00
epoch 10/25: loss: -37666190.07
epoch 11/25: loss: -38229617.00
epoch 12/25: loss: -38750481.20
epoch 13/25: loss: -39212377.00
epoch 14/25: loss: -39621216.87
epoch 15/25: loss: -39948480.40
epoch 16/25: loss: -40318984.13
epoch 17/25: loss: -40562551.60
epoch 18/25: loss: -40910458.53
epoch 19/25: loss: -41089545.60
epoch 20/25: loss: -41354111.87
epoch 21/25: loss: -41551871.80
epoch 22/25: loss: -41798064.40
epoch 23/25: loss: -41989220.00
epoch 24/25: loss: -42200797.07
epoch 25/25: loss: -42372306.00


In [9]:
vae_model = VariationalAutoencoder()
fweights = 'VAE_weights.p'

vae_weights = torch.load(fweights)
vae_model.load_state_dict(vae_weights)
vae_model.eval();

Train and Test SVMs for different data sizes

In [10]:
data_sizes = [100, 600, 1000, 3000]

for data_size in data_sizes:
  print('Training SVM on {} labels...'.format(data_size))
  data_loader = MnistDataSet(n=data_size,batch_size=data_size).train_loader
  inputs, classes = next(iter(data_loader))  

  with torch.no_grad():
    _,_,_,x_svm = vae_model(inputs)
  x_svm = x_svm.detach().numpy()
  y_svm = classes.detach().numpy()
  _svm = svm.SVC(kernel='rbf')
  _svm.fit(x_svm, y_svm)
  filename = 'svm_{}.sav'.format(data_size)
  pickle.dump(_svm, open(filename, 'wb'))
  test_loader = MnistDataSet(n=0,batch_size=10000).test_loader
  inputs, classes = next(iter(test_loader))  

  _,_,_,latent_test = vae_model(inputs)
  latent_test = latent_test.detach().numpy()
  classes = classes.detach().numpy()
  y = _svm.predict(latent_test)
  acc  = accuracy_score(y, classes)
  print("Test SVM accuracy %.2f" % (acc*100))

Training SVM on 100 labels...
Test SVM accuracy 66.34
Training SVM on 600 labels...
Test SVM accuracy 76.14
Training SVM on 1000 labels...
Test SVM accuracy 77.58
Training SVM on 3000 labels...
Test SVM accuracy 81.38


Test the SVM on MNIST dataset

In [11]:
vae_model = VariationalAutoencoder()
fweights = 'VAE_weights.p'

vae_weights = torch.load(fweights)
vae_model.load_state_dict(vae_weights)
vae_model.eval();

data_sizes = [100, 600, 1000, 3000]

for data_size in data_sizes:
  print('{} labels:'.format(data_size))
  data_loader = MnistDataSet(n=data_size,batch_size=data_size).test_loader
  inputs, classes = next(iter(data_loader))  

  _svm = pickle.load(open('svm_{}.sav'.format(data_size), 'rb'))  

  with torch.no_grad():

    _,_,_,latent_test = vae_model(inputs)
    latent_test = latent_test.detach().numpy()
    classes = classes.detach().numpy()
    y = _svm.predict(latent_test)
    acc  = accuracy_score(y, classes)
    print("SVM Test accuracy %.2f" % (acc*100))

100 labels:
SVM Test accuracy 68.00
600 labels:
SVM Test accuracy 77.83
1000 labels:
SVM Test accuracy 78.00
3000 labels:
SVM Test accuracy 81.43
