# [Neural CDE](https://docs.kidger.site/diffrax/examples/neural_cde/)
Neural CDE は次のような式で表現されるモデルである。

$$y(t) = y(0) + \int_0^t f_\theta(y(s)) \frac{\mathrm{d}x}{\mathrm{d}s}(s) \mathrm{d}s$$

ここでは、 Neural CDE を用いて時計回りの渦と、反時計回りの渦の分類を行う。

In [1]:
import json
import time
import datetime
import gc
from typing import Sequence, Tuple, Union, Callable, Optional

import jax
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray
import optax  # https://github.com/deepmind/optax
import equinox as eqx
from memray import Tracker

from tools._dataset.datasets import MNISTStrokeDataset
from tools._dataset.dataloader import dataloader_ununiformed_sequence
from tools._model.neural_cde import NeuralCDE
from tools._model.discrete_cde import RNN
from tools._loss.cross_entropy import bce_loss, nll_loss
from tools.config import ExperimentConfig
from tools._train.update import Trainer, run_train
from tools._eval.evaluation import run_eval



# Train and Eval

In [2]:
def experiment(
    *,
    dataset_size_train: int,
    dataset_size_test: int,
    noise_ratio: float,
    input_format: str,
    interpolation: str,
    neural_model_name: str,
    out_size: int,
    hidden_size: int,
    width_size: int,
    depth: int,
    batch_size: int,
    lr: float,
    steps: int,
    seed: int,
    output_model_path: str,
) -> Tuple[float, float, float]:
    key = jr.PRNGKey(seed)
    train_data_key, test_data_key, train_model_key, test_model_key, train_run_key, test_run_key = jr.split(key, 6)
    experiment_result = {}

    # Load the train dataset
    dataset = MNISTStrokeDataset(dataset_size=dataset_size_train, mode_train=True, input_format=input_format, noise_ratio=noise_ratio, interpolation=interpolation, key=train_data_key)
    ts, _, coeffs, labels, in_size = dataset.make_dataset()

    # Initialize the model
    if neural_model_name == 'NeuralCDE':
        Model = NeuralCDE
    elif neural_model_name == 'RNN':
        Model = RNN
    model = Model(in_size, out_size, hidden_size, width_size, depth, interpolation=interpolation, key=train_model_key)

    # Choice loss function
    if out_size == 2:
        loss_func = bce_loss
    elif out_size > 2:
        loss_func = nll_loss
    else:
        raise ValueError(f'The `out_size` must be greater than or equal to 2. But now {out_size}')

    # Training
    model, train_time_avg = run_train(steps, batch_size, (ts, *coeffs, labels), model, loss_func, optax.adam(lr), out_size, key=train_run_key)
    experiment_result['train_time_avg'] = float(train_time_avg)

    # Save the model
    eqx.tree_serialise_leaves(output_model_path, model)

    # Clear caches
    del model, dataset, ts, coeffs, labels
    eqx.clear_caches()
    jax.clear_caches()
    gc.collect()
    
    # Load the test dataset
    dataset = MNISTStrokeDataset(dataset_size=dataset_size_test, mode_train=False, input_format=input_format, noise_ratio=noise_ratio, interpolation=interpolation, key=test_data_key)
    ts, _, coeffs, labels, _ = dataset.make_dataset()

    # Load the trained model
    model = eqx.filter_eval_shape(Model, in_size, out_size, hidden_size, width_size, depth, interpolation=interpolation, key=test_model_key)
    model = eqx.tree_deserialise_leaves(output_model_path, model)

    # Evaluation
    test_loss_avg, test_acc_avg, test_time_avg = run_eval((ts, *coeffs, labels), model, loss_func, out_size, key=test_run_key)
    experiment_result['test_loss_avg'] = float(test_loss_avg)
    experiment_result['test_acc_avg'] = float(test_acc_avg)
    experiment_result['test_time_avg'] = float(test_time_avg)
    
    return experiment_result

In [5]:
def main() -> None:
    eqx.clear_caches()
    jax.clear_caches()
    gc.collect()
    
    config = ExperimentConfig()
    
    config.steps = 10000
    config.dataset_size_train = -1
    config.dataset_size_test = -1
    
    config.noise_ratio = 0.75
    config.neural_model_name = 'NeuralCDE'
    
    date = datetime.datetime.now().strftime('%Y年%m月%d日-%H:%M:%S')
    config.output_model_name = f'/{config.neural_model_name}-{date}'

    include_fields = [key for key in config.model_dump() if not 'output_' in key] + ['output_model_path',]
    experiment_condition = config.model_dump(include={*include_fields})
    with Tracker(config.output_memray_path):
        experiment_result = experiment(**experiment_condition)

    with open(config.output_config_path, "w") as o:
        print(config.model_dump_json(indent=4), file=o)

    experiment_result['date'] = date
    experiment_result['neural_model_name'] = config.neural_model_name
    experiment_result['interpolation'] = config.interpolation
    experiment_result['noise_ratio'] = config.noise_ratio

    with open(config.output_result_path, mode="w", encoding="utf-8") as o:
        json.dump(experiment_result, o, ensure_ascii=False, indent=4)

In [6]:
main()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60000/60000 [1:07:11<00:00, 14.88it/s]


Step: 0, Loss: 10.188117980957031, Accuracy: 0.0625, Computation time: 4.353410005569458
Step: 1, Loss: 7.19356107711792, Accuracy: 0.0, Computation time: 4.817702770233154
Step: 2, Loss: 3.3303074836730957, Accuracy: 0.0625, Computation time: 4.23955512046814
Step: 3, Loss: 3.830678701400757, Accuracy: 0.21875, Computation time: 4.750100135803223
Step: 4, Loss: 3.833340883255005, Accuracy: 0.09375, Computation time: 0.566748857498169
Step: 5, Loss: 3.784780263900757, Accuracy: 0.09375, Computation time: 4.225117206573486
Step: 6, Loss: 3.485246181488037, Accuracy: 0.0625, Computation time: 4.745680809020996
Step: 7, Loss: 4.207230567932129, Accuracy: 0.0625, Computation time: 0.448012113571167
Step: 8, Loss: 3.5348446369171143, Accuracy: 0.15625, Computation time: 4.707114934921265
Step: 9, Loss: 2.786353826522827, Accuracy: 0.0625, Computation time: 0.44878268241882324
Step: 10, Loss: 2.901895761489868, Accuracy: 0.15625, Computation time: 4.876110076904297
Step: 11, Loss: 3.08238458

jax.pure_callback failed
Traceback (most recent call last):
  File "/Users/tomoki.fujihara/Desktop/test_diffrax/test_diffrax/.venv/lib/python3.11/site-packages/jax/_src/callback.py", line 77, in pure_callback_impl
    return callback(*args)
           ^^^^^^^^^^^^^^^
  File "/Users/tomoki.fujihara/Desktop/test_diffrax/test_diffrax/.venv/lib/python3.11/site-packages/jax/_src/callback.py", line 65, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tomoki.fujihara/Desktop/test_diffrax/test_diffrax/.venv/lib/python3.11/site-packages/equinox/_errors.py", line 70, in raises
    raise EqxRuntimeError(msgs[_index.item()])
equinox._errors.EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.


ValueError: not enough values to unpack (expected 2, got 1)