# Multi-input models

In [1]:
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

from torchvision import transforms

import torch.optim as optim

from torchmetrics import Accuracy

## Two-input dataset

In [5]:
class OmniglotDataset(Dataset):
  def __init__(self, transform, samples):
    self.transform = transform
    self.samples = samples
                  
  def __len__(self):
    return len(self.samples)

  def __getitem__(self, idx):
    img_path, alphabet, label = self.samples[idx]
    img = Image.open(img_path).convert('L')
    img_transformed = self.transform(img)
    return img_transformed, alphabet, label

## Two-input model

In [6]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.image_layer = nn.Sequential(
      nn.Conv2d(1, 16, kernel_size=3, padding=1),
      nn.MaxPool2d(kernel_size=2),
      nn.ELU(),
      nn.Flatten(),
      nn.Linear(16*32*32, 128)
    )
    self.alphabet_layer = nn.Sequential(
      nn.Linear(30, 8),
      nn.ELU(), 
    )
    self.classifier = nn.Sequential(
      nn.Linear(128 + 8, 964), 
    )
      
  def forward(self, x_image, x_alphabet):
    x_image = self.image_layer(x_image)
    x_alphabet = self.alphabet_layer(x_alphabet)
    x = torch.cat((x_image, x_alphabet), dim=1)
    return self.classifier(x)

# Multi-output models

## Two-output DataSet and DataLoader

In [None]:
print(samples[100])

dataset_train = OmniglotDataset(
    transform=transforms.Compose([
        transforms.ToTensor(),
      	transforms.Resize((64, 64)),
    ]),
    samples=samples,
)

dataloader_train = DataLoader(
  dataset_train, shuffle=True, batch_size=32
)

## Two-output model architecture

In [None]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.image_layer = nn.Sequential(
      nn.Conv2d(1, 16, kernel_size=3, padding=1),
      nn.MaxPool2d(kernel_size=2),
      nn.ELU(),
      nn.Flatten(),
      nn.Linear(16*32*32, 128)
    )
    self.classifier_alpha = nn.Linear(128, 30)
    self.classifier_char = nn.Linear(128, 964)
      
  def forward(self, x):
    x_image = self.image_layer(x)
    output_alpha = self.classifier_alpha(x_image)
    output_char = self.classifier_char(x_image)
    return output_alpha, output_char

## Training multi-output model

In [None]:
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.05)

for epoch in range(1):
  for images, labels_alpha, labels_char in dataloader_train:
    optimizer.zero_grad()
    outputs_alpha, outputs_char = net(images)
    loss_alpha = criterion(outputs_alpha, labels_alpha)
    loss_char = criterion(outputs_char, labels_char)
    loss = loss_alpha + loss_char
    loss.backward()
    optimizer.step()

# Evaluation of multi-output models and loss weighting

## Multi-output evaluation

In [None]:
def evaluate_model(model):
  acc_alpha = Accuracy(task="multiclass", num_classes=30)
  acc_char = Accuracy(task="multiclass", num_classes=964)

  model.eval()
  with torch.no_grad():
    for images, labels_alpha, labels_char in dataloader_test:
      outputs_alpha, outputs_char = model(images)
      _, pred_alpha = torch.max(outputs_alpha, 1)
      _, pred_char = torch.max(outputs_char, 1)
      acc_alpha(pred_alpha, labels_alpha)
      acc_char(pred_char, labels_char)
  
  print(f"Alphabet: {acc_alpha.compute()}")
  print(f"Character: {acc_char.compute()}")