In [32]:
import equinox as eqx
from neuralodes.model.oderesnet.odenet import ODENet
from neuralodes.data.dataloader import get_dataloaders
from neuralodes.model.oderesnet.evaluation import evaluate
from pathlib import Path
import os
import jax.random as jrandom

In [33]:
if Path(os.getcwd()).name == "notebooks":
    os.chdir(Path(os.getcwd()).parent)

In [37]:
key = jrandom.PRNGKey(0)
solver_name = "tsit5"

odenet = eqx.tree_deserialise_leaves(Path("models", "oderesnet", "odenet_fashionmnist_Tsit5_64.eqx"), ODENet(key, solver_name))

In [35]:
_, test_dataloader = get_dataloaders("fashionmnist", 256)

In [36]:
evaluate(odenet, test_dataloader)

(0.24162933342158793, 0.921875)

In [38]:
evaluate(odenet, test_dataloader)

(0.2494609896093607, 0.91884765625)

In [28]:
odenet.layers[1]

ODEBlock(
  odefunc=ODEFunc(
    norm1=GroupNorm(
      groups=32,
      channels=64,
      eps=1e-05,
      channelwise_affine=True,
      weight=f32[64],
      bias=f32[64]
    ),
    relu=<wrapped function relu>,
    conv1=ConcatConv2D(
      layer=Conv2d(
        num_spatial_dims=2,
        weight=f32[64,65,3,3],
        bias=f32[64,1,1],
        in_channels=65,
        out_channels=64,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding=((1, 1), (1, 1)),
        dilation=(1, 1),
        groups=1,
        use_bias=True
      )
    ),
    norm2=GroupNorm(
      groups=32,
      channels=64,
      eps=1e-05,
      channelwise_affine=True,
      weight=f32[64],
      bias=f32[64]
    ),
    conv2=ConcatConv2D(
      layer=Conv2d(
        num_spatial_dims=2,
        weight=f32[64,65,3,3],
        bias=f32[64,1,1],
        in_channels=65,
        out_channels=64,
        kernel_size=(3, 3),
        stride=(1, 1),
        padding=((1, 1), (1, 1)),
        dilation=(1, 1),


In [None]:
from matplotlib import pyplot as plt

t_values = [0,1]
ode_func = 

# Define a function to visualize the sensitivity
def visualize_sensitivity(sensitivity, t_values):
    num_t_values = len(t_values)
    num_channels = sensitivity.shape[1]
    
    fig, axes = plt.subplots(num_channels, num_t_values, figsize=(4 * num_t_values, 4 * num_channels))
    
    for i, t in enumerate(t_values):
        for j in range(num_channels):
            ax = axes[j, i]
            ax.imshow(sensitivity[i, j], cmap='viridis', aspect='auto')
            ax.axis('off')
            if j == 0:
                ax.set_title(f't = {t}')

    plt.tight_layout()
    plt.show()

# Generate a small perturbation to the input data (x)
epsilon = 1e-2
perturbation = epsilon * jrandom.normal(key, x.shape)
x_perturbed = x + perturbation

# Compute the output for the original and perturbed inputs
outputs = []
outputs_perturbed = []
for t in t_values:
    output = ode_func(t, x, None)
    output_perturbed = ode_func(t, x_perturbed, None)
    outputs.append(output[0])
    outputs_perturbed.append(output_perturbed[0])

outputs = jnp.stack(outputs)
outputs_perturbed = jnp.stack(outputs_perturbed)

# Calculate the difference between the outputs
sensitivity = jnp.abs(outputs_perturbed - outputs)

# Visualize the sensitivity
visualize_sensitivity(sensitivity, t_values)