# PCA animation
Aesthetically pleasing GitHub repos tend to have a nice animation on the initial `README.md`. In this notebook, we create an animation of the phase code solution.

## Imports

In [1]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.gridspec import GridSpec

import jax
import jax.numpy as jnp
from jax import random

import scipy.signal
import scipy.stats

from ctrnn_jax.model import CTRNNCell
from ctrnn_jax.training import ModelParameters, create_train_state
from ctrnn_jax.pca import compute_pca

from emergence_phase_codes.model import initialize_ctrnn
from emergence_phase_codes.task import ModuloNArithmetic
from emergence_phase_codes.pca import add_data

from emergence_phase_codes.animations.pca import PCATrajectoryAnimator
from emergence_phase_codes.animations.utils import interpolate_colors

In [2]:
key = random.PRNGKey(69)

## Configure parameters

In [3]:
# Modulo3Arithmetic task parameters
BATCH_SIZE_M3A = 16

MOD = 3
CONGRUENT_NUMBER = 0
TIME_LENGTH = 50
NUM_TRIALS = 2500

TRIALS_PER_POS = 100
TRIALS_PER_NEG = 50

PULSE_CONFIG = {
    "num_pulses": 3,
    "pulse_window": 35,
    "pulse_buffer": 5,
    "pulse_gap": 5,
    "pulse_amplitude": 5,
}

INIT_ARRAY_M3A = jnp.ones([BATCH_SIZE_M3A, TIME_LENGTH, MOD])

In [4]:
# Configure model parameters
HIDDEN_FEATURES = 100
OUTPUT_FEATURES = 1
ALPHA = 1.0
NOISE_SCALAR = 0.00

## Initialize MNA task

In [5]:
# Initialize task
key, task_key = random.split(key, num=2)
task = ModuloNArithmetic(
    task_key, 
    MOD,
    congruent_number=CONGRUENT_NUMBER,
    time_length=TIME_LENGTH,
    num_trials=NUM_TRIALS,
    pulse_config=PULSE_CONFIG,
)
task.return_number_of_sequences()

27

In [6]:
# Initialize tensorflow dataset
tf_dataset_train = task.generate_tf_dataset(BATCH_SIZE_M3A)
len(task.dataset)

2500

## Load CT-RNN

In [7]:
# Initialize model
ctrnn = initialize_ctrnn(
    hidden_features=HIDDEN_FEATURES,
    output_features=OUTPUT_FEATURES,
    alpha=ALPHA,
    noise_const=NOISE_SCALAR,
)

In [8]:
# Initialize train state
key, train_state_key = random.split(key, num=2)
train_state = create_train_state(
    train_state_key, 
    ctrnn, 
     1E-4, 
    INIT_ARRAY_M3A,
)

In [9]:
# Load parameters
params = ModelParameters(train_state)
params.deserialize("../data/phase_code_solution_m3a_task.bin")

## Compute PCA and null rates

In [10]:
# Compute PCA
key, pca_key = random.split(key, num=2)
model_behavior, pca = compute_pca(
    pca_key,
    train_state,
    params.params,
    tf_dataset_train,
    3,
)

In [11]:
# Compute null rates
key, test_key = random.split(key, num=2)
output_null, rates_null = train_state.apply_fn(
    params.params, 
    jnp.zeros((1, TIME_LENGTH, MOD)), 
    rngs={"noise_stream": test_key}
)
rates_pc_null = pca.transform(rates_null[0,:,:])

## Figure preliminaries

In [12]:
# Make congruent example input
congruent_example_input, _ = task.create_trial_with_indices(
    jnp.array([2, 0, 1,]),
    jnp.array([10, 20, 30]),
)
congruent_example_input = congruent_example_input[None, :, :]

# Input congruent example into CT-RNN
key, congruent_key = random.split(key, num=2)
output_congruent, rates_congruent = train_state.apply_fn(
    params.params, 
    congruent_example_input, 
    rngs={"noise_stream": congruent_key}
)
rates_pc_congruent = pca.transform(rates_congruent[0,:,:])

# Add input, rates, and output to model_behavior
model_behavior= add_data(
    model_behavior, 
    pca, 
    congruent_example_input, 
    rates_congruent, 
    output_congruent,
)

In [13]:
# Make incongruent example input
incongruent_example_input, _ = task.create_trial_with_indices(
    jnp.array([0, 0, 2,]),
    jnp.array([15, 25, 35]),
)
incongruent_example_input = incongruent_example_input[None, :, :]

# Input congruent example into CT-RNN
key, incongruent_key = random.split(key, num=2)
output_incongruent, rates_incongruent = train_state.apply_fn(
    params.params, 
    incongruent_example_input, 
    rngs={"noise_stream": incongruent_key}
)
rates_pc_incongruent = pca.transform(rates_incongruent[0,:,:])

# Add input, rates, and output to model_behavior
model_behavior= add_data(
    model_behavior, 
    pca, 
    incongruent_example_input, 
    rates_incongruent, 
    output_incongruent,
)

In [14]:
output_congruent_dict = {-2: output_congruent[0,:,0]}
output_incongruent_dict = {-1: output_incongruent[0,:,0]}

blue_gradient = interpolate_colors("#7f7f7f", "#1f77b4", n_steps=TIME_LENGTH)
orange_gradient = interpolate_colors("#7f7f7f", "#ff7f0e", n_steps=TIME_LENGTH)

trajectory_indices = [-2, -1]
trajectory_colors = {
    -2: blue_gradient, 
    -1: orange_gradient,
}
classification_colors = {
    (1,): 'tab:blue', 
    (-1,): 'tab:orange',
}

In [15]:
# Define colors for integer stimuli
stimulus_colors = {0: "tab:green", 1: "tab:red", 2: "tab:purple"}

# Decode integer sequences for example trials
decoded_congruent = task.decode_integer_inputs(congruent_example_input[0,:,:],)
decoded_incongruent = task.decode_integer_inputs(incongruent_example_input[0,:,:],)

## PCA animation

In [16]:
# Create the Figure
fig_anim = plt.figure(figsize=(6, 6))
gs = GridSpec(1, 1)

# Panel 1
ax = fig_anim.add_subplot(gs[0, 0])
animator = PCATrajectoryAnimator(
    ax, model_behavior, 1, 2, 
    trajectory_indices, trajectory_colors, 
    '', stimulus_colors=stimulus_colors,
    null_trajectory=rates_pc_null, 
)

# Add decorations
animator.color_integer_points(-2, decoded_congruent,)
animator.color_integer_points(-1, decoded_incongruent,)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Create the animation using FuncAnimation.
ani = animation.FuncAnimation(fig_anim, animator.update, frames=TIME_LENGTH+10, blit=True)

# Save the animation as a GIF.
ani.save("../results/pca_animation.gif", writer="pillow", fps=7)
plt.close(fig_anim)