# Hyperparameter search v2
This hyperparameter search is an iteration on v1. This search now assesses the accuracy of models on trials of length 500.

## Imports

In [1]:
import json
import os

import jax.numpy as jnp
from jax import random

## File and JSON generation

In [2]:
job_folder = "../../data/hyperparameter_2"
key = random.PRNGKey(2)

alpha = [0.1]
noise = [0.05]
pulse_mean = 12
mod_amount = [1, 1, 1, 2, 2, 3,]
mod_set = [3, 4, 5, 6, 7, 8,]
training_trials = [3200, 6400,]
train_batch_size = 128
testing_trials = 640
lr = [0.01, 0.005,]
epochs = [1500, 2000,]
weight_decay = [0.001, 0.0001]
l2_penalty = [0.01, 0.001, 0.0001]

In [3]:
for task_id in range(384):
    task_folder = os.path.join(job_folder, f"task_{task_id}")
    os.makedirs(task_folder, exist_ok=True)

    keys = random.split(key, num=11)
    key = keys[0,:]

    mod_amt = random.permutation(keys[1,:], jnp.array(mod_amount))[0].item()
    params = {
        'seed': task_id,
        'alpha': round(random.permutation(keys[2,:], jnp.array(alpha))[0].item(), 2),
        'noise': round(random.permutation(keys[3,:], jnp.array(noise))[0].item(), 2),
        'pulse_mean': pulse_mean,
        'mod_set': random.permutation(keys[4,:], jnp.array(mod_set))[:mod_amt].tolist(),
        'training_trials': random.permutation(keys[5,:], jnp.array(training_trials))[0].item(),
        'train_batch_size': train_batch_size,
        'testing_trials': testing_trials,
        'lr': round(random.permutation(keys[6,:], jnp.array(lr))[0].item(), 4),
        'epochs': random.permutation(keys[7,:], jnp.array(epochs))[0].item(),
        'weight_decay': round(random.permutation(keys[8,:], jnp.array(weight_decay))[0].item(), 6),
        'l2_penalty': round(random.permutation(keys[9,:], jnp.array(l2_penalty))[0].item(), 6),
        'trial_length': 100,
    }
    
    json_path = os.path.join(task_folder, "params.json")
    with open(json_path, 'w') as f:
        json.dump(params, f)

In [4]:
print(task_id)

383
