Code for Martin Björklund's 2024 master thesis in mathematical statistics.

# Decoder Architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

'''
Each DecoderBlock decodes one block of ResNet 18.

For decoding the first checkpoint, one DecoderBlock is needed,
for decoding the second one two are needed and so on. Each block
contains 4 transposed convolutions and 2 residual connections.

Note that the dimensionality changes only in the final transposed convolution.
'''
#Decodes each block (two "BasicBlocks")
class DecoderBlock(nn.Module):
  #Constructor
  def __init__(self, in_channels, out_channels, final_transpose_stride = 1,
               transpose_padding = 1, output_padding = 0, device = None):
    super(DecoderBlock, self).__init__()

    #First "BasicBlock" decoder

    #First deconvolution
    #NOTE: "padding argument [in transpose convolution] effectively adds dilation * (kernel_size - 1) - padding amount of
    #      zero padding to both sizes of the input." (documentation)
    self.deconv1 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size = 3,
                                      stride = 1, padding = transpose_padding,
                                      bias=False)
    self.bn1 = nn.BatchNorm2d(in_channels)
    #Second deconvolution
    self.deconv2 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=3, #Changing the number of channels
                                      stride=1, padding=transpose_padding,
                                      bias=False)
    self.bn2 = nn.BatchNorm2d(in_channels)
    #Residual connection
    #A 1x1 convolution would be needed here if the stride is not 1
    self.shortcut1 = nn.Sequential()

    #Second "BasicBlock"
    self.deconv3 = nn.ConvTranspose2d(in_channels, in_channels, kernel_size = 3,
                                      stride = 1, padding = transpose_padding,
                                      bias=False)
    self.bn3 = nn.BatchNorm2d(in_channels)

    #NOTE: THIS IS WHERE WE CHANGE THE DIMENSIONS AND CHANNELS
    #NOTE: For decoding the second (and later blocks) block, this Transposed convolution is not exactly the
    #      one corresponding to the convolution used in Resnet18, as that type of transposed convolution
    #      does not seem to be implemented in PyTorch.
    self.deconv4 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 3,
                                      stride = final_transpose_stride,
                                      padding=transpose_padding,
                                      output_padding = output_padding, bias=False)
    self.bn4 = nn.BatchNorm2d(out_channels)

    #Second residual connection
    self.shortcut2 = nn.Sequential()
    if in_channels != out_channels or final_transpose_stride != 1:
      self.shortcut2 = nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=1,
                           stride = final_transpose_stride, padding = 0,
                           output_padding = 1, bias=False),
        nn.BatchNorm2d(out_channels)
      )

  #Forward pass method
  def forward(self, input):
    #First "Basic block"
    sub_block_1 = F.relu(self.bn1(self.deconv1(input)))
    sub_block_1 = self.bn2(self.deconv2(sub_block_1))
    sub_block_1 += self.shortcut1(input) #Adding residual connection
    sub_block_1 = F.relu(sub_block_1)

    #Second "Basic Block"
    out = F.relu(self.bn3(self.deconv3(sub_block_1)))
    out = self.bn4(self.deconv4(out))
    out += self.shortcut2(sub_block_1) #Adding second residual connection
    out = F.relu(out)

    return out

#Class for the whole decoder
class ResNetDecoder(nn.Module):
  def __init__(self, n_blocks, n_channels_in, n_channels_out, strides,
               final_stride, final_padding, output_padding, checkpoint_n = 0):
    '''
    Args:
      n_blocks (int): Number of DecoderBlocks needed, 1 <= n_blocks <= 4.

      n_channels_in: A list, where each element is the number of input channels to each
                    decoder block. (i.e. the number of output channels of
                    each resnet block, in reverse order)

      n_channels_out: A list where each element is the number of output channels of each
                      decoder block.

      strides: A list. The stride to be used in the final deconvolution of each block.

      final_stride, final_padding: Stride and padding to be used in the final
                                  deconvolution of the decoder
                                  When decoding only the first block, stride and
                                  padding should both be 1. (in our first experiment)

      output_padding (list): Sets the output_padding to be used in the final
                            deconvolution of each DecoderBlock

      checkpoint_n (int): Integer representing the checkpoint number. Only required
                          for checkpoints 5 and 6.
    '''
    super(ResNetDecoder, self).__init__()

    layers = []

    self.checkpoint_n = checkpoint_n

    #Last checkpoint requires a linear layer
    if checkpoint_n == 6:
      self.linear = nn.Linear(10, 512)

    #if checkpoint_n >= 5:
      self.register_parameter('unpool_params', None)

    for i in range(0, n_blocks):
      layers.append(DecoderBlock(n_channels_in[i], n_channels_out[i],
                                 final_transpose_stride=strides[i],
                                 output_padding = output_padding[i]))

    self.decoder_blocks = nn.Sequential(*layers)

    self.deconv_final = nn.ConvTranspose2d(in_channels=n_channels_out[-1], out_channels = 3,
                                           stride = final_stride, padding = final_padding,
                                           kernel_size = 3)

  def forward(self, input):
    out = input

    #assuming dimension 0 is minibatch, we need to go from 2d to 3d
    if self.checkpoint_n == 6:
      #Fully connected layer
      #dims (minibatch, 10)
      out = out.view(-1, 1, out.size(1)) #dims (minibatch, 1, 10)
      out = self.linear(out) #dims (minibatch, 1, 512)
      out = out.view(-1, 512, 1).unsqueeze(3) #dims (minibatch, 512, 1, 1)

    #If we are decoding checkpoint 5 or 6, we need to go from (512, 1, 1) to
    #(512, 4, 4) (we do unpooling)
    if self.checkpoint_n >= 5:
      out = out.expand(-1, 512, 4, 4) #dims (minibatch, 512, 4, 4)
      #Introducing trainable parameters for the unpooling
      #We initialize them as 1, since setting them to 1 would mean the average
      #pooling is inversed (almost)
      self.unpool_params = nn.Parameter(torch.ones(out.size()).to(device))
      out = out * self.unpool_params

    out = self.decoder_blocks(out)
    out = self.deconv_final(out)
    return out



# Loading Target Data

In [None]:
# CUDA for PyTorch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True #Finds best algorithms automatically? Somehow

In [None]:
from google.colab import drive
drive.mount('/content/drive')

data_path = "/content/drive/My Drive/Colab Notebooks/Examensarbete/Data/"

targets_training = torch.load(data_path + "images_train.pt")
targets_testing = torch.load(data_path + "images_test.pt")
train_labels = torch.load("/content/drive/My Drive/Colab Notebooks/Examensarbete/train_labels.pt")
test_labels = torch.load("/content/drive/My Drive/Colab Notebooks/Examensarbete/test_labels.pt")

In [None]:
# @title Plotting Example Images {display-mode: "form"}

import matplotlib.pyplot as plt

classes = ["airplane",
           "automobile",
           "bird",
           "cat",
           "deer",
           "dog",
           "frog",
           "horse",
           "ship",
           "truck"]

train_labels_list = list(train_labels.numpy())
train_labels_str = [classes[i] for i in train_labels_list]

class_indices_list = [train_labels_str.index(elem) for elem in set(train_labels_str)]
class_indices = torch.tensor(class_indices_list).reshape(2, 5)

fig, axes = plt.subplots(2, 5, figsize=(15, 3))

for row in range(0, 2):
  for col in range(0, 5):
    img_idx = class_indices[row, col]

    #The tensors are of dimensions (3, 32, 32), but imshow expects
    #(32, 32, 3), so we need to permute the tensor
    image = targets_training[img_idx].permute(1, 2, 0)
    # Normalizing the pixel values (0-1 float expected by imshow)
    image = (image - image.min()) / (image.max() - image.min())

    axes[row, col].imshow(image)
    axes[row, col].spines['top'].set_visible(False)
    axes[row, col].spines['right'].set_visible(False)
    axes[row, col].spines['bottom'].set_visible(False)
    axes[row, col].spines['left'].set_visible(False)
    axes[row, col].set_xlabel(train_labels_str[img_idx],
                          rotation=0, ha = "center", size = "medium",
                              weight = "bold")

plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=-0.861, hspace=0.28)
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
plt.show()




In [None]:
print(class_indices_list)
print(class_indices)

# Functions for Training

## Training and validation functions

Functions for doing a training and a testing step

In [None]:
import torch.optim as optim
import os
import matplotlib.pyplot as plt
from collections import OrderedDict

#Function for doing a training step
def train_step(model, training_DataLoader, device, loss_fn, optimizer):
  "Trains one epoch of the model."

  # Puting model in training mode
  model.train()

  # Torch tensor for storing loss for each batch
  train_loss_tensor = torch.zeros(len(training_DataLoader))

  for i, (local_batch, local_targets) in enumerate(training_DataLoader):
    # Transfer to GPU
    local_batch, local_targets = local_batch.to(device), local_targets.to(device)

    # 1. Forward pass on training data
    y_pred = model(local_batch)

    # 2. Calculating and accumulating the loss
    temp_loss = loss_fn(y_pred, local_targets)
    train_loss_tensor[i] = temp_loss.item()

    # 3. Setting gradient to zero
    optimizer.zero_grad()

    # 4. Loss backwards
    temp_loss.backward()

    # 5. Progressing the optimizer
    optimizer.step()

  #Standard error of the training loss
  train_loss_se = train_loss_tensor.std().item() / len(train_loss_tensor)

  #Getting average loss per batch
  train_loss = train_loss_tensor.mean().item()

  return train_loss, train_loss_se

#Function for doing a validation step
def test_step(model, testing_DataLoader, device, loss_fn):
  "Tests one epoch of the model."

  #Putting model in eval mode
  model.eval()

  #Empty tensor for storing loss values
  test_loss_tensor = torch.zeros(len(testing_DataLoader))

  # Turn on inference context manager
  with torch.inference_mode():
    #Looping over batches
    for i, (local_batch, local_targets) in enumerate(testing_DataLoader):
      # Transfer to GPU
      local_batch, local_targets = local_batch.to(device), local_targets.to(device)

      #Forward pass
      preds = model(local_batch)

      #Calculating and accumulating loss
      temp_loss = loss_fn(preds, local_targets)
      test_loss_tensor[i] = temp_loss.item()

  #Calculating average loss
  test_loss = test_loss_tensor.mean().item()
  test_loss_se = test_loss_tensor.std().item() / len(test_loss_tensor)

  return test_loss, test_loss_se

In [None]:
#Training and validation function
def train_model(model, model_save_name, training_DataLoader, testing_DataLoader,
                device, loss_fn, optimizer, patience = 10, max_epochs = 30,
                save_path = "/content/drive/My Drive/Colab Notebooks/Examensarbete/Model parameters/"):
  """
    Trains a PyTorch model for a number of epochs, implements early stopping.
    Also plots the validation curve after each epoch.

    Args:
        model: The PyTorch model.
        model_save_name (str): Name to use when saving the parameters of the best model.
        training_DataLoader: DataLoader for training
        testing_DataLoader: DataLoader for validation
        device: Device to store tensors on.
        loss_fn: Loss function to use
        optimizer: Optimizer to use
        patience (int): Patience used for early stopping.
        max_epochs (int): Maximum number of epochs to run.
        save_path (str): Where to store the best parameter values.
  """
  import time

  #Empty lists for tracking loss and epochs
  train_loss_values = []
  test_loss_values = []
  train_loss_se = []
  test_loss_se = []
  epoch_count = []

  #For early stopping
  best_loss = None
  best_train_loss = None
  early_stop_count = 0

  # Loop over epochs
  for epoch in range(1, max_epochs + 1):
    start_time = time.time()
    train_loss, train_loss_se = train_step(model, training_DataLoader, device, loss_fn,
                            optimizer)

    test_loss, test_loss_se = test_step(model, testing_DataLoader, device, loss_fn)

    #Printing results
    epoch_count.append(epoch)
    train_loss_values.append(train_loss)
    test_loss_values.append(test_loss)
    end_time = round(time.time() - start_time, 2)

    print(f"Epoch: {epoch} | Time: {end_time} seconds \
          \n| MSE Train Loss: {train_loss} SE: {train_loss_se} \
          \n| MSE Test Loss: {test_loss} SE: {test_loss_se}")

    # Ploting the loss curves
    plt.errorbar(epoch_count, train_loss_values, train_loss_se, label="Train loss")
    plt.errorbar(epoch_count, test_loss_values, test_loss_se, label="Validation loss")
    plt.title("Training and validation loss curves")
    plt.ylabel("Loss")
    plt.xlabel("Epochs")
    plt.legend()
    plt.show()

    if best_train_loss == None or train_loss < best_train_loss:
      torch.save(model.state_dict(), save_path + model_save_name + "_best_trainloss.pt")

    #Early stopping
    if best_loss == None or test_loss < best_loss:
      best_loss = test_loss
      early_stop_count = 0
      best_epoch = epoch
      #Saving best model parameters
      torch.save(model.state_dict(), save_path + model_save_name + ".pt")

    else:
      early_stop_count += 1
      if early_stop_count >= patience:
        print(f"Validation loss not decreasing. Stopping early. Best epoch: {best_epoch}")
        break

  return train_loss_values, test_loss_values

Loss function (used for all decoders)

In [None]:
#Loss function
loss_fn = nn.MSELoss()

## Function for plotting results

In [None]:
import matplotlib.pyplot as plt

# Assuming that the images are of shape (3, 32, 32)
def plot_generated_imgs(model, left_col_data, right_col_data,
                        num_images_to_plot = 10, figsize = (15, 6),
                        img_indices_to_plot = None):
  """
  Plots generated and original images.

  Args:
    model: The PyTorch model to be used.
    left_col_data (tensor): Data for the left column images (original data).
    right_col_data (tensor): Target data. Expects (3, 32, 32) tensors.
    num_imgs_plotted (int): Number of images to plot.
    img_indices_to_plot (list): Optional. Indices of specific images to be decoded.
  """
  model.eval()
  if img_indices_to_plot != None:
    indices = img_indices_to_plot
  else:
    indices = range(num_images_to_plot)

  fig, axes = plt.subplots(len(indices), 2, figsize = figsize)

  # Plotting original images (left column)
  for j in indices:
    original_image = left_col_data[j].permute(1, 2, 0)
    original_image = (original_image - original_image.min()) / (original_image.max() - original_image.min())
    axes[j, 0].imshow(original_image.cpu().numpy()) #Images need to be stored on CPU
    axes[j, 0].axis('off')

  #Plotting decoded images (right column)
  for i in indices:
    #Selecting observation i, and then adding back the missing dimension
    observation = right_col_data[i].unsqueeze(0)
    #Running model to decode an observation
    with torch.no_grad():
      decoded_image = model(observation)

    # The tensors are of dimensions (3, 32, 32), but imshow expects (32, 32, 3)
    decoded_image = decoded_image.squeeze(0).permute(1, 2, 0)
    # Normalizing the pixel values (0-1 float expected by imshow)
    decoded_image = (decoded_image - decoded_image.min()) / (decoded_image.max() - decoded_image.min())
    axes[i, 1].imshow(decoded_image.cpu().numpy())
    axes[i, 1].axis('off')

  plt.subplots_adjust(wspace=0, hspace=0)

  plt.show()

## Function for getting error of individual predictions

In [None]:
def get_squared_error(model, features, targets):
    """"
    Calculates squared error for each individual sample, based on model.
    """
    sq_er = nn.MSELoss(reduction = "sum")

    model.eval()

    losses = torch.zeros(targets.size(0))

    #Calculating loss for each image
    for i in range(targets.size(0)):
        with torch.no_grad():
            pred = model(features[i].unsqueeze(0)).cpu()

        losses[i] = sq_er(pred.squeeze(0), targets[i])

    return losses

In [None]:
err_save_path = "/content/drive/My Drive/Colab Notebooks/Examensarbete/Individual errors/"

In [None]:
save_path = "/content/drive/My Drive/Colab Notebooks/Examensarbete/Model parameters/"

# Loading ResNet-18 Model

In [None]:
'''ResNet in PyTorch.

For Pre-activation ResNet, see 'preact_resnet.py'.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385

Code in this cell from:
    https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
'''


import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [None]:
def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

resnet = ResNet18().to(device)

In [None]:
resnet_params = torch.load(save_path + "resnet_params.pkl")

In [None]:
#Changing keys (resnet_params.pkl keys start with the name "module.")
resnet_params = OrderedDict([(k.replace("module.", ""), v) for k, v in resnet_params.items()])

resnet.load_state_dict(resnet_params)

## Function for making ResNet Prediction on Decoded Images

In [None]:
def decoded_predictions(data, decoder, prediction_model):
  """
    Decodes images in data using decoder, then returns output of
    prediction_model when run on decoded images.
  """
  decoder.eval()
  prediction_model.eval()
  n_obs = data.size(0)

  #Empty tensors for predictions
  decoded_images = torch.zeros(n_obs, 3, 32, 32).to(device)
  preds = torch.zeros(n_obs, 10).to(device)

  with torch.inference_mode():
    for i in range(n_obs):
      datapoint = data[i].to(device)
      decoded_images[i] = decoder(datapoint.unsqueeze(0))
      preds[i] = prediction_model(decoded_images[i].unsqueeze(0))

  return preds.to("cpu")

# Training and Results

## Checkpoint 1

In [None]:
#Loading feature tensors
features_training_1 = torch.load(data_path + "c_1_train.pt")
features_testing_1 = torch.load(data_path + "c_1_test.pt")

# Parameters for the dataloaders
params = {'batch_size': 10,
          'shuffle': True,
          'num_workers': 0}

# Creating datasets and dataloaders
training_set = torch.utils.data.TensorDataset(features_training_1, targets_training)
training_generator = torch.utils.data.DataLoader(training_set, **params)

validation_set = torch.utils.data.TensorDataset(features_testing_1, targets_testing)
validation_generator = torch.utils.data.DataLoader(validation_set, **params)

In [None]:
#Initializing model and moving it to the GPU
decoder_1 = ResNetDecoder(n_blocks = 1, n_channels_in = [64], n_channels_out = [64],
                          strides = [1], final_stride=1, final_padding = 1,
                          output_padding = [0]).to(device)

decoder_1 = torch.nn.DataParallel(decoder_1)

#Trying default parameter values for Adam
optimizer = optim.Adam(decoder_1.parameters(), lr = 0.001)

In [None]:
#Training the model
train_model(decoder_1, "decoder_1_2", training_generator, validation_generator,
                device, loss_fn, optimizer, patience = 10, max_epochs = 30)

### Visualizing Images (Original and Decoded)

In [None]:
decoder_1.load_state_dict(torch.load(save_path + "decoder_1_3.pt"))

In [None]:
#Training data
plot_generated_imgs(decoder_1, targets_training, features_training_1,
                        num_images_to_plot = 10, figsize = (16, 32))

In [None]:
#Validation data
plot_generated_imgs(decoder_1, targets_testing, features_testing_1,
                        num_images_to_plot = 10, figsize = (16, 32))

### Getting error for each image

In [None]:
checkpoint_1_train_err = get_squared_error(decoder_1, features_training_1, targets_training)
checkpoint_1_test_err = get_squared_error(decoder_1, features_testing_1, targets_testing)

torch.save(checkpoint_1_train_err, err_save_path + "checkpoint_1_train_err.pt")
torch.save(checkpoint_1_test_err, err_save_path + "checkpoint_1_test_err.pt")

In [None]:
#freeing up memory
del(checkpoint_1_train_err)
del(checkpoint_1_test_err)

### Running Resnet on checkpoint

In [None]:
c_1_resnet = decoded_predictions(features_testing_1, decoder_1, resnet)

c_1_resnet_loss = F.cross_entropy(c_1_resnet, test_labels)
c_1_resnet_preds = F.softmax(c_1_resnet, dim = 1).argmax(dim = 1)
c_1_resnet_acc = ((c_1_resnet_preds == test_labels).sum() / 10000).sum()

## Checkpoint 2

In [None]:
#Loading feature tensors
features_training_2 = torch.load(data_path + "c_2_train.pt")
features_testing_2 = torch.load(data_path + "c_2_test.pt")

# Parameters for the dataloaders
params = {'batch_size': 10,
          'shuffle': True,
          'num_workers': 0}

# Creating datasets and dataloaders
training_set = torch.utils.data.TensorDataset(features_training_2, targets_training)
training_generator = torch.utils.data.DataLoader(training_set, **params)

validation_set = torch.utils.data.TensorDataset(features_testing_2, targets_testing)
validation_generator = torch.utils.data.DataLoader(validation_set, **params)

In [None]:
#Initializing model and moving it to the GPU
decoder_2 = ResNetDecoder(n_blocks = 2, n_channels_in = [128, 64],
                          n_channels_out = [64, 64], strides = [2, 1],
                          final_stride=1, final_padding = 1,
                          output_padding = [1, 0]).to(device)

decoder_2 = torch.nn.DataParallel(decoder_2)

In [None]:
#Loading weights from decoder_1
state_dict_1 = decoder_1.state_dict()
#Changing keys (necessary for loading weights into decoder_2)
state_dict_1 = OrderedDict([(k.replace("blocks.0", "blocks.1"), v) for k, v in state_dict_1.items()])

#Initializing weights for decoder_2
decoder_2.load_state_dict(state_dict_1, strict = False)

#Trying default parameter values for Adam
optimizer = optim.Adam(decoder_2.parameters(), lr = 0.001)

In [None]:
#Training the model
train_model(decoder_2, "decoder_2", training_generator, validation_generator,
                device, loss_fn, optimizer, patience = 10, max_epochs = 30)

### Plotting results

In [None]:
decoder_2.load_state_dict(torch.load(save_path + "decoder_2_2.pt"))

In [None]:
#Training data
plot_generated_imgs(decoder_2, targets_training, features_training_2,
                        num_images_to_plot = 10, figsize = (16, 32))

In [None]:
#Validation data
plot_generated_imgs(decoder_2, targets_testing, features_testing_2,
                        num_images_to_plot = 10, figsize = (16, 32))


### Getting errors

In [None]:
checkpoint_2_train_err = get_squared_error(decoder_2, features_training_2, targets_training)
checkpoint_2_test_err = get_squared_error(decoder_2, features_testing_2, targets_testing)

torch.save(checkpoint_2_train_err, err_save_path + "checkpoint_2_train_err.pt")
torch.save(checkpoint_2_test_err, err_save_path + "checkpoint_2_test_err.pt")

In [None]:
#freeing up memory
del(checkpoint_2_train_err)
del(checkpoint_2_test_err)

### Running Resnet on checkpoint

In [None]:
c_2_resnet = decoded_predictions(features_testing_2, decoder_2, resnet)

c_2_resnet_loss = F.cross_entropy(c_2_resnet, test_labels)
c_2_resnet_preds = F.softmax(c_2_resnet, dim = 1).argmax(dim = 1)
c_2_resnet_acc = ((c_2_resnet_preds == test_labels).sum() / 10000).sum()

## Checkpoint 3

In [None]:
#Loading feature tensors
features_training_3 = torch.load(data_path + "c_3_train.pt")
features_testing_3 = torch.load(data_path + "c_3_test.pt")

# Parameters for the dataloaders
params = {'batch_size': 10,
          'shuffle': True,
          'num_workers': 0}

# Creating datasets and dataloaders
training_set = torch.utils.data.TensorDataset(features_training_3, targets_training)
training_generator = torch.utils.data.DataLoader(training_set, **params)

validation_set = torch.utils.data.TensorDataset(features_testing_3, targets_testing)
validation_generator = torch.utils.data.DataLoader(validation_set, **params)

In [None]:
#Initializing model and moving it to the GPU
decoder_3 = ResNetDecoder(n_blocks = 3, n_channels_in = [256, 128, 64],
                          n_channels_out = [128, 64, 64], strides = [2, 2, 1],
                          final_stride=1, final_padding = 1,
                          output_padding = [1, 1, 0]).to(device)

decoder_3 = torch.nn.DataParallel(decoder_3)

In [None]:
#Loading weights from decoder_2
state_dict_2 = decoder_2.state_dict()
#Changing keys (necessary for loading weights into decoder_3)
state_dict_2 = OrderedDict([(k.replace("blocks.1", "blocks.2"), v) for k, v in state_dict_2.items()])
state_dict_2 = OrderedDict([(k.replace("blocks.0", "blocks.1"), v) for k, v in state_dict_2.items()])

#Initializing weights for decoder_3
decoder_3.load_state_dict(state_dict_2, strict = False)

#Trying default parameter values for Adam
optimizer = optim.Adam(decoder_3.parameters(), lr = 0.001)

In [None]:
#Training the model
train_model(decoder_3, "decoder_3", training_generator, validation_generator,
                device, loss_fn, optimizer, patience = 20, max_epochs = 80)

### Plotting results

In [None]:
decoder_3.load_state_dict(torch.load(save_path + "decoder_3_2.pt"))

In [None]:
#Training data
plot_generated_imgs(decoder_3, targets_training, features_training_3,
                        num_images_to_plot = 10, figsize = (16, 32))

In [None]:
#Validation data
plot_generated_imgs(decoder_3, targets_testing, features_testing_3,
                        num_images_to_plot = 10, figsize = (16, 32))

### Getting errors

In [None]:
checkpoint_3_train_err = get_squared_error(decoder_3, features_training_3, targets_training)
checkpoint_3_test_err = get_squared_error(decoder_3, features_testing_3, targets_testing)

torch.save(checkpoint_3_train_err, err_save_path + "checkpoint_3_train_err.pt")
torch.save(checkpoint_3_test_err, err_save_path + "checkpoint_3_test_err.pt")

In [None]:
#freeing up memory
del(checkpoint_3_train_err)
del(checkpoint_3_test_err)

### Running Resnet on Checkpoint

In [None]:
c_3_resnet = decoded_predictions(features_testing_3, decoder_3, resnet)

c_3_resnet_loss = F.cross_entropy(c_3_resnet, test_labels)
c_3_resnet_preds = F.softmax(c_3_resnet, dim = 1).argmax(dim = 1)
c_3_resnet_acc = ((c_3_resnet_preds == test_labels).sum() / 10000).sum()

## Checkpoint 4

In [None]:
#Loading feature tensors
features_training_4 = torch.load(data_path + "c_4_train.pt")
features_testing_4 = torch.load(data_path + "c_4_test.pt")

# Parameters for the dataloaders
params = {'batch_size': 10,
          'shuffle': True,
          'num_workers': 0}

# Creating datasets and dataloaders
training_set = torch.utils.data.TensorDataset(features_training_4, targets_training)
training_generator = torch.utils.data.DataLoader(training_set, **params)

validation_set = torch.utils.data.TensorDataset(features_testing_4, targets_testing)
validation_generator = torch.utils.data.DataLoader(validation_set, **params)

In [None]:
#Initializing model and moving it to the GPU
decoder_4 = ResNetDecoder(n_blocks = 4, n_channels_in = [512, 256, 128, 64],
                          n_channels_out = [256, 128, 64, 64],
                          strides = [2, 2, 2, 1], final_stride=1,
                          final_padding = 1,
                          output_padding = [1, 1, 1, 0]).to(device)

decoder_4 = torch.nn.DataParallel(decoder_4)

In [None]:
#Loading weights from decoder_3
state_dict_3 = decoder_3.state_dict()
#Changing keys (necessary for loading weights into decoder_4)
state_dict_3 = OrderedDict([(k.replace("blocks.2", "blocks.3"), v) for k, v in state_dict_3.items()])
state_dict_3 = OrderedDict([(k.replace("blocks.1", "blocks.2"), v) for k, v in state_dict_3.items()])
state_dict_3 = OrderedDict([(k.replace("blocks.0", "blocks.1"), v) for k, v in state_dict_3.items()])

#Initializing weights for decoder_2
decoder_4.load_state_dict(state_dict_3, strict = False)

#Trying default parameter values for Adam
optimizer = optim.Adam(decoder_4.parameters(), lr = 0.001)

In [None]:
#Training the model
train_model(decoder_4, "decoder_4_2", training_generator, validation_generator,
                device, loss_fn, optimizer, patience = 80, max_epochs = 120)

### Plotting results

In [None]:
print('/a')

In [None]:
decoder_4.load_state_dict(torch.load(save_path + "decoder_4_2.pt"))

In [None]:
#Training data
plot_generated_imgs(decoder_4, targets_training, features_training_4,
                        num_images_to_plot = 10, figsize = (16, 32))

In [None]:
#Validation data
plot_generated_imgs(decoder_4, targets_testing, features_testing_4,
                        num_images_to_plot = 10, figsize = (16, 32))

### Getting error

In [None]:
checkpoint_4_train_err = get_squared_error(decoder_4, features_training_4, targets_training)
checkpoint_4_test_err = get_squared_error(decoder_4, features_testing_4, targets_testing)

torch.save(checkpoint_4_train_err, err_save_path + "checkpoint_4_2_train_err.pt")
torch.save(checkpoint_4_test_err, err_save_path + "checkpoint_4_2_test_err.pt")

In [None]:
#freeing up memory
del(checkpoint_4_train_err)
del(checkpoint_4_test_err)

### Running Resnet on checkpoint

In [None]:
c_4_resnet = decoded_predictions(features_testing_4, decoder_4, resnet)

c_4_resnet_loss = F.cross_entropy(c_4_resnet, test_labels)
c_4_resnet_preds = F.softmax(c_4_resnet, dim = 1).argmax(dim = 1)
c_4_resnet_acc = ((c_4_resnet_preds == test_labels).sum() / 10000).sum()

## Checkpoint 5

In [None]:
#Loading feature tensors
features_training_5 = torch.load(data_path + "c_5_train.pt")
features_testing_5 = torch.load(data_path + "c_5_test.pt")

# Parameters for the dataloaders
params = {'batch_size': 10,
          'shuffle': True,
          'num_workers': 0}

# Creating datasets and dataloaders
training_set = torch.utils.data.TensorDataset(features_training_5, targets_training)
training_generator = torch.utils.data.DataLoader(training_set, **params)

validation_set = torch.utils.data.TensorDataset(features_testing_5, targets_testing)
validation_generator = torch.utils.data.DataLoader(validation_set, **params)

In [None]:
#Initializing model and moving it to the GPU
decoder_5 = ResNetDecoder(n_blocks = 4, n_channels_in = [512, 256, 128, 64],
                          n_channels_out = [256, 128, 64, 64], strides = [2, 2, 2, 1],
                          final_stride=1, final_padding = 1,
                          output_padding = [1, 1, 1, 0],
                          checkpoint_n = 5).to(device)

decoder_5 = torch.nn.DataParallel(decoder_5)

In [None]:
#Loading weights from decoder_4
state_dict_4 = decoder_4.state_dict()

#Initializing weights for decoder_5
decoder_5.load_state_dict(state_dict_4, strict = False)

#Trying default parameter values for Adam
optimizer = optim.Adam(decoder_5.parameters(), lr = 0.001)

In [None]:
#Training the model
decoder_5_train_loss, decoder_5_test_loss = train_model(decoder_5, "decoder_5_3", training_generator,
                                                        validation_generator, device, loss_fn,
                                                        optimizer, patience = 200, max_epochs = 300)

### Plotting results

In [None]:
#This needs to run to ensure that unpooling weights are initialized
init = decoder_5(torch.zeros(10, 512, 1, 1))

In [None]:
with torch.inference_mode():
  decoder_5.load_state_dict(torch.load(save_path + "decoder_5_3.pt"))

In [None]:
#Training data
plot_generated_imgs(decoder_5, targets_training, features_training_5,
                        num_images_to_plot = 10, figsize = (16, 32))

In [None]:
#Validation data
plot_generated_imgs(decoder_5, targets_testing, features_testing_5,
                        num_images_to_plot = 10, figsize = (16, 32))

### Getting errors

In [None]:
checkpoint_5_train_err = get_squared_error(decoder_5, features_training_5, targets_training)
checkpoint_5_test_err = get_squared_error(decoder_5, features_testing_5, targets_testing)

torch.save(checkpoint_5_train_err, err_save_path + "checkpoint_5_3_train_err.pt")
torch.save(checkpoint_5_test_err, err_save_path + "checkpoint_5_3_test_err.pt")

In [None]:
#freeing up memory
del(checkpoint_5_train_err)
del(checkpoint_5_test_err)

### Running resnet on checkpoint

In [None]:
c_5_resnet = decoded_predictions(features_testing_5, decoder_5, resnet)

c_5_resnet_loss = F.cross_entropy(c_5_resnet, test_labels)
c_5_resnet_preds = F.softmax(c_5_resnet, dim = 1).argmax(dim = 1)
c_5_resnet_acc = ((c_5_resnet_preds == test_labels).sum() / 10000).sum()

In [None]:
del(features_training_5)
#del(features_testing_5)

## Checkpoint 6

In [None]:
#Loading feature tensors
features_training_6 = torch.load(data_path + "c_6_train.pt")
features_testing_6 = torch.load(data_path + "c_6_test.pt")

# Parameters for the dataloaders
params = {'batch_size': 10,
          'shuffle': True,
          'num_workers': 0}

# Creating datasets and dataloaders
training_set = torch.utils.data.TensorDataset(features_training_6, targets_training)
training_generator = torch.utils.data.DataLoader(training_set, **params)

validation_set = torch.utils.data.TensorDataset(features_testing_6, targets_testing)
validation_generator = torch.utils.data.DataLoader(validation_set, **params)

In [None]:
#Initializing model and moving it to the GPU
decoder_6 = ResNetDecoder(n_blocks = 4, n_channels_in = [512, 256, 128, 64],
                          n_channels_out = [256, 128, 64, 64], strides = [2, 2, 2, 1],
                          final_stride=1, final_padding = 1,
                          output_padding = [1, 1, 1, 0],
                          checkpoint_n = 6).to(device)

decoder_6 = torch.nn.DataParallel(decoder_6)

In [None]:
#Loading weights from decoder_5
state_dict_5 = decoder_5.state_dict()

#Initializing weights for decoder_2
decoder_6.load_state_dict(state_dict_5, strict = False)

#Trying default parameter values for Adam
optimizer = optim.Adam(decoder_6.parameters(), lr = 0.001)

In [None]:
#Training the model
train_model(decoder_6, "decoder_6_8", training_generator, validation_generator,
                device, loss_fn, optimizer, patience = 200, max_epochs = 300)

### Plotting results

In [None]:
#This needs to run to ensure that unpooling weights are initialized
init = decoder_6(torch.zeros(10, 10))

In [None]:
with torch.inference_mode():
  decoder_6.load_state_dict(torch.load(save_path + "decoder_6_8.pt", map_location = device))

In [None]:
#Training data
plot_generated_imgs(decoder_6, targets_training, features_training_6,
                        num_images_to_plot = 10, figsize = (16, 32))

In [None]:
#Validation data
plot_generated_imgs(decoder_6, targets_testing, features_testing_6,
                        num_images_to_plot = 10, figsize = (16, 32))


### Getting errors

In [None]:
checkpoint_6_train_err = get_squared_error(decoder_6, features_training_6, targets_training)
checkpoint_6_test_err = get_squared_error(decoder_6, features_testing_6, targets_testing)

torch.save(checkpoint_6_train_err, err_save_path + "checkpoint_6_8_train_err.pt")
torch.save(checkpoint_6_test_err, err_save_path + "checkpoint_6_8_test_err.pt")

In [None]:
#freeing up memory
del(checkpoint_6_train_err)
del(checkpoint_6_test_err)

### Running resnet on checkpoint

In [None]:
c_6_resnet = decoded_predictions(features_testing_6, decoder_6, resnet)

c_6_resnet_loss = F.cross_entropy(c_6_resnet, test_labels)
c_6_resnet_preds = F.softmax(c_6_resnet, dim = 1).argmax(dim = 1)
c_6_resnet_acc = ((c_6_resnet_preds == test_labels).sum() / 10000).sum()

In [None]:
resnet_losses = torch.tensor([c_1_resnet_loss,
                              c_2_resnet_loss,
                              c_3_resnet_loss,
                              c_4_resnet_loss,
                              c_5_resnet_loss,
                              c_6_resnet_loss])

resnet_acc = torch.tensor([c_1_resnet_acc,
                           c_2_resnet_acc,
                           c_3_resnet_acc,
                           c_4_resnet_acc,
                           c_5_resnet_acc,
                           c_6_resnet_acc,])

torch.save(resnet_losses,
           "/content/drive/My Drive/Colab Notebooks/Examensarbete/resnet_losses.pt")

torch.save(resnet_acc,
           "/content/drive/My Drive/Colab Notebooks/Examensarbete/resnet_acc.pt")

# Decoding Each Checkpoint

In [None]:
#Function for decoding individual embeddings into images
import numpy as np

def decode_image(model, observation, show = False, device = "cpu"):
  model.eval()
  with torch.no_grad():
    decoded_image = model(observation.unsqueeze(0))

  if show == True:
    # The tensor is of dimensions (3, 32, 32), but imshow expects (32, 32, 3)
    decoded_image = decoded_image.squeeze(0).permute(1, 2, 0)

    #Normalizing pixel values to 0-1
    decoded_image = (decoded_image - decoded_image.min()) / (decoded_image.max() - decoded_image.min())
    decoded_image = decoded_image.to("cpu")
    plt.imshow(decoded_image)
    plt.show()

  else:
    return decoded_image

In [None]:
def decode_checkpoint(data, indices, decoder):
  num_images = len(indices)
  decoded_imgs = torch.zeros(len(indices), 3, 32, 32)

  for i in range(num_images):
    decoded_imgs[i] = decode_image(decoder, data[indices[i]], show = False)

  return decoded_imgs

In [None]:
test_probs = F.softmax(features_testing_6, dim = 1)

#Tensor containing the probability that the network assigns to the correct class
correct_class_probs = torch.zeros(10000)
for i in range(10000):
  correct_class_probs[i] = test_probs[i, test_labels[i]]

In [None]:
classes = ["airplane",
           "automobile",
           "bird",
           "cat",
           "deer",
           "dog",
           "frog",
           "horse",
           "ship",
           "truck"]

In [None]:
test_labels_list = list(test_labels.numpy())

## High probability observations

In [None]:
#Getting indices of observations that were easy to predict
not_high_probs_idx = correct_class_probs < 0.9
high_probs_labels = test_labels.clone().detach()
high_probs_labels[not_high_probs_idx] = -1 #Setting low prob. observations to -1
high_probs_labels = list(high_probs_labels.numpy())

high_probs_idx = [high_probs_labels.index(elem) for elem in set(test_labels_list)]

In [None]:
num_checkpoints = 7

#Indices of the images we wish to show

#First occurence of each class
#image_idx = [test_labels_list.index(elem) for elem in set(test_labels_list)]

image_idx = high_probs_idx
num_images = len(image_idx)

In [None]:
class_prob = correct_class_probs[image_idx].numpy()

In [None]:
#Creating tensors containing decoded images
c_1_decoded = decode_checkpoint(features_testing_1, image_idx, decoder_1)
c_2_decoded = decode_checkpoint(features_testing_2, image_idx, decoder_2)
c_3_decoded = decode_checkpoint(features_testing_3, image_idx, decoder_3)
c_4_decoded = decode_checkpoint(features_testing_4, image_idx, decoder_4)
c_5_decoded = decode_checkpoint(features_testing_5, image_idx, decoder_5)
c_6_decoded = decode_checkpoint(features_testing_6, image_idx, decoder_6)

images_to_plot = torch.zeros(num_checkpoints, num_images, 3, 32, 32)
images_to_plot[0] = targets_testing[tuple(image_idx), :, :, :]
images_to_plot[1] = c_1_decoded
images_to_plot[2] = c_2_decoded
images_to_plot[3] = c_3_decoded
images_to_plot[4] = c_4_decoded
images_to_plot[5] = c_5_decoded
images_to_plot[6] = c_6_decoded

In [None]:
#Predicting decodings with resnet
resnet_out = torch.zeros(num_images, num_checkpoints, 10)
resnet.eval()
for i in range(num_images):
  for c in range(num_checkpoints):
    with torch.no_grad():
      resnet_out[i, c] = resnet(images_to_plot[c, i].unsqueeze(0).to(device))

In [None]:
#Getting resnet probabilities
resnet_probs = F.softmax(resnet_out, dim = 2)
resnet_preds = resnet_probs.argmax(dim = 2)

#Getting resnet predictions as strings
resnet_preds_str = np.empty((num_images, num_checkpoints), dtype = "U10")

for i in range(num_images):
  for c in range(num_checkpoints):
    resnet_preds_str[i, c] = classes[resnet_preds[i, c]]

In [None]:
#Plotting decodings
fig, axes = plt.subplots(num_images, num_checkpoints, figsize = (15, 6))

for i in range(0, num_images):
  current_class = test_labels[image_idx[i]]

  axes[i, 0].set_ylabel(classes[current_class],# + "\np = " + str(round(class_prob[i], 4)),
                        rotation=0, size='medium',
                        ha = "right", va = "center", weight = "bold")
  for j in range(0, num_checkpoints):
    image = images_to_plot[j, i]
    image = image.permute(1, 2, 0).to("cpu")
    image = (image - image.min()) / (image.max() - image.min())
    axes[i, j].imshow(image)
    axes[i, j].spines['top'].set_visible(False)
    axes[i, j].spines['right'].set_visible(False)
    axes[i, j].spines['bottom'].set_visible(False)
    axes[i, j].spines['left'].set_visible(False)
    axes[i, j].set_xlabel(resnet_preds_str[i, j].item(),
                          rotation=0, ha = "center", size = "small")#, size='large'

axes[0, 0].set_title("Original", rotation=0, size='medium', ha = "center",
                     weight = "bold")

for k in range(1, num_checkpoints):
  axes[0, k].set_title("CP " + str(k), rotation=0, size='medium',
                       ha = "center", weight = "bold")

# Remove padding and margin around the subplots
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=-0.93, hspace=0.5)

# Remove the white space around the figure
#plt.tight_layout(w_pad = -5, h_pad=0)
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
plt.show()

## Low probability observations

In [None]:
#Getting indices of observations that were hard to predict
not_low_probs_idx = correct_class_probs > 0.5
low_probs_labels = test_labels.clone().detach()
low_probs_labels[not_low_probs_idx] = -1 #Setting high prob. observations to -1
low_probs_labels = list(low_probs_labels.numpy())

low_probs_idx = [low_probs_labels.index(elem) for elem in set(test_labels_list)]

In [None]:
num_checkpoints = 7

#Indices of the images we wish to show

#First occurence of each class
#image_idx = [test_labels_list.index(elem) for elem in set(test_labels_list)]

image_idx = low_probs_idx
num_images = len(image_idx)

In [None]:
class_prob = correct_class_probs[image_idx].numpy()

In [None]:
#Creating tensors containing decoded images
c_1_decoded = decode_checkpoint(features_testing_1, image_idx, decoder_1)
c_2_decoded = decode_checkpoint(features_testing_2, image_idx, decoder_2)
c_3_decoded = decode_checkpoint(features_testing_3, image_idx, decoder_3)
c_4_decoded = decode_checkpoint(features_testing_4, image_idx, decoder_4)
c_5_decoded = decode_checkpoint(features_testing_5, image_idx, decoder_5)
c_6_decoded = decode_checkpoint(features_testing_6, image_idx, decoder_6)

images_to_plot = torch.zeros(num_checkpoints, num_images, 3, 32, 32)
images_to_plot[0] = targets_testing[tuple(image_idx), :, :, :]
images_to_plot[1] = c_1_decoded
images_to_plot[2] = c_2_decoded
images_to_plot[3] = c_3_decoded
images_to_plot[4] = c_4_decoded
images_to_plot[5] = c_5_decoded
images_to_plot[6] = c_6_decoded

In [None]:
#Predicting decodings with resnet
resnet_out = torch.zeros(num_images, num_checkpoints, 10)
resnet.eval()
for i in range(num_images):
  for c in range(num_checkpoints):
    with torch.no_grad():
      resnet_out[i, c] = resnet(images_to_plot[c, i].unsqueeze(0).to(device))

In [None]:
#Getting resnet probabilities
resnet_probs = F.softmax(resnet_out, dim = 2)
resnet_preds = resnet_probs.argmax(dim = 2)

#Getting resnet predictions as strings
resnet_preds_str = np.empty((num_images, num_checkpoints), dtype = "U10")

for i in range(num_images):
  for c in range(num_checkpoints):
    resnet_preds_str[i, c] = classes[resnet_preds[i, c]]

In [None]:
#Plotting decodings
fig, axes = plt.subplots(num_images, num_checkpoints, figsize = (15, 6))

for i in range(0, num_images):
  current_class = test_labels[image_idx[i]]

  axes[i, 0].set_ylabel(classes[current_class], # + "\np = " + str(round(class_prob[i], 4)),
                        rotation=0, size='medium',
                        ha = "right", va = "center", weight = "bold")
  for j in range(0, num_checkpoints):
    image = images_to_plot[j, i]
    image = image.permute(1, 2, 0).to("cpu")
    image = (image - image.min()) / (image.max() - image.min())
    axes[i, j].imshow(image)
    axes[i, j].spines['top'].set_visible(False)
    axes[i, j].spines['right'].set_visible(False)
    axes[i, j].spines['bottom'].set_visible(False)
    axes[i, j].spines['left'].set_visible(False)
    axes[i, j].set_xlabel(resnet_preds_str[i, j].item(),
                          rotation=0, ha = "center", size = "small")#, size='large'

axes[0, 0].set_title("Original", rotation=0, size='medium', ha = "center",
                     weight = "bold")

for k in range(1, num_checkpoints):
  axes[0, k].set_title("CP " + str(k), rotation=0, size='medium',
                       ha = "center", weight = "bold")

# Remove padding and margin around the subplots
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=-0.93, hspace=0.5)

# Remove the white space around the figure
#plt.tight_layout(w_pad = -5, h_pad=0)
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
plt.show()

## Decoding artificial final layer values

In [None]:
#Function for decoding individual embeddings into images
import numpy as np

def decode_image(model, observation, show = False, device = "cpu"):
  model.eval()
  with torch.no_grad():
    decoded_image = model(observation.unsqueeze(0))

  if show == True:
    # The tensor is of dimensions (3, 32, 32), but imshow expects (32, 32, 3)
    decoded_image = decoded_image.squeeze(0).permute(1, 2, 0)

    #Normalizing pixel values to 0-1
    decoded_image = (decoded_image - decoded_image.min()) / (decoded_image.max() - decoded_image.min())
    decoded_image = decoded_image.to("cpu")
    plt.imshow(decoded_image)
    plt.show()

  else:
    return decoded_image

In [None]:
test_obs = features_testing_6[0]
test_obs

In [None]:
test_obs = torch.zeros(10)
test_obs[3] = 10

In [None]:
decode_image(decoder_6, test_obs, show = True)

In [None]:
#Creating artificial data for each class

num_classes = 10

num_images = 15

classes = ["airplane",
           "automobile",
           "bird",
           "cat",
           "deer",
           "dog",
           "frog",
           "horse",
           "ship",
           "truck"]

artificial_data = torch.zeros(num_classes, num_images, 10)
for i in range(num_classes):
  artificial_data[i, :, i] = torch.arange(-2, (num_images - 2), 1)# - 1)

In [None]:
#Decoding artificial data
decoded_tensor = torch.zeros(num_classes, num_images, 3, 32, 32)
for i in range(num_classes):
  for j in range(num_images):
    decoded_tensor[i, j] = decode_image(decoder_6, artificial_data[i, j])

In [None]:
#Predicting decodings with resnet
resnet_artificial_out = torch.zeros(num_classes, num_images, 10)
resnet.eval()
for i in range(num_classes):
  for j in range(num_images):
    with torch.no_grad():
      resnet_artificial_out[i, j] = resnet(decoded_tensor[i, j].unsqueeze(0).to(device))

In [None]:
resnet_artificial_probs = F.softmax(resnet_artificial_out, dim = 2)
resnet_artificial_preds = resnet_artificial_probs.argmax(dim = 2)
resnet_artificial_preds_str = np.empty((num_classes, num_images), dtype = "U10")

for i in range(num_classes):
  for j in range(num_images):
    resnet_artificial_preds_str[i, j] = classes[resnet_artificial_preds[i, j]]

In [None]:
fig, axes = plt.subplots(num_classes, num_images, figsize = (15, 6))

for i in range(0, num_classes):
  axes[i, 0].set_ylabel(classes[i], rotation=0, size='medium', ha = "right",
                        weight = "bold")
  for j in range(0, num_images):
    image = decoded_tensor[i, j]
    image = image.permute(1, 2, 0).to("cpu")
    image = (image - image.min()) / (image.max() - image.min())
    axes[i, j].imshow(image)
    axes[i, j].spines['top'].set_visible(False)
    axes[i, j].spines['right'].set_visible(False)
    axes[i, j].spines['bottom'].set_visible(False)
    axes[i, j].spines['left'].set_visible(False)
    axes[i, j].set_xlabel(resnet_artificial_preds_str[i, j].item(),
                          rotation=0, ha = "center", size = "small")

for k in range(num_images):
  axes[0, k].set_title(k - 2, rotation=0, size='medium', ha = "center",
                       weight = "bold")

# Remove padding and margin around the subplots
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=-0.92, hspace=0.5)

# Remove the white space around the figure
#plt.tight_layout(w_pad = -5, h_pad=0)
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
plt.show()

In [None]:
cat_varied_data = torch.zeros((2, 12, 10))
cat_varied_data[0, :, 5] = torch.tensor(0.4)
cat_varied_data[:, :, 3] = torch.arange(-2.0, 10.0)

In [None]:
fig, axes = plt.subplots(2, 12, figsize = (15, 6))

for i in range(0, 2):
  axes[i, 0].set_ylabel(f"Dog: {round(cat_varied_data[i, 0, 5].item(), 1)}",
                        rotation = 0, ha = "right")
  for j in range(0, 12):
    image = decode_image(decoder_6, cat_varied_data[i, j])
    axes[i, j].imshow(image.unsqueeze(0).permute(1, 2, 0))
    axes[i, j].spines['top'].set_visible(False)
    axes[i, j].spines['right'].set_visible(False)
    axes[i, j].spines['bottom'].set_visible(False)
    axes[i, j].spines['left'].set_visible(False)
    axes[i, j].set_xlabel(f"Cat: {cat_varied_data[i, j, 3]}")
plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[])
plt.show()

## Generating images between dog and cat

In [None]:
def get_mix(class1, class2, max_val = 9, show = True):
  classes = {"airplane" : 0,
             "automobile" : 1,
             "bird" : 2,
             "cat" : 3,
             "deer" : 4,
             "dog" : 5,
             "frog" : 6,
             "horse": 7,
             "ship" : 8,
             "truck" : 9
           }

  mix_tensor = torch.zeros(max_val, 10)
  mix_tensor[:, classes[class2]] = torch.arange(0.0, max_val) #class2 values
  mix_tensor[:, classes[class1]] = -1 * mix_tensor[:, classes[class2]] + (max_val - 1)

  if show == True:
    fig, axes = plt.subplots(1, max_val, figsize = (15, 6))
    for i in range(0, max_val):
      image = decode_image(decoder_6, mix_tensor[i])
      axes[i].imshow(image)
      axes[i].axis("off")
      val1 = mix_tensor[i, classes[class1]].item()
      val2 = mix_tensor[i, classes[class2]].item()
      axes[i].set_title(f"{class1}: {val1}\n{class2}: {val2}",
                        fontsize=12, ha = "center", va = "top")
    plt.show()

  else:
    return mix_tensor


In [None]:
import plotly.express as px

def get_mix_animation(class1, class2, max_val = 9, show = True):
  classes = {"airplane" : 0,
             "automobile" : 1,
             "bird" : 2,
             "cat" : 3,
             "deer" : 4,
             "dog" : 5,
             "frog" : 6,
             "horse": 7,
             "ship" : 8,
             "truck" : 9
           }

  mix_tensor = torch.zeros(max_val * 2, 10)
  mix_tensor[:, classes[class2]] = torch.arange(0.0, max_val, 0.5) #class2 values
  mix_tensor[:, classes[class1]] = -1 * mix_tensor[:, classes[class2]] + (max_val - 1)

  res_tensor = torch.zeros(max_val * 2, 32, 32, 3)
  for i in range(max_val * 2):
    res_tensor[i] = decode_image(decoder_6, mix_tensor[i])

  fig =px.imshow(res_tensor, animation_frame = 0)
  if show == True:
    fig.show()
  else:
    return fig

In [None]:
animation = get_mix_animation("cat", "dog", show = True)

In [None]:
get_mix("cat", "dog")

In [None]:
get_mix("airplane", "ship")

In [None]:
get_mix("bird", "horse")

In [None]:
get_mix("ship", "cat")

In [None]:
get_mix("cat", "horse")

In [None]:
get_mix("automobile", "horse")

In [None]:
get_mix("cat", "frog")