# How do I train the model?
Training the model is easy as the code has been made highly modular. Train follows the following pipeline:
1. Create the SET datasets.
2. Instantiate the RNN model.
3. Call the training function.
4. Save the model and metrics history.

## Imports

In [9]:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from src.task import SETDataset
from src.model import EulerCTRNNCell
from src.training import create_train_state, train_model, serialize_parameters, save_metrics_to_csv

key = random.PRNGKey(0)

## Create SET datasets

In [10]:
key, subkey = random.split(key)
set_dataset = SETDataset(subkey, 15, 5, 5, 108)
#set_dataset.grok_SET(2)
#set_dataset.corrupt_SET(3)
set_dataset.print_training_testing()
training_tf_dataset, testing_tf_dataset, grok_tf_dataset, corrupt_tf_dataset = set_dataset.tf_datasets()


TRAINING DATA

Accepting Grid:
SET_combinations | Number of Trials | Status
GGG | 30 | 
GPR | 30 | 
GRP | 30 | 
PGR | 30 | 
PPP | 30 | 
PRG | 30 | 
RGP | 30 | 
RPG | 30 | 
RRR | 30 | 

Rejecting Grid:
SET_combinations | Number of Trials | Status
GGP | 15 | 
GGR | 15 | 
GPG | 15 | 
GPP | 15 | 
GRG | 15 | 
GRR | 15 | 
PGG | 15 | 
PGP | 15 | 
PPG | 15 | 
PPR | 15 | 
PRP | 15 | 
PRR | 15 | 
RGG | 15 | 
RGR | 15 | 
RPP | 15 | 
RPR | 15 | 
RRG | 15 | 
RRP | 15 | 

----------

TESTING DATA

Accepting Grid:
SET_combinations | Number of Trials | Status
GGG | 5 | 
GPR | 5 | 
GRP | 5 | 
PGR | 5 | 
PPP | 5 | 
PRG | 5 | 
RGP | 5 | 
RPG | 5 | 
RRR | 5 | 

Rejecting Grid:
SET_combinations | Number of Trials | Status
GGP | 5 | 
GGR | 5 | 
GPG | 5 | 
GPP | 5 | 
GRG | 5 | 
GRR | 5 | 
PGG | 5 | 
PGP | 5 | 
PPG | 5 | 
PPR | 5 | 
PRP | 5 | 
PRR | 5 | 
RGG | 5 | 
RGR | 5 | 
RPP | 5 | 
RPR | 5 | 
RRG | 5 | 
RRP | 5 | 

----------

GROK DATA

Accepting Grid:
SET_combinations | Number of Trials | Status

Reje

## Instantiate RNN model

In [3]:
features = 100
alpha = jnp.float32(1.0)
noise = jnp.float32(0.1)

ctrnn = nn.RNN(EulerCTRNNCell(features=features, alpha=alpha, noise=noise,))

## Train model

In [4]:
lr = 0.0001
epochs = 1000

In [5]:
key, subkey = random.split(key)
state = create_train_state(ctrnn, subkey, lr,)

In [6]:
key, subkey = random.split(key)
trained_state, metrics_history = train_model(
    subkey, 
    state, 
    training_tf_dataset, 
    testing_tf_dataset, 
    grok_tf_dataset, 
    corrupt_tf_dataset, 
    epochs,
)

Metrics after epoch 0:
train_loss: 1.0897804498672485
train_accuracy: 0.05740740895271301
test_loss: 1.1486366987228394
test_accuracy: 0.06666666269302368
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None


Metrics after epoch 50:
train_loss: 0.9749795794487
train_accuracy: 0.025925925001502037
test_loss: 1.0523685216903687
test_accuracy: 0.029629629105329514
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None


Metrics after epoch 100:
train_loss: 0.9664098024368286
train_accuracy: 0.024074073880910873
test_loss: 1.0609228610992432
test_accuracy: 0.04444444179534912
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None


Metrics after epoch 150:
train_loss: 0.9361637830734253
train_accuracy: 0.07222221791744232
test_loss: 1.0967377424240112
test_accuracy: 0.06666666269302368
grok_loss: None
grok_accuracy: None
corrupt_loss: None
corrupt_accuracy: None


Metrics after epoch 200:
train_loss: 0.8980261087417603

Simple right?

## Save model and metrics

In [7]:
save_loc = '../results/script_examples/params.bin'
params = {'params': trained_state.params}
serialize_parameters(params, save_loc)

In [8]:
save_loc = '../results/script_examples/metrics_history.csv'
save_metrics_to_csv(metrics_history, save_loc)