In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as f
from math import log, exp
import math
import datetime
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("darkgrid")

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        # Ensure that the model dimension (d_model) is divisible by the number of heads
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        # Initialize dimensions
        self.d_model = d_model # Model's dimension
        self.num_heads = num_heads # Number of attention heads
        self.d_k = d_model // num_heads # Dimension of each head's key, query, and value

        # Linear layers for transforming inputs
        self.W_q = nn.Linear(d_model, d_model) # Query transformation
        self.W_k = nn.Linear(d_model, d_model) # Key transformation
        self.W_v = nn.Linear(d_model, d_model) # Value transformation
        self.W_o = nn.Linear(d_model, d_model) # Output transformation

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply mask if provided (useful for preventing attention to certain parts like padding)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)

        # Softmax is applied to obtain attention probabilities
        attn_probs = torch.softmax(attn_scores, dim=-1)

        # Multiply by values to obtain the final output
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        # Reshape the input to have num_heads for multi-head attention
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        # Combine the multiple heads back to original shape
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        # Apply linear transformations and split heads
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        # Perform scaled dot-product attention
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)

        # Combine heads and apply output transformation
        output = self.W_o(self.combine_heads(attn_output))
        return output

class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_output = self.self_attn(x, x, x)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x


def generate_mask(tgt):
    seq_length = tgt.size(1)
    no_peek_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length, device=tgt.device), diagonal=1)).bool()
    return no_peek_mask


class EncoderTransformer(nn.Module):
    def __init__(self, output_parameter_count, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(EncoderTransformer, self).__init__()
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, output_parameter_count)
        self.dropout = nn.Dropout(dropout)
        self.max_seq_length = max_seq_length

    def forward(self, src):
        src_embedded = self.dropout(self.positional_encoding(src.reshape(-1, self.max_seq_length, 1)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output)

        output = self.fc(enc_output[:, -1, :])
        return output

def train_one_epoch(model: nn.Module, optimizer: torch.optim.Optimizer, loss_function, device, batched_samples: torch.Tensor, batched_params: torch.Tensor):
    model.train()

    for samples, params in zip(batched_samples, batched_params):
        samples = samples.to(device)
        params = params.to(device)
        
        optimizer.zero_grad()
        predictions = model(samples)
        loss = loss_function(predictions, params)
        loss.backward()
        optimizer.step()

def train(model: nn.Module, optimizer: torch.optim.Optimizer, loss_function, device, batched_samples: torch.Tensor, batched_params: torch.Tensor, epochs: int, eval_step: int | None = None):
    start = datetime.datetime.now()
    for epoch in range(epochs):
        train_one_epoch(model, optimizer, loss_function, device, batched_samples, batched_params)
        if eval_step is not None and (epoch + 1) % eval_step == 0:
            evaluation = evaluate(model, loss_function, batched_samples, batched_params)
            print(f"Epoch {epoch + 1}: {evaluation} after {datetime.datetime.now() - start}")

def evaluate(model: nn.Module, loss_function, batched_eval_samples: torch.Tensor, batched_eval_params: torch.Tensor):
    model.eval()

    total_loss = 0.0
    for samples, params in zip(batched_eval_samples, batched_eval_params):
        samples = samples.to(device)
        params = params.to(device)
        
        with torch.no_grad():
            predictions = model(samples)
            loss = loss_function(predictions, params).item()
            total_loss += loss

    return total_loss / (batched_eval_samples.size(0) * batched_eval_samples.size(1))

def predict_next_values(model: nn.Module, values: torch.Tensor, sequence_length: int, count: int):
    assert values.size(0) == sequence_length, f"length of values should be {sequence_length}"

    device = next(model.parameters()).device
    values = values.to(device)
    
    all_values = torch.empty(sequence_length + count, device=device)
    all_values[:sequence_length] = values

    for i in range(count):
        current_values = all_values[i:i+sequence_length]
        with torch.no_grad():
            predicted_next_value = model(current_values.reshape(1, -1)).item()
        all_values[i + sequence_length] = predicted_next_value

    return all_values.cpu().split((sequence_length, count))

def random_values(count: int, minimum, maximum):
    return torch.rand(count) * (maximum - minimum) + minimum

def create_signals(omegas: torch.Tensor, signal_function, length: int, time_step: float, phases: torch.Tensor | None = None) -> torch.Tensor:
    count = omegas.size(0)
    waves = torch.empty(count, length)

    if phases is None:
        phases = torch.zeros(count)

    times = torch.arange(0, length) * time_step
    for i in range(count):
        wave = times * omegas[i] + phases[i]
        waves[i] = signal_function(wave)
    
    return waves

def sine(inputs: torch.Tensor):
    return inputs.sin()

def cosine_squared(inputs: torch.Tensor):
    return inputs.cos().square()

def sine_plus_cosine_squared(cosine_weight: float):
    def tensor_function(inputs: torch.Tensor):
        return sine(inputs) + cosine_weight * cosine_squared(inputs)
    
    return tensor_function

def lengthen_tensors(first: torch.Tensor, second: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
    first_length = first.size(0)
    second_length = second.size(0)

    first_lengthened = np.empty(first_length + second_length)
    first_lengthened[:first_length] = first
    first_lengthened[first_length:] = np.nan

    second_lengthened = np.empty(first_length + second_length)
    second_lengthened[:first_length] = np.nan
    second_lengthened[first_length:] = second

    return first_lengthened, second_lengthened

def plot_prediction(values: torch.Tensor, predicted_values: torch.Tensor, time_step: float):
    first, second = lengthen_tensors(values, predicted_values)
    time = np.arange(first.shape[0]) * time_step

    fig, ax = plt.subplots()
    ax.plot(time, first)
    ax.plot(time, second)
    return fig

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def create_batches(unbatched_list: list[torch.Tensor], batch_size: int):
    batch_count = unbatched_list[0].size(0) // batch_size
    result = []

    for unbatched in unbatched_list:
        result.append(unbatched.reshape(batch_count, batch_size, *unbatched.shape[1:]))
    
    return tuple(result)

def add_noise(values: torch.Tensor, noise_strength: float):
    noise = noise_strength * torch.randn_like(values)
    return values + noise

def remove_samples(values: torch.Tensor, missing_segment_length: int, missing_segment_count: int):
    sample_count = values.size(0)
    sample_length = values.size(1)
    new_values = values.detach().clone()

    random_indices = torch.randint(sample_length - missing_segment_length, (sample_count, missing_segment_count))

    for i in range(sample_count):
        for j in range(missing_segment_count):
            random_index = random_indices[i, j]
            new_values[i, random_index:random_index+missing_segment_length] = -1e9
    
    return new_values

In [2]:
MAX_SEQUENCE_LENGTH = 512

def create_model(device: str):
    return EncoderTransformer(
        output_parameter_count=MAX_SEQUENCE_LENGTH,
        d_model=256,
        num_heads=16,
        num_layers=3,
        d_ff=512,
        max_seq_length=MAX_SEQUENCE_LENGTH,
        dropout=0
    ).to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
transformer = create_model(device)
optim = torch.optim.Adam(transformer.parameters(), lr=0.0001)
criterion = nn.MSELoss(reduction="sum")
count_parameters(transformer)

cuda


1712896

In [3]:
TIME_STEP = 0.5 / MAX_SEQUENCE_LENGTH
MIN_OMEGA = 80
MAX_OMEGA = 120
HOLE_WIDTH = 50
HOLE_COUNT = 2
NOISE = 0.025

def create_data(count: int):
    frequencies_1 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    frequencies_2 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    frequencies_3 = random_values(count, MIN_OMEGA, MAX_OMEGA)
    
    phases_1 = random_values(count, 0, 2 * torch.pi)
    phases_2 = random_values(count, 0, 2 * torch.pi)
    phases_3 = random_values(count, 0, 2 * torch.pi)
            
    sine_1 = create_signals(
        omegas=frequencies_1,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        phases=phases_1
    )

    sine_2 = create_signals(
        omegas=frequencies_2,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        phases=phases_2
    )

    sine_3 = create_signals(
        omegas=frequencies_3,
        signal_function=sine,
        length=MAX_SEQUENCE_LENGTH,
        time_step=TIME_STEP,
        phases=phases_3
    )

    signals = (sine_1 + sine_2 + sine_3) / 3
    messy_signals = add_noise(signals, NOISE)
    messy_signals = remove_samples(messy_signals, HOLE_WIDTH, HOLE_COUNT)
    return messy_signals, signals

In [4]:
EPOCHS = 60
TRAIN_COUNT = 64_000
BATCH_SIZE = 128

for i in range(EPOCHS):
    values, parameters = create_data(TRAIN_COUNT)
    batched_values, batched_parameters = create_batches([values, parameters], BATCH_SIZE)
    train_one_epoch(transformer, optim, criterion, device, batched_values, batched_parameters)
    torch.save(transformer.state_dict(), f"{i+1}.pt")
    print(f"Epoch {i+1} finished")

Epoch 1 finished
Epoch 2 finished
Epoch 3 finished
Epoch 4 finished
Epoch 5 finished
Epoch 6 finished
Epoch 7 finished
Epoch 8 finished
Epoch 9 finished
Epoch 10 finished
Epoch 11 finished
Epoch 12 finished
Epoch 13 finished
Epoch 14 finished
Epoch 15 finished
Epoch 16 finished
Epoch 17 finished
Epoch 18 finished
Epoch 19 finished
Epoch 20 finished
Epoch 21 finished
Epoch 22 finished
Epoch 23 finished
Epoch 24 finished
Epoch 25 finished
Epoch 26 finished
Epoch 27 finished
Epoch 28 finished
Epoch 29 finished
Epoch 30 finished
Epoch 31 finished
Epoch 32 finished
Epoch 33 finished
Epoch 34 finished
Epoch 35 finished
Epoch 36 finished
Epoch 37 finished
Epoch 38 finished
Epoch 39 finished
Epoch 40 finished
Epoch 41 finished
Epoch 42 finished
Epoch 43 finished
Epoch 44 finished
Epoch 45 finished
Epoch 46 finished
Epoch 47 finished
Epoch 48 finished
Epoch 49 finished
Epoch 50 finished
Epoch 51 finished
Epoch 52 finished
Epoch 53 finished
Epoch 54 finished
Epoch 55 finished
Epoch 56 finished
E

In [5]:
EVAL_COUNT = 6_400

eval_values, eval_parameters = create_data(EVAL_COUNT)
eval_values_batched, eval_parameters_batched = create_batches([eval_values, eval_parameters], BATCH_SIZE)

for i in range(EPOCHS):
    model = create_model(device)
    model.load_state_dict(torch.load(f"{i+1}.pt"))
    evaluation = evaluate(model, criterion, eval_values_batched, eval_parameters_batched)
    print(f"Epoch {i + 1}: {evaluation}")

Epoch 1: 79.21369369506836
Epoch 2: 14.727992687225342
Epoch 3: 13.475112857818603
Epoch 4: 10.661662921905517
Epoch 5: 9.200533695220948
Epoch 6: 8.218801498413086
Epoch 7: 7.386648387908935
Epoch 8: 6.856926965713501
Epoch 9: 6.10939736366272
Epoch 10: 5.694736909866333
Epoch 11: 5.25014513015747
Epoch 12: 4.921695713996887
Epoch 13: 4.465054459571839
Epoch 14: 4.0948410320281985
Epoch 15: 3.930643048286438
Epoch 16: 3.7203274631500243
Epoch 17: 3.3511900520324707
Epoch 18: 3.354081754684448
Epoch 19: 3.0760944747924803
Epoch 20: 2.97194571018219
Epoch 21: 2.7260362100601196
Epoch 22: 2.7847671365737914
Epoch 23: 2.811449580192566
Epoch 24: 2.591637215614319
Epoch 25: 2.3846767020225523
Epoch 26: 2.39386549949646
Epoch 27: 2.198474397659302
Epoch 28: 2.170709316730499
Epoch 29: 2.0648192739486695
Epoch 30: 2.6998743295669554
Epoch 31: 2.1910663652420044
Epoch 32: 2.0210211205482485
Epoch 33: 2.0074137353897097
Epoch 34: 2.02333377122879
Epoch 35: 2.281569519042969
Epoch 36: 2.3574787