In [1]:
import torch
import numpy as np
import torch.nn as nn
from pprint import pprint
from dataclasses import dataclass,asdict, field
from pytorch_lightning import Trainer

from typing import List
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from fimodemix.data.datasets import (
    FIMSDEDataset,
    FIMSDEDatabatch
)

from fimodemix.utils.experiment_files import ExperimentsFiles
from fimodemix.models.blocks import (
    TimeEncoding,
    Mlp,
    TransformerModel
)

@dataclass
class ModelParams:
    # data
    input_size: int = 1  # Original input size

    max_dimension:int = 1
    max_hypercube_size:int = 1
    max_num_steps:int = 1

    # model architecture
    dim_time:int = 19

    # phi_0 / first data encoding
    x0_hidden_layers: List[int] = field(default_factory=lambda:[50,50])
    x0_out_features: int = 21
    x0_dropout: float = 0.2

    encoding0_dim:int = 40 #  x0_out_features + dim_time

    #psi_1 / first transformer
    psi1_nhead:int = 2
    psi1_hidden_dim:int = 300
    psi1_nlayers:int = 2

    #Multiheaded Attention 1 / first path summary
    query_dim:int = 10

    n_heads: int = 4
    hidden_dim: int = 64
    output_size: int = 1
    batch_size: int = 32
    seq_length: int = 10

    # training
    num_epochs: int = 10
    learning_rate: float = 0.001
    embed_dim: int = 8  # New embedding dimension

    def __post__init__(self):
        self.encoding0_dim = self.x0_out_features + self.dim_time

In [2]:
class TimeSeriesTransformer(pl.LightningModule):
    def __init__(self, input_size, n_heads, hidden_dim, output_size, embed_dim):
        super(TimeSeriesTransformer, self).__init__()

        self.embedding = nn.Linear(input_size, embed_dim)  # Transform input to embed_dim
        assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"
        
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=n_heads)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_size)
        self.relu = nn.ReLU()
        self.criterion = nn.MSELoss()

    def forward(self, x):
        x = self.embedding(x)  # Transform to (batch_size, seq_length, embed_dim)
        x = x.permute(1, 0, 2)  # Change to (seq_length, batch_size, embed_dim)
        
        attn_output, _ = self.attention(x, x, x)
        x = attn_output.permute(1, 0, 2)  # Back to (batch_size, seq_length, embed_dim)
        x = x.mean(dim=1)  # Global average pooling

        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self.forward(inputs)
        loss = self.criterion(outputs, targets)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [3]:
# 1. Define your query generation model (a simple linear layer can work)
class QueryGenerator(nn.Module):
    def __init__(self, input_dim, query_dim):
        super(QueryGenerator, self).__init__()
        self.linear = nn.Linear(input_dim, query_dim)

    def forward(self, x):
        return self.linear(x)
    
# 1. Define a static query matrix as a learnable parameter
class StaticQuery(nn.Module):
    def __init__(self, num_steps, query_dim):
        super(StaticQuery, self).__init__()
        self.queries = nn.Parameter(torch.randn(num_steps, query_dim))  # Learnable queries

    def forward(self):
        return self.queries
    
# 5. Apply multi-headed attention
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads)

    def forward(self, query, key_value):
        # Use the same keys and values
        attn_output, _ = self.attention(query, key_value, key_value)
        return attn_output

# Get Data

In [4]:
params = ModelParams()
# Example usage:
dataset = FIMSDEDataset(params=params)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=24, shuffle=True)
databatch = next(data_loader.__iter__())
obs_values, obs_times, diffusion_at_hypercube, drift_at_hypercube, hypercube_locations, mask = databatch

Max Hypercube Size: 1024
Max Dimension: 3
Max Num Steps: 129


  data: FIMSDEDatabatch = torch.load(file_path)  # Adjust loading method as necessary


# Test Stuff

In [8]:
batch_size = obs_times.size(0)
num_steps = obs_times.size(1)
dimensions = obs_values.size(2)
num_hyper = hypercube_locations.size(1)

In [9]:
phi_t0 = TimeEncoding(params.dim_time)

phi_x0 = Mlp(in_features=params.max_dimension,
             out_features=params.x0_out_features,
             hidden_layers=params.x0_hidden_layers,
             output_act=nn.SiLU())

phi_1 = Mlp(in_features=params.max_dimension,
            out_features=params.max_dimension,
            hidden_layers=params.x0_hidden_layers)

phi_2 = Mlp(in_features=params.encoding0_dim,
            out_features=params.max_dimension,
            hidden_layers=params.x0_hidden_layers)

psi1 = TransformerModel(input_dim=params.encoding0_dim, 
                        nhead=params.psi1_nhead, 
                        hidden_dim=params.psi1_hidden_dim, 
                        nlayers=params.psi1_nlayers)


queries = nn.Parameter(torch.randn(1, params.encoding0_dim))

query_1x = QueryGenerator(input_dim=params.max_dimension,
                          query_dim=params.encoding0_dim)

query_1 =  StaticQuery(num_steps=params.max_num_steps,
                       query_dim=params.encoding0_dim)

# Create the MultiheadAttention module
omega_1 = nn.MultiheadAttention(params.encoding0_dim, params.psi1_nhead)



\begin{equation}
\left| G'(u)(y) - \sum_{k=1}^{p} \sum_{i=1}^{n} c_{i}^{k} \sigma \left( \sum_{j=1}^{m} \epsilon_{ij}^{k} u(x_{j}) + \theta_{i}^{k} \right) \sigma(u_{k} \cdot y + \zeta_{k}) \right| < \epsilon
\end{equation}

In [10]:
# G -----------------------------------

time_encoding_ = phi_t0(obs_times.reshape(batch_size*num_steps,-1)) #(batch_size*num_steps,dim_time)
x_enconding = phi_x0(obs_values.reshape(batch_size*num_steps,-1)) #(batch_size*num_steps,x0_out_features)
H = torch.cat([time_encoding_,x_enconding],dim=1) #(batch_size*num_steps,encoding0_dim)
H  = H.reshape(batch_size,num_steps,params.encoding0_dim) 
H = psi1(torch.transpose(H,0,1)) # (seq_lenght,batch_size,encoding0_dim)

# Reshape obs_values as needed to create queries
obs_values_reshaped = obs_values.reshape(batch_size * num_steps, -1)

# Trunk Queries ------------------
hypercube_locations = hypercube_locations.reshape(batch_size*num_hyper,dimensions)
tx = query_1x(hypercube_locations)  # Shape: (batch_size*num_steps, encoding0_dim)

# Reshape queries to match the attention requirements
tx = tx.reshape(num_hyper, batch_size, params.encoding0_dim)  # Shape: (num_hyper, batch_size, encoding0_dim)

In [11]:
# Representation per path
# attn_output, _ = multihead_attn(queries[:,None,:].repeat(1,batch_size,1), H, H) # Shape: (1, batch_size, query_dim)
attn_output, _ = omega_1(tx, H, H) # Shape: (num_hyper, batch_size, query_dim)
attn_output = torch.transpose(attn_output,1,0) # Shape: (num_hyper, batch_size, query_dim)
attn_output = attn_output.reshape(num_hyper*batch_size,params.encoding0_dim)

f_hat = phi_2(attn_output).reshape(batch_size,num_hyper,dimensions)

In [12]:
drift_at_hypercube.shape

torch.Size([24, 1024, 3])

In [13]:
f_hat.shape

torch.Size([24, 1024, 3])

\textbf{Architecture}:
\begin{enumerate}[(i)]    \textit{Spatial embedding} $\phi^s_0$. Currently $\phi^s_0$ is a MLP with silu activation. 
    \item 
    
    \item \textit{Temporal embedding} $\phi^t_0$. We use the time embedding of~\citet{shukla2020multitime}.

    \item \textit{Trunk net equivalent}. It's given by an MLP $\phi_1$ which takes as input the embedded evaluation point. We denote it with
    $$
    \mathbf{t}(\mathbf{x}) = \phi_1(\phi^s_0(\mathbf{x}))
    $$    
    
    \item \textit{Embedded input}. Let us denote the $i$th element of the $k$th time series in our input time series with 
    %
    $$
        \mathbf{u}_{ki} = \text{Concat}(\phi^s_0(\mathbf{x}_{ki}, \Delta \mathbf{x}_{ki}, \Delta \mathbf{x}_{ki}^2, \theta), \phi^t_0(\tau_{ki}, \theta)).
    $$
    
    \item \textit{Sequence processing network} $\psi_1$. We process each path with a Transformer network $\psi_1$ as follows
    %
    $$\mathbf{h}_{k1}, \dots, \mathbf{h}_{kl} = \psi_1(\mathbf{u}_{k1}, \dots, \mathbf{u}_{kl}, \theta), \, \, \text{with} \, \, \mathbf{h}_{ki} \in \mathbb{R}^{d_{att}}$$ 

    \item \textit{Path embedding}. Let's denote the output sequence of vectors ($\mathbf{h}_{k1}, \dots, \mathbf{h}_{kl}$) for the $k$th path with the matrix of $H_k \in \mathbb{R}^{l \times d_{att}}$. We summarize each path with an attention network

    \begin{equation}    
    \mathbf{h}_k(\mathbf{x}) = \Omega_1(\mathbf{t}(\mathbf{x}), H_k, H_k),
    \label{eq:simple-attention-q}
    \end{equation}    
    where $\mathbf{t}(\mathbf{x}) \in \mathbb{R}^{1 \times d_{att}}$ is the output of Trunk net. 
    %
    In this way, we have a path embedding which depends on the location at which we want to evaluate the function.    
    Note that, as usual, the $i$th attention network is given by
    $$
    \Omega_i(Q, K, V) = \text{softmax} \left( d_{att}^{-1/2} Q \cdot K^T\right) \cdot V. 
    $$    

    \item \textit{Summary over paths}. We summarize the set of $K$ \textit{state-dependent} path embeddings with
    %
    $$
    \mathbf{b}(\mathbf{x}) = \Omega_2(\mathbf{q}, H(\mathbf{x}), H(\mathbf{x})),
    $$
    with $\mathbf{q} \in \mathbb{R}^{1 \times d_{att}}$ a learnable query, and $H(\mathbf{x}) = \mathbf{h}_1(\mathbf{x}), \dots, \mathbf{h}_K(\mathbf{x})$.
    
    \item \textit{Final layer}: We use a simple MLP $\phi_2$ to project the final embedding: 
    $$
    \mathbf{\hat f}(\mathbf{x}), \log \text{Var}(\mathbf{\hat f})(\mathbf{x}),  \mathbf{\hat g}(\mathbf{x}), \log \text{Var}(\mathbf{\hat g})(\mathbf{x}) = \phi_2(\mathbf{b}(\mathbf{x})).
    $$    
\end{enumerate}