# Solutions

## Tutorial and home work 1: Classification and encoder for MNIST

### Installation

We will be using torch for the exercises. It is recommended to install a seperate conda environment for the course and install torch, torchvision, python 3.12 and matplotlib there. Check the online manuals for the installation on your computer.

We also use two other packages, torchinfo and torcheval that can be installed using pip.

The following imports all the packages used, if there are any error messages, check your installation.

### Tutorial outline 

The tutorial recaps how to do design and train a neural network using pytorch. As an exercise, the tutorial should be augmented by designing and training a different network that encodes and decodes the input image and does the classification.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchinfo import summary
from torcheval.metrics import MulticlassAccuracy

import numpy as np
import matplotlib.pyplot as plt

import wandb

### Data preparation
We will use data sets from torchvision. These data sets have to be transformed into a tensor for torch and also must be normalized if they are not already. We will use 1-channel (intensity) images from the (probably too well known :-) ) MNIST data set for this tutorial.

In [None]:
transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()])

Generate the train and test data sets. Download the data if necessary.

In [None]:
data_train = torchvision.datasets.MNIST(root='data/mnist', download=True, transform=transform)

In [None]:
data_test = torchvision.datasets.MNIST(root='data/mnist', train=False, download=True, transform=transform)

In [None]:
print (f'train: {len(data_train)}')

### Train and validation data sets

The train data set should be further split into train and validation. We can use random split here as the data set is balanced.

In [None]:
len_train = (int)(0.8 * len(data_train))
len_val = len(data_train) - len_train

print(len_train, len_val, len(data_train))

data_train_subset, data_val_subset = torch.utils.data.random_split(
        data_train, [len_train, len_val])

Construct data loaders for the 3 data sets.
                  

In [None]:
BATCH_SIZE = 64

data_train_loader = torch.utils.data.DataLoader(dataset=data_train_subset, shuffle=True, batch_size=BATCH_SIZE)
data_val_loader = torch.utils.data.DataLoader(dataset=data_val_subset, shuffle=False, batch_size=BATCH_SIZE)
data_test_loader = torch.utils.data.DataLoader(data_test, batch_size=64)

### Verify the images in the data set
Verify the images and also check that the range is correct

In [None]:
train_iter = iter(data_train_loader)
images, labels = next(train_iter)

image = images[0].numpy().squeeze()
print(f'max: {np.max(image)}, min: {np.min(image)}')
plt.imshow(image, cmap='gray')
plt.show()

In [None]:
plt.imshow(np.transpose(torchvision.utils.make_grid(images), (1, 2, 0)))

## Define a CNN for classification
Define a convolutional neural network for classification. 

We will use a simple network for this example. This should be refined in the exercise.


In [None]:
class MyCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, out_channels=4, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(4, 8, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(8, 8, 3),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(72, 120),
            nn.ReLU(),
            nn.Linear(120, 10)
        )
        

    def forward(self, x):
        return self.layers.forward(x)


my_cnn = MyCNN()
print(my_cnn)

### Displaying the network
The print function does not display the resulting sizes and number of parameters. The summary function from torchinfo provides similar output as the function in keras 

In [None]:
summary(my_cnn, input_size=(64, 1, 28, 28))

### Loss function and optimizer
We next define the loss function and the optimizer to use. We will use a simple optimizer for the moment

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(my_cnn.parameters(), lr=0.001, momentum=0.9)

### Device

In pytorch, we must specify which device to use and move the input and the model to the device. Lets first have a simple function to get the device

In [None]:
def get_device():
    if torch.cuda.is_available():
        device = torch.device('cuda')
        # test if it worked
        x = torch.ones(1, device=device)
        print('Using CUDA device')

    elif torch.backends.mps.is_available():
        device = torch.device('mps')
        x = torch.ones(1, device=device)
        print('Using MPS device')
    else:
        print('Using CPU')
        device = torch.device('cpu')
    return device

In [None]:
device = get_device()

### Simple train function
Lets define a simple train function for one epoch.

In [None]:
def train_one_epoch(epoch_index, model, loss_function, optimizer, device):
    model.to(device)
    model.train(True)
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(data_train_loader):
        # Every data instance is an input + label pair
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_function(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 100 == 99:
            last_loss = running_loss / 1000 # loss per batch
            print(f'  batch {i+1} loss: {last_loss:.5f}'.format(i + 1, last_loss))
            running_loss = 0.

    return last_loss

In [None]:
EPOCHS = 2
for epoch in range(EPOCHS):
    print(f'EPOCH {epoch + 1}')

    # Make sure gradient tracking is on, and do a pass over the data
    my_cnn.train(True)
    avg_loss = train_one_epoch(epoch, my_cnn, criterion, optimizer, device)
    print(f'EPOCH {epoch + 1} Loss: {avg_loss:.5f}')


## Training and evaluation

The code above works, however there are a couple of things missing:
- first it would be more interesting to calculate some metrics in addition to the loss function
- then, we should evaluate the loss and the metrics on the evaluation set
- and thirdly, we would like to monitor the loss and metrics using graphics representation specially for longer training time

We will use weights and biases (wandb) for the display. Another option is to use tensorboard.

In [None]:
wandb.login()

### Training and evalution loop

We define our a training loop that receives the model, loss function, optimizer, metrics and the device as parameters.

There are different possibilities to do this, with wandb I find it easier to do training and evaluation in the same loop.

In [None]:
def train(epochs: int, model, loss_function, optimizer, metrics, device):
    # define the project and store some setting for the projects to compare results later.
    run = wandb.init(project="mnist-example", config={'epochs': epochs, 
                                                       'batch_size': data_train_loader.batch_size}
                                                       )
    input_count = 0
    step_count = 0
    model = model.to(device)
    
    for epoch in range(epochs):
        model.train()
        metrics.reset()
        for step, (inputs, labels) in enumerate(data_train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Zero your gradients for every batch!
            optimizer.zero_grad()
            # calculate results
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            train_loss = loss_function(outputs, labels)
            train_loss.backward()
            optimizer.step()

            metrics.update(predicted, labels)
            train_acc = metrics.compute()

            # wandb will store the matrics with the step across x, so we also store the epoch
            train_metrics = {'train/train_loss:': train_loss,
                       'train/train_acc': train_acc,
                       'train/epoch': epoch}

            step_count += 1

            wandb.log(train_metrics, step=step_count)

        model.eval()
        metrics.reset()
        val_loss = []
        val_steps = 0
        for step, (inputs, labels) in enumerate(data_val_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)

                val_loss.append(loss_function(outputs, labels).item())
                metrics.update(predicted, labels)

        val_acc = metrics.compute()
        val_loss_mean = np.mean(val_loss)
        val_metrics = {'val/val_loss': val_loss_mean,
                       'val/val_acc' : val_acc}
        # log both metrics
        wandb.log(val_metrics, step=step_count)

        print(f"Epoch {epoch:02} Train Loss: {train_loss:.3f}, Valid Loss: {val_loss_mean:.3f}, Train Accuracy: {train_acc:.2f} Valid Acc: {val_acc:.2f}")
    wandb.finish()
                       
            

In [None]:
my_cnn = MyCNN()
my_metrics = MulticlassAccuracy(num_classes=10)
my_optimizer = optim.Adam(my_cnn.parameters(), lr=0.001)
my_loss = nn.CrossEntropyLoss()

train(10, my_cnn, criterion, my_optimizer, my_metrics, device)

## Exercise 1.1: Autoencoder

In exercise 1.1, you should implement an autoencoder. In an autoencoder, we would like to compress the input image into a small representation, this is the encoder part and then decode this representation again into an image.

The resulting image should be similar to the input image.

The size of the encoded representation, the so called latent variables should be a parameter to the class.

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, latent_dim: int):
        super(Autoencoder,self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            # 28 x 28
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),
            # 14 x 14
            nn.Conv2d(32, 16, 3, 2, padding=1),
            nn.ReLU(True),
            # 7 x 7
            nn.Flatten(),
            nn.Linear(7 * 7 * 16, latent_dim),
            )
        self.decoder = nn.Sequential(
            # 2
            nn.Linear(latent_dim, 7 * 7 * 16),
            nn.ReLU(),
            nn.Unflatten(1, (16, 7, 7)),
            nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            # 28 x 28 
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid(),
            )
    def forward(self, x):
        enc = self.encoder(x)
        dec = self.decoder(enc)
        return dec
    

Next steps:

- Adapt the training loop above for the autoencoder
- Train the autoencoder, adjusting hyper parameters if necessary
- Visualize the results, i.e. the decoded image
- Try different number of latent variables, for example 2 and 10
- Visualize the latent variables (for 2)

In [None]:
def train_ae(epochs: int, model, loss_function, optimizer, device):
    # define the project and store some setting for the projects to compare results later.
    run = wandb.init(project="mnist-ae", config={'epochs': epochs, 
                                                 'batch_size': data_train_loader.batch_size}
                    )
    input_count = 0
    step_count = 0
    model = model.to(device)
    
    for epoch in range(epochs):
        model.train()
        for step, (inputs, labels) in enumerate(data_train_loader):
            inputs = inputs.to(device)
            # we dont need the labels here
            
            # Zero your gradients for every batch!
            optimizer.zero_grad()
            # calculate results
            outputs = model(inputs)

            # the loss function is here between the inputs and the outputs
            train_loss = loss_function(inputs, outputs)
            train_loss.backward()
            optimizer.step()

            # wandb will store the metrics with the step across x, so we also store the epoch
            train_metrics = {'train/train_loss:': train_loss,
                       'train/epoch': epoch}

            step_count += 1

            wandb.log(train_metrics, step=step_count)

        model.eval()
        val_loss = []
        for step, (inputs, labels) in enumerate(data_val_loader):
            inputs = inputs.to(device)

            with torch.no_grad():
                outputs = model(inputs)
                val_loss.append(loss_function(outputs, inputs).item())

        val_loss_mean = np.mean(val_loss)
        val_metrics = {'val/val_loss': val_loss_mean}

        wandb.log(val_metrics, step=step_count)

        print(f"Epoch {epoch:02} Train Loss: {train_loss:.3f}, Valid Loss: {val_loss_mean:.3f}")
    wandb.finish()
        

In [None]:
auto_encoder = Autoencoder(latent_dim=10)
optimizer = torch.optim.RMSprop(auto_encoder.parameters(), lr=0.001)
loss_mse = nn.MSELoss()
summary(auto_encoder, input_size=(64, 1, 28, 28))

In [None]:
auto_encoder.to(device)
train_ae(10, auto_encoder, loss_mse, optimizer, device)

In [None]:
train_iter = iter(data_train_loader)
images, labels = next(train_iter)

In [None]:
plt.figure(figsize = (20,20))
plt.subplot(2, 1, 1)
plt.imshow(np.transpose(torchvision.utils.make_grid(images), (1, 2, 0)))
plt.subplot(2, 1, 2)
plt.imshow(np.transpose(torchvision.utils.make_grid(auto_encoder.forward(images.to(device)).cpu()), (1, 2, 0)))

In [None]:
auto_encoder = Autoencoder(latent_dim=2)
optimizer = torch.optim.RMSprop(auto_encoder.parameters(), lr=0.001)
loss_mse = nn.MSELoss()
summary(auto_encoder, input_size=(64, 1, 28, 28))
auto_encoder.to(device)
train_ae(10, auto_encoder, loss_mse, optimizer, device)

In [None]:
train_iter = iter(data_train_loader)
images, labels = next(train_iter)
plt.figure(figsize = (20,20))
plt.subplot(2, 1, 1)
plt.imshow(np.transpose(torchvision.utils.make_grid(images), (1, 2, 0)))
plt.subplot(2, 1, 2)
plt.imshow(np.transpose(torchvision.utils.make_grid(auto_encoder.forward(images.to(device)).cpu()), (1, 2, 0)))

In [None]:
# Plot latent space
def plot_latent_space(ae, n, figsize, device):
    # display a n*n 2D manifold of digits
    digit_size = 28
    scale = 2.5
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-scale, scale, n)
    grid_y = np.linspace(-scale, scale, n)[::-1]

    with torch.no_grad():
        for i, yi in enumerate(grid_y):
            for j, xi in enumerate(grid_x):
                z_sample = torch.tensor([[xi, yi]], dtype=torch.float32, device=device)
                x_decoded = ae.decoder(z_sample)
                digit = x_decoded[0].reshape(digit_size, digit_size).cpu()
                figure[
                    i * digit_size : (i + 1) * digit_size,
                    j * digit_size : (j + 1) * digit_size,
                ] = digit

    plt.figure(figsize=(figsize, figsize))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap="Greys_r")
    plt.show()


plot_latent_space(auto_encoder, n=20, figsize=10, device=device)

## Exercise 1.2: Two headed network with autoencoder and classificator

Next we want to train a model that can do both the autoencoder and the classification. One idea of this approach is to force the model to learn a representation that works well for reconstruction and classification. The classification head should branch off after the encoding part.

In [None]:
class EncoderAndClassificator(nn.Module):
    def __init__(self, latent_dim: int):
        super(EncoderAndClassificator,self).__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            # 28 x 28
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(True),
            # 14 x 14
            nn.Conv2d(32, 16, 3, 2, padding=1),
            nn.ReLU(True),
            # 7 x 7
            nn.Flatten(),
            nn.Linear(7 * 7 * 16, latent_dim),
            )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 7 * 7 * 16),
            nn.ReLU(),
            nn.Unflatten(1, (16, 7, 7)),
            nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ConvTranspose2d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            # 28 x 28 
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid(),
            )
        self.classifier = nn.Sequential (
            nn.Linear(latent_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 10),
            nn.Softmax(dim=1),
        )
    def forward(self, x):
        enc = self.encoder(x)
        dec = self.decoder(enc)
        prob = self.classifier(enc)
        return dec, prob

Next steps:
- Adapt the training loop, how is the loss calculated?
- Train the model and compare it to the two seperate models above

In [None]:
def train_enc_class(epochs: int, model, cls_loss, rec_loss, loss_coeff, optimizer, device):
    run = wandb.init(project="mnist-encode-cls", config={'epochs': epochs, 
                                                       'batch_size': data_train_loader.batch_size,
                                                        'latent_dim': model.latent_dim}
                                                       )
    input_count = 0
    step_count = 0

    metrics = MulticlassAccuracy(num_classes=10)
    
    for epoch in range(epochs):
        #
        # Training Loop
        #
        model.train()
        metrics.reset()
        for step, (inputs, labels) in enumerate(data_train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            # Zero your gradients for every batch!
            optimizer.zero_grad()
            
            # calculate results
            reconstruction, class_prob = model(inputs)
            _, predicted = torch.max(class_prob, 1)

            # calculate metrics
            metrics.update(predicted, labels)
            train_acc = metrics.compute()

            # calculate and add losses
            loss_1 = loss_coeff * rec_loss(inputs, reconstruction)
            loss_2 = (1.0-loss_coeff) * cls_loss(class_prob, labels)
            train_total_loss = loss_1 + loss_2
            
            train_total_loss.backward()
            optimizer.step()

            train_metrics = \
                {'train/train_total_loss:': train_total_loss,
                 'train/train_rec_loss:': loss_1,
                 'train/train_cls_loss:': loss_2,
                 'train/train_acc': train_acc,
                 'train/epoch': epoch}

            wandb.log(train_metrics)
        #
        # Evaluation Loop
        #
        model.eval()
        metrics.reset()
        val_loss = []
        val_steps = 0
        for step, (inputs, labels) in enumerate(data_val_loader):
            with torch.no_grad():
                inputs = inputs.to(device)
                labels = labels.to(device)

                 # calculate results
                reconstruction, class_prob = model(inputs)
                _, predicted = torch.max(class_prob, 1)

                # calculate metrics
                metrics.update(predicted, labels)

                loss_1 = loss_coeff * rec_loss(inputs, reconstruction)
                loss_2 = (1.0-loss_coeff) * cls_loss(class_prob, labels)
                val_total_loss = loss_1 + loss_2
                val_loss.append(val_total_loss.cpu().numpy())
            
                val_acc = metrics.compute()

        val_loss_mean = np.mean(val_loss)
        val_metrics = {'val/val_total_loss': val_loss_mean,
                       'val/val_acc' : val_acc}
        # log both metrics
        wandb.log({**train_metrics, **val_metrics})

        print(f"Epoch {epoch:02} Train Loss: {train_total_loss:.3f}, Train Acc: {train_acc:.3f}, Val Loss: {val_loss_mean:.3f}, Val Acc: {val_acc:.3f}")
    wandb.finish()
                
            

In [None]:
model = EncoderAndClassificator(10)
classifier_loss = nn.CrossEntropyLoss()
reconstruction_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_enc_class(20, model.to(device), classifier_loss, reconstruction_loss, 0.95, optimizer, device)