# Brain tumor simulation using PDEs

In [None]:
import matplotlib.pyplot as plt
import torch as th

from ipywidgets import interact, IntSlider, FloatSlider
from pathlib import Path

from phynn.dataloader import DynamicsElement, DataInterface, HDF5DirectlyFromFile
from phynn.pde import PDEEval, PDEStaticParams, FisherKolmogorovPDE

## Load data

In [None]:
dataset_path = Path("../data/preprocessed/Brain-Tumor-Progression/2d.h5")
data_interface = HDF5DirectlyFromFile(dataset_path)


print(f"Number of loaded time series: {data_interface.times_shape[0]}")
print(f"Shape of each image: {data_interface.image_shape}")

## Reaction-diffusion equation

In [None]:
def simulate_reaction_diffusion(data_interface: DataInterface):
    series = data_interface.times_shape[0]

    pde = FisherKolmogorovPDE()
    params = PDEStaticParams(0, 0)
    pde_eval = PDEEval(
        pde,
        params,
        min_concentration=0.3,
        boundary_condition=lambda x: x > 0.02,
    )

    def plot_image(index: int, time: int, D: float, p: float, x: int):
        params.values = (D, p)

        with th.no_grad():
            data = data_interface.get(index, 0, 0)
            original = data[DynamicsElement.START].cpu()
            simulated = pde_eval(
                original.unsqueeze(0).unsqueeze(0), th.tensor(time).unsqueeze(0)
            )

        image_2d = simulated[0][0][x]

        v_min = simulated.min().item()
        v_max = simulated.max().item()

        plt.imshow(image_2d, vmin=v_min, vmax=v_max)
        plt.show()

    index_slider = IntSlider(min=0, max=series - 1, description="Index")
    time_slider = IntSlider(min=0, max=100, step=5, description="Time")
    diffusion_slider = FloatSlider(
        value=0.0, min=0.0, max=2.0, step=0.025, description="Diffusion"
    )
    proliferation_slider = FloatSlider(
        value=0.0, min=0.0, max=4.0, step=0.025, description="Proliferation"
    )
    x_slider = IntSlider(min=0, max=data_interface.image_shape[0] - 1, description="X")

    interact(
        plot_image,
        index=index_slider,
        time=time_slider,
        D=diffusion_slider,
        p=proliferation_slider,
        x=x_slider,
    )

In [None]:
simulate_reaction_diffusion(data_interface)