In [15]:
import os
import yaml
import torch

from fimodemix.utils.experiment_files import ExperimentsFiles

from fimodemix.configs.config_classes.fim_sde_config import FIMSDEpModelParams
from fimodemix.data.datasets import FIMSDEpDatabatchTuple
from fimodemix.data.dataloaders import FIMSDEpDataLoader
from fimodemix.models.fim_sdep import FIMSDEp
from fimodemix.utils.grids import define_mesh_points,define_mask

from fimodemix.data.generation_sde import constant_diffusion

# Read Experiment

In [2]:
# define the experiment dir where everything is located
experiment_dir = r"C:\Users\cesar\Desktop\Projects\FoundationModels\fimodemix\results\1729141498"
experiment_files = ExperimentsFiles(experiment_dir,delete=False)
checkpoint_path = experiment_files.get_lightning_checkpoint_path("best")

In [4]:
# load parameters and model
params = FIMSDEpModelParams.from_yaml(experiment_files.params_yaml)
model = FIMSDEp.load_from_checkpoint(checkpoint_path).to(torch.device("cpu"))
dataloaders = FIMSDEpDataLoader(params)
databatch = dataloaders.one_batch
X = databatch.hypercube_locations[:,0,:]

  data: FIMSDEpDatabatch = torch.load(file_path)  # Adjust loading method as necessary


Max Hypercube Size: 1024
Max Dimension: 3
Max Num Steps: 129
Max Hypercube Size: 1024
Max Dimension: 3
Max Num Steps: 129
Max Hypercube Size: 1024
Max Dimension: 3
Max Num Steps: 129


In [13]:

#mask = define_mask(X,databatch)
#grid_d = define_mesh_points(total_points=num_grid_points,n_dims=2)
#all_grid = torch.zeros(X.size(0),1024,)

torch.Size([32, 1024, 3])

# Solver

In [37]:
from tqdm import tqdm  # Import tqdm for the progress bar
from torch import Tensor
from typing import Tuple
from fimodemix.data.generation_sde import constant_diffusion

def model_as_drift_n_diffusion(
        model:FIMSDEp,
        X:Tensor,
        time:Tensor=None,
        databatch:FIMSDEpDatabatchTuple=None
    )->Tuple[Tensor,Tensor]:
    """
    Defines the drift and the diffusion from the forward pass
    and handles the padding accordingly 

    Args:
        X (Tensor[B,D]): state 
        time: (None)
        databatch (FIMSDEpDatabatchTuple):
    Returns:
        drift,diffusion
    """
    D = X.size(1)
    B = X.size(0)
    X = X.unsqueeze(1) 
    drift = model(databatch,X).squeeze()
    diffusion = constant_diffusion(X.squeeze(),None,databatch.diffusion_parameters)
    # Create a mask based on the dimensions
    mask = torch.arange(D, device=X.device).expand(B, -1) < databatch.process_dimension  # Shape [B, D]
    # Apply the mask to X
    drift = drift * mask.float()  # Zero out elements where mask is False
    diffusion = diffusion * mask.float()  # Zero out elements where mask is False
    return drift,diffusion

def model_euler_maruyama_step(states,dt,model:FIMSDEp,databatch:FIMSDEpDatabatchTuple):
    """
    Assumes diagonal diffusion 
     
    Args:
        states (Tensor[B,D])
        dt (float)
        model (FIMSDEp)
        databatch (databatch)
    Returns:
        new_states(Tensor[B,D])
    """
    # Calculate the deterministic part
    drift,diffusion = model_as_drift_n_diffusion(model,states,None,databatch)
    # Update the state with the deterministic part
    new_states = states + drift * dt
    # Add the diffusion part
    new_states += diffusion * torch.sqrt(torch.tensor(dt)) * torch.randn_like(states)
    return new_states

def model_euler_maruyama_loop(
        num_steps: int = 100,
        dt: float = 0.01,
        model: FIMSDEp = None,
        databatch: FIMSDEpDatabatchTuple = None,
):
    """
    Simulates paths from the Model using the Euler-Maruyama method.

    This is highly costly as the method needs to calculate a forward pass 
    per Euler Mayorama Step

    Args:
        num_steps: int = 100,
        dt: float = 0.01,
        model: FIMSDEp = None,
        databatch: FIMSDEpDatabatchTuple = None,
    Returns:
        paths(Tensor[B,number_of_steps,D]),times([B,number_of_steps,D])

    """
    dimensions = databatch.obs_values.size(2)
    num_paths = databatch.obs_values.size(0)
    
    # Initialize states for all paths
    states = torch.nn.functional.sigmoid(torch.normal(0., 1., size=(num_paths, dimensions)))

    # Store paths
    paths = torch.zeros((num_paths, num_steps + 1, dimensions))  # +1 for initial state
    paths[:, 0] = states.clone()  # Store initial states

    times = torch.linspace(0., num_steps * dt, num_steps + 1)
    times = times[None, :].repeat(num_paths, 1)

    # Simulate the paths with tqdm progress bar
    for step in tqdm(range(num_steps), desc="Simulating steps", unit="step"):
        states = model_euler_maruyama_step(states, dt, model, databatch)  # Diffusion term
        paths[:, step + 1] = states.clone()  # Store new states

    return paths,times

In [36]:
paths = model_euler_maruyama_loop(100,0.01,model,databatch)

Simulating steps:   0%|          | 0/100 [00:00<?, ?step/s]

Simulating steps: 100%|██████████| 100/100 [00:04<00:00, 21.89step/s]
