#NEURAL NETWORKS AND DEEP LEARNING

## Homework 2 - Unsupervised Deep Learning

Puppin Michele - 1227474

In [None]:
# Import packages
import matplotlib.pyplot as plt 
import numpy as np 
import pandas as pd 
import random 
import os 
from tqdm import tqdm
from google.colab import files

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F

import itertools

# Set random seed
np.random.seed(25)
torch.manual_seed(25)

In [None]:
# Set device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Training device: {device}")

## Dataset

In [None]:
# Download the data and create dataset
train_dataset = torchvision.datasets.MNIST('dataset', train=True, download=True)
test_dataset  = torchvision.datasets.MNIST('dataset', train=False, download=True)

In [None]:
# Set the train transform
train_dataset.transform = transforms.Compose([ transforms.ToTensor(), ])
# Set the test transform
test_dataset.transform = transforms.Compose([ transforms.ToTensor(), ])

# Define train dataloader
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=True)
# Define test dataloader
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

## Early stopping

In [None]:
# Early stopping class definition
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

# Autoencoders Definition

In [None]:
# Encoder definition

class Encoder(nn.Module):
    
    def __init__(self, encoded_space_dim):
        super().__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            # First convolutional layer
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),
            # Second convolutional layer
            nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),
            # Third convolutional layer
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=0),
            nn.ReLU(True)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)

        ### Linear section
        self.encoder_lin = nn.Sequential(
            # First linear layer
            nn.Linear(288, 64),
            nn.ReLU(),
            # Second linear layer
            nn.Linear(64, encoded_space_dim)
        )
        
    def forward(self, x):
        # Apply convolutions
        x = self.encoder_cnn(x)
        # Flatten
        x = self.flatten(x) 
        # Apply linear layers
        x = self.encoder_lin(x) 
        return x

In [None]:
# Decoder definition

class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim):
        super().__init__()

        ### Linear section
        self.decoder_lin = nn.Sequential(
            # First linear layer
            nn.Linear(encoded_space_dim, 64),
            nn.ReLU(True),
            # Second linear layer
            nn.Linear(64, 3 * 3 * 32),
            nn.ReLU(True)
        )

        ### Unflatten
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(32, 3, 3))

        ### Convolutional section
        self.decoder_conv = nn.Sequential(
            # First transposed convolution
            nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2, output_padding=0),
            nn.ReLU(True),
            # Second transposed convolution
            nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(True),
            # Third transposed convolution
            nn.ConvTranspose2d(in_channels=8, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),
        )
        
    def forward(self, x):
        # Apply linear layers
        x = self.decoder_lin(x)
        # Unflatten
        x = self.unflatten(x)
        # Apply transposed convolutions
        x = self.decoder_conv(x)
        # Apply a sigmoid to force the output to be between 0 and 1 (valid pixel values)
        x = torch.sigmoid(x)
        return x

### Training and Testing

In [None]:
### Training function
def train_epoch(encoder, decoder, device, dataloader, loss_func, optimizer):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch, _ in dataloader:
        # Move tensor to the proper device
        image_batch = image_batch.to(device)
        # Encode data
        encoded_data = encoder(image_batch)
        # Decode data
        decoded_data = decoder(encoded_data)
        # Evaluate loss
        loss = loss_func(decoded_data, image_batch)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ## Print batch loss
        #print('\t partial train loss (single batch): %f' % (loss.data))
    return loss.data   

In [None]:
### Testing function
def test_epoch(encoder, decoder, device, dataloader, loss_func):
    # Set evaluation mode for encoder and decoder
    encoder.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_label = []
        for image_batch, _ in dataloader:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # Encode data
            encoded_data = encoder(image_batch)
            # Decode data
            decoded_data = decoder(encoded_data)
            # Append the network output and the original image to the lists
            conc_out.append(decoded_data.cpu())
            conc_label.append(image_batch.cpu())
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label) 
        # Evaluate global loss
        val_loss = loss_func(conc_out, conc_label)
    return val_loss.data

In [None]:
# Fit function
def fit(num_epochs, encoder, decoder, device, train_dataloader, loss_func, optim, test_dataloader, test_dataset):
    
    train_loss_log = []
    val_loss_log = [] 

    # Early stopping
    early_stopping = EarlyStopping(patience=5, verbose=False)

    for epoch in range(num_epochs):
        print('EPOCH %d/%d' % (epoch + 1, num_epochs))

        ### Training (use the training function)
        train_loss = train_epoch(
            encoder=encoder, 
            decoder=decoder, 
            device=device, 
            dataloader=train_dataloader, 
            loss_func=loss_func, 
            optimizer=optim)
        train_loss_log.append(train_loss)
        ### Validation (use the testing function)

        val_loss = test_epoch(
            encoder=encoder, 
            decoder=decoder, 
            device=device, 
            dataloader=test_dataloader, 
            loss_func=loss_func)
        val_loss_log.append(val_loss)

        # Print Validationloss
        print('\n\n\t VALIDATION - EPOCH %d/%d - loss: %f\n\n' % (epoch + 1, num_epochs, val_loss))

        # and if it has, it will make a checkpoint of the current model
        early_stopping(val_loss, decoder)
        if early_stopping.early_stop:
            print("Early stopping")
            break

        # Save network parameters
        torch.save(encoder.state_dict(), 'encoder_params.pth')
        torch.save(decoder.state_dict(), 'decoder_params.pth')
        torch.save(optim.state_dict(), 'optim_params.pth')

    return train_loss_log, val_loss_log 

### Grid Search for model selection

In [None]:
dict_params = {
            'LearningRate'    : [0.1, 0.01, 0.001, 0.0001],
            'Regularization'  : [1e-3, 1e-4, 1e-5, 1e-6],
            'Epochs'          : [1000],
            'EncodedSpaceDim' : [4, 8, 16]
         }

In [None]:
comb_params = list(itertools.product(*dict_params.values()))

In [None]:
par_log = []
train_loss_log = []
val_loss_log = []

for iter, params in enumerate(comb_params):
    print('Iteration:', iter)

    par_log.append(params)

    loss_func = nn.MSELoss()
    lr, lamb, num_epochs,  encoded_space_dim = params

    encoder = Encoder(encoded_space_dim = encoded_space_dim) 
    decoder = Decoder(encoded_space_dim = encoded_space_dim)

    autoenc_params = [{'params': encoder.parameters()}, {'params': decoder.parameters()}]
    optim = torch.optim.Adam(autoenc_params, lr=lr, weight_decay=lamb) 

    encoder.to(device)
    decoder.to(device)

    # Training
    train_loss, val_loss = fit(num_epochs = num_epochs,
                                       encoder = encoder,
                                       decoder = decoder,
                                       device = device,
                                       train_dataloader = train_dataloader,
                                       loss_func = loss_func,
                                       optim = optim,
                                       test_dataloader = test_dataloader,
                                       test_dataset = test_dataset)
    
    train_loss_log.append(train_loss[-1])
    val_loss_log.append(val_loss[-1])

In [None]:
# Select best parameters
best_params = par_log[np.argmin([v[-1] for v in val_loss_log])]
best_params

### Train with best parameters

In [None]:
best_params = (0.001, 1e-5, 1000, 16)

In [None]:
loss_func = nn.MSELoss()
lr, lamb, num_epochs,  encoded_space_dim = best_params

encoder = Encoder(encoded_space_dim = encoded_space_dim) 
decoder = Decoder(encoded_space_dim = encoded_space_dim)

autoenc_params = [{'params': encoder.parameters()}, {'params': decoder.parameters()}]
optim = torch.optim.Adam(autoenc_params, lr=lr, weight_decay=lamb) 

encoder.to(device)
decoder.to(device)

# Training
train_loss, val_loss = fit(num_epochs = num_epochs,
                                       encoder = encoder,
                                       decoder = decoder,
                                       device = device,
                                       train_dataloader = train_dataloader,
                                       loss_func = loss_func,
                                       optim = optim,
                                       test_dataloader = test_dataloader,
                                       test_dataset = test_dataset)

In [None]:
print('Training loss:', train_loss[-1].detach().cpu().numpy())
print('Validation loss:', val_loss[-1].detach().cpu().numpy())

# Save network parameters
torch.save(encoder.state_dict(), 'encoder_params.pth')
torch.save(decoder.state_dict(), 'decoder_params.pth')
torch.save(optim.state_dict(), 'optim_params.pth')

In [None]:
# Plot Training and Validation loss
plt.plot(train_loss, label='Training')
plt.plot(val_loss, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.savefig('TrainValLoss_AE.pdf', bbox_inches='tight')
files.download('TrainValLoss_AE.pdf')
plt.show()

## Test the trained model

In [None]:
test_loss = test_epoch(encoder = encoder,
                       decoder = decoder,
                       device = device,
                       dataloader = test_dataloader,
                       loss_func = loss_func)

print('Test loss', test_loss)

In [None]:
# Example of reconstructed images
img1 = test_dataset[0][0].unsqueeze(0).to(device)
img2 = test_dataset[1][0].unsqueeze(0).to(device)
img3 = test_dataset[2][0].unsqueeze(0).to(device)
encoder.eval()
decoder.eval()
with torch.no_grad():
    rec_img1 = decoder(encoder(img1))
    rec_img2 = decoder(encoder(img2))
    rec_img3 = decoder(encoder(img3))
# Plot the reconstructed image
fig, axs = plt.subplots(2, 2, figsize=(10,11))
# Image 1
axs[0,0].imshow(img1.cpu().squeeze().numpy(), cmap='gist_gray')
axs[0,0].set_title('Original image')
axs[0,1].imshow(rec_img1.cpu().squeeze().numpy(), cmap='gist_gray')
axs[0,1].set_title('Reconstructed image')
# Image 2
axs[1,0].imshow(img2.cpu().squeeze().numpy(), cmap='gist_gray')
axs[1,0].set_title('Original image')
axs[1,1].imshow(rec_img2.cpu().squeeze().numpy(), cmap='gist_gray')
axs[1,1].set_title('Reconstructed image')

# Save figures
fig.savefig('Reconstruct_AE.pdf', bbox_inches='tight')
files.download('Reconstruct_AE.pdf')
plt.show()

## Denoiser

In [None]:
# Introduce noise on imges in Training and Testing functions

# Training function
def train_epoch(encoder, decoder, device, dataloader, loss_func, optimizer):
    # Set train mode for both the encoder and the decoder
    encoder.train()
    decoder.train()
    # Iterate the dataloader (we do not need the label values, this is unsupervised learning)
    for image_batch, _ in dataloader: # with "_" we just ignore the labels (the second element of the dataloader tuple)
        # Move tensor to the proper device
        image_batch = image_batch.to(device)
        # Add noise 
        mean = torch.randn(1).to(device)
        std = torch.randn(1).to(device) * 0.5 + 0.5
        noisy_image = image_batch + torch.randn(image_batch.size()).to(device) * std + mean
        # Encode data
        encoded_data = encoder(image_batch)
        # Decode data
        decoded_data = decoder(encoded_data)
        # Evaluate loss
        loss = loss_func(decoded_data, image_batch)
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ## Print batch loss
        #print('\t partial train loss (single batch): %f' % (loss.data))
    return loss.data  

# Testing function
def test_epoch(encoder, decoder, device, dataloader, loss_func):
    # Set evaluation mode for encoder and decoder
    encoder.eval()
    decoder.eval()
    with torch.no_grad(): # No need to track the gradients
        # Define the lists to store the outputs for each batch
        conc_out = []
        conc_label = []
        for image_batch, _ in dataloader:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # Add noise
            mean = torch.randn(1).to(device)
            std = torch.randn(1).to(device) * 0.5 + 0.5
            noisy_image = image_batch + torch.randn(image_batch.size()).to(device) * std + mean 
            # Encode data
            encoded_data = encoder(image_batch)
            # Decode data
            decoded_data = decoder(encoded_data)
            # Append the network output and the original image to the lists
            conc_out.append(decoded_data.cpu())
            conc_label.append(image_batch.cpu())
        # Create a single tensor with all the values in the lists
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label) 
        # Evaluate global loss
        val_loss = loss_func(conc_out, conc_label)
    return val_loss.data

### Train with best parameters

In [None]:
best_params = (0.001, 1e-5, 150, 16)

In [None]:
loss_func = nn.MSELoss()
lr, lamb, num_epochs,  encoded_space_dim = best_params

encoder = Encoder(encoded_space_dim = encoded_space_dim) 
decoder = Decoder(encoded_space_dim = encoded_space_dim)

autoenc_params = [{'params': encoder.parameters()}, {'params': decoder.parameters()}]
optim = torch.optim.Adam(autoenc_params, lr=lr, weight_decay=lamb) 

encoder.to(device)
decoder.to(device)

# Training
train_loss, val_loss = fit(num_epochs = num_epochs,
                                       encoder = encoder,
                                       decoder = decoder,
                                       device = device,
                                       train_dataloader = train_dataloader,
                                       loss_func = loss_func,
                                       optim = optim,
                                       test_dataloader = test_dataloader,
                                       test_dataset = test_dataset)

In [None]:
print('Training loss:', train_loss[-1].detach().cpu().numpy())
print('Validation loss:', val_loss[-1].detach().cpu().numpy())

# Save network parameters
torch.save(encoder.state_dict(), 'encoder_params.pth')
torch.save(decoder.state_dict(), 'decoder_params.pth')
torch.save(optim.state_dict(), 'optim_params.pth')

In [None]:
# Plot Training and Validation loss
plt.plot(train_loss, label='Training')
plt.plot(val_loss, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.savefig('TrainValLoss_DAE.pdf', bbox_inches='tight')
files.download('TrainValLoss_DAE.pdf')
plt.show()

## Test the trained model

In [None]:
test_loss = test_epoch(encoder = encoder,
                       decoder = decoder,
                       device = device,
                       dataloader = test_dataloader,
                       loss_func = loss_func)

print('Test loss', test_loss)

Test loss tensor(0.0094)


In [None]:
fig, axs = plt.subplots(3, 3, figsize=(12,12))
for ax in axs:
    # Original
    img, label = random.choice(test_dataset)
    imgc = img[0]
    ax[0].imshow(np.array(imgc), cmap='gist_gray')
    ax[0].set_xticks([])
    ax[0].set_yticks([])
    ax[0].set_title('Original')
    # Noisy
    imgc += np.random.normal(0, 1, size=imgc.shape)
    ax[1].imshow(np.array(imgc), cmap='gist_gray')
    ax[1].set_xticks([])
    ax[1].set_yticks([])
    ax[1].set_title('Noisy')
    # Denoised
    encoder.eval()
    decoder.eval()
    img = img.unsqueeze(0).to(device)
    with torch.no_grad():
        enc_img = encoder(img)
        dec_img = decoder(enc_img)
    ax[2].imshow(np.array(dec_img.detach().cpu()[0][0]), cmap='gist_gray')
    ax[2].set_xticks([])
    ax[2].set_yticks([])
    ax[2].set_title('Decoded')

plt.savefig('Reconstruct_DAE.pdf', bbox_inches='tight')
files.download('Reconstruct_DAE.pdf')

## Fine Tuning with Supervised Classification

In [None]:
# Initialize the decoder and load trained model 
encoder = Encoder(encoded_space_dim = 16)
encoder.load_state_dict(torch.load('encoder_params.pth'))

In [None]:
class Classification(nn.Module):

    def __init__(self, encoded_space_dim, pretrained_encoder):
        super().__init__()
        # Encoder
        self.encoder = pretrained_encoder
        # Linear readout for classification
        self.lin_readout = nn.Sequential(nn.Linear(encoded_space_dim, 64),
                                             nn.ReLU(True),
                                             nn.Linear(64, 10),
                                             nn.LogSoftmax(dim=-1))
    def forward(self, x):
        x = self.encoder(x)
        x = self.lin_readout(x)
        return x

    def train_nn(self, train_loader, optimizer, loss_func, device):
        train_loss= []
        self.train()
        for sample_batched in train_loader:
            x_batch = sample_batched[0].to(device)
            label_batch = sample_batched[1].to(device)
            out = self.forward(x_batch)
            loss = loss_func(out, label_batch)
            self.zero_grad()
            loss.backward()
            optimizer.step()
            loss_batch = loss.detach().cpu().numpy()
            train_loss.append(loss_batch)
        return train_loss
    
    def validation_nn(self, val_loader, loss_func, device):
        val_loss = []
        self.eval() 
        with torch.no_grad():
            for sample_batched in val_loader:
                x_batch = sample_batched[0].to(device)
                label_batch = sample_batched[1].to(device)
                out = self.forward(x_batch)
                loss = loss_func(out, label_batch)
                loss_batch = loss.detach().cpu().numpy()
                val_loss.append(loss_batch)
        return val_loss
    
    def fit(self, train_loader, val_loader, optimizer, loss_func, epochs, device):
        train_loss_log = []
        val_loss_log = []

        early_stopping = EarlyStopping(patience = 5, verbose = False)

        for epoch in range(epochs):
            print(epoch)
            # Training
            train_loss = self.train_nn(train_loader, optimizer, loss_func, device)
            train_loss_log.append(np.mean(train_loss))
            # Validation
            val_loss = self.validation_nn(val_loader, loss_func, device)
            val_loss_log.append(np.mean(val_loss))

            early_stopping(np.mean(val_loss), self)
            if early_stopping.early_stop:
                print("Early stopping")
                break
            
        return train_loss_log, val_loss_log

    def predict(self, input_loader, loss_func, device):
        inputs = []
        outputs = []
        labels = []
        self.eval()
        with torch.no_grad(): 
            for sample_batched in input_loader:
                x_batch = sample_batched[0].to(device)
                label = sample_batched[1].to(device) 
                out = self.forward(x_batch)
                inputs.append(x_batch)
                outputs.append(out)
                labels.append(label) 
        inputs = torch.cat(inputs)
        outputs = torch.cat(outputs)
        labels = torch.cat(labels)
        test_loss = loss_func(outputs, labels) 
        return inputs, outputs, labels, test_loss

### Train the model

In [None]:
model_class = Classification(encoded_space_dim=16, pretrained_encoder=encoder)

In [None]:
# Fix encoder weights
for param_name, param in model_class.encoder.named_parameters():
    param.requires_grad = False

In [None]:
loss_func = nn.NLLLoss()
optim = torch.optim.Adam(model_class.parameters(), lr=0.001, weight_decay=1e-5) 

model_class.to(device) 

In [None]:
# Prepare dataset
train_set, val_set = torch.utils.data.random_split(train_dataset, [int(0.2*len(train_dataset)), int(0.8*len(train_dataset))])

train_dataloader = DataLoader(train_set, batch_size=256, shuffle=True)
val_dataloader = DataLoader(val_set, batch_size=256, shuffle=False)

In [None]:
train_loss, val_loss = model_class.fit(train_loader = train_dataloader, 
                                             val_loader = val_dataloader, 
                                             loss_func = loss_func, 
                                             optimizer = optim, 
                                             epochs = 20, 
                                             device = device)

In [None]:
# Plot Training and Validation loss
plt.plot(train_loss, label='Training')
plt.plot(val_loss, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.savefig('TrainValLoss_CLassAE.pdf', bbox_inches='tight')
files.download('TrainValLoss_CLassAE.pdf')
plt.show()

### Test the model

In [None]:
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

inputs, outputs, labels, test_loss = model_class.predict(test_dataloader, loss_func, device)

print("Test loss:", test_loss)

In [None]:
# Compute test accuracy
outputs = outputs.detach().cpu().numpy()
labels = labels.detach().cpu().numpy()

predicted_labels = [outputs[i].argmax() for i in range(len(outputs))]
diffs = np.array([predicted_labels[i]-labels[i] for i in range(len(outputs))])
wrong = np.count_nonzero(diffs) 
test_accuracy = 1 - wrong/len(outputs)
print("Test accuracy: ", test_accuracy)

### Confusion matrix for the test set

In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues,
						  save_path='models/'):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        #print("Normalized confusion matrix")
    #else:
        #print('Confusion matrix, without normalization')

    plt.figure(figsize=(15, 15))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title, fontsize=30)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, fontsize=15)
    plt.yticks(tick_marks, classes, fontsize=15)

    fmt = '.3f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), size=11,
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label', fontsize=30)
    plt.xlabel('Predicted label', fontsize=30)
    plt.savefig(save_path+"_picConfMatrix.png", dpi=400)
    plt.tight_layout()

In [None]:
# Confusion Matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(labels, predicted_labels)
categories=[0,1,2,3,4,5,6,7,8,9]
plot_confusion_matrix(cm,categories, normalize=False,save_path='./confusion.pdf')

# Explore the latent space structure

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import DBSCAN 
import plotly.express as px

In [None]:
# Load network parameters
encoder.load_state_dict(torch.load('encoder_params.pth'))
decoder.load_state_dict(torch.load('decoder_params.pth'))

<All keys matched successfully>

In [None]:
# Get the encoded representation of the test samples
encoded_samples = []
for sample in tqdm(test_dataset):
    img = sample[0].unsqueeze(0).to(device)
    label = sample[1]
    # Encode image
    encoder.eval()
    with torch.no_grad():
        encoded_img  = encoder(img)
    # Append to list
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    encoded_sample['label'] = label
    encoded_samples.append(encoded_sample)

# Convert to a dataframe
encoded_samples = pd.DataFrame(encoded_samples)

In [None]:
px.scatter(encoded_samples, x='Enc. Variable 0', y='Enc. Variable 1', color=encoded_samples.label.astype(str), opacity=0.7)

In [None]:
n_components = 2
pca = PCA(n_components=n_components)
pca.fit(encoded_samples.iloc[:, 0:encoded_space_dim])

pca_transf_samples = pd.DataFrame(pca.transform(encoded_samples.iloc[:, 0:encoded_space_dim]), 
                                  columns=[f'Enc. Variable {i}' for i in range(n_components)])

fig = px.scatter(pca_transf_samples, 
                 x='Enc. Variable 0', 
                 y='Enc. Variable 1', 
                 color=encoded_samples.label.astype(str), 
                 opacity=0.7)

fig.show()

In [None]:
n_components = 2
tsne = TSNE(n_components=n_components)
tsne.fit(encoded_samples.iloc[:, 0:encoded_space_dim])

tsne_tranf_samples = pd.DataFrame(tsne.fit_transform(encoded_samples.iloc[:, 0:encoded_space_dim]), 
                                  columns=[f'Enc. Variable {i}' for i in range(n_components)])

fig = px.scatter(tsne_tranf_samples, 
                 x='Enc. Variable 0', 
                 y='Enc. Variable 1', 
                 color=encoded_samples.label.astype(str), 
                 opacity=0.7)
fig.show()

# Generate new samples

In [None]:
# Generate a custom sample
custom_encoded_sample1 = np.random.randn(16)*50
encoded_value1 = torch.tensor(custom_encoded_sample1).float().unsqueeze(0).to(device)
custom_encoded_sample2 = np.random.randn(16)*50
encoded_value2 = torch.tensor(custom_encoded_sample2).float().unsqueeze(0).to(device)
custom_encoded_sample3 = np.random.randn(16)*50
encoded_value3 = torch.tensor(custom_encoded_sample3).float().unsqueeze(0).to(device)

# Decode sample
decoder.eval()
with torch.no_grad():
    generated_img1 = decoder(encoded_value1)
    generated_img2 = decoder(encoded_value2)
    generated_img3 = decoder(encoded_value3)

# Plot the reconstructed image
fig, axs = plt.subplots(1, 3, figsize=(12,6))
axs[0].imshow(generated_img1.cpu().squeeze().numpy(), cmap='gist_gray')
axs[1].imshow(generated_img2.cpu().squeeze().numpy(), cmap='gist_gray')
axs[2].imshow(generated_img3.cpu().squeeze().numpy(), cmap='gist_gray')
plt.tight_layout()
# Save figures
fig.savefig('GenSamp_AE.pdf', bbox_inches='tight')
files.download('GenSamp_AE.pdf')
plt.show()

# Variational Autoencoders

In [None]:
# Define Variational AutoEncoder
class VAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
        super(VAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)

        self.act1 = nn.ReLU()
        self.act2 = nn.Sigmoid()

        
    def encoder(self, x):
        h = self.act1(self.fc1(x))
        h = self.act1(self.fc2(h))
        return self.fc31(h), self.fc32(h) # mu, log_var
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu) # return z sample
        
    def decoder(self, z):
        h = self.act1(self.fc4(z))
        h = self.act1(self.fc5(h))
        return self.act2(self.fc6(h)) 
    
    def forward(self, x):
        mu, log_var = self.encoder(x.view(-1, 784))
        z = self.sampling(mu, log_var)
        return self.decoder(z), mu, log_var

def train():
    vae.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_dataloader):
        data = data[0].cuda()
        optimizer.zero_grad()
        
        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    return train_loss / len(train_dataloader.dataset)



def test():
    vae.eval()
    test_loss= 0
    with torch.no_grad():
        for data, _ in test_dataloader:
            data = data[0].cuda()
            recon, mu, log_var = vae(data)
            
            # sum up batch loss
            test_loss += loss_function(recon, data, mu, log_var).item()
        
    test_loss /= len(test_dataloader.dataset)

    return test_loss 

In [None]:
vae = VAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2)

In [None]:
vae.to(device)

In [None]:
optimizer = torch.optim.Adam(vae.parameters())

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

In [None]:
train_loss = []
test_loss = []

for epoch in range(1, 51):
    print('Epoch: ', epoch)
    train_loss.append(train())
    test_loss.append(test())

In [None]:
# Plot Training and Validation loss
plt.plot(train_loss, label='Training')
plt.plot(test_loss, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.savefig('TrainValLoss_VAE.pdf', bbox_inches='tight')
files.download('TrainValLoss_VAE.pdf')
plt.show()

In [None]:
# Generate a custom sample
encoded_value1 = torch.randn(1, 9).to(device)
encoded_value2 = torch.randn(1, 9).to(device)
encoded_value3 = torch.randn(1, 9).to(device)


# Decode sample
decoder.eval()
with torch.no_grad():
    generated_img1 = vae.decoder(encoded_value1).to(device)
    generated_img2 = vae.decoder(encoded_value2).to(device)
    generated_img3 = vae.decoder(encoded_value3).to(device)

generated_img1 = generated_img1.view(1, 1, 28, 28)[0].cpu().squeeze().numpy()
generated_img2 = generated_img2.view(1, 1, 28, 28)[0].cpu().squeeze().numpy()
generated_img3 = generated_img3.view(1, 1, 28, 28)[0].cpu().squeeze().numpy()

# Plot the reconstructed image
fig, axs = plt.subplots(1, 3, figsize=(12,6))
axs[0].imshow(generated_img1, cmap='gist_gray')
axs[1].imshow(generated_img2, cmap='gist_gray')
axs[2].imshow(generated_img3, cmap='gist_gray')
plt.tight_layout()
# Save figures

plt.savefig('GenSamp_VAE.pdf', bbox_inches='tight')
files.download('GenSamp_VAE.pdf')
plt.show()