In [None]:
import torch
import torch.nn as nn
from src.models import DecoderTransformer, count_parameters
from src.data import random_values, create_signals, sine, cosine_squared, create_batches
from src.train import train_one_epoch, evaluate

In [None]:
MAX_SEQUENCE_LENGTH = 512

def make_model(device):
    return DecoderTransformer(
        output_parameter_count=1, # if we were predicting multiple parameters, this would increase
        d_model=128,
        num_heads=16,
        num_layers=2,
        d_ff=512,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        dropout=0.1
    ).to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
transformer = make_model(device)
optim = torch.optim.Adam(transformer.parameters(), lr=0.0001)
criterion = nn.L1Loss()
count_parameters(transformer)

In [None]:
TIME_STEP = 0.5 / MAX_SEQUENCE_LENGTH
MIN_OMEGA = 85
MAX_OMEGA = 115
BATCH_SIZE = 128

def create_data(count: int):
    frequencies = random_values(count, MIN_OMEGA, MAX_OMEGA)
    phases = random_values(count, 0, 2 * torch.pi)
        
    sines = create_signals(
        omegas=frequencies,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        phases=phases
    )
    
    only_100 = torch.ones(count) * 100
    cosines = create_signals(
        omegas=only_100,
        signal_function=cosine_squared,
        length=MAX_SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        phases=phases
    )
    
    values = 0.7 * sines + 0.3 * cosines # weighted average to not complicate the signal too much (the model is quite small) - the cosine part acts like noise
    return create_batches([values, frequencies.reshape(-1, 1)], BATCH_SIZE)

In [None]:
EPOCHS = 10
TRAIN_COUNT = 64_000
EVAL_COUNT = 6_400

eval_batched_values, eval_batched_parameters = create_data(EVAL_COUNT)

for i in range(EPOCHS):
    batched_values, batched_parameters = create_data(TRAIN_COUNT)
    train_one_epoch(transformer, optim, criterion, device, batched_values, batched_parameters)
    torch.save(transformer.state_dict(), f"{i+1}.pt")
    evaluation = evaluate(transformer, criterion, device, eval_batched_values, eval_batched_parameters)
    print(f"Epoch {i + 1}: {evaluation}")