In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as f
from math import log, exp
import math
import datetime
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("darkgrid")

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value

        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x


def generate_mask(tgt):
    seq_length = tgt.size(1)
    no_peek_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length, device=tgt.device), diagonal=1)).bool()
    return no_peek_mask


class DecoderTransformer(nn.Module):
    def __init__(self, output_parameter_count, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(DecoderTransformer, self).__init__()
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, output_parameter_count)
        self.dropout = nn.Dropout(dropout)
        self.max_seq_length = max_seq_length

    def forward(self, tgt):
        tgt_mask = generate_mask(tgt)
        tgt_embedded = self.dropout(self.positional_encoding(tgt.reshape(-1, self.max_seq_length, 1)))

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, tgt_mask)

        output = self.fc(dec_output[:, -1, :])
        return output

def predict_next_values(model: nn.Module, values: torch.Tensor, sequence_length: int, count: int):
    assert values.size(0) == sequence_length, f"length of values should be {sequence_length}"

    device = next(model.parameters()).device
    values = values.to(device)
    
    all_values = torch.empty(sequence_length + count, device=device)
    all_values[:sequence_length] = values

    for i in range(count):
        current_values = all_values[i:i+sequence_length]
        with torch.no_grad():
            predicted_next_value = model(current_values.reshape(1, -1)).item()
        all_values[i + sequence_length] = predicted_next_value

    return all_values.cpu().split((sequence_length, count))

def random_values(count: int, minimum, maximum):
    return torch.rand(count) * (maximum - minimum) + minimum

def create_signals(omegas: torch.Tensor, signal_function, length: int, time_step: float, phases: torch.Tensor | None = None) -> torch.Tensor:
    count = omegas.size(0)
    waves = torch.empty(count, length)

    if phases is None:
        phases = torch.zeros(count)

    times = torch.arange(0, length) * time_step
    for i in range(count):
        wave = times * omegas[i] + phases[i]
        waves[i] = signal_function(wave)
    
    return waves

def sine(inputs: torch.Tensor):
    return inputs.sin()

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def create_batches(unbatched_list: list[torch.Tensor], batch_size: int):
    batch_count = unbatched_list[0].size(0) // batch_size
    result = []

    for unbatched in unbatched_list:
        result.append(unbatched.reshape(batch_count, batch_size, *unbatched.shape[1:]))
    
    return tuple(result)

def add_noise(values: torch.Tensor, noise_strength: float):
    noise = noise_strength * torch.randn_like(values)
    return values + noise

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

MAX_SEQUENCE_LENGTH = 512

def create_model(model_device):
    return DecoderTransformer(
        output_parameter_count=1,
        d_model=256,
        num_heads=16,
        num_layers=3,
        d_ff=512,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        dropout=0.2
    ).to(model_device)

In [None]:
TIME_STEP = 0.5 / MAX_SEQUENCE_LENGTH
MIN_OMEGA = 80
MAX_OMEGA = 120
NOISE = 0.025

def create_data(count: int):
    frequencies_1 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    frequencies_2 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    frequencies_3 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    
    phases_1 = random_values(count, 0, 2 * torch.pi)
    phases_2 = random_values(count, 0, 2 * torch.pi)
    phases_3 = random_values(count, 0, 2 * torch.pi)
            
    sine_1 = create_signals(
        omegas=frequencies_1,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH * 2,
        time_step=TIME_STEP,
        phases=phases_1
    )
        
    sine_2 = create_signals(
        omegas=frequencies_2,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH * 2,
        time_step=TIME_STEP,
        phases=phases_2
    )

    sine_3 = create_signals(
        omegas=frequencies_3,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH * 2,
        time_step=TIME_STEP,
        phases=phases_3
    )
    
    signals = (sine_1 + sine_2 + sine_3) / 3
    values, next_value = signals.split((MAX_SEQUENCE_LENGTH, MAX_SEQUENCE_LENGTH), dim=1)
    values = add_noise(values, NOISE)
    return values, next_value

In [None]:
def predict_next_values(model: nn.Module, values: torch.Tensor, sequence_length: int, count: int):
    assert values.size(0) == sequence_length, f"length of values should be {sequence_length}"

    device = next(model.parameters()).device
    values = values.to(device)
    
    all_values = torch.empty(sequence_length + count, device=device)
    all_values[:sequence_length] = values

    for i in range(count):
        current_values = all_values[i:i+sequence_length]
        with torch.no_grad():
            predicted_next_value = model(current_values.reshape(1, -1)).item()
        all_values[i + sequence_length] = predicted_next_value

    return all_values.cpu().split((sequence_length, count))[1]

In [None]:
best_model = create_model(device)
best_model.load_state_dict(torch.load("../models/next-value-predictor.pt"))

In [None]:
EVAL_COUNT = 1_024
values, next_values = create_data(EVAL_COUNT)

In [None]:
cmap = sns.color_palette("Paired")

In [None]:
index = 4
plt.plot(values[index], color=cmap[1], label="Wave")
plt.plot(torch.cat((torch.full_like(values[index], float("nan")), next_values[index])), color=cmap[0], label="Theoretical continuation")
plt.legend()
plt.xticks([], [])
plt.savefig("next_value_prediction_task.png")
plt.show()

In [None]:
predictions = torch.empty(EVAL_COUNT, MAX_SEQUENCE_LENGTH)
for i in range(EVAL_COUNT):
    predictions[i] = predict_next_values(best_model, values[i], MAX_SEQUENCE_LENGTH, MAX_SEQUENCE_LENGTH)

In [None]:
index = 4
plt.plot(values[index], color=cmap[1], label="Wave")
plt.plot(torch.cat((torch.full_like(values[index], float("nan")), next_values[index])), color=cmap[0], label="Theoretical continuation")
plt.plot(torch.cat((torch.full_like(values[index], float("nan")), predictions[index])), color=cmap[9], label="Predicted")
plt.legend()
plt.xticks([], [])
plt.savefig("next_value_prediction_example.png")
plt.show()

In [None]:
sequence_lengths = torch.tensor(range(1, MAX_SEQUENCE_LENGTH + 1))
diffs = (predictions - next_values).abs()
avg_loss = diffs.sum(dim=0) / EVAL_COUNT
rolling_avg_loss = torch.cumsum(avg_loss, dim=0) / sequence_lengths

In [None]:
plt.plot(sequence_lengths, avg_loss)
plt.title("Average L1 loss during the nth step")
plt.xlabel("Steps")
plt.ylabel("L1 loss")
plt.savefig("next_value_prediction_loss.png")
plt.show()

In [None]:
plt.plot(sequence_lengths, rolling_avg_loss)
plt.title("Average L1 loss after n steps")
plt.xlabel("Steps")
plt.ylabel("L1 loss")
plt.savefig("next_value_prediction_cum_loss.png")
plt.show()