# Hyperparameter search
The purpose of this notebook is to setup a hyperparameter search to understand what setting lead to successful RNNs.

## Imports

In [1]:
import json
import os

import jax.numpy as jnp
from jax import random

## File and JSON generation

In [2]:
job_folder = "../../data/hyperparameter"
key = random.PRNGKey(0)

alpha = 0.1
noise = [0.00, 0.05, 0.10,]
pulse_mean = [8, 9, 10, 11, 12,]
mod_amount = [1, 2, 3, 4, 5, 6,]
mod_set = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11,]
training_trials = [3200, 6400, 9600,]
train_batch_size = 128
testing_trials = 640
lr = [0.01, 0.005, 0.001]
epochs = [500, 1000, 1500, 2000, 2500]

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
for task_id in range(384):
    task_folder = os.path.join(job_folder, f"task_{task_id}")
    os.makedirs(task_folder, exist_ok=True)

    keys = random.split(key, num=8)
    key = keys[0,:]

    mod_amt = random.permutation(keys[1,:], jnp.array(mod_amount))[0].item()
    params = {
        'seed': task_id,
        'alpha': alpha,
        'noise': round(random.permutation(keys[2,:], jnp.array(noise))[0].item(), 2),
        'pulse_mean': random.permutation(keys[3,:], jnp.array(pulse_mean))[0].item(),
        'mod_set': random.permutation(keys[4,:], jnp.array(mod_set))[:mod_amt].tolist(),
        'training_trials': random.permutation(keys[5,:], jnp.array(training_trials))[0].item(),
        'train_batch_size': train_batch_size,
        'testing_trials': testing_trials,
        'lr': round(random.permutation(keys[6,:], jnp.array(lr))[0].item(), 4),
        'epochs': random.permutation(keys[7,:], jnp.array(epochs))[0].item(),
    }
    
    json_path = os.path.join(task_folder, "params.json")
    with open(json_path, 'w') as f:
        json.dump(params, f)

In [4]:
print(task_id)

383
