# How do I train the model?
Training the model is easy as the code has been made highly modular. Train follows the following pipeline:
1. Create the SET datasets.
2. Instantiate the RNN model.
3. Call the training function.
4. Save the model and metrics history.

## Imports

In [9]:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from src.task import SETDataset
from src.model import EulerCTRNNCell
from src.training import create_train_state, train_model

key = random.PRNGKey(0)

## Create SET datasets

In [10]:
key, subkey = random.split(key)
set_dataset = SETDataset(subkey, 30, 5, 108)
#set_dataset.grok_SET(2)
#set_dataset.corrupt_SET(3)
set_dataset.print_training_testing()
training_tf_dataset, testing_tf_dataset, grok_tf_dataset, corrupt_tf_dataset = set_dataset.tf_datasets()


TRAINING DATA

Accepting Grid:
SET_combinations | Number of Trials | Status
GGG | 60 | 
GPR | 60 | 
GRP | 60 | 
PGR | 60 | 
PPP | 60 | 
PRG | 60 | 
RGP | 60 | 
RPG | 60 | 
RRR | 60 | 

Rejecting Grid:
SET_combinations | Number of Trials | Status
GGP | 30 | 
GGR | 30 | 
GPG | 30 | 
GPP | 30 | 
GRG | 30 | 
GRR | 30 | 
PGG | 30 | 
PGP | 30 | 
PPG | 30 | 
PPR | 30 | 
PRP | 30 | 
PRR | 30 | 
RGG | 30 | 
RGR | 30 | 
RPP | 30 | 
RPR | 30 | 
RRG | 30 | 
RRP | 30 | 

----------

TESTING DATA

Accepting Grid:
SET_combinations | Number of Trials | Status
GGG | 5 | 
GPR | 5 | 
GRP | 5 | 
PGR | 5 | 
PPP | 5 | 
PRG | 5 | 
RGP | 5 | 
RPG | 5 | 
RRR | 5 | 

Rejecting Grid:
SET_combinations | Number of Trials | Status
GGP | 5 | 
GGR | 5 | 
GPG | 5 | 
GPP | 5 | 
GRG | 5 | 
GRR | 5 | 
PGG | 5 | 
PGP | 5 | 
PPG | 5 | 
PPR | 5 | 
PRP | 5 | 
PRR | 5 | 
RGG | 5 | 
RGR | 5 | 
RPP | 5 | 
RPR | 5 | 
RRG | 5 | 
RRP | 5 | 

----------

GROK DATA

Accepting Grid:
SET_combinations | Number of Trials | Status

Reje

## Instantiate RNN model

In [3]:
features = 100
alpha = jnp.float32(1.0)
noise = jnp.float32(0.1)

ctrnn = nn.RNN(EulerCTRNNCell(features=features, alpha=alpha, noise=noise,))

## Train model

In [4]:
lr = 0.001
epochs = 250

In [5]:
key, subkey = random.split(key)
state = create_train_state(ctrnn, subkey, lr,)

In [6]:
key, subkey = random.split(key)
model_params, metrics_history = train_model(
    subkey, 
    state, 
    training_tf_dataset, 
    testing_tf_dataset, 
    grok_tf_dataset, 
    corrupt_tf_dataset, 
    epochs,
)

Metrics after epoch 50:
train_loss: 0.9260943531990051
train_accuracy: 0.059259265661239624
test_loss: 1.0503238439559937
test_accuracy: 0.014814814552664757
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None


Metrics after epoch 100:
train_loss: 0.006110505200922489
train_accuracy: 0.9981481432914734
test_loss: 0.03263946622610092
test_accuracy: 0.9481481313705444
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None


Metrics after epoch 150:
train_loss: 0.0028154884930700064
train_accuracy: 0.9990741014480591
test_loss: 0.002217269502580166
test_accuracy: 1.0
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None


Metrics after epoch 200:
train_loss: 0.0015888254856690764
train_accuracy: 1.0
test_loss: 0.002964671002700925
test_accuracy: 0.9925925731658936
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None


Metrics after epoch 250:
train_loss: 0.0012388726463541389
train_accuracy:

Simple right?

## Save model and metrics

In [7]:
save_loc = '../results/script_examples/params.bin'
model_params.serialize(save_loc)

In [8]:
save_loc = '../results/script_examples/metrics_history.csv'
metrics_history.save_to_csv(save_loc)