<a href="https://colab.research.google.com/github/hnipun/ColabProjects/blob/master/Stacked_Auto_Encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
import time
import random
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image

# Checks for the availability of GPU 
if torch.cuda.is_available():
    print("working on gpu!")
    device = 'cuda'
else:
    print("No gpu! only cpu ;)")
    device = 'cpu'
    
if device == 'cpu':    
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
elif device == 'cuda':
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = '0'

No gpu! only cpu ;)


In [0]:
def to_img(x):
  """
  convert a flatten input to 32 * 32 image
  """
    x = x.view(x.size(0), 1, 32, 32)
    return x

def flatten_img(x):
  """
  flatten the input x.
  """
    x = x.view(x.size(0), 1*32*32)
    return x

def get_data_indices(dataset, fine_tune_idx, labels_per_class = 10):
  """
  return the indices of given labels_per_class.
  Args:
      dataset : dataset downlaoded
      fine_tune_idx : list of intrested indices
      labels_per_class : number of labels per class
  """
    indices = []
    count = {}
    for i in range(10):
      count[i] = 0
    
    for index in fine_tune_idx:
        image, label = dataset[index]
        
        if count[label] < labels_per_class:
          count[label] += 1
          indices.append(index)
        continue

    return indices

In [0]:
class AutoEncoder(nn.Module):
    """
    Autoencoder layer for stacked autoencoders.
    Args:
        input_size: The number of features in the input
        hidden_size: The number of features in the hidden layer
    """
    def __init__(self, input_size, hidden_size):
        super(AutoEncoder, self).__init__()

        self.encode = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
        )
        self.decode = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.ReLU(),
        )

        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.parameters(), lr=0.1)

    def forward(self, x):
        # Train each autoencoder individually
        x = x.detach()
        y = self.encode(x)

        if self.training:
            x_reconstruct = self.decode(y)
            loss = self.criterion(x_reconstruct, Variable(x.data, requires_grad=False))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
        return y.detach()

    def reconstruct(self, x):
        return self.decode(x)

In [0]:
class StackedAutoEncoder(nn.Module):
    """
    A stacked autoencoder made from the autoencoders above.
    Each autoencoder is trained independently at the same time.
    """

    def __init__(self):
        super(StackedAutoEncoder, self).__init__()

        self.ae1 = AutoEncoder(1024, 1000)
        self.ae2 = AutoEncoder(1000, 800)
        self.ae3 = AutoEncoder(800, 500)

    def forward(self, x):
        x  = flatten_img(x)
        a1 = self.ae1(x)
        a2 = self.ae2(a1)
        a3 = self.ae3(a2)

        if self.training:
            return a3
        else:
            return a3, self.reconstruct(a3)

    def reconstruct(self, x):
            a2_reconstruct = self.ae3.reconstruct(x)
            a1_reconstruct = self.ae2.reconstruct(a2_reconstruct)
            x_reconstruct = self.ae1.reconstruct(a1_reconstruct)
            return x_reconstruct

In [0]:
class Classifier(nn.Module):
    def __init__(self, num_classes, stacked_encoder):
            super(Classifier, self).__init__()
            
            self.features = stacked_encoder
            self.linear_layers = nn.Sequential(nn.Linear(500, 10))
          
    def forward(self, x):
      if self.training:
            x    = self.features(x)
      else:
            x, _ = self.features(x)
      x  = self.linear_layers(x)
            
      return x


In [8]:
if not os.path.exists('./imgs'):
    os.mkdir('./imgs')
    
batch_size = 128

img_transform = transforms.Compose([
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0, hue=0),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
])
dataset = CIFAR10('../data/cifar10/', train=True, transform=img_transform, download=True)

num_dataset = len(dataset)
indices = list(range(num_dataset))
np.random.shuffle(indices)
train_split = int(np.floor(0.2 * num_dataset))
test_split = int(np.floor(0.1 * num_dataset)) + train_split
train_idx, test_idx, fine_tune_idx = indices[:train_split], indices[train_split:test_split], indices[test_split:]

fine_tune_idx_10 = get_data_indices(dataset, fine_tune_idx, 10)
fine_tune_idx_100 = get_data_indices(dataset, fine_tune_idx, 100)

train_sampler = SubsetRandomSampler(train_idx)
test_sampler = SubsetRandomSampler(test_idx)
finetune_sampler_10 = SubsetRandomSampler(fine_tune_idx_10)
finetune_sampler_100 = SubsetRandomSampler(fine_tune_idx_100)

train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, drop_last=False)
finetune_loader_10 = DataLoader(dataset, batch_size=batch_size, sampler=finetune_sampler_10, drop_last=False)
finetune_loader_100 = DataLoader(dataset, batch_size=batch_size, sampler=finetune_sampler_100, drop_last=False)
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, drop_last=False)

Files already downloaded and verified


In [10]:
loss_ = []
mean_ = []
sparsity_ = []
max_ = []

num_epochs = 200
model = StackedAutoEncoder().to(device)
for epoch in range(num_epochs):
    model.train()
    total_time = time.time()
    for i, data in enumerate(train_loader):
        img, target = data
        target = Variable(target).to(device)
        img = Variable(img).to(device)
        features = model(img).detach()

    total_time = time.time() - total_time

    model.eval()
    img, _ = data
    img = Variable(img)
    features, x_reconstructed = model(img)
    reconstruction_loss = torch.mean((x_reconstructed.data - flatten_img(img).data)**2)
    loss_.append(reconstruction_loss)

    if epoch % 10 == 0:
        print("Saving epoch {}".format(epoch))
        orig = to_img(img.cpu().data)
        save_image(orig, './imgs/orig_{}.png'.format(epoch))
        pic = to_img(x_reconstructed.cpu().data)
        save_image(pic, './imgs/reconstruction_{}.png'.format(epoch))
    
    mean = torch.mean(features.data)
    sparsity = torch.sum(features.data == 0.0)*100 / features.data.numel()
    max_value = torch.max(features.data)
    mean_.append(mean)
    sparsity_.append(sparsity)
    max_.append(max_value)

    print("Epoch {} complete\tTime: {:.4f}s\t\tLoss: {:.4f}".format(epoch, total_time, reconstruction_loss))
    print("Feature Statistics\tMean: {:.4f}\t\tMax: {:.4f}\t\tSparsity: {:.4f}%".format(
        mean, max_value, sparsity))
    print("="*80)

torch.save(model.state_dict(), './CDAE.pth')

Saving epoch 0
Epoch 0 complete	Time: 8.3670s		Loss: 0.2094
Feature Statistics	Mean: 0.0838		Max: 0.9719		Sparsity: 53.0000%
Epoch 1 complete	Time: 8.4053s		Loss: 0.1702
Feature Statistics	Mean: 0.1975		Max: 2.5369		Sparsity: 58.0000%
Epoch 2 complete	Time: 8.3070s		Loss: 0.1222
Feature Statistics	Mean: 0.1882		Max: 2.2471		Sparsity: 59.0000%
Epoch 3 complete	Time: 8.4363s		Loss: 0.1284
Feature Statistics	Mean: 0.1952		Max: 3.2500		Sparsity: 58.0000%
Epoch 4 complete	Time: 8.3810s		Loss: 0.1379
Feature Statistics	Mean: 0.2062		Max: 2.6101		Sparsity: 58.0000%
Epoch 5 complete	Time: 8.3932s		Loss: 0.1566
Feature Statistics	Mean: 0.2190		Max: 2.5719		Sparsity: 58.0000%
Epoch 6 complete	Time: 8.3935s		Loss: 0.1280
Feature Statistics	Mean: 0.1962		Max: 2.8778		Sparsity: 58.0000%
Epoch 7 complete	Time: 8.3620s		Loss: 0.1643
Feature Statistics	Mean: 0.2328		Max: 2.6232		Sparsity: 58.0000%
Epoch 8 complete	Time: 8.3097s		Loss: 0.1356
Feature Statistics	Mean: 0.2140		Max: 3.2638		Sparsity: 58.0

In [0]:
def train_model(loader, images_per_class = 10, epochs=100):
  classifier = Classifier(num_classes = 10, stacked_encoder = model)
  classifier = classifier.to(device)

  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

  loss_ = []
  for epoch in range(epochs):
          training_loss = 0.0
          classifier.train()
         
          for (images,labels), in zip(loader):
              ## Move the images to the device
              images = images.to(device)
              ## Move the labels to the device
              labels = labels.to(device)
              ## Get the output of the model by passing input to the model
              output_train = classifier(images)
              ## Find the loss of the input batch by passing output & ground truth labels to the criterion
              loss_train = criterion(output_train, labels)
              training_loss += loss_train.item()
              ## clear the gradients
              optimizer.zero_grad()
              ## compute the gradients by backpropagating through the computational graph.
              loss_train.backward()
              ## update the parameters 
              optimizer.step()
              loss_.append(training_loss)

          print("Epoch {} complete\tTime: {:.4f}s\t\tLoss: {:.4f}".format(epoch, total_time, training_loss))
          print("="*80)

  return classifier, loss_

In [0]:
## Testing Loop
def count_correct(preds, labels):
  count = 0
  for pred, label in zip(preds, labels):
    if pred == label:
      count += 1
  return count

def test_model(classifier):
    '''
    A function to test the trained model on the test dataset and print the accuracy.
    
    Inputs:
        model: Trained model.
        
    outputs:
        None. Prints and returns the accuracy.
    '''
    classifier.eval()
    with torch.no_grad():
        correct = 0.0
        total_samples = 0.0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            output = classifier(images)
            softmax = torch.exp(output).cpu()
            prob = list(softmax.numpy())
            predictions = np.argmax(prob, axis=1)
          
            correct += count_correct(predictions, [label.item() for label in labels])

            total_samples += labels.size(0)
        
        accuracy = (correct/total_samples)*100
        print("Total Accuracy on the Test set: {} %".format(accuracy))
        return accuracy

In [13]:
classfier_10, loss_10 = train_model(finetune_loader_10, images_per_class = 10, epochs=2500)
accuracy_10 = test_model(classfier_10)

classfier_100, loss_100 = train_model(finetune_loader_100, images_per_class = 100, epochs=2500)
accuracy_100 = test_model(classfier_100)

Epoch 0 complete	Time: 8.3906s		Loss: 2.4825
Epoch 1 complete	Time: 8.3906s		Loss: 2.3750
Epoch 2 complete	Time: 8.3906s		Loss: 2.3227
Epoch 3 complete	Time: 8.3906s		Loss: 2.2974
Epoch 4 complete	Time: 8.3906s		Loss: 2.2985
Epoch 5 complete	Time: 8.3906s		Loss: 2.3043
Epoch 6 complete	Time: 8.3906s		Loss: 2.3111
Epoch 7 complete	Time: 8.3906s		Loss: 2.3154
Epoch 8 complete	Time: 8.3906s		Loss: 2.3082
Epoch 9 complete	Time: 8.3906s		Loss: 2.2933
Epoch 10 complete	Time: 8.3906s		Loss: 2.2829
Epoch 11 complete	Time: 8.3906s		Loss: 2.2749
Epoch 12 complete	Time: 8.3906s		Loss: 2.2618
Epoch 13 complete	Time: 8.3906s		Loss: 2.2500
Epoch 14 complete	Time: 8.3906s		Loss: 2.2383
Epoch 15 complete	Time: 8.3906s		Loss: 2.2371
Epoch 16 complete	Time: 8.3906s		Loss: 2.2337
Epoch 17 complete	Time: 8.3906s		Loss: 2.2352
Epoch 18 complete	Time: 8.3906s		Loss: 2.2387
Epoch 19 complete	Time: 8.3906s		Loss: 2.2266
Epoch 20 complete	Time: 8.3906s		Loss: 2.2156
Epoch 21 complete	Time: 8.3906s		Loss: 2.207

In [21]:
print("Accuracy for 10 samples per class  : ", accuracy_10)
print("Accuracy for 100 samples per class : ", accuracy_100)

Accuracy for 10 samples per class  :  17.740000000000002
Accuracy for 100 samples per class :  25.36
