In [None]:
import torch
import numpy as np


from fnope.simulators.darcy import Darcy2D
from fnope.flow_matching.fnope_2D import FNOPE_2D


from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam

import matplotlib.pyplot as plt


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)


## Darcy Simulator

In this tutorial we use the darcy flow problem to showcase FNOPE (fix) for parameters
varying in 2 dimensions.

In [None]:

prior_params_darcy = {"tau": 9.0,"alpha":2.0,"scale":1000.0}

batch_size = 256 #simulation batch size
resolution = 64 #smaller than reported experiment so this notebook can be run faster
theta_x_size = resolution+1
theta_y_size = resolution+1

darcy = Darcy2D(prior = "Darcy_GP",   prior_params_darcy=prior_params_darcy, snr=30.0,batch_size=batch_size, resolution=resolution)

In [None]:
n_sim = 1000 #num training_sims

n_full_batches = n_sim // batch_size
last_batch_size = n_sim % batch_size

all_theta = []
all_x = []

# Generate all batches and collect
for i in range(n_full_batches):
    theta_temp, theta_res, x_temp = darcy.sample_darcy()
    all_x.append(x_temp.cpu())
    all_theta.append(theta_temp.cpu()) # theta is the parameter on the grid we sample the prior from.
    print(f"batch {i+1}/{n_full_batches} done.")

if last_batch_size > 0:
    # Generate last batch
    theta_temp, theta_res, x_temp = darcy.sample_darcy()
    all_x.append(x_temp.cpu()[:last_batch_size])
    all_theta.append(theta_temp.cpu()[:last_batch_size])

# Concatenate all
sim_theta = torch.cat(all_theta, dim=0)
sim_x = torch.cat(all_x, dim=0)

theta_o, theta_res_o,x_temp_o = darcy.sample_darcy()
theta_o = theta_o[:10]
x_o = x_temp_o[:10]

In [None]:
print("sim_theta contains NaNs:", torch.isnan(sim_theta).any().item())
print("sim_x contains NaNs:", torch.isnan(sim_x).any().item())

In [None]:
sim_theta = sim_theta.view(-1,1,theta_x_size,theta_y_size).to(device)
sim_x = sim_x.view(-1,1,theta_x_size,theta_y_size).to(device)
theta_o = theta_o.view(-1,1,theta_x_size,theta_y_size).to(device)
x_o = x_o.view(-1,1,theta_x_size,theta_y_size).to(device)

In [None]:
modes_max = 16 #number of modes used by FNO blocks


model = FNOPE_2D(
    x = sim_theta, #note: x in `FMPE_Unified` is the parameter (theta), `ctx` is the observation
    ctx=sim_x,
    simulation_grid=None, #In this example we always evaluate on the same discretization to demonstrate FNOPE (fix)
    x_finite=None, #vector-valued parameters ("finite")
    modes= modes_max, 
    conv_channels = 16,
    ctx_embedding_channels=16,
    time_embedding_channels=4,
    position_embedding_channels=4,
    num_layers=5,
    base_dist='gp', #This is the type of distribution used as the base distribution for flow matching. We use a Gaussian Process (the lengthscale is set depending on `modes`).
    padding = {"type":"zero","pad_length":10}, #when using FFT, padding is helpful to avoid artefacts.
    always_equispaced=True, # Set to True for FNOPE (fix) - always evaluating on the same grid as the training data
    always_match_x_theta=True, #If always_equispaced=True, can also specify whether parameters and observations are defined on the same grid.
).to(device)



total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")


In [None]:
training_batch_size = 256
learning_rate = 1e-3


#Set up custom training here, training function also at fnope.flow_matching.training.py::train_fnope
dataset = TensorDataset(sim_theta, sim_x)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)

optimizer = Adam(model.parameters(), lr=learning_rate)


In [None]:
model.train()
num_epochs = 200

for epoch in range(num_epochs):
    avg_loss = 0.0
    for theta_batch,x_batch in dataloader:
        optimizer.zero_grad()
        loss = model.loss(theta_batch, ctx=x_batch)
        avg_loss += loss.item()*theta_batch.shape[0]/len(dataloader.dataset)

        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch}, total loss: {avg_loss:.4f}")


In [None]:
model.eval();


## Evaluation

In [None]:
num_samples = 100

fnope_samples = torch.zeros(num_samples,10,theta_x_size,theta_y_size).to(device)

for true_idx in range(10):
    print(f"Sampling for observation no. {true_idx}")
    func_samples= model.sample(num_samples, x_o[true_idx].view(-1,theta_x_size,theta_y_size),atol=1e-2, rtol=1e-2)
    fnope_samples[:,true_idx,:] = func_samples.view(num_samples,theta_x_size,theta_y_size)

In [None]:
#Posterior predictive

#simulator object with matching batch size to number of posterior samples
darcy_predictive_check = Darcy2D(
            batch_size=num_samples,
            resolution=resolution,
            snr=30.0,
        )

true_idx = 0
t_temp,t_res_temp,fnope_predictive_samples = darcy_predictive_check.simulate_darcy(
    fnope_samples[:, true_idx, :, :],
)

In [None]:
view = (slice(None), slice(None))
show_exp = False #show permeability or log-permeability
sample_idx = 1 #which sample to show

if show_exp:
    vmin = 0
    vmax = max(fnope_samples[sample_idx,0, *view].exp().max(), theta_o[true_idx,0, *view].exp().max())
else:
    vmin = None
    vmax = None

fig,ax = plt.subplots(1, 4, figsize=(16, 5))


if show_exp:
    colors = ax[0].imshow(theta_o[true_idx,0, *view].exp().detach().cpu().numpy(),vmin=vmin, vmax=vmax)
else:
    colors = ax[0].imshow(theta_o[true_idx,0,*view].detach().cpu().numpy(),vmin=vmin, vmax=vmax)
ax[0].set_title("True theta")

if show_exp:
    colors2 = ax[1].imshow(fnope_samples[sample_idx,true_idx, *view].exp().detach().cpu().numpy(),vmin=vmin, vmax=vmax)
else:
    colors2 = ax[1].imshow(fnope_samples[sample_idx,true_idx, *view].detach().cpu().numpy(),vmin=vmin, vmax=vmax)


if show_exp:
    colors3 = ax[2].imshow(torch.abs(fnope_samples[:,true_idx].exp().mean(dim=0)[view]-theta_o[true_idx,0,*view].exp()).detach().cpu().numpy())
else:
    colors3 = ax[2].imshow(torch.abs(fnope_samples[:,true_idx].mean(dim=0)[view]-theta_o[true_idx,0,*view]).detach().cpu().numpy())

ax[2].set_title("Abs error of mean")

if show_exp:
    colors4 = ax[3].imshow(fnope_samples[:,true_idx].exp().std(0)[*view].detach().cpu().numpy())
else:
    colors4 = ax[3].imshow(fnope_samples[:,true_idx].std(0)[view].detach().cpu().numpy())


ax[1].set_title("FNOPE")
ax[3].set_title("FNOPE std")
plt.colorbar(colors,ax=ax[0])
plt.colorbar(colors2,ax=ax[1])
plt.colorbar(colors3,ax=ax[2])
plt.colorbar(colors4,ax=ax[3])


In [None]:
view = (slice(None), slice(None))
sample_idx = 1 #which sample to show

vmin = None
vmax = None

fig,ax = plt.subplots(1, 4, figsize=(16, 5))


colors = ax[0].imshow(x_o[true_idx,0,*view].detach().cpu().numpy(),vmin=vmin, vmax=vmax)
ax[0].set_title("True x")

colors2 = ax[1].imshow(fnope_predictive_samples[sample_idx, *view].detach().cpu().numpy(),vmin=vmin, vmax=vmax)


colors3 = ax[2].imshow(torch.abs(fnope_predictive_samples.mean(dim=0)[view]-x_o[true_idx,0,*view]).detach().cpu().numpy())

ax[2].set_title("Abs error of mean")

colors4 = ax[3].imshow(fnope_predictive_samples.std(0)[view].detach().cpu().numpy())


ax[1].set_title("FNOPE predictive")
ax[3].set_title("FNOPE predictive std")
plt.colorbar(colors,ax=ax[0])
plt.colorbar(colors2,ax=ax[1])
plt.colorbar(colors3,ax=ax[2])
plt.colorbar(colors4,ax=ax[3])
