In [None]:
import torch
import numpy as np


from fnope.simulators.simulator import SIR
from fnope.simulators.gp_priors import get_gaussian_process_prior_1d
from fnope.flow_matching.fnope_1D import FNOPE_1D


from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam

import matplotlib.pyplot as plt


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)


## SIR Model

The SIR example contains several parameter channels, as well as additional vector-valued parameters.

In [None]:
# Define prior

seq_len = 100 #Simulation gridpoints
T = 50 #length of simulation

func_prior = get_gaussian_process_prior_1d(num_points=seq_len, domain_length = T, mean=0.0, lengthscale = 7.0, sigma=2.5)
vec_prior = torch.distributions.Uniform(torch.Tensor([0.0, 0.0]), torch.Tensor([0.5, 0.5])) #gamma,mu




In [None]:
#Generate prior samples

num_sims = 5000
likelihood_scale = 0.05 #i.i.d noise


#sample evenly spaced simulation grid
sim_ts = torch.linspace(0, T, seq_len).to(device)

# We could already simulate each prior sample on a non-uniform grid, if our simulator supports this.
# sim_ts = torch.rand(seq_len).to(device)
# sim_ts = torch.sort(sim_ts)[0]
# sim_ts *= T


sim_theta_func = func_prior.sample((num_sims,)).to(device)
sim_theta_func = torch.sigmoid(sim_theta_func)

sim_theta_vec = vec_prior.sample((num_sims,)).to(device)

# delta = torch.rand(num_sims)*0.09+0.01

sim_x = SIR(sim_theta_func,sim_ts,gamma=sim_theta_vec[:,0],delta=sim_theta_vec[:,1],likelihood_scale=likelihood_scale,device=device)

#Also generate ground truth
theta_func_o = func_prior.sample(torch.Size([10])).to(device)
theta_func_o = torch.sigmoid(theta_func_o)
theta_vec_o = vec_prior.sample((10,)).to(device)
x_o = SIR(theta_func_o, sim_ts, gamma=theta_vec_o[:,0], delta=theta_vec_o[:,1], likelihood_scale=likelihood_scale, device=device)


In [None]:
#make sure timeseries are in shape (Batch, Channel, Time)
sim_theta_func  = sim_theta_func.view(-1, 1, seq_len)
sim_x = sim_x.view(-1, 3, seq_len)
theta_func_o = theta_func_o.view(-1, 1, seq_len)
x_o = x_o.view(-1, 3, seq_len)

## Define FNOPE model

In [None]:
modes_max = 32 #number of modes used by FNO blocks


model = FNOPE_1D(
    x = sim_theta_func, #note: x in `FNOPE` models is the parameter (theta), `ctx` is the observation
    ctx=sim_x,
    simulation_grid=sim_ts/T, #We normalize the grid to [0,1] - this means we also need to normalize the grid to [0,1] when evaluating!
    x_finite=sim_theta_vec, #vector-valued parameters ("finite")
    modes= modes_max, 
    conv_channels = 16,
    ctx_embedding_channels=8,
    time_embedding_channels=4,
    position_embedding_channels=4,
    num_layers=5,
    base_dist='gp', #This is the type of distribution used as the base distribution for flow matching. We use a Gaussian Process (the lengthscale is set depending on `modes`).
    padding = {"type":"none","pad_length":0}, #when using FFT, padding is helpful to avoid artefacts.
    training_point_noise={
         "jitter": 0.001 #scale of noise added to each position independently
         ,"target_gridsize": 50 #amount of points left over for each training sample after masking
           }, #This becomes unnecessary if we are already passing in simulations on different grids.
    always_equispaced=False, # Set to True for FNOPE (fix) - always evaluating on the same grid as the training data
    always_match_x_theta=False, #If always_equispaced=True, can also specify whether parameters and observations are defined on the same grid.
).to(device)



total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")


In [None]:
batch_size = 512
learning_rate = 1e-3


#Set up custom training here, training function also at fnope.flow_matching.training.py::train_fnope
dataset = TensorDataset(sim_theta_func, sim_theta_vec, sim_x)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

optimizer = Adam(model.parameters(), lr=learning_rate)


In [None]:
model.train()
num_epochs = 400

for epoch in range(num_epochs):
    avg_loss = 0.0
    for theta_func_batch,theta_vec_batch, x_batch in dataloader:
        optimizer.zero_grad()
        #when the model is defined with always_match_x_theta=False and always_equispaced=False, you need to pass the simulation positions for both
        #theta and x (simulation_positions and ctx_simulation_positions), EVEN if they're the same. This for BOTH training AND evaluation.
        loss = model.loss(theta_func_batch, ctx=x_batch, x_finite=theta_vec_batch, simulation_positions=sim_ts/T,ctx_simulation_positions=sim_ts/T)
        avg_loss += loss.item()*theta_func_batch.shape[0]/len(dataloader.dataset)

        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, total loss: {avg_loss:.4f}")


In [None]:
model.eval();


## Evaluate Model

In [None]:
num_samples = 1000

fnope_func_samples = torch.zeros(num_samples,10,seq_len).to(device)
fnope_vec_samples = torch.zeros(num_samples,10,2).to(device)

for true_idx in range(10):
    print(f"Sampling for observation no. {true_idx}")
    #when the model is defined with always_match_x_theta=False and always_equispaced=False, you need to pass the simulation positions for both
    #theta and x (simulation_positions and ctx_simulation_positions), EVEN if they're the same. This for BOTH training AND evaluation.
    func_samples,vec_samples = model.sample(num_samples, x_o[true_idx].view(-1,seq_len), point_positions = sim_ts/T, ctx_point_positions=sim_ts/T, atol=1e-2, rtol=1e-2)
    fnope_func_samples[:,true_idx,:] = func_samples.view(num_samples,seq_len)
    fnope_vec_samples[:,true_idx,:] = vec_samples.view(num_samples,2)

In [None]:
#Posterior predictive

true_idx = 2

fnope_predictive_samples = SIR(fnope_func_samples[:,true_idx,:].view(-1,seq_len),sim_ts, gamma=fnope_vec_samples[:,true_idx,0].view(-1), delta=fnope_vec_samples[:,true_idx,1].view(-1),  likelihood_scale=likelihood_scale, device=device)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(18, 5))
np_times = sim_ts.cpu().numpy()
# 1. Scatter plot of fnope_vec_samples with ground truth (pure matplotlib)
axs[0].scatter(
    fnope_vec_samples[:, true_idx, 0].cpu().numpy(),
    fnope_vec_samples[:, true_idx, 1].cpu().numpy(),
    alpha=0.3,
    label='Samples'
)
axs[0].scatter(
    theta_vec_o[true_idx, 0].cpu().numpy(),
    theta_vec_o[true_idx, 1].cpu().numpy(),
    color='red',
    label='Ground Truth'
)
axs[0].set_xlabel('gamma')
axs[0].set_ylabel('delta')
axs[0].set_title('Posterior samples (vec) with ground truth')
axs[0].set_xlim(0, 0.5)
axs[0].set_ylim(0, 0.5)
axs[0].legend()

# 2. Distribution of fnope_func_samples (show mean and 95% interval)
mean_func = fnope_func_samples[:, true_idx, :].mean(dim=0).cpu().numpy()
lower = fnope_func_samples[:, true_idx, :].quantile(0.025, dim=0).cpu().numpy()
upper = fnope_func_samples[:, true_idx, :].quantile(0.975, dim=0).cpu().numpy()
axs[1].plot(np_times,mean_func, label='Posterior Mean')
axs[1].fill_between(np_times, lower, upper, alpha=0.3, label='95% CI')
axs[1].plot(np_times,theta_func_o[true_idx, 0, :].cpu().numpy(), color='red', label='Ground Truth')
axs[1].set_title('Posterior func samples')
axs[1].legend()

# 3. Distribution of fnope_predictive_samples (show mean and 95% interval for first channel)
mean_pred = fnope_predictive_samples[:, 0, :].mean(dim=0).cpu().numpy()
lower_pred = fnope_predictive_samples[:, 0, :].quantile(0.025, dim=0).cpu().numpy()
upper_pred = fnope_predictive_samples[:, 0, :].quantile(0.975, dim=0).cpu().numpy()
axs[2].plot(np_times,mean_pred, label='Posterior Predictive Mean')
axs[2].fill_between(np_times, lower_pred, upper_pred, alpha=0.3, label='95% CI')
axs[2].plot(np_times,x_o[true_idx, 0, :].cpu().numpy(), color='red', label='Ground Truth')
axs[2].set_title('Posterior predictive samples')
axs[2].legend()

plt.tight_layout()
plt.show()

## Evaluate on nonuniform discretizatoins

In [None]:
# Create some artificial simulations on nonuniform discretizations
from torch.distributions import MultivariateNormal
from fnope.simulators.gp_priors import squared_exponential_kernel

uneven_seq_len = 100
uneven_ts = torch.rand(uneven_seq_len).to(device)
uneven_ts = torch.sort(uneven_ts)[0]
uneven_ts *= T

mean = torch.full((uneven_seq_len,), 0.0).to(device).to(torch.float32)
cov = (
    squared_exponential_kernel(uneven_ts, uneven_ts, 7.0, 2.5)
    + torch.eye(uneven_seq_len).to(device) * 1e-5
)
mvn = MultivariateNormal(mean, covariance_matrix=cov)



theta_func_1 = mvn.sample((10,)).to(device)
theta_func_1 = torch.sigmoid(theta_func_1)
theta_vec_1 = vec_prior.sample((10,)).to(device)
x_1 = SIR(theta_func_1, uneven_ts, gamma=theta_vec_1[:, 0], delta=theta_vec_1[:, 1], likelihood_scale=likelihood_scale, device=device)


In [None]:
num_samples = 1000

fnope_func_samples_1 = torch.zeros(num_samples,10,uneven_seq_len).to(device)
fnope_vec_samples_1 = torch.zeros(num_samples,10,2).to(device)

for true_idx in range(10):
    print(f"Sampling for observation no. {true_idx}")
    func_samples,vec_samples = model.sample(num_samples, x_1[true_idx].view(-1,uneven_seq_len), point_positions = uneven_ts/T, ctx_point_positions=uneven_ts/T, atol=1e-2, rtol=1e-2)
    fnope_func_samples_1[:,true_idx,:] = func_samples.view(num_samples,uneven_seq_len)
    fnope_vec_samples_1[:,true_idx,:] = vec_samples.view(num_samples,2)

In [None]:
#Posterior predictive

true_idx = 2
fnope_predictive_samples = SIR(fnope_func_samples_1[:,true_idx,:].view(-1,uneven_seq_len),uneven_ts, gamma=fnope_vec_samples_1[:,true_idx,0].view(-1), delta=fnope_vec_samples_1[:,true_idx,1].view(-1),  likelihood_scale=likelihood_scale, device=device)

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(18, 5))
np_times = uneven_ts.cpu().numpy()

# 1. Scatter plot of fnope_vec_samples with ground truth (pure matplotlib)
axs[0].scatter(
    fnope_vec_samples_1[:, true_idx, 0].cpu().numpy(),
    fnope_vec_samples_1[:, true_idx, 1].cpu().numpy(),
    alpha=0.3,
    label='Samples'
)
axs[0].scatter(
    theta_vec_1[true_idx, 0].cpu().numpy(),
    theta_vec_1[true_idx, 1].cpu().numpy(),
    color='red',
    label='Ground Truth'
)
axs[0].set_xlabel('gamma')
axs[0].set_ylabel('delta')
axs[0].set_title('Posterior samples (vec) with ground truth')
axs[0].set_xlim(0, 0.5)
axs[0].set_ylim(0, 0.5)
axs[0].legend()

# 2. Distribution of fnope_func_samples (show mean and 95% interval)
mean_func = fnope_func_samples_1[:, true_idx, :].mean(dim=0).cpu().numpy()
lower = fnope_func_samples_1[:, true_idx, :].quantile(0.025, dim=0).cpu().numpy()
upper = fnope_func_samples_1[:, true_idx, :].quantile(0.975, dim=0).cpu().numpy()
axs[1].plot(np_times,mean_func, label='Posterior Mean')
axs[1].fill_between(np_times, lower, upper, alpha=0.3, label='95% CI')
axs[1].plot(np_times,theta_func_1[true_idx, :].cpu().numpy(), color='red', label='Ground Truth')
axs[1].set_title('Posterior func samples')
axs[1].legend()

# 3. Distribution of fnope_predictive_samples (show mean and 95% interval for first channel)
mean_pred = fnope_predictive_samples[:, 0, :].mean(dim=0).cpu().numpy()
lower_pred = fnope_predictive_samples[:, 0, :].quantile(0.025, dim=0).cpu().numpy()
upper_pred = fnope_predictive_samples[:, 0, :].quantile(0.975, dim=0).cpu().numpy()
axs[2].plot(np_times,mean_pred, label='Posterior Predictive Mean')
axs[2].fill_between(np_times, lower_pred, upper_pred, alpha=0.3, label='95% CI')
axs[2].plot(np_times,x_1[true_idx, 0, :].cpu().numpy(), color='red', label='Ground Truth')
axs[2].set_title('Posterior predictive samples')
axs[2].legend()

plt.tight_layout()
plt.show()