# How do I train the model?
Training the model is easy as the code has been made highly modular. Train follows the following pipeline:
1. Create the SET datasets.
2. Instantiate the RNN model.
3. Call the training function.
4. Save the model and metrics history.

## Imports

In [9]:
import jax
import jax.numpy as jnp
from jax import random
from flax import linen as nn
from src.task import SETDataset
from src.model import EulerCTRNNCell
from src.training import create_train_state, train_model, serialize_parameters, save_metrics_to_csv

key = random.PRNGKey(0)

## Create SET datasets

In [10]:
key, subkey = random.split(key)
set_dataset = SETDataset(subkey, 15, 5, 5, 32)
set_dataset.grok_SET(2)
set_dataset.corrupt_SET(3)
set_dataset.print_training_testing()
training_tf_dataset, testing_tf_dataset, grok_tf_dataset, corrupt_tf_dataset = set_dataset.tf_datasets()


TRAINING DATA

Accepting Grid:
SET_combinations | Number of Trials | Status
GGG | 32 | 
GGR | 32 | Corrupted
GPR | 32 | 
PGR | 32 | 
PPP | 32 | 
RGG | 32 | Corrupted
RGP | 32 | 
RPG | 32 | 

Rejecting Grid:
SET_combinations | Number of Trials | Status
GGP | 15 | 
GPG | 15 | 
GPP | 15 | 
GRG | 15 | 
GRR | 15 | 
PGG | 15 | 
PGP | 15 | 
PPG | 15 | 
PPR | 15 | 
PRG | 15 | Corrupted
PRP | 15 | 
PRR | 15 | 
RGR | 15 | 
RPP | 15 | 
RPR | 15 | 
RRG | 15 | 
RRP | 15 | 

----------

TESTING DATA

Accepting Grid:
SET_combinations | Number of Trials | Status
GGG | 5 | 
GGR | 5 | Corrupted
GPR | 5 | 
GRP | 5 | Grokked
PGR | 5 | 
PPP | 5 | 
RGG | 5 | Corrupted
RGP | 5 | 
RPG | 5 | 
RRR | 5 | Grokked

Rejecting Grid:
SET_combinations | Number of Trials | Status
GGP | 5 | 
GPG | 5 | 
GPP | 5 | 
GRG | 5 | 
GRR | 5 | 
PGG | 5 | 
PGP | 5 | 
PPG | 5 | 
PPR | 5 | 
PRG | 5 | Corrupted
PRP | 5 | 
PRR | 5 | 
RGR | 5 | 
RPP | 5 | 
RPR | 5 | 
RRG | 5 | 
RRP | 5 | 

----------

GROK DATA

Accepting Grid:
SET_co

## Instantiate RNN model

In [3]:
features = 100
alpha = jnp.float32(0.1)
noise = jnp.float32(0.1)

ctrnn = nn.RNN(EulerCTRNNCell(features=features, alpha=alpha, noise=noise,))

## Train model

In [4]:
lr = 0.001
epochs = 25

In [5]:
key, subkey = random.split(key)
state = create_train_state(ctrnn, subkey, lr,)

In [6]:
key, subkey = random.split(key)
trained_state, metrics_history = train_model(
    subkey, 
    state, 
    training_tf_dataset, 
    testing_tf_dataset, 
    grok_tf_dataset, 
    corrupt_tf_dataset, 
    epochs,
)

Metrics after epoch 1:
train_loss: 0.9106971025466919
train_accuracy: 0.13750000298023224
test_loss: 1.1535935401916504
test_accuracy: 0.08888888359069824
grok_loss: 1.5786027908325195
grok_accuracy: 0.0
corrupt_loss: 0.5053710341453552
corrupt_accuracy: 0.3333333432674408


Metrics after epoch 2:
train_loss: 0.8939483165740967
train_accuracy: 0.13958333432674408
test_loss: 0.9020090103149414
test_accuracy: 0.16296295821666718
grok_loss: 2.017728567123413
grok_accuracy: 0.0
corrupt_loss: 0.5072855949401855
corrupt_accuracy: 0.06666667014360428


Metrics after epoch 3:
train_loss: 0.8865842223167419
train_accuracy: 0.19583334028720856
test_loss: 0.9001325964927673
test_accuracy: 0.1111111119389534
grok_loss: 1.7196682691574097
grok_accuracy: 0.0
corrupt_loss: 0.4213934540748596
corrupt_accuracy: 0.06666667014360428


Metrics after epoch 4:
train_loss: 0.8553716540336609
train_accuracy: 0.14791665971279144
test_loss: 0.9416314363479614
test_accuracy: 0.08888888359069824
grok_loss: 1.7057

Simple right?

## Save model and metrics

In [7]:
save_loc = '../results/script_examples/params.bin'
params = {'params': trained_state.params}
serialize_parameters(params, save_loc)

In [8]:
save_loc = '../results/script_examples/metrics_history.csv'
save_metrics_to_csv(metrics_history, save_loc)