In [1]:
%load_ext autoreload
%autoreload 2

# Testing ModularArithmeticTask
The purpose of this notebook is to test the features and functionality of the ModularArithmeticTask.

## Imports

In [2]:
from functools import partial

import jax
import jax.numpy as jnp
from jax import random
from modularRNN.task import ModularArithmeticTask

import matplotlib.pyplot as plt

## Setup class

In [3]:
key = random.PRNGKey(0)
training_trials = 12800
testing_trials = 1280
train_batch_size = 128
mod_set = jnp.array([2,3,4,7,9])
pulse_distribution = partial(random.poisson, lam=12)

modtask = ModularArithmeticTask(key, training_trials, testing_trials, train_batch_size, mod_set, pulse_distribution,)

## Test methods

In [4]:
# Test the trial_distribution probability function implementation
print(modtask.pulse_distribution(modtask.generate_subkey()))
print(modtask.pulse_distribution(modtask.generate_subkey()))

13
11


In [5]:
# Test the sample_modular_value method
print(modtask.sample_modular_value())
print(modtask.sample_modular_value())

4
9


In [6]:
# Test generate_pulse_indicies method
pulse_indicies_0 = modtask.generate_pulse_indicies(13)
print(pulse_indicies_0)
print(modtask.test_pulse_indicies(pulse_indicies_0))
pulse_indicies_1 = modtask.generate_pulse_indicies(10)
print(pulse_indicies_1)
print(modtask.test_pulse_indicies(pulse_indicies_1))

[ 9 15 18 20 25 32 34 53 61 71 76 86 95]
True
[23 27 36 42 49 55 59 69 83 91]
True


In [7]:
# Test generate_pulse_values method
print(modtask.generate_pulse_values(12, 8))
print(modtask.generate_pulse_values(10, 5))

[5 6 6 1 6 3 6 2 4 4 3 3]
[1 2 4 1 4 4 3 2 1 4]


In [8]:
# Test create_pulses_and_cumulative_mod method
pulse_indicies = jnp.array([4,8,13,28,31,39,42,84,90,94])
pulse_values = jnp.array([2,2,1,3,3,4,1,3,4,4])
pulses, cumulative_mod = modtask.create_pulses_and_cumulative_mod(pulse_indicies, pulse_values, 5)
print(pulses)

[0 0 0 0 2 0 0 0 2 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 3 0 0 0 0 0
 0 0 4 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 4 0 0 0 4 0 0 0 0 0]


In [9]:
# Test create_input_output_tensors method
input_tensor, output_tensor = modtask.create_input_output_tensors(pulses, cumulative_mod, 5)
print(input_tensor.shape)
print(output_tensor.shape)

(100, 20)
(100, 10)


In [10]:
# Test generate_task_trial method
input_tensor, output_tensor = modtask.generate_task_trial()
print(input_tensor.shape)
print(output_tensor.shape)

(100, 20)
(100, 10)


In [11]:
# Test generate_trials method
inputs_tensor, outputs_tensor = modtask.generate_trials(6400)
print(inputs_tensor.shape)
print(outputs_tensor.shape)

100%|███████████████████████████████████████| 6400/6400 [03:38<00:00, 29.31it/s]


(6400, 100, 20)
(6400, 100, 10)
