In [1]:
import jax
import jax.numpy as jnp
from jax.experimental import optimizers
import os
import h5py
import numpy as onp
import time
import cardiax
import deepx
from deepx import optimise
import helx
from helx.optimise.optimisers import Optimiser
import json
import wandb
import pickle
import IPython
from IPython.display import display
from IPython.display import display_javascript
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import matplotlib.animation as animation
import helx
from matplotlib import rc, rcParams
import scipy.io
from functools import partial

rc('animation', html='jshtml')
rc('text', usetex=False)
rcParams['animation.embed_limit'] = 2**128
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

## Setup parameters

In [2]:
shape = (1200, 1200)
reshape = (256, 256)
state = cardiax.solve.init(shape)
start = 0
stop = 200_000
step = 500  #  5 milliseconds
dt = 0.01
dx = 0.01
paramset = cardiax.params.PARAMSET_3
n_refeed = 100


def load_model(url):
    train_state = deepx.optimise.TrainState.restore(url)
    hparams = train_state.hparams
    model = deepx.resnet.ResNet(hparams.hidden_channels, 1, hparams.depth)
    opt = optimizers.adam(0.001)
    params = opt.params_fn(train_state.opt_state)
    return model, hparams, params


# def infer(xs):
#     model, hparams, params = load_model("p3aetobr/9336")
# #     model, hparams, params = load_model("p3aetobr/9996")
# #     model, hparams, params = load_model("p3aetobr/10326")
# #     model, hparams, params = load_model("p3aetobr/10976")
#     start = time.time()
#     ys_hat = deepx.optimise.infer(model, n_refeed, params, xs)
#     print("Solved forward propagation to {}ms in: {}s".format(n_refeed * 5, time.time() - start))
#     if ys_hat.shape[0] == 1:
#         ys_hat = jnp.swapaxes(ys_hat, -3, -4)[None]
#     return ys_hat


def infer(xs, a_min=None, a_max=None):
    @partial(
        jax.pmap,
        in_axes=(None, None, 0, 0),
        static_broadcasted_argnums=(0, 1),
        axis_name="device",
    )
    def _infer(model, n_refeed, params, xs):
        def body_fun(inputs, i):
            x = inputs
            x = jnp.pad(x, 30, mode='edge')
            y_hat = model.apply(params, x)
            y_hat = jnp.pad(x, -30)
            x = deepx.optimise.refeed(x, y_hat)  #  add the new pred to the inputs
            return x, y_hat

        _, ys_hat = jax.lax.scan(body_fun, xs, xs=jnp.arange(n_refeed))
        ys_hat = jnp.swapaxes(jnp.squeeze(ys_hat), 0, 1)
        return ys_hat

    model, hparams, params = load_model("p3aetobr/9336")
#     model, hparams, params = load_model("p3aetobr/9996")
#     model, hparams, params = load_model("p3aetobr/10326")
#     model, hparams, params = load_model("p3aetobr/10976")
    start = time.time()
    if (a_min is None) and (a_max is None):
        ys_hat = deepx.optimise.infer(model, n_refeed, params, xs)
    else:
        ys_hat = _infer(model, n_refeed, params, xs)
    print("Solved forward propagation to {}ms in: {}s".format(n_refeed * 5, time.time() - start))
    if ys_hat.shape[0] == 1:
        ys_hat = jnp.swapaxes(ys_hat, -3, -4)[None]
    return ys_hat


def evaluate(xs, ys):
    ys_hat = infer(xs)
    assert ys_hat.shape == ys.shape, "ys_hat and ys are of different shapes {} and {}".format(ys_hat.shape, ys.shape)
    loss = jnp.mean((ys_hat - ys) ** 2, axis=(0, 1, 4, 5))
    return ys_hat, loss


def read(filename, t, normalise=False):
    with h5py.File(filename, "r") as sequence:
        states = onp.array(
            sequence["states"][t:]
        )
        diffusivity = onp.array(sequence["diffusivity"])
        if normalise:
            diffusivity = diffusivity * 500
        diffusivity = onp.tile(diffusivity[None, None], (len(states), 1, 1, 1))
        xs = onp.concatenate([states, diffusivity], axis=-3)[None, None]
        return xs
    
    
def save_mat(xs, filename):
    mat_filename = os.path.splitext(filename)[0] + ".mat"
    mdict = {"v": xs[:, 0], "w": xs[:, 1], "u": xs[:, 2], "d": xs[:, 3]}
    return scipy.io.savemat(mat_filename, mdict)


def animate_state(a, d, filename=None):
    states_a = [cardiax.solve.State(*x) for x in a.squeeze()]
    anim = cardiax.plot.animate_state(states_a, d, cmap="Blues", vmin=0, vmax=1)
    display(anim)
    if filename is not None:
        anim.save("data/{}".format(filename))
    return anim

# 1. Heterogeneous tissue

## 1.1 Linear regime

In [34]:
#  setup
filename = "rebuttal/heterogeneous_linear_{}-{}.hdf5"

### a) Finite difference

In [15]:
#  Finite difference simulation
p1 = cardiax.stimulus.Protocol(40_000 * 0, 2, 1e9)
p2 = cardiax.stimulus.Protocol(40_000 * 1, 2, 1e9)
p3 = cardiax.stimulus.Protocol(40_000 * 2, 2, 1e9)
p4 = cardiax.stimulus.Protocol(40_000 * 3, 2, 1e9)
for seed in range(10):
    rng = jax.random.PRNGKey(seed)
    angle = int(jax.random.randint(rng, (1,), 0, 180, ))
    s1 = cardiax.stimulus.triangular(shape, cardiax.stimulus.Direction.NORTH, angle, 0.2, 20, p1)
    s2 = cardiax.stimulus.triangular(shape, cardiax.stimulus.Direction.NORTH, angle, 0.2, 20, p2)
    s3 = cardiax.stimulus.triangular(shape, cardiax.stimulus.Direction.NORTH, angle, 0.2, 20, p3)
    s4 = cardiax.stimulus.triangular(shape, cardiax.stimulus.Direction.NORTH, angle, 0.2, 20, p4)
    stimuli = [s1, s2, s3, s4]
    diffusivity = deepx.generate.random_diffusivity(rng, shape)
    deepx.generate.sequence(
        start, stop, step, dt, dx, paramset, diffusivity, stimuli, filename.format(seed, "fd"), reshape=reshape, use_memory=True, plot_while=False
    )

### b) Neural network

In [47]:
for seed in range(10):
    xs = read(filename.format(seed, "fd"), int(40_000 * 3 * dt / 5), normalise=True)
    save_mat(xs.squeeze(), filename.format(seed, "fd"))
    ys_hat = infer(xs[:, :, :2]).squeeze()
    save_mat(ys_hat.squeeze(), filename.format(seed, "nn"))
    animate_state(ys_hat.squeeze(), xs.squeeze()[0, -1])

In [None]:
#  setup
filename = "rebuttal/heterogeneous_linear_{}-{}.hdf5"

seed = 0
xs = read(filename.format(seed, "fd"), int(40_000 * 3 * dt / 5), normalise=True)
ys_hat = infer(xs[:, :, :2]).squeeze()
animate_state(xs.squeeze()[0:3], xs.squeeze()[0, -1])
animate_state(ys_hat.squeeze(), xs.squeeze()[0, -1])

## 1.2 Spiral regime

In [3]:
#  setup
filename = "rebuttal/heterogeneous_spiral_{}-{}.hdf5"

### a) Finite difference

In [68]:
#  Finite difference simulation
stop = 100_000
p1 = cardiax.stimulus.Protocol(40_000 * 0, 2, 1e9)
p2 = cardiax.stimulus.Protocol(40_000 * 0 + 40_000, 2, 1e9)
for seed in range(10):
    rng = jax.random.PRNGKey(seed)
    angle = int(jax.random.randint(rng, (1,), 0, 180, ))
    s1 = cardiax.stimulus.triangular(shape, cardiax.stimulus.Direction.NORTH, angle, 0.2, 20, p1)
    s2 = cardiax.stimulus.triangular(shape, cardiax.stimulus.Direction.NORTH, angle + 90, 0.8, 20, p2)
    stimuli = [s1, s2]
    diffusivity = deepx.generate.random_diffusivity(rng, shape)
    deepx.generate.sequence(
        start, stop, step, dt, dx, paramset, diffusivity, stimuli, filename.format(seed, "fd"), reshape=reshape, use_memory=True, plot_while=False
    )
    xs = read(filename.format(seed, "fd"), 0)
    animate_state(xs.squeeze()[:, 0:3], xs.squeeze()[0, -1])

In [62]:
#  Finite difference simulation
# stop = 200_000
p1 = cardiax.stimulus.Protocol(40_000 * 0, 2, 1e9)
p2 = cardiax.stimulus.Protocol(40_000 * 1, 2, 1e9)
p3 = cardiax.stimulus.Protocol(40_000 * 2, 2, 1e9)
p4 = cardiax.stimulus.Protocol(40_000 * 2 + 35_000, 2, 1e9)
for seed in range(10):
    rng = jax.random.PRNGKey(seed)
    angle = int(jax.random.randint(rng, (1,), 0, 180, ))
    s1 = cardiax.stimulus.triangular(shape, cardiax.stimulus.Direction.NORTH, angle, 0.2, 20, p1)
    s2 = cardiax.stimulus.triangular(shape, cardiax.stimulus.Direction.NORTH, angle, 0.2, 20, p2)
    s3 = cardiax.stimulus.triangular(shape, cardiax.stimulus.Direction.NORTH, angle, 0.2, 20, p3)
    s4 = cardiax.stimulus.triangular(shape, cardiax.stimulus.Direction.NORTH, angle + 90, 0.2, 20, p4)
    stimuli = [s1, s2, s3, s4]
    diffusivity = deepx.generate.random_diffusivity(rng, shape)
    deepx.generate.sequence(
        start, stop, step, dt, dx, paramset, diffusivity, stimuli, filename.format(seed, "fd"), reshape=reshape, use_memory=True, plot_while=False
    )
    xs = read(filename.format(seed, "fd"), 0)
    animate_state(xs.squeeze()[:, 0:3], xs.squeeze()[0, -1])
    break

### b) Neural network

# 2. Homogeneous tissue

# Pre-rebuttal data

## Generate linear wave

In [None]:
stop = 60_000
p1 = cardiax.stimulus.Protocol(0, 2, 1e9)
s1 = [cardiax.stimulus.linear(shape, cardiax.stimulus.Direction.SOUTH, 0.2, 20.0, p1)]
stimuli = s1

filename = "data/linear.hdf5"

if not os.path.isfile(filename):
    deepx.generate.sequence(
        start, stop, step, dt, dx, paramset, diffusivity, stimuli, filename, reshape=reshape, use_memory=True, plot_while=True
    )

## Generate spiral

In [None]:
stop = 200_000
p1 = cardiax.stimulus.Protocol(0, 2, 1e9)
s1 = [cardiax.stimulus.linear(shape, cardiax.stimulus.Direction.SOUTH, 0.2, 20.0, p1)]
p2 = cardiax.stimulus.Protocol(30000, 2, 1e9)
s2 = [cardiax.stimulus.linear(shape, cardiax.stimulus.Direction.WEST, 0.2, 20.0, p2)]
stimuli = s1 + s2

filename = "data/spiral.hdf5"
linear_wave_start = 1
spiral_start = 62
broken_spiral_start = 82

if not os.path.isfile(filename):
    deepx.generate.sequence(
        start, stop, step, dt, dx, paramset, diffusivity, stimuli, filename, reshape=reshape, use_memory=True, plot_while=True
    )

## Generate linear wave in homogeneous conductivity

In [None]:
stop = 100_000
p1 = cardiax.stimulus.Protocol(0, 2, 1e9)
s1 = [cardiax.stimulus.linear(shape, cardiax.stimulus.Direction.SOUTH, 0.2, 20.0, p1)]
stimuli = s1

diffusivity = jnp.ones(shape) * 0.001
filename = "data/linear_homogeneous.hdf5"

if not os.path.isfile(filename):
    deepx.generate.sequence(
        start, stop, step, dt, dx, paramset, diffusivity, stimuli, filename, reshape=reshape, use_memory=True, plot_while=True
    )

## Generate spiral wave in homogeneous conductivity

In [None]:
stop = 100_000
p1 = cardiax.stimulus.Protocol(0, 2, 1e9)
s1 = [cardiax.stimulus.linear(shape, cardiax.stimulus.Direction.SOUTH, 0.2, 20.0, p1)]
p2 = cardiax.stimulus.Protocol(35000, 2, 1e9)
s2 = [cardiax.stimulus.linear(shape, cardiax.stimulus.Direction.WEST, 0.6, 20.0, p2)]
stimuli = s1 + s2

diffusivity = jnp.ones(shape) * 0.001
filename = "data/spiral_homogeneous.hdf5"

if not os.path.isfile(filename):
    deepx.generate.sequence(
        start, stop, step, dt, dx, paramset, diffusivity, stimuli, filename, reshape=reshape, use_memory=True, plot_while=True
    )

In [None]:
%%time
stop = 100_000
p1 = cardiax.stimulus.Protocol(0, 2, 1e9)
s1 = [cardiax.stimulus.linear(shape, cardiax.stimulus.Direction.SOUTH, 0.2, 20.0, p1)]
p2 = cardiax.stimulus.Protocol(35000, 2, 1e9)
s2 = [cardiax.stimulus.linear(shape, cardiax.stimulus.Direction.WEST, 0.6, 20.0, p2)]
stimuli = s1 + s2

diffusivity = jnp.ones(shape) * 0.001
filename = "dev/null"

if not os.path.isfile(filename):
    deepx.generate.sequence(
        start, stop, step, dt, dx, paramset, diffusivity, stimuli, filename, reshape=reshape, use_memory=True, plot_while=False
    )