In [8]:
import mlflow
from mlflow.tracking import MlflowClient
import pandas as pd
from pprint import pprint

In [9]:
mlflow.set_tracking_uri("sqlite:///mlruns.db")

client = MlflowClient()

experiments = client.search_experiments()

exp_df = pd.DataFrame(
    [
        {
            "experiment_id": e.experiment_id,
            "name": e.name,
            "lifecycle_stage": e.lifecycle_stage,
            "artifact_location": e.artifact_location,
        }
        for e in experiments
    ]
)

exp_df.sort_values("name")



Unnamed: 0,experiment_id,name,lifecycle_stage,artifact_location
1,0,Default,active,/home/julian/Projects/MasterThesis/Pilot Decod...
0,1,grid-representations,active,/home/julian/Projects/MasterThesis/Pilot Decod...


In [10]:
EXPERIMENT_NAME = "grid-representations"

experiment = client.get_experiment_by_name(EXPERIMENT_NAME)
experiment_id = experiment.experiment_id

experiment


<Experiment: artifact_location='/home/julian/Projects/MasterThesis/Pilot Decoder/mlruns/1', creation_time=1770104346161, experiment_id='1', last_update_time=1770104346161, lifecycle_stage='active', name='grid-representations', tags={'mlflow.experimentKind': 'custom_model_development'}>

In [11]:
runs_df = mlflow.search_runs(
    experiment_ids=[experiment_id],
    order_by=["attributes.start_time DESC"],
)

runs_df.head()

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.k0/train_loss,metrics.k0/sm_60_mean,metrics.k0/sm_90_mean,metrics.k0/lg_60_mean,...,metrics.k3/lg_60_max,metrics.k1/norm,tags.mlflow.runName,tags.mlflow.source.name,tags.mlflow.source.type,tags.mlflow.source.git.commit,tags.mlflow.user,tags.mlflow.parentRunId,tags.experiment,tags.param.regularization.norm.k
0,8f6092c9e75d4edfbe7b9df98fb9e5ef,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-03 14:13:41.015000+00:00,2026-02-03 14:18:14.826000+00:00,0.804669,0.158165,0.179349,-0.093568,...,,,auspicious-bear-616,train.py,LOCAL,75a524b88bdf1cf32659315ed4b7d156577cdc07,julian,,,
1,5d06566c2ff24ae79dd4c0f75d298475,1,FAILED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-03 14:07:03.766000+00:00,2026-02-03 14:09:24.265000+00:00,0.859195,,,,...,,,k=1,sweep.py,LOCAL,75a524b88bdf1cf32659315ed4b7d156577cdc07,julian,62239bd12490444b83df5bf926559560,default,1.0
2,7a7284161d9a4fe283eeee41aa6b5123,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-03 14:02:33.879000+00:00,2026-02-03 14:07:03.738000+00:00,0.574523,0.114235,0.240922,0.142707,...,,,k=-1,sweep.py,LOCAL,75a524b88bdf1cf32659315ed4b7d156577cdc07,julian,62239bd12490444b83df5bf926559560,default,-1.0
3,f0b7cfb4f5c6453eb094ead840fd9b3b,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-03 13:58:04.684000+00:00,2026-02-03 14:02:33.850000+00:00,0.473632,0.237947,0.206793,0.349499,...,,,k=-2,sweep.py,LOCAL,75a524b88bdf1cf32659315ed4b7d156577cdc07,julian,62239bd12490444b83df5bf926559560,default,-2.0
4,77b5c8836d064f158a6872e812789896,1,FINISHED,/home/julian/Projects/MasterThesis/Pilot Decod...,2026-02-03 13:53:31.995000+00:00,2026-02-03 13:58:04.653000+00:00,0.472837,0.104333,0.327268,0.097586,...,,,k=-4,sweep.py,LOCAL,75a524b88bdf1cf32659315ed4b7d156577cdc07,julian,62239bd12490444b83df5bf926559560,default,-4.0


In [12]:
RUN_ID = None  # Set to specific run ID or None for latest

if RUN_ID is None:
    RUN_ID = runs_df.iloc[0]["run_id"]
    print(f"Using latest run: {RUN_ID}")

run = client.get_run(RUN_ID)
mlflow.start_run(run_id=RUN_ID)

Using latest run: 8f6092c9e75d4edfbe7b9df98fb9e5ef


<ActiveRun: >

In [None]:
import torch
from omegaconf import OmegaConf
import hydra
k = 0

# Load config.yaml
config_path = client.download_artifacts(
    run_id=RUN_ID,
    path="config.yaml"  # or the directory
)

cfg = OmegaConf.load(config_path)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = hydra.utils.instantiate(cfg.model).to(device)

state_dict_path = client.download_artifacts(
    run_id=RUN_ID,
    path=f"models/model_k{k}_state_dict.pt"
)

state_dict = torch.load(state_dict_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 1407.96it/s]
Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00, 1907.37it/s]


ActionableRGM()

In [None]:
from analysis import generate_2d_plots, create_loss_plots_from_mlflow, neuron_plotter_2d, get_ratemaps

# maps = get_ratemaps(model, res = 70, widths=tuple([4]))
# neuron_plotter_2d(maps[0], 70)
res = generate_2d_plots(model, k)

Scoring Small Ratemaps: 100%|██████████| 65/65 [00:20<00:00,  3.11it/s]
Scoring Medium Ratemaps: 100%|██████████| 65/65 [00:21<00:00,  3.06it/s]
Scoring Large Ratemaps: 100%|██████████| 65/65 [00:21<00:00,  3.09it/s]
