In [None]:
import torch
import torch.nn as nn
from src.models import DecoderTransformer, count_parameters
from src.data import random_values, create_signals, sine, create_batches, add_noise
from src.train import train_one_epoch, evaluate

In [None]:
MAX_SEQUENCE_LENGTH = 512

def create_model(device):
    return DecoderTransformer(
        output_parameter_count=1,
        d_model=256,
        num_heads=16,
        num_layers=3,
        d_ff=512,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        dropout=0.2
    ).to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
transformer = create_model(device)
optim = torch.optim.Adam(transformer.parameters(), lr=0.0001)
criterion = nn.MSELoss()
count_parameters(transformer)

cpu


1581569

In [None]:
TIME_STEP = 0.5 / MAX_SEQUENCE_LENGTH
MIN_OMEGA = 80
MAX_OMEGA = 120
BATCH_SIZE = 128
NOISE = 0.025

def create_data(count: int):
    frequencies_1 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    frequencies_2 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    frequencies_3 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    
    phases_1 = random_values(count, 0, 2 * torch.pi)
    phases_2 = random_values(count, 0, 2 * torch.pi)
    phases_3 = random_values(count, 0, 2 * torch.pi)
            
    sine_1 = create_signals(
        omegas=frequencies_1,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH + 1,
        time_step=TIME_STEP,
        phases=phases_1
    )
        
    sine_2 = create_signals(
        omegas=frequencies_2,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH + 1,
        time_step=TIME_STEP,
        phases=phases_2
    )

    sine_3 = create_signals(
        omegas=frequencies_3,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH + 1,
        time_step=TIME_STEP,
        phases=phases_3
    )
    
    signals = (sine_1 + sine_2 + sine_3) / 3
    values, next_value = signals.split((MAX_SEQUENCE_LENGTH, 1), dim=1)
    values = add_noise(values, NOISE)
    return create_batches([values, next_value], BATCH_SIZE)

In [None]:
EPOCHS = 50
TRAIN_COUNT = 64_000
EVAL_COUNT = 6_400

eval_batched_values, eval_batched_parameters = create_data(EVAL_COUNT)

for i in range(EPOCHS):
    batched_values, batched_parameters = create_data(TRAIN_COUNT)
    train_one_epoch(transformer, optim, criterion, device, batched_values, batched_parameters)
    torch.save(transformer.state_dict(), f"{i+1}.pt")
    evaluation = evaluate(transformer, criterion, device, eval_batched_values, eval_batched_parameters)
    print(f"Epoch {i + 1}: {evaluation}")