In [None]:
import torch
from src.models import DecoderTransformer
from src.data import random_values, create_signals, sine, add_noise
from src.evaluation import predict_next_values
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("darkgrid")
cmap = sns.color_palette("Paired")
light_blue = cmap[0]
blue = cmap[1]
purple = cmap[9]

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

MAX_SEQUENCE_LENGTH = 512

def create_model(model_device):
    return DecoderTransformer(
        output_parameter_count=1,
        d_model=256,
        num_heads=16,
        num_layers=3,
        d_ff=512,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        dropout=0.2
    ).to(model_device)

In [None]:
TIME_STEP = 0.5 / MAX_SEQUENCE_LENGTH
MIN_OMEGA = 80
MAX_OMEGA = 120
NOISE = 0.025

def create_data(count: int):
    frequencies_1 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    frequencies_2 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    frequencies_3 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    
    phases_1 = random_values(count, 0, 2 * torch.pi)
    phases_2 = random_values(count, 0, 2 * torch.pi)
    phases_3 = random_values(count, 0, 2 * torch.pi)
            
    sine_1 = create_signals(
        omegas=frequencies_1,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH * 2,
        time_step=TIME_STEP,
        phases=phases_1
    )
        
    sine_2 = create_signals(
        omegas=frequencies_2,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH * 2,
        time_step=TIME_STEP,
        phases=phases_2
    )

    sine_3 = create_signals(
        omegas=frequencies_3,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH * 2,
        time_step=TIME_STEP,
        phases=phases_3
    )
    
    signals = (sine_1 + sine_2 + sine_3) / 3
    values, next_value = signals.split((MAX_SEQUENCE_LENGTH, MAX_SEQUENCE_LENGTH), dim=1)
    values = add_noise(values, NOISE)
    return values, next_value

In [None]:
best_model = create_model(device)
best_model.load_state_dict(torch.load("models/next-value-predictor.pt"))

In [None]:
EVAL_COUNT = 1_024
values, next_values = create_data(EVAL_COUNT)

In [None]:
index = 4
plt.plot(values[index], color=blue, label="Wave")
plt.plot(torch.cat((torch.full_like(values[index], float("nan")), next_values[index])), color=light_blue, label="Theoretical continuation")
plt.legend()
plt.xticks([], [])
plt.savefig("next_value_prediction_task.png")
plt.show()

In [None]:
predictions = torch.empty(EVAL_COUNT, MAX_SEQUENCE_LENGTH)
for i in range(EVAL_COUNT):
    predictions[i] = predict_next_values(best_model, values[i], MAX_SEQUENCE_LENGTH, MAX_SEQUENCE_LENGTH)

In [None]:
index = 4
plt.plot(values[index], color=blue, label="Wave")
plt.plot(torch.cat((torch.full_like(values[index], float("nan")), next_values[index])), color=light_blue, label="Theoretical continuation")
plt.plot(torch.cat((torch.full_like(values[index], float("nan")), predictions[index])), color=purple, label="Predicted")
plt.legend()
plt.xticks([], [])
plt.savefig("next_value_prediction_example.png")
plt.show()

In [None]:
sequence_lengths = torch.tensor(range(1, MAX_SEQUENCE_LENGTH + 1))
diffs = (predictions - next_values).abs()
avg_loss = diffs.sum(dim=0) / EVAL_COUNT
rolling_avg_loss = torch.cumsum(avg_loss, dim=0) / sequence_lengths

In [None]:
plt.plot(sequence_lengths, avg_loss)
plt.title("Average L1 loss during the nth step")
plt.xlabel("Steps")
plt.ylabel("L1 loss")
plt.savefig("next_value_prediction_loss.png")
plt.show()

In [None]:
plt.plot(sequence_lengths, rolling_avg_loss)
plt.title("Average L1 loss after n steps")
plt.xlabel("Steps")
plt.ylabel("L1 loss")
plt.savefig("next_value_prediction_cum_loss.png")
plt.show()