# [Neural CDE](https://docs.kidger.site/diffrax/examples/neural_cde/)
Neural CDE は次のような式で表現されるモデルである。

$$y(t) = y(0) + \int_0^t f_\theta(y(s)) \frac{\mathrm{d}x}{\mathrm{d}s}(s) \mathrm{d}s$$

ここでは、 Neural CDE を用いて時計回りの渦と、反時計回りの渦の分類を行う。

In [4]:
import math
import time
from typing import Sequence, Tuple, Union, Callable, Optional

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.lax as lax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
from jaxtyping import Array, Float, PRNGKeyArray
import matplotlib
import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax
from memory_profiler import memory_usage

from tools._dataset.datasets import MNISTStrokeDataset
from tools._dataset.dataloader import dataloader_ununiformed_sequence
from tools._model.neural_cde import NeuralCDE
from tools._loss.cross_entropy import nll_loss

In [5]:
%load_ext memray

# Train and Eval

In [6]:
def main(
    dataset_size=128,
    noise_ratio=0.1,
    input_format='point_sequence',
    interpolation='cubic',
    out_size=10,
    batch_size=32,
    lr=3e-3,
    steps=50,
    hidden_size=8,
    width_size=128,
    depth=3,
    seed=5678,
):
    key = jr.PRNGKey(seed)
    train_data_key, test_data_key, model_key, loader_key, test_loss_key = jr.split(key, 5)

    dataset = MNISTStrokeDataset(dataset_size=dataset_size, mode_train=True, input_format=input_format, noise_ratio=noise_ratio, interpolation=interpolation, key=train_data_key)
    ts, _, coeffs, labels, in_size = dataset.make_dataset()
    
    model = NeuralCDE(in_size, out_size, hidden_size, width_size, depth, interpolation=interpolation, key=model_key)

    # Training loop like normal.

    grad_loss = eqx.filter_value_and_grad(nll_loss, has_aux=True)

    @eqx.filter_jit
    def make_step(model: eqx.Module, data: Array, opt_state: Tuple, *args, key:PRNGKeyArray) -> Tuple[Float, Float, eqx.Module, Tuple]:
        ts, *coeffs, labels = data
        (xe, acc), grads = grad_loss(model, (ts, coeffs), labels, *args, key=key)
        updates, opt_state = optim.update(grads, opt_state, model)
        model = eqx.apply_updates(model, updates)
        return xe, acc, model, opt_state

    model = eqx.nn.inference_mode(model, value=False)
    optim = optax.adam(lr)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

    loader_key, _loader_key = jr.split(loader_key, 2)
    for step, data in zip(
        range(steps), dataloader_ununiformed_sequence(((ts, *coeffs), labels), batch_size, key=_loader_key)
    ):
        loader_key, _loader_key = jr.split(loader_key, 2)
        start = time.time()
        xe, acc, model, opt_state = make_step(model, data, opt_state, out_size, key=_loader_key)
        end = time.time()
        print(
            f"Step: {step}, Loss: {xe}, Accuracy: {acc}, Computation time: "
            f"{end - start}"
        )
        
    model = eqx.nn.inference_mode(model)
    dataset = MNISTStrokeDataset(dataset_size=256, mode_train=False, input_format=input_format, noise_ratio=noise_ratio, interpolation=interpolation, key=test_data_key)
    ts, _, coeffs, labels, _ = dataset.make_dataset()
    list_xe = []
    list_acc = []
    for data in zip(ts, *coeffs, labels):
        data = [jnp.expand_dims(x, axis=0) for x in data]
        _ts, *_coeffs, _label = data
        xe, acc = nll_loss(model, (_ts, _coeffs), _label, out_size, key=test_loss_key)
        list_xe.append(lax.stop_gradient(xe))
        list_acc.append(lax.stop_gradient(acc))
    print(f"Test loss: {jnp.mean(jnp.array(list_xe))}, Test Accuracy: {jnp.mean(jnp.array(list_acc))}")

In [7]:
#%%memray_flamegraph
eqx.clear_caches()
jax.clear_caches()
main()

100%|█████████████████████████████████████████| 128/128 [00:08<00:00, 15.68it/s]


Step: 0, Loss: 3.394502639770508, Accuracy: 0.03125, Computation time: 3.0246801376342773
Step: 1, Loss: 2.6009562015533447, Accuracy: 0.03125, Computation time: 3.0226399898529053


100%|█████████████████████████████████████████| 256/256 [00:05<00:00, 48.49it/s]


Test loss: 3.051206111907959, Test Accuracy: 0.125
