## Set up

In [58]:
# autoreload
%load_ext autoreload
%autoreload 2
# jupyter black formatter
%load_ext jupyter_black

import subprocess
import os
import sys

gitroot_path = subprocess.check_output(
    ["git", "rev-parse", "--show-toplevel"], universal_newlines=True
)

os.chdir(os.path.join(gitroot_path[:-1], "pirnns"))
print("Working directory: ", os.getcwd())

sys_dir = os.path.dirname(os.getcwd())
sys.path.append(sys_dir)
print("Directory added to path: ", sys_dir)
sys.path.append(os.getcwd())
print("Directory added to path: ", os.getcwd())

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The jupyter_black extension is already loaded. To reload it, use:
  %reload_ext jupyter_black
Working directory:  /home/facosta/pirnns/pirnns
Directory added to path:  /home/facosta/pirnns
Directory added to path:  /home/facosta/pirnns/pirnns


## Load model

In [59]:
from pirnns.paper_figs.load_models import load_experiment_sweep
import torch

sweep_dir = (
    "/home/facosta/pirnns/pirnns/logs/experiments/timescales_sweep_20251002_055821"
)

device = "cuda" if torch.cuda.is_available() else "cpu"

models, metadata, summary = load_experiment_sweep(
    sweep_dir=sweep_dir,
    device=device,
    use_lightning_checkpoint=True,
    checkpoint_type="best",
)

Loading 1 experiments with 1 seeds each...
Total models to load: 1
Using Lightning checkpoints

Loading experiment: discrete_single_05
  ✓ Loaded discrete_single_05/seed_0

Successfully loaded: 1/1 models


### Load Position Decoding Measurement

In [63]:
from pirnns.analysis.measurements import PositionDecodingMeasurement

config = models["discrete_single_05"][0]["config"]

measurement = PositionDecodingMeasurement(config)

### Load OOD Trajectory Length Analysis

In [66]:
from pirnns.analysis.analyses import OODAnalysis

model = models["discrete_single_05"][0]["model"]

place_cell_centers = models["discrete_single_05"][0]["place_cell_centers"]

analysis = OODAnalysis(
    config, test_lengths=[20, 25, 50, 100], place_cell_centers=place_cell_centers
)
analysis.run(model, measurement)

AnalysisResult(test_conditions=[20, 25, 50, 100], measurements=[0.08945048522949219, 0.087663720703125, 0.10240396728515624, 0.1354678955078125], condition_name='trajectory_length', metadata={'training_length': 20, 'num_test_trajectories': 100})

In [65]:
models["discrete_single_05"][0]["position_decoding_errors"]["position_errors_epoch"][-1]

0.08473280212402344