# Transformer Application

## Configuration

### Hyperparameters

In [None]:
from CustomLoss import CustomEmbeddingSliceLoss
import torch.nn as nn
import matplotlib.pyplot as plt

# HYPERPARAMETERS
BATCH_SIZE = 64
LEARNING_RATE = 0.000082

#transformer
NUM_HEADS = 90 # Dividers of 270: {1; 2; 3; 5; 6; 9; 10; 15; 18; 27; 30; 45; 54; 90; 135; 270}
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 5
DROPOUT=0.15

# Methods
loss_function_val = CustomEmbeddingSliceLoss(weight_deep_svg=10, weight_type=0.91, weight_parameters=650, weight_eos=2.7)
loss_function_train = CustomEmbeddingSliceLoss(weight_deep_svg=10,
                                               weight_type=8,
                                               weight_parameters=95,
                                               weight_eos=1.3)
# CONSTANTS
FEATURE_DIM = 270

## Load Prepared Tensors from Disk
Run file `prototype_dataset.ipynb` first

In [None]:
import torch

train_sequence_input = torch.load('data/prototype_dataset/train_sequence_input_len_20.pt')
train_sequence_output = torch.load('data/prototype_dataset/train_sequence_output_len_20.pt')
test_sequence_input = torch.load('data/prototype_dataset/test_sequence_input_len_20.pt')
test_sequence_output = torch.load('data/prototype_dataset/test_sequence_output_len_20.pt')

In [None]:
from prototype_dataset_helper import warn_if_contains_NaN

warn_if_contains_NaN(train_sequence_input)
warn_if_contains_NaN(train_sequence_output)
warn_if_contains_NaN(test_sequence_input)
warn_if_contains_NaN(test_sequence_output)

In [None]:
print(train_sequence_input.size())
print(train_sequence_output.size())
print(test_sequence_input.size())
print(test_sequence_output.size())

## Build Dataloader with Batches

In [None]:
from torch.utils.data import DataLoader, TensorDataset

train_dataloader = DataLoader(TensorDataset(train_sequence_input.float(), train_sequence_output.float()),
                              batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_dataloader = DataLoader(TensorDataset(test_sequence_input.float(), test_sequence_output.float()),
                            batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
print(f"Train batches: {len(train_dataloader)}\n"
      f"Validation batches: {len(val_dataloader)}")

## Initialize

In [None]:
# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)
print(torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = "cpu"
print(device)

In [None]:
from torch import nn
from AnimationTransformer import AnimationTransformer

model = AnimationTransformer(
    dim_model=FEATURE_DIM,
    num_heads=NUM_HEADS,
    num_encoder_layers=NUM_ENCODER_LAYERS,
    num_decoder_layers=NUM_DECODER_LAYERS,
    dropout_p=DROPOUT,
    use_positional_encoder=False # No improvement
).to(device)

# optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

total_param = sum(p.numel() for p in model.parameters())
print(f"The model has {total_param} trainable parameters")

## Training

In [None]:
from AnimationTransformer import fit

train_loss_list, validation_loss_list, variance_list = fit(model,
                                            optimizer,
                                            loss_function_train,
                                            loss_function_val,
                                            train_dataloader,
                                            val_dataloader,
                                            epochs=10,
                                            device=device)

In [None]:
# Define the number of additional epochs you want to train for
additional_epochs = 5

# Continue training the model for more epochs
new_train_loss, new_validation_loss, new_variance_list = fit(model,
                                          optimizer,
                                          loss_function_train,
                                          loss_function_val,
                                          train_dataloader,
                                          val_dataloader,
                                          epochs=additional_epochs,
                                          device=device)

# Extend the original loss lists with the new loss values
train_loss_list.extend(new_train_loss)
validation_loss_list.extend(new_validation_loss)
variance_list.extend(new_variance_list)

In [None]:
def print_for_excel(list, title=""):
    print(f"{title};", ";".join([str(f"{element:.3f}").replace('.', ',') for element in list]))

In [None]:
from CreativityLoss import dict_list_to_list_dict

print_for_excel(train_loss_list, title="Train Loss")
print_for_excel(validation_loss_list, title="Validation Loss")

print_dict = dict_list_to_list_dict(variance_list)
#print_for_excel(print_dict["batch_variance"], title="Batch Variance")
print_for_excel(print_dict["val_loss_on_train"], title="Validation loss on train data")
print_for_excel(print_dict["batch_variance_deep_svg"], title="Batch Variance (deep_svg)")
print_for_excel(print_dict["batch_variance_type"], title="Batch Variance (type)")
print_for_excel(print_dict["batch_variance_parameters"], title="Batch Variance (parameter)")
print_for_excel(print_dict["batch_variance_eos"], title="Batch Variance (eos)")
#print_for_excel(print_dict["sequence_variance"], title="Sequence Variance")
print_for_excel(print_dict["sequence_variance_deep_svg"], title="Sequence Variance (deep_svg)")
print_for_excel(print_dict["sequence_variance_type"], title="Sequence Variance (type)")
print_for_excel(print_dict["sequence_variance_parameters"], title="Sequence Variance (parameter)")
print_for_excel(print_dict["sequence_variance_eos"], title="Sequence Variance (eos)")

## Training and Validation Loss Plot

In [None]:
# # Plot the training and validation loss
plt.figure(figsize=(10, 6))
plt.plot(train_loss_list, label='Training Loss')
plt.plot(validation_loss_list, label='Validation Loss')

# Add title and labels
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

# Add a legend
plt.legend()

# Show the plot
plt.show()

# Prediction

In [None]:
# Create a tensor of zeros with 270 elements
sos_token = torch.zeros(270)
# Set the value at the 256 index to 1
sos_token[269] = 1

In [None]:
examples = torch.load('data/prototype_dataset/examples.pt')

In [None]:
from AnimationTransformer import predict

predict(model, examples[30], sos_token=sos_token, device=device, max_length=10, eos_threshold=0.5)

In [None]:
import prototype_dataset_helper
from AnimationTransformer import create_pad_mask


def predict(model, source_sequence, sos_token: torch.Tensor, device, max_length=32, eos_threshold=0.5, silent=False, random=False):
    model.eval()

    source_sequence = source_sequence.float().to(device)
    y_input = torch.unsqueeze(sos_token, dim=0).float().to(device)

    i = 0
    while i < max_length:
        # Get source mask
        prediction = model(source_sequence.unsqueeze(0), y_input.unsqueeze(0),  # un-squeeze for batch
                           # tgt_mask=get_tgt_mask(y_input.size(0)).to(device),
                           src_key_padding_mask=create_pad_mask(source_sequence.unsqueeze(0)).to(device))

        next_embedding = prediction[0, -1, :]  # prediction on last token
        pred_deep_svg, pred_type, pred_parameters, pred_eos = prototype_dataset_helper.unpack_embedding(next_embedding,
                                                                                                        dim=0)
        pred_deep_svg, pred_type, pred_parameters, pred_eos = pred_deep_svg.to(device), pred_type.to(
            device), pred_parameters.to(
            device), pred_eos.to(device)

        # === SOFTMAX ===
        type_softmax = torch.softmax(pred_type, dim=0)
        animation_type = torch.argmax(type_softmax, dim=0)
        
        if random:
            animation_type = torch.multinomial(type_softmax, 1).item()

        eos_softmax = torch.softmax(pred_eos, dim=0)
        if not silent: print(f"EOS: {eos_softmax[1] * 100:.1f}%")

        # Break if EOS is most likely
        if eos_softmax[1] > eos_threshold:
            if not silent: print("END OF ANIMATION")
            y_input = torch.cat((y_input, sos_token.unsqueeze(0).to(device)), dim=0)
            return y_input

        pred_type = torch.zeros(6)
        pred_type[animation_type] = 1

        # === DEEP SVG ===
        # Find the closest path
        distances = [torch.norm(pred_deep_svg - embedding[:-14]) for embedding in source_sequence]
        closest_index = distances.index(min(distances))
        closest_token = source_sequence[closest_index]

        # === PARAMETERS ===
        # overwrite unused parameters
        for j in range(len(pred_parameters)):
            if j in prototype_dataset_helper.ANIMATION_PARAMETER_INDICES[int(animation_type)]:
                continue
            pred_parameters[j] = -1

        # === SEQUENCE ===
        y_new = torch.concat([closest_token[:-14],
                              pred_type.to(device),
                              pred_parameters,
                              torch.tensor([1, 0]).to(device)],
                             dim=0)
        y_input = torch.cat((y_input, y_new.unsqueeze(0)), dim=0)

        # === INFO PRINT ===
        if not silent:
            print(f"{int(y_input.size(0))}: Path {closest_index} ({round(float(distances[closest_index]), 3)}) "
                  f"got animation {animation_type} ({round(float(type_softmax[animation_type]), 3)}%) "
                  f"with parameters {[round(num, 2) for num in pred_parameters.tolist()]}")

        i += 1

    return y_input

In [None]:
def get_predicting_sequence_length(model, input_sequences, sos_token, device, max_length=10, eos_threshold=0.5):
    total_animations = 0
    animation_types_all = torch.zeros(9, 6).to(device)
    animation_types_div = torch.zeros(6).to(device)
    for i in range(len(input_sequences)):
        sequence = predict(model, input_sequences[i], sos_token=sos_token, device=device, max_length=max_length, silent=True, eos_threshold=eos_threshold, random=True)
        animation_types_all += sequence[... ,-14:-8]
        animation_types_div += torch.ones(6).to(device) * (sequence[... ,-14:-8].sum(dim=0) != 0)
        total_animations += sum(sequence[:, -2])
        
        print(f"{i}: {sum(sequence[:, -2])} Animations")
    
    print(f"In average {total_animations / len(input_sequences)} per sequence with eos_threshold {eos_threshold}") 
    print(animation_types_all)
    print(animation_types_div)
get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, eos_threshold=0.99999, max_length=8)

In [None]:
# get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, eos_threshold=0.6)
get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, eos_threshold=0.7)
# get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, eos_threshold=0.8)
# get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, eos_threshold=0.9)

In [None]:
# get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, max_length=1, eos_threshold=0.99999)
# get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, max_length=2, eos_threshold=0.99999)
# get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, max_length=3, eos_threshold=0.99999)
# get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, max_length=4, eos_threshold=0.99999)
# get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, max_length=5, eos_threshold=0.99999)
# get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, max_length=6, eos_threshold=0.99999)
# get_predicting_sequence_length(model, examples, device=device, sos_token=sos_token, max_length=7, eos_threshold=0.99999)
get_predicting_sequence_length(model, examples[:2], device=device, sos_token=sos_token, max_length=8, eos_threshold=0.99999)

In [None]:
# Sequences with long input
test_sequence_input[220, :, 0]

In [None]:
# check sequence length / embedding
print(test_sequence_output[222, :, -14:])
print(test_sequence_input [222, :, -14:])

In [None]:
# torch.save(model.state_dict(), "data/prototype_transformer.pth")

# Hyperparameter Tuning

In [None]:
from AnimationTransformer import validation_loop, train_loop, creativity_loop
import optuna
from torch.utils.data import DataLoader

MAX_EPOCHS = 10

def objective(trial):
    # Define the hyperparameter search space
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-4, log=True)
    num_encoder_layers = trial.suggest_categorical('num_encoder_layers', [3, 4])
    num_decoder_layers = trial.suggest_categorical('num_decoder_layers', [5, 6])
    # batch_size = trial.suggest_categorical('batch_size', [32, 64, 128])
    # num_heads = trial.suggest_categorical('num_heads', [9, 15, 27, 45, 90])
    dropout = trial.suggest_float('dropout', 0.1, 0.3)
    #use_positional_encoder = trial.suggest_categorical('pos_encoder_max_len', [True, False])
    
    loss_weight_type = trial.suggest_float('loss_weight_type', 0.1, 100, log=True)
    loss_weight_param = trial.suggest_float('loss_weight_param', 0.1, 100, log=True)
    loss_weight_eos = trial.suggest_float('loss_weight_eos', 0.1, 10, log=True)
    
    print(f'Parameters selected')
    print(f'learning_rate; num_encoder_layers; num_decoder_layers; dropout; loss_weight_type; loss_weight_param; loss_weight_eos')
    print(f'{learning_rate}; {num_encoder_layers}; {num_decoder_layers}; {dropout}; {loss_weight_type}; {loss_weight_param}; {loss_weight_eos}')
   
    # Instantiate the model with suggested hyperparameters
    model = AnimationTransformer(
        dim_model=FEATURE_DIM,
        num_heads=NUM_HEADS,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        dropout_p=dropout,
        use_positional_encoder=False
    ).to(device)
        
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        
    loss_function_validation = CustomEmbeddingSliceLoss(weight_deep_svg=10,
                                                   weight_type=0.91,
                                                   weight_parameters=650,
                                                   weight_eos=2.7)
    loss_function_train = CustomEmbeddingSliceLoss(weight_deep_svg=10,
                                                   weight_type=loss_weight_type,
                                                   weight_parameters=loss_weight_param,
                                                   weight_eos=loss_weight_eos)
    
    train_loss_list = []
    validation_loss_list = []
    variance_list = []

    validation_loss = -1
    # Training loop with early stopping, validation, etc.
    for epoch in range(MAX_EPOCHS):
        print(f' =========== EPOCH {epoch} ===========')
        
        train_loss = train_loop(model, optimizer, loss_function_train, train_dataloader, device)
        train_loss_list += [train_loss]

        validation_loss = validation_loop(model, loss_function_validation, val_dataloader, device)
        validation_loss_list += [validation_loss]
        
        variance = creativity_loop(model, val_dataloader, device)
        variance_list.append(variance)
        
        print(f'Train Loss: {train_loss:.4f}, Validation Loss: {validation_loss:.4f}')
        
        # Report the validation loss to Optuna
        trial.report(validation_loss, epoch)
        
        # Implement early stopping logic
        if trial.should_prune():
            print(f"PRUNING IN EPOCH {epoch}")
            raise optuna.exceptions.TrialPruned()
    
    print(f'Best validation loss: {validation_loss}')
    print(f'loss_weight_type; loss_weight_param; loss_weight_eos')
    print(f'{loss_weight_type}; {loss_weight_param}; {loss_weight_eos}'.replace('.', ','))
    print_for_excel(train_loss_list, title="Train Loss")
    print_for_excel(validation_loss_list, title="Validation Loss")
    
    print_dict = dict_list_to_list_dict(variance_list)
    #print_for_excel(print_dict["batch_variance"], title="Batch Variance")
    print_for_excel(print_dict["batch_variance_deep_svg"], title="Batch Variance (deep_svg)")
    print_for_excel(print_dict["batch_variance_type"], title="Batch Variance (type)")
    print_for_excel(print_dict["batch_variance_parameters"], title="Batch Variance (parameter)")
    print_for_excel(print_dict["batch_variance_eos"], title="Batch Variance (eos)")
    #print_for_excel(print_dict["sequence_variance"], title="Sequence Variance")
    print_for_excel(print_dict["sequence_variance_deep_svg"], title="Sequence Variance (deep_svg)")
    print_for_excel(print_dict["sequence_variance_type"], title="Sequence Variance (type)")
    print_for_excel(print_dict["sequence_variance_parameters"], title="Sequence Variance (parameter)")
    print_for_excel(print_dict["sequence_variance_eos"], title="Sequence Variance (eos)")
    
    return validation_loss

Used:
- pick_and_animate_from_8     First Run
- pick_and_animate_from_8_v3  First Main Run
- pick_and_animate_from_8_loss_optimization
- pick_and_animate_from_8_loss_optimization_normalized
- pick_and_animate_from_8_lo_correction       corrected dataset

In [None]:
my_study = optuna.create_study(
    direction='minimize',
    study_name='pick_and_animate_from_8_lo_correction', # IMPORTANT: Chance Name when new Dataset
    storage='sqlite:///animate_svg_optuna.db',
    load_if_exists=True
)

In [None]:
my_study.optimize(objective, n_trials=200)

In [None]:
print("Best trial:")
trial = my_study.best_trial
print(f"  Value: {trial.value}")
print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

In [None]:
from optuna.visualization import plot_optimization_history, plot_param_importances

plot_optimization_history(my_study)

In [None]:
plot_param_importances(my_study)

In [None]:
from optuna.visualization import plot_slice

plot_slice(my_study)

In [None]:
from optuna.visualization import plot_timeline

plot_timeline(my_study)