<a href="https://colab.research.google.com/github/hnipun/ColabProjects/blob/master/Stacked_Auto_Encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
import time
import random
import numpy as np

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image

# Checks for the availability of GPU 
if torch.cuda.is_available():
    print("working on gpu!")
    device = 'cuda'
else:
    print("No gpu! only cpu ;)")
    device = 'cpu'
    
if device == 'cpu':    
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
elif device == 'cuda':
    random.seed(0)
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = '0'

No gpu! only cpu ;)


In [0]:
def to_img(x):
    x = x.view(x.size(0), 1, 32, 32)
    return x

def flatten_img(x):
    x = x.view(x.size(0), 1*32*32)
    return x

In [0]:
class AutoEncoder(nn.Module):
    """
    Convolutional denoising autoencoder layer for stacked autoencoders.
    This module is automatically trained when in model.training is True.
    Args:
        input_size: The number of features in the input
        output_size: The number of features to output
        stride: Stride of the convolutional layers.
    """
    def __init__(self, input_size, hidden_size):
        super(AutoEncoder, self).__init__()

        self.encode = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
        )
        self.decode = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.ReLU(),
        )

        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.SGD(self.parameters(), lr=0.1)

    def forward(self, x):
        # Train each autoencoder individually
        x = x.detach()
        # Add noise, but use the original lossless input as the target.
        y = self.encode(x)

        if self.training:
            x_reconstruct = self.decode(y)
            loss = self.criterion(x_reconstruct, Variable(x.data, requires_grad=False))
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
        return y.detach()

    def reconstruct(self, x):
        return self.decode(x)

In [0]:
class StackedAutoEncoder(nn.Module):
    """
    A stacked autoencoder made from the convolutional denoising autoencoders above.
    Each autoencoder is trained independently and at the same time.
    """

    def __init__(self):
        super(StackedAutoEncoder, self).__init__()

        self.ae1 = AutoEncoder(1024, 1000)
        self.ae2 = AutoEncoder(1000, 800)
        self.ae3 = AutoEncoder(800, 500)

    def forward(self, x):
        x  = flatten_img(x)
        a1 = self.ae1(x)
        a2 = self.ae2(a1)
        a3 = self.ae3(a2)

        if self.training:
            return a3
        else:
            return a3, self.reconstruct(a3)

    def reconstruct(self, x):
            a2_reconstruct = self.ae3.reconstruct(x)
            a1_reconstruct = self.ae2.reconstruct(a2_reconstruct)
            x_reconstruct = self.ae1.reconstruct(a1_reconstruct)
            return x_reconstruct

In [8]:
if not os.path.exists('./imgs'):
    os.mkdir('./imgs')

num_epochs = 250
batch_size = 128

img_transform = transforms.Compose([
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0, hue=0),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
])

dataset = CIFAR10('../data/cifar10/', transform=img_transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8)

model = StackedAutoEncoder().to(device)
for epoch in range(num_epochs):
    model.train()
    total_time = time.time()
    for i, data in enumerate(dataloader):
        img, target = data
        target = Variable(target).to(device)
        img = Variable(img).to(device)
        features = model(img).detach()

    total_time = time.time() - total_time

    model.eval()
    img, _ = data
    img = Variable(img)
    features, x_reconstructed = model(img)
    reconstruction_loss = torch.mean((x_reconstructed.data - flatten_img(img).data)**2)

    if epoch % 10 == 0:
        print("Saving epoch {}".format(epoch))
        orig = to_img(img.cpu().data)
        save_image(orig, './imgs/orig_{}.png'.format(epoch))
        pic = to_img(x_reconstructed.cpu().data)
        save_image(pic, './imgs/reconstruction_{}.png'.format(epoch))

    print("Epoch {} complete\tTime: {:.4f}s\t\tLoss: {:.4f}".format(epoch, total_time, reconstruction_loss))
    print("Feature Statistics\tMean: {:.4f}\t\tMax: {:.4f}\t\tSparsity: {:.4f}%".format(
        torch.mean(features.data), torch.max(features.data), torch.sum(features.data == 0.0)*100 / features.data.numel()))
    print("="*80)

torch.save(model.state_dict(), './CDAE.pth')

Files already downloaded and verified
Saving epoch 0
Epoch 0 complete	Time: 39.0332s		Loss: 0.1524
Feature Statistics	Mean: 0.2260		Max: 3.2095		Sparsity: 55.0000%
Epoch 1 complete	Time: 38.8934s		Loss: 0.1565
Feature Statistics	Mean: 0.2432		Max: 3.3754		Sparsity: 54.0000%
Epoch 2 complete	Time: 39.0764s		Loss: 0.1409
Feature Statistics	Mean: 0.2384		Max: 2.7379		Sparsity: 53.0000%
Epoch 3 complete	Time: 38.9912s		Loss: 0.1412
Feature Statistics	Mean: 0.2489		Max: 3.0562		Sparsity: 52.0000%
Epoch 4 complete	Time: 38.9941s		Loss: 0.1312
Feature Statistics	Mean: 0.2480		Max: 2.7943		Sparsity: 51.0000%
Epoch 5 complete	Time: 38.8695s		Loss: 0.1311
Feature Statistics	Mean: 0.2634		Max: 3.0299		Sparsity: 49.0000%
Epoch 6 complete	Time: 38.9022s		Loss: 0.1312
Feature Statistics	Mean: 0.2768		Max: 2.7660		Sparsity: 48.0000%
Epoch 7 complete	Time: 38.9841s		Loss: 0.1240
Feature Statistics	Mean: 0.2736		Max: 3.3342		Sparsity: 47.0000%
Epoch 8 complete	Time: 39.1508s		Loss: 0.1200
Feature Stati