# (CNN) Autoencoder

In [7]:
import torch  
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 

from torch.utils.data import DataLoader
from torchvision import transforms, datasets

from typing import List

### Helper Class for logging trainings data

In [8]:
class TrainDataLogger():
    '''Helper class that stores informatoion during training'''
    def __init__(self) -> None:
        self.ls_losses:List[float] = []

        '''important for visualization'''
        self.ls_epochs:List[int]    = []
        self.ls_imgs:List[torch.Tensor]  = []
        self.ls_recimgs:List[torch.Tensor]  = [] 
        pass 

### Visualization

In [9]:
import matplotlib.pyplot as plt 

In [10]:
# plot loss per epoch
def plot_losses(losses) -> None:
    '''plots loss per epoch'''
    epochs = range(1, len(losses)+1)

    xticks = range(min(epochs), max(epochs)+1) # transforms into integer

    plt.plot(epochs, losses)
    plt.title('Loss per epoch'), plt.ylabel('Loss'),  plt.xlabel('Epoch')
    plt.xticks(xticks)
    plt.legend()
    plt.show()

    return

### Data

In [11]:
# DATA LOADER
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))
    ])

data_train = datasets.MNIST('./data', train=True, download=True, transform=transform)
data_test  = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_dataloader = DataLoader(data_train, batch_size=64, shuffle=True)
test_dataloader = DataLoader(data_test, batch_size=64, shuffle=True)

In [12]:
# DATA: look at pixels value range (important for activation function in decoder)
imgs, labels = next(iter(train_dataloader))

print( torch.min(imgs)  , torch.max(imgs) )

tensor(-1.) tensor(1.)


### Class: Autoencoder (Conv.)

In [13]:
# determine output size
conv1 = nn.Conv2d(1, 16, 3, stride=2, padding=1)
conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
conv3 = nn.Conv2d(32, 64, 7)

tconv1 = nn.ConvTranspose2d(64, 32, 7)
tconv2 = nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1)
tconv3 = nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1)


imgs, _ = next(iter(train_dataloader))

# PRINTS
print(f'''

        Normal:
            {list(imgs.size())}
        Convolutional:
            {list(conv1(imgs).size())}, 
            {list(conv2(conv1(imgs)).size())}, 
            {list(conv3(conv2(conv1(imgs))).size())},
        Tranpose:    
            {list(tconv1(conv3(conv2(conv1(imgs)))).size())}
            {list(tconv2(tconv1(conv3(conv2(conv1(imgs))))).size())}, 
            {list(tconv3(tconv2(tconv1(conv3(conv2(conv1(imgs)))))).size())},    
            
''')




        Normal:
            [64, 1, 28, 28]
        Convolutional:
            [64, 16, 14, 14], 
            [64, 32, 7, 7], 
            [64, 64, 1, 1],
        Tranpose:    
            [64, 32, 7, 7]
            [64, 16, 14, 14], 
            [64, 1, 28, 28],    
            



In [14]:
class Autoencoder_Conv(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            conv1, nn.ReLU(),
            conv2, nn.ReLU(),
            conv3, 
        )

        self.decoder = nn.Sequential(
            tconv1, nn.ReLU(),
            tconv2, nn.ReLU(),
            tconv3, 
            nn.Sigmoid() # making pixel in range [0,1] (gray scaled)
        )

        pass 

    def forward(self, x):
        return self.decoder(self.encoder(x))
        

### Train & Test functions

In [15]:
import math 

def train(epochs: int, model: nn.Module, dataloader: DataLoader, optimizer: optim.Optimizer, criterion) -> TrainDataLogger:
    n_batches = len(dataloader)
    train_data: TrainDataLogger = TrainDataLogger()
    
    model.train()
    for i_epoch in range(epochs):
        loss_epoch = 0
        for i_batch, (X,_) in enumerate(dataloader):

            # adjust shahpe of tensor, so that it fits into first layer
            # X = X.reshape(-1, 28*28) 

            # pred (encode & decode)
            recon = model(X)

            # loss (how much difference between each pixel)
            loss = criterion(recon, X)
            if math.isnan(loss): 
                print(f'NAN @ epoch={i_epoch}, batch={i_batch}')
                return
            loss_epoch += loss

            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # print
            if i_batch % 250 == 0:
                print(f'Epoch\t{i_epoch+1}/{epochs}\t\tBatch\t{i_batch}/{n_batches}\t({(100. * i_batch / n_batches):.1f}%)\t\tLoss\t{loss:.4f}')
        
        # Save training data per epoch
        train_data.ls_losses.append(loss_epoch / n_batches)
        train_data.ls_imgs.append(X)
        train_data.ls_recimgs.append(recon)
        train_data.ls_epochs.append(i_epoch)


    return train_data


In [16]:
def test(model:nn.Module, dataloader:DataLoader, criterion) -> float:
    model.eval()

    n_batches = len(dataloader)
    loss = 0

    with torch.no_grad():
        for _, (X,_) in enumerate(dataloader):
           # X = X.reshape(-1,  28*28)

            # pred 
            recon = model(X)

            # loss
            loss += criterion(recon, X)

        loss = loss / n_batches

    return loss

### Training

In [17]:
model = Autoencoder_Conv() # Autoencoder_Linear()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [18]:
train_data:TrainDataLogger = train(3, model, train_dataloader, optimizer, criterion)

Epoch	1/3		Batch	0/938	(0.0%)		Loss	1.8554
NAN @ epoch=0, batch=122


In [19]:
plot_losses(torch.Tensor(train_data.ls_losses))

AttributeError: 'NoneType' object has no attribute 'ls_losses'

In [None]:
test(model, test_dataloader,  criterion).item()

### Show generated images

In [None]:
num_epochs = len(train_data.ls_epochs)

num_imgs_per_row = 9 
num_rows = 2


for i_epoch in range(0, num_epochs):
    plt.figure(figsize=(9,2))
    plt.gray()
    imgs    =  train_data.ls_imgs[i_epoch].detach().numpy()
    recimgs  = train_data.ls_recimgs[i_epoch].detach().numpy()

    for i, imgs_batch in enumerate(imgs):
        if i >= num_imgs_per_row : break

        plt.subplot(num_rows,num_imgs_per_row, i+1)
        imgs_batch = imgs_batch.reshape(-1, 28, 28)
        plt.imshow(imgs_batch[0])

    for i, imgs_batch in enumerate(recimgs):
        if i >= num_imgs_per_row : break

        plt.subplot(num_rows,num_imgs_per_row, i+1+num_imgs_per_row)
        imgs_batch = imgs_batch.reshape(-1, 28, 28)
        plt.imshow(imgs_batch[0])

plt.show()         