In [6]:
import matplotlib.pyplot as plt
from typing import List, Tuple, Union
import math
import os

import torch
import plotly.express as px

from train import evaluate_model_copy, CopyConfig, inference_generate
from rnn import RNN, GRULayer

torch.manual_seed(8)
device = "cpu"

In [2]:
# 1024 achieved 1.0 no issues, 256 took some tries and the best is at: "run_256_fixed_lower_wdlr_ctd". 
d_hidden = 256
run_name = "run_256_fixed_lower_wdlr_ctd"
copy_test_cfg = CopyConfig(run_name=run_name, d_hidden=d_hidden, gru=True)

model = RNN(31, copy_test_cfg.d_hidden, out_size=30, 
            out_act= lambda x: x, use_gru=copy_test_cfg.gru)
model.load_state_dict(torch.load(f"models/copy_train/{copy_test_cfg.run_name}/{copy_test_cfg.run_name}.ckpt", weights_only=True, map_location=device))

ModuleList(
  (0): GRULayer(
    (input_to_reset): Linear(in_features=31, out_features=256, bias=False)
    (hidden_to_reset): Linear(in_features=256, out_features=256, bias=True)
    (input_to_update): Linear(in_features=31, out_features=256, bias=False)
    (hidden_to_update): Linear(in_features=256, out_features=256, bias=True)
    (input_to_new): Linear(in_features=31, out_features=256, bias=False)
    (hidden_to_new): Linear(in_features=256, out_features=256, bias=True)
  )
  (1): Linear(in_features=256, out_features=30, bias=True)
)


<All keys matched successfully>

In [4]:
evaluate_model_copy(copy_test_cfg, model)

100%|██████████| 5000/5000 [01:09<00:00, 72.09it/s]


(0.00040266723594573897, 1.0)

plot gru gate movement

In [10]:
def plot_update_gate_heatmap(model, sequence, loss_mask, layer=-1):
    # Run forward with gate recording enabled
    outputs, h_t, r_t_all, z_t_all = inference_generate(model, sequence, 
                                                        discrete=True,
                                                        record_gates=True)


    layers, batch_size, seq_len, hidden_size = z_t_all.shape
    if layer > 0:
        z_t_all, r_t_all = z_t_all[layer].unsqueeze(0), r_t_all[layer].unsqueeze(0)


    mask = loss_mask.unsqueeze(0).unsqueeze(-1) # (1, b, l, 1)
    z_t_all_masked = z_t_all * mask  # invalid positions become 0

     # Count valid positions per layer, timestep, and hidden unit across batches
    valid_counts = mask.sum(dim=1)  # sum over batch dimension, shape (1, seq_len, 1)
    valid_counts = valid_counts.clamp(min=1)  # avoid division by zero
    
    # Sum over batches to get total activations per layer, timestep, hidden unit
    z_t_sum = z_t_all_masked.sum(dim=1)  # sum over batch dimension, shape (layers, seq_len, hidden_size)
    
    # Average by dividing sum by counts of valid batches
    # Note: valid_counts shape (1, seq_len, 1) broadcasts over layers and hidden_size
    z_t_avg = z_t_sum / valid_counts

    # Plot one heatmap per layer
    for layer in range(layers):
        plt.figure(figsize=(10, 6))
        # Transpose so x axis = time steps, y axis = hidden units
        plt.imshow(z_t_avg[layer].T.cpu(), aspect='auto', cmap='viridis', vmin=0, vmax=1)
        plt.colorbar(label="Update gate z_t")
        plt.xlabel("Time step")
        plt.ylabel("Hidden unit")
        plt.title(f"Update Gate Activations for Layer {layer}")
        plt.show()

In [11]:
# data to viz:
test_dataset = torch.load(f"data/copy_test/{copy_test_cfg.run_name}.pt")
batches = [1]
test_data = [test_dataset[batch] for batch in batches]
test_data, test_loss_masks = zip(*test_data)
test_data, test_loss_masks = torch.stack(test_data), torch.stack(test_loss_masks)

In [12]:

plot_update_gate_heatmap(model, test_data, test_loss_masks)

ValueError: not enough values to unpack (expected 4, got 2)