In [1]:
import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
from pprint import pprint

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mlflow.set_tracking_uri("sqlite:///mlruns.db")

client = MlflowClient()

experiments = client.search_experiments()

exp_df = pd.DataFrame(
    [
        {
            "experiment_id": e.experiment_id,
            "name": e.name,
            "lifecycle_stage": e.lifecycle_stage,
            "artifact_location": e.artifact_location,
        }
        for e in experiments
    ]
)

exp_df.sort_values("name")



2026/02/06 12:41:09 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.schemas
2026/02/06 12:41:09 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.tables
2026/02/06 12:41:09 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.types
2026/02/06 12:41:09 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.constraints
2026/02/06 12:41:09 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.defaults
2026/02/06 12:41:09 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.comments
2026/02/06 12:41:10 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2026/02/06 12:41:10 INFO alembic.runtime.migration: Will assume non-transactional DDL.


Unnamed: 0,experiment_id,name,lifecycle_stage,artifact_location
1,0,Default,active,/home/julian/Projects/MasterThesis/Pilot Decod...
0,1,grid-representations,active,/home/julian/Projects/MasterThesis/Pilot Decod...


In [3]:
EXPERIMENT_NAME = "grid-representations"

experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
experiment_id = experiment.experiment_id

experiment


<Experiment: artifact_location='/home/julian/Projects/MasterThesis/Pilot Decoder/mlruns/1', creation_time=1770104346161, experiment_id='1', last_update_time=1770104346161, lifecycle_stage='active', name='grid-representations', tags={'mlflow.experimentKind': 'custom_model_development'}>

In [4]:
runs_df = mlflow.search_runs(
    experiment_ids=[experiment_id],
    order_by=["attributes.start_time DESC"],
)

runs_df.head()

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.k2/md_60_max,metrics.k3/positivity_geco,metrics.k0/sm_60_mean,metrics.k1/md_90_max,...,metrics.k4/lg_60_mean,metrics.k4/sm_60_max,metrics.k4/md_60_max,tags.experiment,tags.mlflow.source.name,tags.mlflow.source.git.commit,tags.mlflow.user,tags.mlflow.runName,tags.mlflow.source.type,tags.mlflow.parentRunId
0,93a8b89b51d74dcd8d89dc7bbfd86baa,1,FAILED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-03 17:06:30.485000+00:00,2026-02-03 17:51:29.364000+00:00,0.787433,2.358251,0.073496,1.060948,...,,,,seq_200,sweep.py,7d3216c753b2593856d0ee188b47f10063c21b37,julian,seq_200,LOCAL,185f96546a554258a58ec5ed7302c3f4
1,db507b84a26049d4b0d5ecbea7afb2aa,1,RUNNING,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-03 16:29:33.460000+00:00,2026-02-03 17:06:30.452000+00:00,0.575461,1.373189,0.088102,0.989037,...,-0.093204,1.025579,0.362797,seq_100,sweep.py,7d3216c753b2593856d0ee188b47f10063c21b37,julian,seq_100,LOCAL,185f96546a554258a58ec5ed7302c3f4
2,3b82052374494e3a93a27a60c208fbb7,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-03 16:01:14.223000+00:00,2026-02-03 16:29:33.428000+00:00,0.446142,-0.437249,0.104838,0.442456,...,0.0281,0.402249,0.214604,seq_50,sweep.py,7d3216c753b2593856d0ee188b47f10063c21b37,julian,seq_50,LOCAL,185f96546a554258a58ec5ed7302c3f4
3,9afb583aa144430a9db0612d3c2d869f,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-03 15:41:57.563000+00:00,2026-02-03 16:01:14.194000+00:00,0.013208,-2.232798,0.048751,0.02925,...,0.01052,0.393195,0.017943,seq_10,sweep.py,7d3216c753b2593856d0ee188b47f10063c21b37,julian,seq_10,LOCAL,185f96546a554258a58ec5ed7302c3f4
4,93810a36a2654961aebe64d9ef4034ed,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-03 15:38:03.823000+00:00,2026-02-03 15:41:57.533000+00:00,,,0.01333,,...,,,,seq_5,sweep.py,7d3216c753b2593856d0ee188b47f10063c21b37,julian,seq_5,LOCAL,185f96546a554258a58ec5ed7302c3f4


In [5]:
RUN_ID = "db507b84a26049d4b0d5ecbea7afb2aa"  # Set to specific run ID or None for latest
# RUN_ID = None

if RUN_ID is None:
    RUN_ID = runs_df.iloc[0]["run_id"]
    print(f"Using latest run: {RUN_ID}")

run = client.get_run(RUN_ID)
mlflow.start_run(run_id=RUN_ID)

<ActiveRun: >

# Generate analysis plots

In [6]:
import torch
from omegaconf import OmegaConf
import hydra
k = 0

# Load config.yaml
config_path = client.download_artifacts(
    run_id=RUN_ID,
    path="config.yaml"  # or the directory
)

cfg = OmegaConf.load(config_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = hydra.utils.instantiate(cfg.model).to(device)

state_dict_path = client.download_artifacts(
    run_id=RUN_ID,
    path=f"models/model_k{k}_state_dict.pt"
)

state_dict = torch.load(state_dict_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 4310.69it/s] 
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 1470.14it/s]


ActionableRGM()

In [None]:
from analysis import (
    get_ratemaps, quantitative_analysis, loss_plots,
    neuron_plotter_2d, frequency_plot, s_matrix_plot,
    create_loss_plots_from_mlflow
)
from mlflow.tracking import MlflowClient

client = MlflowClient()

# Get model parameters
om = model.om.detach().cpu().numpy()
S = model.S.detach().cpu().numpy()

# Frequency plot
fig_freq = frequency_plot(om)
fig_freq.show()

# S matrix analysis
fig_S = s_matrix_plot(S)
fig_S.show()

# Compute ratemaps
res = 70
widths = (1, 2, 4)
Vs = get_ratemaps(model, res, widths)
V_small, V_medium, V_large = Vs

# Grid scores histogram
fig_score, scores = quantitative_analysis(Vs, widths, res)
fig_score.show()

# Neuron plots
fig_neurons_sm = neuron_plotter_2d(V_small, res, scores["sm_60"])
fig_neurons_sm.show()

fig_neurons_md = neuron_plotter_2d(V_medium, res, scores["md_60"])
fig_neurons_md.show()

fig_neurons_lg = neuron_plotter_2d(V_large, res, scores["lg_60"])
fig_neurons_lg.show()

In [7]:
from analysis import loss_plots

from mlflow.tracking import MlflowClient
import numpy as np

client = MlflowClient()

# Loss plots from MLflow metrics
def fetch_metric(metric_key: str):
    history = client.get_metric_history(RUN_ID, metric_key)
    if history:
        history = sorted(history, key=lambda x: x.step)
        return np.array([m.value for m in history])
    return None

train_losses = {
    "loss": fetch_metric(f"k{k}/train_loss"),
    "separation": fetch_metric(f"k{k}/separation"),
    "positivity": fetch_metric(f"k{k}/positivity_geco"),
    "norm": fetch_metric(f"k{k}/norm_geco"),
}
train_losses = {k: v for k, v in train_losses.items() if v is not None}

val_losses = {"loss": fetch_metric(f"k{k}/val_loss")}
val_losses = {k: v for k, v in val_losses.items() if v is not None}

lambda_pos = fetch_metric(f"k{k}/lambda_pos")
lambda_norm = fetch_metric(f"k{k}/lambda_norm")

if train_losses:
    fig_loss = loss_plots(train_losses, val_losses or None, lambda_pos, lambda_norm)
    fig_loss.show()

In [None]:
from analysis import create_loss_plots_from_mlflow

create_loss_plots_from_mlflow(0)

In [None]:
from analysis import generate_2d_plots

res = generate_2d_plots(model, k)

# Sweep score distributions

In [1]:
from analysis import sweep_score_distributions_mlflow

PARENT_RUN_ID = "739dfc50c9cb41d3875e79a55b35ca09"
TRACKING_URI = "sqlite:///mlruns.db"

print(f"Generating sweep score distributions for parent run: {PARENT_RUN_ID}")
figures = sweep_score_distributions_mlflow(
    parent_run_id=PARENT_RUN_ID,
    x_param="data.seq_len",
    log_x=True,
    tracking_uri=TRACKING_URI,
)
for name, fig in figures.items():
    fig.show()

  from .autonotebook import tqdm as notebook_tqdm
2026/02/09 15:18:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.schemas
2026/02/09 15:18:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.tables
2026/02/09 15:18:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.types
2026/02/09 15:18:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.constraints
2026/02/09 15:18:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.defaults
2026/02/09 15:18:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.comments


Generating sweep score distributions for parent run: 739dfc50c9cb41d3875e79a55b35ca09


2026/02/09 15:18:21 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2026/02/09 15:18:21 INFO alembic.runtime.migration: Will assume non-transactional DDL.
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 3504.01it/s]
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 3125.41it/s]
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 2779.53it/s]
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 4485.89it/s] 
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 3320.91it/s] 
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 3983.19it/s] 
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 6345.39it/s] 
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 3501.09it/s] 


Sweep score distribution plots logged to MLflow for run 739dfc50c9cb41d3875e79a55b35ca09
