In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import pandas as pd

In [None]:
def create_dataset():
    dataset_size = 20000
    time_horizon = 10
    features = 1
    dataset_x = np.random.randint(low=0, high=10, size=(dataset_size, time_horizon, features))
    dataset_y = dataset_x.sum(axis=1)  # sum the time dimension.
    return dataset_x, dataset_y

In [None]:
if torch.cuda.is_available() is True:
    device = 'cuda'
#elif torch.backends.mps.is_available():
#    device = 'mps'
else:
    device = 'cpu'
#device = 'cpu'

device = torch.device(device)

In [None]:
# this is a reference multi-head attention implementation. You should just use the one implemented in torch.
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, num_heads, dropout):
        super(MultiHeadAttention, self).__init__()

        self.input_dim = input_dim
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads

        self.query_projection = nn.Linear(input_dim, input_dim)
        self.key_projection = nn.Linear(input_dim, input_dim)
        self.value_projection = nn.Linear(input_dim, input_dim)

        self.out_projection = nn.Linear(input_dim, input_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size = x.shape[0]

        # Project queries, keys, and values
        q = self.query_projection(x)
        k = self.key_projection(x)
        v = self.value_projection(x)

        # Split into heads
        q = q.view(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, self.num_heads, -1, self.head_dim).transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.head_dim**0.5

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Apply softmax to get attention weights
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)

        # Apply attention weights to values
        x = torch.matmul(weights, v)

        # Concatenate heads and project back to the original dimension
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.input_dim)
        x = self.out_projection(x)

        return x, weights


In [None]:
class TransformerLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout):
        super(TransformerLayer, self).__init__()

        self.multi_head_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)

        self.feed_forward = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.ReLU(),
            nn.Linear(4 * hidden_dim, hidden_dim)
        )

        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(hidden_dim)
            for _ in range(2)
        ])

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x

        x = self.layer_norms[0](x)
        x, _ = self.multi_head_attention(x, x, x)
        x = self.dropout(x)
        x = residual + x

        residual = x

        x = self.layer_norms[1](x)
        x = self.feed_forward(x)
        x = self.dropout(x)
        x = residual + x

        return x

In [None]:
class Transformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, dropout, output_dim):
        super(Transformer, self).__init__()
        self.input_projection = nn.Linear(input_dim, hidden_dim)

        self.layers = nn.ModuleList([
            TransformerLayer(hidden_dim, num_heads, dropout)
            for _ in range(num_layers)
        ])

        self.output_projection = nn.Linear(hidden_dim, input_dim)

        self.final_proojection = nn.Linear(input_dim, output_dim)
        # this is added to reduce the time dimension down to 1 and is not a standard part of the model.
        # self.output_time_dimension_reduction = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.input_projection(x)
        for layer in self.layers:
            x = layer(x)
        x = self.output_projection(x)



        # alternatively
        #x = torch.sum(x, dim=1)
        #x = self.final_proojection(x)

        x = x[:, 0, :]
        x = self.final_proojection(x) # strictly speaking this is not needed.
        return x

In [None]:
class Trainer:
    def __init__(self):
        self.model = Transformer(dropout=0.1, hidden_dim=12, input_dim=1, num_heads=4, num_layers=2, output_dim=1)
        self.model.to(device)
        self.criterion = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.model.parameters())


    def train(self, x_mb, y_mb):
        x_mb = torch.from_numpy(x_mb.astype(np.float32)).to(device=device)
        y_hat = self.model(x_mb)
        y_mb = torch.tensor([y_mb]).float().to(device)
        loss = self.criterion(y_hat, y_mb)
        self.optimizer.zero_grad()
        loss.backward()
        clip = False
        if clip is True:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.1)
        self.optimizer.step()
        return loss

    def predict(self, x):
        x = x.astype(np.float32)
        if len(x.shape) < 3:
            x = np.expand_dims(x, 0)
        x = torch.from_numpy(x)
        y_hat = self.model(x)
        y_hat = y_hat.detach().cpu().numpy()
        return y_hat

In [None]:
def main():
    dataset_x, dataset_y = create_dataset()
    trainer = Trainer()
    batch_size = 32

    one_y_mb = trainer.predict(dataset_x[0:32, :, :])

    for i in range(10000):
        idxs = np.random.randint(len(dataset_x), size=batch_size)
        x_mb = dataset_x[idxs]
        y_mb = dataset_y[idxs]
        loss = trainer.train(x_mb, y_mb)
        loss = loss.detach().cpu().numpy()
        if i % 1000 == 0:
            print(F"For Epoch {i} loss is {loss}")
            print(F"Given vector {dataset_x[0, :, 0]} the solution is {dataset_y[0, 0]}")
            one_y = trainer.predict(dataset_x[0, :, :])[0][0]
            print(F"The predicted solution is: {one_y}")