## Name: Raffaello Baluyot
## Course: DT8058

<center><h1 style="font-size:40px;">Image Classification using CNNs</h1></center>

Welcome to the *Fourth* lab for Deep Learning!

In this lab an CNN network to classify RGB images. Image classification refers to classify classes from images. This labs the *dataset* consist of multiple images where each image have a target label for classification.

Good luck!


In [None]:
import torch
from torch import nn
import numpy as np
import os
import imageio
import torchvision
import math
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from collections import OrderedDict
import copy
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import seaborn as sns
import pandas as pd
import numpy as np
from typing import List

sns.set()

### In case you have uploaded a zip file unzip it first. 

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
torch.set_default_device(device)
batch_size = 256

In [None]:
# !unzip ../data/FlyingObjectDataset_10K.zip -d ../data/

Then lets define the ```path``` to the dataset. Make sure you explore the directories of the dataset and get familiar with it!

In [None]:
training_img_dir = "../data/FlyingObjectDataset_10K/training"
validation_img_dir = "../data/FlyingObjectDataset_10K/validation"
testing_img_dir = "../data/FlyingObjectDataset_10K/testing"

Optionally we will start using ```tensorboard```. Use of tensorboard is optional, but we recommend the students to use it. 

In [None]:
# dummy_train_loss, _ = torch.sort(torch.rand(100) * 5, descending=True)
# dummy_val_loss, _ = torch.sort(torch.rand(100) * 6, descending=True )
# dummy_train_acc, _ = torch.sort(torch.rand(100) * 100)
# dummy_val_acc, _ = torch.sort(torch.rand(100) * 80)
# for i in range(100): 
#     writer.add_scalar('Loss/Train', dummy_train_loss[i], i)
#     writer.add_scalar('Loss/Val', dummy_val_loss[i], i)
#     writer.add_scalar('Acc/Train', dummy_train_acc[i], i)
#     writer.add_scalar('Acc/Val', dummy_val_acc[i], i)


Please make sure to read the [doc](https://pytorch.org/tutorials/intermediate/tensorboard_tutorial.html) to understand how to correctly plot your ```losses``` and ```metrics``` to tensorboard

Ok now that we have the path to the tree different splits, lets start by defining our ```Dataset``` class!
The main two methods we need to define when subclassing ```Dataset``` is ```__getitem__``` and ```__len__```:

In [None]:
class FlyingObjects(torch.utils.data.Dataset):
    """Dataset to Flying Object Dataset for the classification task.
       The label information is encoded on the filename, __extract_label will extract the label following the chosen granularity
    
    """
    def __init__(self, root,fine_grained=False,transform=None):
        super(FlyingObjects,self).__init__()
        self.root = root
        self.transform = transform
        self.fine_grained = fine_grained

        self.images = [os.path.join(dp,f) for dp, dn, fn in os.walk(os.path.expanduser(self.root+'/image')) for f in fn if f.endswith(".png")]
        self.images.sort()
        
        self.classes = [
            'square_red',
            'square_green',
            'square_blue',
            'square_yellow',
            'triangle_red',
            'triangle_green',
            'triangle_blue',
            'triangle_yellow',
            'circular_red',
            'circular_green',
            'circular_blue',
            'circular_yellow'
        ] if self.fine_grained else [
            'square',
            'triangle',
            'circular',
            'background']
        self.labels = [self.__extract_label(f) for f in self.images]
    

    def get_classes(self):
        return self.classes
    
    
    def __extract_label(self, image_file):
        """Extract label from image_file name"""
        path, img_name = os.path.split(image_file)
        names = img_name.split(".")[0].split("_")

        currLabel = names[1] + "_" + names[2] if self.fine_grained else names[1]

        if currLabel in self.classes:
            label = self.classes.index(currLabel)
        else:
            raise ValueError("ERROR: Label " + str(currLabel) + " is not defined!")
        return label
    
    def __getitem__(self, index):
        # get data
        x = imageio.imread(self.images[index])
        if self.transform:
            x = self.transform(x)
        else:
            x = torch.from_numpy(x)
        x = x.float()
   
        # get label
        y = self.labels[index]
        #y = np.eye(len(self.get_classes()))[y]
        #y = torch.tensor(y)
        return x, y

    def __len__(self):
        return len(self.images)

We can define our transformations. Note that not all transformations are considered ```Data Augmentation```.
The following transformations are used to convert our data to ```Tensor``` and resize our images to the desired size!

In [None]:
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((64, 64)), 
])
test_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((64, 64))
])

# Question 1

Define the three dataloaders for the three splits: ```train```, ```validation``` and ```test``` and visualize data from each split. A function to visualize the image with label is given. You are free to use it or make your own visualizaiton tools.  

In [None]:
def image_with_labels(dataloader, title:str=None, nimages:int=10, nrows:int=2, fig_dimension=1,title_size=10, label_prefix="Label: ", label_transform = lambda l:l):
    """Creates a grid of images with/without labels.

    :param title: str:  (Default value = None)
    :param nimages: int:  (Default value = 10)
    :param nrows: int:  (Default value = 2)
    :param fig_dimension: Default value = 1)
    :param data:"tensor": 
    :param labels:"tensor":  (Default value = None)
    :param title:str:  (Default value = None)
    :param nimages:int:  (Default value = 10)
    :param nrows:int:  (Default value = 2)

    """
    indices = np.random.choice(np.arange(len(dataloader.dataset)), nimages, replace=False)
    data, labels = zip(*[dataloader.dataset[i] for i in indices])
    data = torch.stack(list(data)).permute(0,2,3,1)

    image_ratio = data[0].shape[0] /data[0].shape[1]
    if len(data)< nimages:
        nimages = len(data)
 
    columns = math.ceil(nimages/nrows)
    
    if nrows*columns > nimages:
        nrows = math.ceil(nimages/columns)
    
    fig = plt.figure(figsize=(fig_dimension*columns,1.4*fig_dimension*nrows*image_ratio))
    for i in range(0, nimages):
        ax = fig.add_subplot(nrows, columns, i+1)
        ax.imshow(data[i])
        ax.set_xlabel(f"{label_prefix}{label_transform(labels[i])}") if labels is not None else None
        ax.set_xticks([])
        ax.set_yticks([])
        ax.grid(False)

    if labels is None:
        fig.suptitle(title,x=0.5, y=.95, size=title_size) 
        
        fig.subplots_adjust(
            left=0,
            right=0.9,
            top=0.9,
            bottom=0,
            wspace=0.1,
            hspace=-0.45
        )
    else:
        fig.suptitle(title) #,x=0.45, y=.95
        
        fig.subplots_adjust(
            #left=0,
            #right=1,
            top=0.9,#+((nrows-1)*0.045),
            #bottom=0,
            wspace=0,
            #hspace=0
        )
        
    #plt.tight_layout(h_pad=0,w_pad=0)
    fig.tight_layout(pad=0, h_pad=0,w_pad=0)
    plt.show()

In [None]:
train_loader = DataLoader(
    FlyingObjects(root=training_img_dir, transform=train_transform), 
    batch_size=batch_size, 
    shuffle=True,
    generator=torch.Generator(device='cuda')
)
valid_loader = DataLoader(FlyingObjects(root=validation_img_dir, transform=test_transform), batch_size=batch_size)
test_loader = DataLoader(FlyingObjects(root=testing_img_dir, transform=test_transform), batch_size=batch_size)

In [None]:
image_with_labels(train_loader, nimages=15, nrows=3)

In [None]:
image_with_labels(valid_loader, nimages=15, nrows=3)

In [None]:
image_with_labels(test_loader, nimages=15, nrows=3)

## Let's start with a very simple network

In [None]:
class SimpleModel(torch.nn.Module):
    def __init__(self,num_channels, num_classes, input_shape=(64,64)):
        super(SimpleModel,self).__init__()
        self.conv_layer1 = self._conv_layer_set(num_channels, 32)
        self.conv_layer2 = self._conv_layer_set(32, 64)
        self.fc1 = nn.Linear(64*input_shape[1]//4*input_shape[1]//4, 64)
        self.fc2 = nn.Linear(64, num_classes)
        self.drop = nn.Dropout(0.5)
        self.act = nn.ReLU(inplace=False)
        
    def _conv_layer_set(self, in_c, out_c):
        conv_layer = nn.Sequential(OrderedDict([
            ('conv',nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)),
            ('leakyrelu',nn.LeakyReLU()),
            ('maxpool',nn.MaxPool2d(2)),
        ]))
        return conv_layer
    

    def forward(self, x):
        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
       
        out = out.view(out.size(0), -1)

        out = self.fc1(out)
        # out = self.act(out)
        out = self.drop(out)
        out = self.fc2(out)
        return out

# Question 2

Get inspired on the code you did on the previous lab and create your ```train function```. It might be useful to think about having a ```predict``` function too. Check the code of the previous lab if you need ideas!

Do not forget, to train you need an ```optimizer```, ```loss function``` and an instance of your ```model```! If you need more inspiration check [here](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)!

In [None]:
def train_epoch(
        optimizer:torch.optim.Optimizer, 
        loss: torch.nn.Module, 
        model: torch.nn.Module, 
        train_loader: DataLoader
):
    total_loss = 0
    total_items = 0
    model.train(True)

    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        loss_t = loss(outputs, labels)
        loss_t.backward()

        optimizer.step()
        
        n_items = len(inputs)
        total_loss += loss_t.item() * n_items
        total_items += n_items

    return total_loss / total_items

def validate_epoch(
        loss: torch.nn.Module, 
        model: torch.nn.Module, 
        val_loader: DataLoader
):
    total_loss = 0
    total_items = 0
    model.eval()

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)

            n_items = len(inputs)
            total_loss += loss(outputs, labels).item() * n_items
            total_items += n_items

    return total_loss / total_items


def training_loop(num_epoch, model, optimizer, loss, train_loader, val_loader):
    best_val_loss = np.inf
    best_model = None

    train_losses = list()
    val_losses = list()

    for epoch in range(num_epoch):
        train_loss = train_epoch(optimizer, loss, model, train_loader)
        val_loss = validate_epoch(loss, model, val_loader)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        if val_loss < best_val_loss:
            best_model = copy.deepcopy(model)
            best_val_loss = val_loss
        print(f"epoch {epoch + 1}: loss: {train_loss:0.4f} val loss: {val_loss:0.4f}")

    return best_model, train_losses, val_losses


def predict(model: torch.nn.Module, test_loader: DataLoader):
    with torch.no_grad():
        true = []
        pred = []
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            predictions = outputs.argmax(-1)

            true.append(labels)
            pred.append(predictions)

    return torch.cat(true).cpu().numpy(), torch.cat(pred).cpu().numpy()


# Question 3

Now that you have your train function. Train the network until it overfits. Which ```hyperparameters``` allowed you to overfit?

## Answer

It's hard to say which ```hyperparameters``` caused the overfitting. The size of the network played a role in this, since a small network will underfit and will not have enough complexity to memorize the input. Training epoch as well as the optimization algorithm also played a role in overfitting.

After playing with the dataset for a while, it looks like the training distribution is hard for the network to generalize to the validation and testing data as it is. That's why adding transformations improves the result.

To help you visualize the results we provide a ```confusion matrix function```. 

In [None]:
def matrix(y_true, y_pred, classes):
    cf_matrix = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cf_matrix,display_labels=classes)
    disp.plot()

In [None]:
model = SimpleModel(num_channels=3, num_classes=3).to(device)
critereon = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
torch.cuda.empty_cache()
best_model, train_losses, val_losses = training_loop(10, model, optimizer, critereon, train_loader, valid_loader)

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("Cross Entropy Loss")
plt.legend()
plt.show()

In [None]:
test_true, test_pred = predict(best_model, test_loader)
matrix(test_true, test_pred, train_loader.dataset.classes[:-1])

# Question 4
Go through the [doc](https://pytorch.org/vision/stable/transforms.html) about data augmentation transformations and use some (Try at least 5 augmentations) in your pipeline. Did the ones you try improve your model? Why? 

Along with ```torchvision``` you can also expore ```https://albumentations.ai/``` for advanced augmentation. 

## Answer

The augmentations improve the model in the sense that now it does not overfit. However it still does not learn as much. It improved the model because the transformations increased the distribution covered by the dataset, hence now the model has a hard time learning it.

In [None]:
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.RandomChoice([
        torchvision.transforms.RandomRotation((0, 0), fill=1),
        torchvision.transforms.RandomRotation((180, 180), fill=1),
        torchvision.transforms.RandomRotation((90, 90), fill=1),
        torchvision.transforms.RandomRotation((270, 270), fill=1),
    ]),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.Resize((64, 64)), 
])

train_loader = DataLoader(FlyingObjects(root=training_img_dir, transform=train_transform), batch_size=batch_size)

In [None]:
image_with_labels(train_loader, nimages=15, nrows=3)

In [None]:
model = SimpleModel(num_channels=3, num_classes=3).to(device)
critereon = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
torch.cuda.empty_cache()
best_model, train_losses, val_losses = training_loop(20, model, optimizer, critereon, train_loader, valid_loader)

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("Cross Entropy Loss")
plt.legend()
plt.show()

In [None]:
test_true, test_pred = predict(best_model, test_loader)
matrix(test_true, test_pred, train_loader.dataset.classes[:-1])

# Question 5

Redo the previous questions with an image size of ```128x128```. Make sure to note what changed and why. If you decided to use tensorboard, compare both versions on ```Tensorboard``` plots.

## Answer

The model can now overfit even with the augmentations present. I guess the model has more details to work with and it was able to pick up some of the features from the higher resolution images. 

In [None]:
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.RandomChoice([
        torchvision.transforms.RandomRotation((0, 0), fill=1),
        torchvision.transforms.RandomRotation((180, 180), fill=1),
        torchvision.transforms.RandomRotation((90, 90), fill=1),
        torchvision.transforms.RandomRotation((270, 270), fill=1),
    ]),
    torchvision.transforms.RandomChoice([
        torchvision.transforms.Pad(0, fill=1),
        *[
            torchvision.transforms.Pad(i, fill=1)
            for i in [3, 6, 9]
        ],
        *[
            torchvision.transforms.CenterCrop((128-i, 128-i))
            for i in [3, 6, 9]
        ],
    ]),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.Resize((128, 128)), 
])

test_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((128, 128))
])

train_loader = DataLoader(
    FlyingObjects(root=training_img_dir, transform=train_transform), 
    batch_size=batch_size, 
    shuffle=True,
    generator=torch.Generator(device=device)
)
valid_loader = DataLoader(FlyingObjects(root=validation_img_dir, transform=test_transform), batch_size=batch_size)
test_loader = DataLoader(FlyingObjects(root=testing_img_dir, transform=test_transform), batch_size=batch_size)

In [None]:
image_with_labels(train_loader, nimages=15, nrows=3)

In [None]:
model = SimpleModel(num_channels=3, num_classes=3, input_shape=(128, 128)).to(device)
critereon = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
torch.cuda.empty_cache()
best_model, train_losses, val_losses = training_loop(20, model, optimizer, critereon, train_loader, valid_loader)

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("Cross Entropy Loss")
plt.legend()
plt.show()

In [None]:
test_true, test_pred = predict(best_model, test_loader)
matrix(test_true, test_pred, train_loader.dataset.classes[:-1])

# Question 6

Once you have a good model for ```128x128``` without ```fine grain``` redo the experiments with ```fine grain```. How did the performance change? And why?

## Answer

The model performance has degraded. One, the problem is harder and so there's a lot of things for the model to be confused with. 

Another thing is that it takes the model a longer time to be able to arrive to an relatively acceptable performance during the training. This makes sense since it needs to learn to also separate the different colors.

## Good 128x128 Model

In [None]:
class DeepModel(torch.nn.Module):
    def __init__(self,num_channels, num_classes, input_shape=(64,64)):
        super().__init__()
        self.conv_layer1 = self._conv_layer_set(num_channels, 16)
        self.conv_layer2 = self._conv_layer_set(16, 32)
        self.conv_layer3 = self._conv_layer_set(32, 64)
        self.conv_layer4 = self._conv_layer_set(64, 128)
        flat_shape = 128 * input_shape[0] // 2**4 * input_shape[1] // 2 ** 4
        self.fc1 = nn.Linear(flat_shape, num_classes)
        self.drop = nn.Dropout(0.5)
        
    def _conv_layer_set(self, in_c, out_c):
        conv_layer = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)),
            ('leakyrelu', nn.LeakyReLU()),
            ('maxpool', nn.MaxPool2d(2)),
        ]))
        return conv_layer
    

    def forward(self, x):
        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = self.conv_layer3(out)
        out = self.conv_layer4(out)
       
        out = out.view(out.size(0), -1)

        out = self.drop(out)
        out = self.fc1(out)
        return out

In [None]:
model = DeepModel(num_channels=3, num_classes=3, input_shape=(128, 128)).to(device)
critereon = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
torch.cuda.empty_cache()
best_model, train_losses, val_losses = training_loop(20, model, optimizer, critereon, train_loader, valid_loader)

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("Cross Entropy Loss")
plt.legend()
plt.show()

In [None]:
test_true, test_pred = predict(best_model, test_loader)
matrix(test_true, test_pred, train_loader.dataset.classes[:-1])

In [None]:
print(classification_report(test_true, test_pred))

## Fine Grained Model

In [None]:
train_loader = DataLoader(
    FlyingObjects(root=training_img_dir, transform=train_transform, fine_grained=True), 
    batch_size=batch_size, 
    shuffle=True,
    generator=torch.Generator(device=device)
)
valid_loader = DataLoader(FlyingObjects(root=validation_img_dir, transform=test_transform, fine_grained=True), batch_size=batch_size)
test_loader = DataLoader(FlyingObjects(root=testing_img_dir, transform=test_transform, fine_grained=True), batch_size=batch_size)

In [None]:
model = DeepModel(num_channels=3, num_classes=12, input_shape=(128, 128)).to(device)
critereon = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
torch.cuda.empty_cache()
best_model, train_losses, val_losses = training_loop(100, model, optimizer, critereon, train_loader, valid_loader)

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("Cross Entropy Loss")
plt.legend()
plt.show()

In [None]:
test_true, test_pred = predict(best_model, test_loader)
matrix(test_true, test_pred, ["".join(w[0] for w in c.split("_")) for c in train_loader.dataset.classes])

In [None]:
print(classification_report(test_true, test_pred))

# Question 7
Change the model and dataset to predict both shape and color of the flying object separately. Hint: The model may have 2 output heads. One should predict the color and another should predict the shape. Report the changes you made along with the results. 

## Answer

Made the model large to accommodate two heads. Though while the performance of the two headed model for the individual task is fine, combining the result does not make it better than the previous model.

In [None]:
class FlyingObjectsMulti(torch.utils.data.Dataset):
    """Dataset to Flying Object Dataset for the classification task.
       The label information is encoded on the filename, __extract_label will extract the label following the chosen granularity
    
    """
    def __init__(self, root, transform=None):
        super().__init__()
        self.root = root
        self.transform = transform

        self.images = [os.path.join(dp,f) for dp, dn, fn in os.walk(os.path.expanduser(self.root+'/image')) for f in fn if f.endswith(".png")]
        self.images.sort()
        
        self.color_classes = [
            "red",
            "green",
            "blue",
            "yellow",
        ]
        self.shape_classes = [
            "square",
            "triangle",
            "circular",
        ]
        self.labels = [self.__extract_label(f) for f in self.images]
    
    def __extract_label(self, image_file):
        """Extract label from image_file name"""
        path, img_name = os.path.split(image_file)
        names = img_name.split(".")[0].split("_")

        shape = names[1]
        color = names[2]

        shape_class = self.shape_classes.index(shape)
        color_class = self.color_classes.index(color)

        return torch.LongTensor([shape_class, color_class])
    
    def __getitem__(self, index):
        # get data
        x = imageio.imread(self.images[index])
        if self.transform:
            x = self.transform(x)
        else:
            x = torch.from_numpy(x)
        x = x.float()
   
        # get label
        y = self.labels[index]
        #y = np.eye(len(self.get_classes()))[y]
        #y = torch.tensor(y)
        return x, y

    def __len__(self):
        return len(self.images)

In [None]:
train_loader = DataLoader(
    FlyingObjectsMulti(root=training_img_dir, transform=train_transform), 
    batch_size=batch_size, 
    shuffle=True,
    generator=torch.Generator(device=device)
)
valid_loader = DataLoader(FlyingObjectsMulti(root=validation_img_dir, transform=test_transform), batch_size=batch_size)
test_loader = DataLoader(FlyingObjectsMulti(root=testing_img_dir, transform=test_transform), batch_size=batch_size)

In [None]:
image_with_labels(train_loader, nimages=15, nrows=3, label_transform=lambda l: l.tolist())

In [None]:
def train_epoch_multi(
        optimizer:torch.optim.Optimizer, 
        losses: List[torch.nn.Module], 
        model: torch.nn.Module, 
        train_loader: DataLoader
):
    total_loss = 0
    total_items = 0
    model.train(True)

    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        multi_loss = 0.0
        for idx, loss in enumerate(losses):
            multi_loss += loss(outputs[idx], labels[:, idx]) ** 2
        multi_loss.backward()

        optimizer.step()
        
        n_items = len(inputs)
        total_loss += multi_loss.item() * n_items
        total_items += n_items

    return total_loss / total_items

def validate_epoch_multi(
        losses: List[torch.nn.Module], 
        model: torch.nn.Module, 
        val_loader: DataLoader
):
    total_loss = 0
    total_items = 0
    model.eval()

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)

            multi_loss = 0.0
            for idx, loss in enumerate(losses):
                multi_loss += loss(outputs[idx], labels[:, idx]) ** 2

            n_items = len(inputs)
            total_loss += multi_loss.item() * n_items
            total_items += n_items

    return total_loss / total_items


def training_loop_multi(num_epoch, model, optimizer, losses, train_loader, val_loader):
    best_val_loss = np.inf
    best_model = None

    train_losses = list()
    val_losses = list()

    for epoch in range(num_epoch):
        train_loss = train_epoch_multi(optimizer, losses, model, train_loader)
        val_loss = validate_epoch_multi(losses, model, val_loader)
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        if val_loss < best_val_loss:
            best_model = copy.deepcopy(model)
            best_val_loss = val_loss
        print(f"epoch {epoch + 1}: loss: {train_loss:0.4f} val loss: {val_loss:0.4f}")

    return best_model, train_losses, val_losses


def predict_multi(model: torch.nn.Module, test_loader: DataLoader):
    with torch.no_grad():
        true = []
        pred = []
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            predictions = torch.stack([output.argmax(-1) for output in outputs], dim=-1)

            true.append(labels)
            pred.append(predictions)

    return torch.cat(true).cpu().numpy(), torch.cat(pred).cpu().numpy()

In [None]:
class DeepMultiModel(torch.nn.Module):
    def __init__(self, num_channels: int, head_num_classes: List[int], input_shape=(64,64)):
        super().__init__()
        self.conv_layer1 = self._conv_layer_set(num_channels, 32)
        self.conv_layer2 = self._conv_layer_set(32, 64)
        self.conv_layer3 = self._conv_layer_set(64, 128)
        self.conv_layer4 = self._conv_layer_set(128, 256)

        flat_shape = 256 * input_shape[0] // 2**4 * input_shape[1] // 2 ** 4
        self.fc1 = [nn.Linear(flat_shape, n) for n in head_num_classes]
        self.drop = nn.Dropout(0.7)
        self.act = nn.ReLU()
        
    def _conv_layer_set(self, in_c, out_c):
        conv_layer = nn.Sequential(OrderedDict([
            ('conv', nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)),
            ('leakyrelu', nn.LeakyReLU()),
            ('maxpool', nn.MaxPool2d(2)),
        ]))
        return conv_layer
    

    def forward(self, x):
        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = self.conv_layer3(out)
        out = self.conv_layer4(out)

        out = out.view(out.size(0), -1)
        out = self.drop(out)

        # out = [conv(out) for conv in self.conv_layer4]
        # out = [o.view(o.size(0), -1) for o in out]
        # out = [fc(o) for fc, o in zip(self.fc1, out)]

        out = [fc(out) for fc in self.fc1]

        return out

In [None]:
model = DeepMultiModel(3, [3, 4], input_shape=(128, 128))
critereon = [torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss()]
optimizer = torch.optim.Adam(model.parameters())

In [None]:
torch.cuda.empty_cache()
best_model, train_losses, val_losses = training_loop_multi(100, model, optimizer, critereon, train_loader, valid_loader)

In [None]:
plt.plot(train_losses, label="train")
plt.plot(val_losses, label="val")
plt.title("Total Cross Entropy Losses")
plt.legend()
plt.show()

In [None]:
test_true, test_pred = predict_multi(best_model, test_loader)
matrix(test_true[:, 0], test_pred[:, 0], train_loader.dataset.shape_classes)
matrix(test_true[:, 1], test_pred[:, 1], train_loader.dataset.color_classes)

In [None]:
print(classification_report(test_true[:, 0], test_pred[:, 0]))
print(classification_report(test_true[:, 1], test_pred[:, 1]))
print(classification_report(test_true[:, 0] * 10 + test_true[:, 1], test_pred[:, 0] * 10 + test_pred[:, 1]))