In [None]:
import torch
import torch.nn as nn

from src.transformer import DecoderTransformer
from src.data import random_values, create_signals, sine, cosine_squared, create_batches
from src.train_eval import train_one_epoch, evaluate

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

MAX_SEQUENCE_LENGTH = 512
TIME_STEP = 0.5 / MAX_SEQUENCE_LENGTH
MIN_OMEGA = 85
MAX_OMEGA = 115
BATCH_SIZE = 128
TRAIN_COUNT = 64_000
EPOCHS = 10

assert 1 / TIME_STEP >= 2 * MAX_OMEGA, "sampling rate not high enough"
assert TRAIN_COUNT % BATCH_SIZE == 0, "batch size should divide train count"

transformer = DecoderTransformer(
    output_parameter_count=1,
    d_model=128,
    num_heads=16,
    num_layers=2,
    d_ff=512,
    max_seq_length=MAX_SEQUENCE_LENGTH,
    dropout=0.1
).to(device)

optim = torch.optim.Adam(transformer.parameters(), lr=0.001)
criterion = nn.L1Loss(reduction='sum')

In [None]:
only_100 = torch.ones(TRAIN_COUNT) * 100
for i in range(EPOCHS):
    # generate new data - can be anything - sine waves, more complicated waves, combinations with different phases / frequencies
    frequencies = random_values(TRAIN_COUNT, MIN_OMEGA, MAX_OMEGA)
    phases = random_values(TRAIN_COUNT, 0, 2 * torch.pi)
        
    sines = create_signals(
        omegas=frequencies,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        phases=phases
    )
    
    cosines = create_signals(
        omegas=only_100,
        signal_function=cosine_squared,
        length=MAX_SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        phases=phases
    )
    
    values = 0.7 * sines + 0.3 * cosines
    batched_values, batched_parameters = create_batches([values, frequencies.reshape(-1, 1)], BATCH_SIZE)

    # train an epoch
    train_one_epoch(transformer, optim, criterion, device, batched_values, batched_parameters)

In [None]:
frequencies = random_values(TRAIN_COUNT, MIN_OMEGA, MAX_OMEGA)
phases = random_values(TRAIN_COUNT, 0, 2 * torch.pi)
        
sines = create_signals(
    omegas=frequencies,
    signal_function=sine,
    length=MAX_SEQUENCE_LENGTH,
    time_step=TIME_STEP,
    phases=phases
)
    
cosines = create_signals(
    omegas=only_100,
    signal_function=cosine_squared,
    length=MAX_SEQUENCE_LENGTH,
    time_step=TIME_STEP,
    phases=phases
)
    
values = 0.7 * sines + 0.3 * cosines
batched_values, batched_parameters = create_batches([values, frequencies.reshape(-1, 1)], BATCH_SIZE)

eval_loss = evaluate(transformer, criterion, device, batched_values, batched_parameters)
print(f"Evaluation loss: {eval_loss}")