In [None]:
from autocvd import autocvd
autocvd(num_gpus=5, )

In [None]:
import os
os.chdir('/export/home/bguenes/COMPASS/')
from src.compass import ScoreBasedInferenceModel as SBIm
from src.compass import ModelTransfuser as MTf
os.chdir('/export/home/bguenes/COMPASS/tutorials')

In [None]:
# from compass import ScoreBasedInferenceModel as SBIm
# from compass import ModelTransfuser as MTf

In [None]:
import torch

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
# --- 1. General ODE Solver (using simple Euler method) ---
def solve_ode(model_func, initial_state, params, t_max, dt):
    """
    Solves a system of ODEs using the Euler method.

    Args:
        model_func (function): The function defining the ODEs (e.g., lotka_volterra).
        initial_state (torch.Tensor): The starting values [N, P].
        params (dict): A dictionary of parameters for the model.
        t_max (int): The maximum simulation time.
        dt (float): The time step.

    Returns:
        (torch.Tensor, torch.Tensor): Tensors for time points and population history.
    """
    # Setup time steps and history arrays
    time_steps = torch.arange(0, t_max, dt)
    history = torch.zeros(initial_state.shape[0], len(time_steps), 2)
    history[:, 0, :] = initial_state
    
    # Current state starts at the initial state
    current_state = initial_state.clone()

    # Euler integration loop
    for i in range(1, len(time_steps)):
        # Get the derivatives from the model function
        derivatives = model_func(current_state, params)
        # Update the state using the Euler step
        current_state += derivatives * dt
        # Ensure populations don't go below zero
        current_state = torch.max(current_state, torch.tensor([0.0, 0.0]))
        history[:, i, :] = current_state
        
    return time_steps, history

In [None]:
# --- 2. The Four Competing Model Functions ---
# Each model uses exactly four parameters: {alpha, beta, gamma, delta}

def lotka_volterra(state, params):
    """Model 1: Classic Lotka-Volterra dynamics."""
    N, P = state.T
    alpha, beta, gamma, delta = params.T

    dN_dt = alpha * N - beta * N * P
    dP_dt = delta * N * P - gamma * P
    return torch.stack([dN_dt, dP_dt]).T

def logistic_prey(state, params):
    """Model 2: Prey with logistic growth."""
    
    N, P = state.T
    alpha, beta, gamma, delta = params.T
    delta_lp = delta * 1000  # Prey carrying capacity (logistic growth)
    cn_rate = 0.5  # Fixed conversion efficiency

    dN_dt = alpha * N * (1 - N / delta_lp) - beta * N * P
    dP_dt = cn_rate * beta * N * P - gamma * P
    return torch.stack([dN_dt, dP_dt]).T

def satiated_predator(state, params):
    """Model 3: Predator with satiation (Holling Type II)."""
    N, P = state.T
    alpha, beta, gamma, delta = params.T
    cn_rate = 0.5 # Fixed conversion efficiency

    consumption = (beta * N) / (1 + beta * delta * N)
    dN_dt = alpha * N - consumption * P
    dP_dt = cn_rate * consumption * P - gamma * P
    return torch.stack([dN_dt, dP_dt]).T

def rosenzweig_macarthur(state, params):
    """Model 4: Both logistic prey and satiated predator."""
    N, P = state.T
    alpha, beta, gamma, delta = params.T
    delta_rm = delta * 1000  # 
    cn_rate = 0.5 # Fixed conversion efficiency
    h_rate = 0.1 # Fixed handling time

    consumption = (beta * N) / (1 + beta * h_rate * N)
    dN_dt = alpha * N * (1 - N / delta_rm) - consumption * P
    dP_dt = cn_rate * consumption * P - gamma * P
    return torch.stack([dN_dt, dP_dt]).T

In [None]:
# --- 3. Simulation Setup ---
# Initial conditions [N0, P0]
initial_state = torch.tensor([[20.0, 30.0]])

# Simulation time
t_max = 30
dt = 0.01

# List of models to run
models = {
    "Lotka-Volterra": lotka_volterra,
    "Logistic Prey": logistic_prey,
    "Satiated Predator": satiated_predator,
    "Rosenzweig-MacArthur": rosenzweig_macarthur
}

## Lynx and Hare Population Dynamics

In [None]:
url = 'http://people.whitman.edu/~hundledr/courses/M250F03/LynxHare.txt'
df = pd.read_csv(url, sep=r'\s+', header=None, index_col=0)
df.index.name = 'Year'
df.columns = ['Hare', 'Lynx']

time = torch.tensor(df.index.values)
df.index = df.index - df.index.min()
index = torch.tensor(df.index.values)
hare, lynx = torch.tensor(df['Hare'].values), torch.tensor(df['Lynx'].values)


plt.figure(figsize=(15,5))
sns.scatterplot(x=time, y=hare, label='Hare (Prey)')
sns.lineplot(x=time, y=hare)
sns.scatterplot(x=time, y=lynx, label='Lynx (Predator)')
sns.lineplot(x=time, y=lynx)


plt.legend()
plt.xlabel('Time')
plt.ylabel('Population')
plt.title('Hare and Lynx Population Over Time')
plt.tight_layout()
plt.show()

## Create Training Data

In [None]:
class prior_distributions:
    def __init__(self):
        self.alpha = torch.distributions.normal.Normal(-0.125, 0.5)
        self.beta = torch.distributions.normal.Normal(-3, 0.5)
        self.gamma = torch.distributions.normal.Normal(-0.125, 0.5)
        self.delta = torch.distributions.normal.Normal(-3, 0.5)

    def sample(self, num_samples=1):
        alpha = self.alpha.sample((num_samples,))
        beta = self.beta.sample((num_samples,))
        gamma = self.gamma.sample((num_samples,))
        delta = self.delta.sample((num_samples,))

        params = torch.stack([alpha, beta, gamma, delta], dim=-1)

        return params

prior = prior_distributions()

## Initialize COMPASS

In [None]:
mtf = MTf(path="data/predator_prey")

In [None]:
# # # Load the pretrained models

for model_name, _ in models.items():
    sbim = SBIm.load(f"data/predator_prey/{model_name}_checkpoint.pt", device="cuda")
    mtf.add_model(model_name, sbim)

In [None]:
data = torch.stack([hare, lynx], dim=-1).flatten()[:60].unsqueeze(0).float()
data = data/100

In [None]:
test_params = prior.sample(1)
time, history = solve_ode(model_func, torch.tensor([[20.0,30.0]]), torch.exp(test_params), 30, 0.01)
test_data = history[:, time % 1 == 0].flatten(1)/100

In [None]:
mtf.compare(x=test_data, device="cuda", timesteps=1000, method="dpm", order=2)

In [None]:
mtf.plot_comparison()

In [None]:
labels= ["r", "a", "m", "K"] + torch.arange(0, test_data.shape[1]).tolist()

In [None]:
mtf.plot_attention(labels=labels)

In [None]:
a = mtf.stats["Lotka-Volterra"]["MAP"][0,0]

In [None]:
a =torch.tensor(a)
params = torch.exp(a)
initial_state = torch.tensor([[20.0,30.0]])

# Simulation time
t_max = 30
dt = 0.01

best_model_fn = models["Lotka-Volterra"]

time_best, history_best = solve_ode(best_model_fn, initial_state, params, t_max, dt)

In [None]:
#time = torch.tensor(df.index.values)
plt.figure(figsize=(15,5))
# sns.scatterplot(x=time, y=hare, label='Hare (Prey)', color='blue')
# sns.lineplot(x=time, y=hare, color='blue', alpha=0.5)
# sns.scatterplot(x=time, y=lynx, label='Lynx (Predator)', color='orange')
# sns.lineplot(x=time, y=lynx, color='orange', alpha=0.5)

plt.plot(time, history[0, :, 0], label='Hare (Sample)', color='blue')
plt.plot(time, history[0, :, 1], label='Lynx (Sample)', color='orange')


sns.lineplot(x=time_best, y=history_best[0, :, 0], label='Hare (Model)', color='blue', linestyle='--')
sns.lineplot(x=time_best, y=history_best[0, :, 1], label='Lynx (Model)', color='orange', linestyle='--')

plt.legend()
plt.xlabel('Time')
plt.ylabel('Population')
plt.title('Hare and Lynx Population Over Time')
plt.tight_layout()
plt.show()