# [Neural CDE](https://docs.kidger.site/diffrax/examples/neural_cde/)
Neural CDE は次のような式で表現されるモデルである。

$$y(t) = y(0) + \int_0^t f_\theta(y(s)) \frac{\mathrm{d}x}{\mathrm{d}s}(s) \mathrm{d}s$$

ここでは、 Neural CDE を用いて時計回りの渦と、反時計回りの渦の分類を行う。

In [5]:
import json
import time
import datetime
import gc
from typing import Sequence, Tuple, Union, Callable, Optional

import jax
import jax.random as jr
from jaxtyping import Array, Float, PRNGKeyArray
import optax  # https://github.com/deepmind/optax
import equinox as eqx
from memray import Tracker

from tools._dataset.datasets import MNISTStrokeDataset
from tools._dataset.dataloader import dataloader_ununiformed_sequence
from tools._model.neural_cde import NeuralCDE
from tools._model.discrete_cde import RNN
from tools._loss.cross_entropy import bce_loss, nll_loss
from tools.config import ExperimentConfig
from tools._train.update import Trainer, run_train
from tools._eval.evaluation import run_eval

# Train and Eval

In [6]:
def experiment(
    *,
    dataset_size_train: int,
    dataset_size_test: int,
    noise_ratio: float,
    input_format: str,
    interpolation: str,
    neural_model_name: str,
    out_size: int,
    hidden_size: int,
    width_size: int,
    depth: int,
    batch_size: int,
    lr: float,
    steps: int,
    seed: int,
    output_model_path: str,
) -> Tuple[float, float, float]:
    key = jr.PRNGKey(seed)
    train_data_key, test_data_key, train_model_key, test_model_key, train_run_key, test_run_key = jr.split(key, 6)
    experiment_result = {}

    # Load the train dataset
    dataset = MNISTStrokeDataset(dataset_size=dataset_size_train, mode_train=True, input_format=input_format, noise_ratio=noise_ratio, interpolation=interpolation, key=train_data_key)
    ts, _, coeffs, labels, in_size = dataset.make_dataset()

    # Initialize the model
    if neural_model_name == 'NeuralCDE':
        Model = NeuralCDE
    elif neural_model_name == 'RNN':
        Model = RNN
    model = Model(in_size, out_size, hidden_size, width_size, depth, interpolation=interpolation, key=train_model_key)

    # Choice loss function
    if out_size == 2:
        loss_func = bce_loss
    elif out_size > 2:
        loss_func = nll_loss
    else:
        raise ValueError(f'The `out_size` must be greater than or equal to 2. But now {out_size}')

    # Training
    model, train_time_avg = run_train(steps, batch_size, (ts, *coeffs, labels), model, loss_func, optax.adam(lr), out_size, key=train_run_key)
    experiment_result['train_time_avg'] = float(train_time_avg)

    # Save the model
    eqx.tree_serialise_leaves(output_model_path, model)

    # Clear caches
    del model, dataset, ts, coeffs, labels
    eqx.clear_caches()
    jax.clear_caches()
    gc.collect()
    
    # Load the test dataset
    dataset = MNISTStrokeDataset(dataset_size=dataset_size_test, mode_train=False, input_format=input_format, noise_ratio=noise_ratio, interpolation=interpolation, key=test_data_key)
    ts, _, coeffs, labels, _ = dataset.make_dataset()

    # Load the trained model
    model = eqx.filter_eval_shape(Model, in_size, out_size, hidden_size, width_size, depth, interpolation=interpolation, key=test_model_key)
    model = eqx.tree_deserialise_leaves(output_model_path, model)

    # Evaluation
    test_loss_avg, test_acc_avg, test_time_avg = run_eval((ts, *coeffs, labels), model, loss_func, out_size, key=test_run_key)
    experiment_result['test_loss_avg'] = float(test_loss_avg)
    experiment_result['test_acc_avg'] = float(test_acc_avg)
    experiment_result['test_time_avg'] = float(test_time_avg)
    
    return experiment_result

In [7]:
def main() -> None:
    #%%memray_flamegraph
    eqx.clear_caches()
    jax.clear_caches()
    gc.collect()
    
    config = ExperimentConfig()
    
    config.steps = 10000
    config.dataset_size_train = -1
    config.dataset_size_test = -1
    
    config.noise_ratio = 0.5
    config.neural_model_name = 'NeuralCDE'
    
    date = datetime.datetime.now().strftime('%Y年%m月%d日-%H:%M:%S')
    config.output_model_name = f'/{config.neural_model_name}-{date}'

    include_fields = [key for key in config.model_dump() if not 'output_' in key] + ['output_model_path',]
    experiment_condition = config.model_dump(include={*include_fields})
    with Tracker(config.output_memray_path):
        experiment_result = experiment(**experiment_condition)

    with open(config.output_config_path, "w") as o:
        print(config.model_dump_json(indent=4), file=o)

    experiment_result['date'] = date
    experiment_result['neural_model_name'] = config.neural_model_name
    experiment_result['interpolation'] = config.interpolation
    experiment_result['noise_ratio'] = config.noise_ratio

    with open(config.output_result_path, mode="w", encoding="utf-8") as o:
        json.dump(experiment_result, o, ensure_ascii=False, indent=4)

In [8]:
main()

  0%|                                                              | 48/60000 [00:10<2:24:18,  6.92it/s]2024-05-06 12:37:05.711014: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 15m49.192389s

********************************
[Compiling module jit__shuffle] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
100%|███████████████████████████████████████████████████████████| 60000/60000 [3:07:10<00:00,  5.34it/s]


Step: 0, Loss: 8.166421890258789, Accuracy: 0.09375, Computation time: 4.532424211502075
Step: 1, Loss: 6.043856620788574, Accuracy: 0.125, Computation time: 4.851641893386841
Step: 2, Loss: 2.9999325275421143, Accuracy: 0.125, Computation time: 4.959268093109131
Step: 3, Loss: 3.5801517963409424, Accuracy: 0.15625, Computation time: 4.858425140380859
Step: 4, Loss: 3.2347958087921143, Accuracy: 0.25, Computation time: 0.8257510662078857
Step: 5, Loss: 3.60296630859375, Accuracy: 0.125, Computation time: 5.070477724075317
Step: 6, Loss: 2.9986965656280518, Accuracy: 0.125, Computation time: 4.378866910934448
Step: 7, Loss: 3.673431634902954, Accuracy: 0.0625, Computation time: 0.5008561611175537
Step: 8, Loss: 3.1192572116851807, Accuracy: 0.125, Computation time: 4.413270711898804
Step: 9, Loss: 3.1442630290985107, Accuracy: 0.03125, Computation time: 0.5458641052246094
Step: 10, Loss: 3.2677061557769775, Accuracy: 0.09375, Computation time: 5.133293151855469
Step: 11, Loss: 2.4693684

100%|█████████████████████████████████████████████████████████████| 10000/10000 [05:59<00:00, 27.85it/s]


  0%|          | 0/10000 [00:00<?, ?it/s]

Test loss: 0.5493088960647583, Test Accuracy: 0.8355000019073486, Computation time: 0.008337065577507019
