In [1]:
import os
import time

import torch
import shutil
import numpy as np
import torch.nn as nn

from pathlib import Path
import pytorch_lightning as pl
from dataclasses import dataclass
from pytorch_lightning import Trainer
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from fimodemix.data.datasets import (
    FIMSDEDataset,
    FIMSDEDatabatch
)

from fimodemix.utils.experiment_files import ExperimentsFiles
from fimodemix.models.blocks import (
    TimeEncoding,
    Mlp,
    TransformerModel
)

from typing import Any, Dict, Optional, Union, List,Tuple
from dataclasses import dataclass,asdict, field

# Utils

In [2]:
from fimodemix import results_path

class ExperimentsFiles:
    """
    """
    def __init__(self,experiment_dir=None,experiment_indentifier=None,delete=False):
        self.delete = delete
        self.define_experiment_folder(experiment_dir,experiment_indentifier)
        self.create_directories()

    def define_experiment_folder(self,experiment_dir=None,experiment_indentifier=None):
        if experiment_dir is None:
            results_dir = str(results_path)
            if experiment_indentifier is None:
                experiment_indentifier = str(int(time.time()))
            self.experiment_dir = os.path.join(results_dir, experiment_indentifier)        
        self.tensorboard_dir = os.path.join(self.experiment_dir, "logs")
        self.checkpoints_dir = os.path.join(self.experiment_dir, "checkpoints")
    
    def create_directories(self):
        if not Path(self.experiment_dir).exists():
            os.makedirs(self.experiment_dir)
        else:
            if self.delete:
                shutil.rmtree(self.experiment_dir)
                os.makedirs(self.experiment_dir)
            else:
                raise Exception("Folder Exist no Experiments Created Set Delete to True")
            
        if not os.path.isdir(self.tensorboard_dir):
            os.makedirs(self.tensorboard_dir)
        
        if not os.path.isdir(self.checkpoints_dir):
            os.makedirs(self.checkpoints_dir)

## Inference/Sample Model

# Model

## Parameters

In [3]:
@dataclass
class FIMSDEModelParams:
    # data
    input_size: int = 1  # Original input size

    max_dimension:int = 1
    max_hypercube_size:int = 1
    max_num_steps:int = 1

    # model architecture
    dim_time:int = 19

    # phi_0 / first data encoding
    x0_hidden_layers: List[int] = field(default_factory=lambda:[50,50])
    x0_out_features: int = 21
    x0_dropout: float = 0.2

    encoding0_dim:int = 40 #  x0_out_features + dim_time

    #psi_1 / first transformer
    psi1_nhead:int = 2
    psi1_hidden_dim:int = 300
    psi1_nlayers:int = 2

    #Multiheaded Attention 1 / first path summary
    query_dim:int = 10

    n_heads: int = 4
    hidden_dim: int = 64
    output_size: int = 1
    batch_size: int = 32
    seq_length: int = 10

    # training
    num_epochs: int = 10
    learning_rate: float = 0.001
    embed_dim: int = 8  # New embedding dimension

    def __post__init__(self):
        self.encoding0_dim = self.x0_out_features + self.dim_time

In [4]:
params = FIMSDEModelParams()

## Architecture

In [5]:
# 1. Define your query generation model (a simple linear layer can work)
class QueryGenerator(nn.Module):
    def __init__(self, input_dim, query_dim):
        super(QueryGenerator, self).__init__()
        self.linear = nn.Linear(input_dim, query_dim)

    def forward(self, x):
        return self.linear(x)
    
# 2. Define a static query matrix as a learnable parameter
class StaticQuery(nn.Module):
    def __init__(self, num_steps, query_dim):
        super(StaticQuery, self).__init__()
        self.queries = nn.Parameter(torch.randn(num_steps, query_dim))  # Learnable queries

    def forward(self):
        return self.queries

# 3. Model Following FIM conventions
class FIMSDE_p(pl.LightningModule):
    """
    This is the more simple architecture for 

    Stochastic Differential Equation Trainining

    """
    def __init__(
            self, 
            params: str|FIMSDEModelParams,
            device:torch.device = None
        ):
        super(FIMSDE_p, self).__init__()
        self._create_model(params)
        if device is not None:
            self.to(device)

    def _create_model(
        self,
        params: dict | FIMSDEModelParams,
    ):
        # Architecture ---------
        self.phi_t0 = TimeEncoding(params.dim_time)

        self.phi_x0 = Mlp(in_features=params.max_dimension,
                          out_features=params.x0_out_features,
                          hidden_layers=params.x0_hidden_layers,
                          output_act=nn.SiLU())

        self.phi_1 = Mlp(in_features=params.max_dimension,
                         out_features=params.max_dimension,
                         hidden_layers=params.x0_hidden_layers)

        self.phi_2 = Mlp(in_features=params.encoding0_dim,
                         out_features=params.max_dimension,
                         hidden_layers=params.x0_hidden_layers)

        self.psi1 = TransformerModel(input_dim=params.encoding0_dim, 
                                     nhead=params.psi1_nhead, 
                                     hidden_dim=params.psi1_hidden_dim, 
                                     nlayers=params.psi1_nlayers)
        
        #self.queries = nn.Parameter(torch.randn(1, params.encoding0_dim))
        self.query_1x = QueryGenerator(input_dim=params.max_dimension,
                                       query_dim=params.encoding0_dim)

        self.query_1 =  StaticQuery(num_steps=params.max_num_steps,
                            query_dim=params.encoding0_dim)

        # Create the MultiheadAttention module
        self.omega_1 = nn.MultiheadAttention(params.encoding0_dim, params.psi1_nhead)

        # Loss ------------------
        self.criterion = nn.MSELoss()


    def forward(
            self, 
            hypercube_locations:torch.tensor, 
            obs_values:torch.tensor, 
            obs_times:torch.tensor,
            observation_mask:torch.tensor,
            training:bool=True,
            ) -> Tuple[torch.tensor,torch.tensor]:
        """
        Args:
            
            hypercube_locations [B, H, D] observation values. optionally with noise.
            obs_values [B, T, D] observation times
            obs_times [B, T, D] 
            observation_mask, dtype: bool (0: value is observed, 1: value is masked out)
            training (bool): flag indicating if model is in training mode. Has an impact on the output.
            
            with B: batch size, T: number of observation times, D: dimensionsm, H: number of fine grid points (locations)

        Returns:
            if training:
                dict: losses
            else:
                dict: losses (if target drift is provided), metrics, visualizations data
        """
        batch_size = obs_times.size(0)
        num_steps = obs_times.size(1)
        dimensions = obs_values.size(2)
        num_hyper = hypercube_locations.size(1)

        # Encoding Paths -----------------
        time_encoding_ = self.phi_t0(obs_times.reshape(batch_size*num_steps,-1)) #(batch_size*num_steps,dim_time)
        x_enconding = self.phi_x0(obs_values.reshape(batch_size*num_steps,-1)) #(batch_size*num_steps,x0_out_features)
        H = torch.cat([time_encoding_,x_enconding],dim=1) #(batch_size*num_steps,encoding0_dim)
        H  = H.reshape(batch_size,num_steps,params.encoding0_dim) 
        H = self.psi1(torch.transpose(H,0,1)) # (seq_lenght,batch_size,encoding0_dim)

        # Trunk Queries ------------------
        hypercube_locations = hypercube_locations.reshape(batch_size*num_hyper,dimensions)
        tx = self.query_1x(hypercube_locations)  # Shape: (batch_size*num_steps, encoding0_dim)
        # Reshape queries to match the attention requirements
        tx = tx.reshape(num_hyper, batch_size, params.encoding0_dim)  # Shape: (num_hyper, batch_size, encoding0_dim)

        # Representation per path
        # attn_output, _ = multihead_attn(queries[:,None,:].repeat(1,batch_size,1), H, H) # Shape: (1, batch_size, query_dim)
        attn_output, _ = self.omega_1(tx, H, H) # Shape: (num_hyper, batch_size, query_dim)
        attn_output = torch.transpose(attn_output,1,0) # Shape: (num_hyper, batch_size, query_dim)
        attn_output = attn_output.reshape(num_hyper*batch_size,params.encoding0_dim)

        # obtain all heads
        f_hat = self.phi_2(attn_output).reshape(batch_size,num_hyper,dimensions)
        
        return f_hat

    def loss(
            self,
            f_hat:torch.tensor = None,
            g_hat:torch.tensor = None,
            f_var_hat:torch.tensor = None,
            g_var_hat:torch.tensor = None,
            mask:torch.tensor = None,
            drift_at_hypercube:torch.tensor = None,
            diffusion_at_hypercube:torch.tensor = None,
        ):
        """
        obs_values, obs_times, diffusion_at_hypercube, drift_at_hypercube, hypercube_locations, mask
        Compute the loss of the FIMODE_mix model (in original space).

        The loss consists of supervised losses
            - negative log-likelihood of the vector field values at fine grid points
            - negative log-likelihood of the initial condition
        and an unsupervised loss
            - one-step ahead prediction loss.
        The total loss is a weighted sum of all losses. The weights are defined in the loss_configs. (loss_scale_drift, loss_scale_init_cond, loss_scale_unsuperv_loss)

        Args:
            f_hat (tuple): mean and log standard deviation of the vector field concepts (in original space) ([B, L, D], [B, L, D])

        Returns:
            dict: llh_drift, llh_init_cond, unsupervised_loss, loss = weighted sum of all losses
        """
        llh_drift = self.criterion(f_hat,drift_at_hypercube)
        return llh_drift
    
    def training_step(
            self, 
            batch, 
            batch_idx
        ):
        obs_values, obs_times, diffusion_at_hypercube, drift_at_hypercube, hypercube_locations, mask = batch
        f_hats = self.forward(hypercube_locations=hypercube_locations,
                               obs_values=obs_values,
                               obs_times=obs_times,
                               observation_mask=mask,
                               training=True)
        loss = self.loss(f_hats, 
                         drift_at_hypercube=drift_at_hypercube)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

# Training Code

## Set up

In [6]:
# experiment
experiment_files = ExperimentsFiles(experiment_indentifier="test",delete=True)

# Create dataset and DataLoader using ModelParams
params = FIMSDEModelParams(seq_length=10, 
                           batch_size=32,
                           num_epochs=2)

# Define Data Set
dataset = FIMSDEDataset(params=params)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=24, shuffle=True)
databatch = next(data_loader.__iter__())
obs_values, obs_times, diffusion_at_hypercube, drift_at_hypercube, hypercube_locations, mask = databatch

batch_size = obs_times.size(0)
num_steps = obs_times.size(1)
dimensions = obs_values.size(2)
num_hyper = hypercube_locations.size(1)

Max Hypercube Size: 1024
Max Dimension: 3
Max Num Steps: 129


  data: FIMSDEDatabatch = torch.load(file_path)  # Adjust loading method as necessary


In [7]:
# Set up TensorBoard logger
logger = TensorBoardLogger(experiment_files.tensorboard_dir, 
                           name="time_series_transformer")

# Set up Model Checkpointing
checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',
    dirpath=experiment_files.checkpoints_dir,
    filename='best-checkpoint',
    save_top_k=1,
    mode='min',
    save_weights_only=True,
    every_n_train_steps=100  # Save checkpoint every 100 training steps
)

## Training

In [8]:
# Instantiate the model and train
model = FIMSDE_p(
    params
)

f_hats = model(hypercube_locations, 
            obs_values, 
            obs_times,
            mask)
f_hats.shape



torch.Size([24, 1024, 3])

In [9]:
model.loss(f_hats, drift_at_hypercube=drift_at_hypercube)

tensor(151414.3281, grad_fn=<MseLossBackward0>)

In [10]:
trainer = Trainer(
    max_epochs=params.num_epochs,
    logger=logger,
    callbacks=[checkpoint_callback]
)

trainer.fit(model, data_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | phi_t0    | TimeEncoding       | 38     | train
1 | phi_x0    | Mlp                | 3.8 K  | train
2 | phi_1     | Mlp                | 2.9 K  | train
3 | phi_2     | Mlp                | 4.8 K  | train
4 | psi1      | TransformerModel   | 93.2 K | train
5 | query_1x  | QueryGenerator     | 160    | train
6 | query_1   | StaticQuery        | 5.2 K  

Epoch 0:   2%|▏         | 8/417 [00:02<02:11,  3.12it/s, v_num=0]

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Epoch 1: 100%|██████████| 417/417 [00:07<00:00, 52.63it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 417/417 [00:07<00:00, 52.63it/s, v_num=0]
