# How do I train the model?
Training the model is easy as the code has been made highly modular. Train follows the following pipeline:
1. Create the SET datasets.
2. Instantiate the RNN model.
3. Call the training function.
4. Save the model and metrics history.

## Imports

In [7]:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from src.task import SETDataset
from src.model import EulerCTRNNCell
from src.training import create_train_state, train_model, serialize_parameters, deserialize_parameters

key = random.PRNGKey(0)

## Create SET datasets

In [8]:
key, subkey = random.split(key)
set_dataset = SETDataset(subkey, 15, 5, 5, 32)
training_tf_dataset, testing_tf_dataset, grok_tf_dataset, corrupt_tf_dataset = set_dataset.tf_datasets()

## Instantiate RNN model

In [9]:
features = 100
alpha = jnp.float32(0.1)
noise = jnp.float32(0.1)

ctrnn = nn.RNN(EulerCTRNNCell(features=features, alpha=alpha, noise=noise,))

## Train model

In [10]:
lr = 0.001
epochs = 15

In [11]:
key, subkey = random.split(key)
state = create_train_state(ctrnn, subkey, lr,)

In [12]:
key, subkey = random.split(key)
trained_state, metrics_history = train_model(
    subkey, 
    state, 
    training_tf_dataset, 
    testing_tf_dataset, 
    grok_tf_dataset, 
    corrupt_tf_dataset, 
    epochs,
)

Metrics after epoch 1:
train_loss: 1.017378807067871
train_accuracy: 0.0
test_loss: 0.9785944223403931
test_accuracy: 0.0
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None
Metrics after epoch 2:
train_loss: 0.9955981969833374
train_accuracy: 0.0
test_loss: 1.079990029335022
test_accuracy: 0.0
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None
Metrics after epoch 3:
train_loss: 0.9964108467102051
train_accuracy: 0.0
test_loss: 0.9241047501564026
test_accuracy: 0.0
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None
Metrics after epoch 4:
train_loss: 1.0010318756103516
train_accuracy: 0.0
test_loss: 1.0197197198867798
test_accuracy: 0.0
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None
Metrics after epoch 5:
train_loss: 0.9834483861923218
train_accuracy: 0.0
test_loss: 0.9421354532241821
test_accuracy: 0.0
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: No

Simple right?