# Transformer Application

## Configuration

### Hyperparameters

In [None]:
from CustomLoss import CustomEmbeddingSliceLoss
import torch.nn as nn
import matplotlib.pyplot as plt

# HYPERPARAMETERS
BATCH_SIZE = 32
LEARNING_RATE = 0.01

#transformer
NUM_HEADS = 6 # Dividers of 170: {1; 2; 3; 5; 6; 9; 10; 15; 18; 27; 30; 45; 54; 90; 135; 270}
NUM_ENCODER_LAYERS = 8
NUM_DECODER_LAYERS = 8
DROPOUT=0.1

# Methods
# - Optimizer SGD
loss_function = CustomEmbeddingSliceLoss()

### Constants

In [None]:
# CONSTANTS
FEATURE_DIM = 270

## Load Prepared Tensors from Disk
Run file `prototype_dataset.ipynb` first

In [None]:
import torch

train_sequence_input = torch.load('data/prototype_dataset/train_sequence_input.pt')
train_sequence_output = torch.load('data/prototype_dataset/train_sequence_output.pt')
test_sequence_input = torch.load('data/prototype_dataset/test_sequence_input.pt')
test_sequence_output = torch.load('data/prototype_dataset/test_sequence_output.pt')

In [None]:
from dataset_helper import warn_if_contains_NaN

warn_if_contains_NaN(train_sequence_input)
warn_if_contains_NaN(train_sequence_output)
warn_if_contains_NaN(test_sequence_input)
warn_if_contains_NaN(test_sequence_output)

In [None]:
print(train_sequence_input.size())
print(train_sequence_output.size())
print(test_sequence_output.size())
print(test_sequence_output.size())

## Build Dataloader with Batches

In [None]:
from torch.utils.data import DataLoader, TensorDataset

train_dataloader = DataLoader(TensorDataset(train_sequence_input.float(), train_sequence_output.float()),
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              drop_last=True)
val_dataloader = DataLoader(TensorDataset(test_sequence_input.float(), test_sequence_output.float()),
                            batch_size=BATCH_SIZE,
                            shuffle=False,
                            drop_last=True)

## Initialize

In [None]:
# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)
torch.__version__

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = "cpu"
print(device)

In [None]:
from torch import nn
from AnimationTransformer import AnimationTransformer

model = AnimationTransformer(
    dim_model=FEATURE_DIM,
    num_heads=NUM_HEADS,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    dropout_p=DROPOUT,
    pos_encoder_max_len=10
).to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

total_param = sum(p.numel() for p in model.parameters())
print(f"The model has {total_param} trainable parameters")

## Training

In [None]:
from AnimationTransformer import fit

train_loss_list, validation_loss_list = fit(model,
                                            optimizer,
                                            loss_function,
                                            train_dataloader,
                                            val_dataloader,
                                            epochs=10,
                                            device=device)

In [None]:
# Define the number of additional epochs you want to train for
additional_epochs = 5

# Continue training the model for more epochs
new_train_loss, new_validation_loss = fit(model,
                                          optimizer,
                                          loss_function,
                                          train_dataloader,
                                          val_dataloader,
                                          epochs=additional_epochs,
                                          device=device)

# Extend the original loss lists with the new loss values
train_loss_list.extend(new_train_loss)
validation_loss_list.extend(new_validation_loss)

## Training and Validation Loss Plot

In [None]:
import matplotlib.pyplot as plt

# Plot the training and validation loss
plt.figure(figsize=(10, 6))
plt.plot(train_loss_list, label='Training Loss')
plt.plot(validation_loss_list, label='Validation Loss')

# Add title and labels
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

# Add a legend
plt.legend()

# Show the plot
plt.show()

# Prediction

In [None]:
# Create a tensor of zeros with 270 elements
sos_token = torch.zeros(270)
# Set the value at the 256 index to 1
sos_token[256] = 1

In [None]:
from AnimationTransformer import predict

predict(model, test_sequence_input[45], sos_token=sos_token, device=device, max_length=5)

In [None]:
test_sequence_input[50]