In [1]:
import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
from pprint import pprint

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
mlflow.set_tracking_uri("sqlite:///mlruns.db")

client = MlflowClient()

experiments = client.search_experiments()

exp_df = pd.DataFrame(
    [
        {
            "experiment_id": e.experiment_id,
            "name": e.name,
            "lifecycle_stage": e.lifecycle_stage,
            "artifact_location": e.artifact_location,
        }
        for e in experiments
    ]
)

exp_df.sort_values("name")



2026/02/11 10:49:47 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.schemas
2026/02/11 10:49:47 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.tables
2026/02/11 10:49:47 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.types
2026/02/11 10:49:47 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.constraints
2026/02/11 10:49:47 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.defaults
2026/02/11 10:49:47 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.comments
2026/02/11 10:49:47 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2026/02/11 10:49:47 INFO alembic.runtime.migration: Will assume non-transactional DDL.


Unnamed: 0,experiment_id,name,lifecycle_stage,artifact_location
1,0,Default,active,/home/julian/Projects/MasterThesis/Pilot Decod...
0,1,grid-representations,active,/home/julian/Projects/MasterThesis/Pilot Decod...


In [3]:
EXPERIMENT_NAME = "grid-representations"

experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
experiment_id = experiment.experiment_id

experiment


<Experiment: artifact_location='/home/julian/Projects/MasterThesis/Pilot Decoder/mlruns/1', creation_time=1770104346161, experiment_id='1', last_update_time=1770104346161, lifecycle_stage='active', name='grid-representations', tags={'mlflow.experimentKind': 'custom_model_development'}>

In [4]:
runs_df = mlflow.search_runs(
    experiment_ids=[experiment_id],
    order_by=["attributes.start_time DESC"],
)

runs_df.head()

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.k0/sm_60_mean,metrics.k1/md_60_max,metrics.k2/md_60_mean,metrics.k3/val_loss,...,metrics.k7/lg_90_max,metrics.k5/lambda_norm,metrics.k6/lg_60_mean,tags.mlflow.parentRunId,tags.mlflow.user,tags.mlflow.runName,tags.mlflow.source.type,tags.experiment,tags.mlflow.source.name,tags.mlflow.source.git.commit
0,cafe708190a044eb8ccd5dd5802e011b,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-10 13:52:53.051000+00:00,2026-02-10 16:44:34.329000+00:00,0.033828,0.864461,0.076642,6.831656,...,,,,88b33e9d84894d87bbcb9422a4a7c54a,julian,seq_500,LOCAL,seq_500,sweep.py,381bdb53bd967e415a8a3948795a7a8604ac3d3b
1,b537cc6c946e44a08e37d0d564dc1514,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-10 12:26:40.099000+00:00,2026-02-11 08:56:55.666000+00:00,0.05387,0.946434,0.190975,1.002849,...,,,,88b33e9d84894d87bbcb9422a4a7c54a,julian,seq_200,LOCAL,seq_200,sweep.py,381bdb53bd967e415a8a3948795a7a8604ac3d3b
2,5a3b2bea9972432d8244385c0428b349,1,RUNNING,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-10 11:27:46.065000+00:00,2026-02-11 07:01:00.309000+00:00,0.068336,0.667255,0.256517,0.021578,...,,,,88b33e9d84894d87bbcb9422a4a7c54a,julian,seq_100,LOCAL,seq_100,sweep.py,381bdb53bd967e415a8a3948795a7a8604ac3d3b
3,6e0d499a382747b9baf3e4f555a83468,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-10 10:43:44.673000+00:00,2026-02-10 11:27:46.033000+00:00,0.084968,0.867514,0.263127,0.001286,...,,,,88b33e9d84894d87bbcb9422a4a7c54a,julian,seq_50,LOCAL,seq_50,sweep.py,381bdb53bd967e415a8a3948795a7a8604ac3d3b
4,25003dca68294720ad36ee8b035c6fb3,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-10 10:12:38.533000+00:00,2026-02-10 10:43:44.640000+00:00,0.050564,0.021501,0.015707,4e-06,...,,,,88b33e9d84894d87bbcb9422a4a7c54a,julian,seq_10,LOCAL,seq_10,sweep.py,381bdb53bd967e415a8a3948795a7a8604ac3d3b


In [5]:
RUN_ID = "5a3b2bea9972432d8244385c0428b349"  # Set to specific run ID or None for latest
# RUN_ID = None

if RUN_ID is None:
    RUN_ID = runs_df.iloc[0]["run_id"]
    print(f"Using latest run: {RUN_ID}")

run = client.get_run(RUN_ID)
mlflow.start_run(run_id=RUN_ID)

<ActiveRun: >

# Generate analysis plots

In [6]:
import torch
from omegaconf import OmegaConf
import hydra
k = 0

# Load config.yaml
config_path = client.download_artifacts(
    run_id=RUN_ID,
    path="config.yaml"  # or the directory
)

cfg = OmegaConf.load(config_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = hydra.utils.instantiate(cfg.model).to(device)

state_dict_path = client.download_artifacts(
    run_id=RUN_ID,
    path=f"models/model_k{k}_state_dict.pt"
)

state_dict = torch.load(state_dict_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 4002.20it/s] 
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 169.67it/s]


ActionableRGM()

In [None]:
from analysis import (
    get_ratemaps, quantitative_analysis, loss_plots,
    neuron_plotter_2d, frequency_plot, s_matrix_plot,
    create_loss_plots_from_mlflow
)
from mlflow.tracking import MlflowClient

client = MlflowClient()

# Get model parameters
om = model.om.detach().cpu().numpy()
S = model.S.detach().cpu().numpy()

# Frequency plot
fig_freq = frequency_plot(om)
fig_freq.show()

# S matrix analysis
fig_S = s_matrix_plot(S)
fig_S.show()

# Compute ratemaps
res = 70
widths = (1, 2, 4)
Vs = get_ratemaps(model, res, widths)
V_small, V_medium, V_large = Vs

# Grid scores histogram
fig_score, scores = quantitative_analysis(Vs, widths, res)
fig_score.show()

# Neuron plots
fig_neurons_sm = neuron_plotter_2d(V_small, res, scores["sm_60"])
fig_neurons_sm.show()

fig_neurons_md = neuron_plotter_2d(V_medium, res, scores["md_60"])
fig_neurons_md.show()

fig_neurons_lg = neuron_plotter_2d(V_large, res, scores["lg_60"])
fig_neurons_lg.show()

In [7]:
from analysis import loss_plots

from mlflow.tracking import MlflowClient
import numpy as np

client = MlflowClient()

# Loss plots from MLflow metrics
def fetch_metric(metric_key: str):
    history = client.get_metric_history(RUN_ID, metric_key)
    if history:
        history = sorted(history, key=lambda x: x.step)
        return np.array([m.value for m in history])
    return None

train_losses = {
    "loss": fetch_metric(f"k{k}/train_loss"),
    "separation": fetch_metric(f"k{k}/separation"),
    "positivity": fetch_metric(f"k{k}/positivity_geco"),
    "norm": fetch_metric(f"k{k}/norm_geco"),
}
train_losses = {k: v for k, v in train_losses.items() if v is not None}

val_losses = {"loss": fetch_metric(f"k{k}/val_loss")}
val_losses = {k: v for k, v in val_losses.items() if v is not None}

lambda_pos = fetch_metric(f"k{k}/lambda_pos")
lambda_norm = fetch_metric(f"k{k}/lambda_norm")

if train_losses:
    fig_loss = loss_plots(train_losses, val_losses or None, lambda_pos, lambda_norm)
    fig_loss.show()

In [None]:
from analysis import create_loss_plots_from_mlflow

create_loss_plots_from_mlflow(0)

In [7]:
from analysis import generate_2d_plots

res = generate_2d_plots(model, k)


invalid value encountered in power


invalid value encountered in power

Scoring Small Ratemaps: 100%|██████████| 65/65 [00:20<00:00,  3.12it/s]
Scoring Medium Ratemaps: 100%|██████████| 65/65 [00:20<00:00,  3.11it/s]
Scoring Large Ratemaps: 100%|██████████| 65/65 [00:20<00:00,  3.11it/s]

invalid value encountered in power


invalid value encountered in power



Found 7 Modules (Score: 0.768)


# Sweep score distributions

In [1]:
from analysis import sweep_score_distributions_mlflow

PARENT_RUN_ID = "88b33e9d84894d87bbcb9422a4a7c54a"
TRACKING_URI = "sqlite:///mlruns.db"

print(f"Generating sweep score distributions for parent run: {PARENT_RUN_ID}")
figures = sweep_score_distributions_mlflow(
    parent_run_id=PARENT_RUN_ID,
    x_param="data.seq_len",
    log_x=True,
    tracking_uri=TRACKING_URI,
)
for name, fig in figures.items():
    fig.show()

  from .autonotebook import tqdm as notebook_tqdm
2026/02/11 10:00:59 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.schemas
2026/02/11 10:00:59 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.tables
2026/02/11 10:00:59 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.types
2026/02/11 10:00:59 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.constraints
2026/02/11 10:00:59 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.defaults
2026/02/11 10:00:59 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.comments


Generating sweep score distributions for parent run: 88b33e9d84894d87bbcb9422a4a7c54a


2026/02/11 10:01:00 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2026/02/11 10:01:00 INFO alembic.runtime.migration: Will assume non-transactional DDL.
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 2680.07it/s]
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 4245.25it/s] 
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 2642.91it/s]
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 2898.62it/s]
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 3542.49it/s]
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 2788.77it/s]
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 3086.32it/s]


Sweep score distribution plots logged to MLflow for run 88b33e9d84894d87bbcb9422a4a7c54a
