Set path to where `mlruns` directory is located (usually, the `CardiacCOMA` repository)

In [1]:
CARDIAC_COMA_REPO = "/home/rodrigo/CISTIB/repos/CardiacCOMA/"

In [59]:
import mlflow
import os, sys

import torch
import torch.nn.functional as F

import os; os.chdir(CARDIAC_COMA_REPO)
from config.load_config import load_yaml_config, to_dict

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import Image
from mlflow.tracking import MlflowClient

import pickle as pkl
import pytorch_lightning as pl

from argparse import Namespace
import matplotlib.pyplot as plt

#import surgeon_pytorch
#from surgeon_pytorch import Inspect, get_layers

import numpy as np
from IPython import embed
sys.path.insert(0, '..')

import model.Model3D
from utils.helpers import get_coma_args, get_lightning_module, get_datamodule
from copy import deepcopy
from pprint import pprint

In [3]:
TRACKING_URI = f"file://{CARDIAC_COMA_REPO}/mlruns"
mlflow.set_tracking_uri(TRACKING_URI)

# Select MLflow experiment

In [4]:
options = [exp.name for exp in mlflow.list_experiments()]

experiment_w = widgets.Select(
    options=options,
    value=options[1]
)
display(experiment_w)

Select(index=1, options=('Cardiac - ED', 'Synthetic data', 'Default'), value='Synthetic data')

Retrieve run data from MLflow

In [51]:
# runs_list = mlflow.search_runs(experiment_ids=[exp_id], output_format="list")
exp_id = mlflow.get_experiment_by_name(experiment_w.value).experiment_id

runs_df = mlflow.search_runs(experiment_ids=[exp_id],)

# Keep only the runs that ended successfully
runs_df = runs_df[runs_df.status == "FINISHED"].reset_index(drop=True)

# Use experiment ID and run ID as indices
runs_df = runs_df.set_index(["experiment_id", "run_id"])
runs_df.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,status,artifact_uri,start_time,end_time,metrics.loss,metrics.test_loss,metrics.recon_loss,metrics.training_loss,metrics.test_recon_loss,metrics.val_loss,...,params.monitor,params.min_delta,params.convolution_type,params.epochs,tags.mlflow.user,tags.Mode,tags.mlflow.source.git.commit,tags.mlflow.source.type,tags.mlflow.source.name,tags.mlflow.log-model.history
experiment_id,run_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
2,29d040d969394cd3b3c09a16790e3bdd,FINISHED,file:///home/rodrigo/CISTIB/repos/CardiacCOMA/...,2022-06-30 21:00:43.301000+00:00,2022-06-30 23:31:22.355000+00:00,497.427856,497.427856,497.427856,497.434357,497.427856,497.427856,...,val_loss,-0.0,Chebyshev,1000,scrb,testing,913e90992eb2e6f1a69e164eb7240d87fa228bc6,LOCAL,main.py,
2,5a49839a2b474e548d1643bb98741d88,FINISHED,file:///home/rodrigo/CISTIB/repos/CardiacCOMA/...,2022-06-30 18:22:47.917000+00:00,2022-06-30 20:59:24.088000+00:00,497.988251,497.988251,497.988251,498.003265,497.988251,497.98822,...,val_loss,-0.0,Chebyshev,1000,scrb,testing,913e90992eb2e6f1a69e164eb7240d87fa228bc6,LOCAL,main.py,
2,93aee9ef9a95401c875da1772154866b,FINISHED,file:///home/rodrigo/CISTIB/repos/CardiacCOMA/...,2022-06-30 15:35:33.312000+00:00,2022-06-30 18:21:25.677000+00:00,490.910797,490.910797,490.910797,490.916626,490.910797,490.910797,...,val_loss,-0.0,Chebyshev,1000,scrb,testing,913e90992eb2e6f1a69e164eb7240d87fa228bc6,LOCAL,main.py,"[{""run_id"": ""93aee9ef9a95401c875da1772154866b""..."
2,cb2ecda96b034689aa1bb7165164b22a,FINISHED,file:///home/rodrigo/CISTIB/repos/CardiacCOMA/...,2022-06-30 15:00:21.300000+00:00,2022-06-30 15:34:09.070000+00:00,478.635101,478.635101,478.635101,478.639282,478.635101,478.635071,...,val_loss,-0.0,Chebyshev,1000,scrb,testing,913e90992eb2e6f1a69e164eb7240d87fa228bc6,LOCAL,main.py,"[{""run_id"": ""cb2ecda96b034689aa1bb7165164b22a""..."
2,4fb99eff61af4a7da7d712271a6d4903,FINISHED,file:///home/rodrigo/CISTIB/repos/CardiacCOMA/...,2022-06-30 14:29:21.297000+00:00,2022-06-30 14:57:09.213000+00:00,494.52652,,494.52652,494.503204,,494.52652,...,val_loss,-0.0,Chebyshev,1000,scrb,training,913e90992eb2e6f1a69e164eb7240d87fa228bc6,LOCAL,main.py,


Choose the run with the minimum `val_recon_loss` and load pretained weights.

In [58]:
experiment_id, run_id = runs_df.index[runs_df["metrics.val_recon_loss"].argmin()]

# print(runs_df.loc[experiment_id, run_id].to_dict())

run_info = runs_df.loc[experiment_id, run_id].to_dict()
artifact_uri = run_info["artifact_uri"].replace("file://", "")
chkpt_dir = os.path.join(artifact_uri, "restored_model_checkpoint", os.listdir(chkpt_dir)[0])

model_pretrained_weights = torch.load(chkpt_file, map_location=torch.device('cpu'))["state_dict"]

# Remove "model." prefix from state_dict's keys.
_model_pretrained_weights = {k.replace("model.", ""): v for k, v in model_pretrained_weights.items()}
# print(_model_pretrained_weights)

NotADirectoryError: [Errno 20] Not a directory: '/home/rodrigo/CISTIB/repos/CardiacCOMA/mlruns/2/958c33a02570470aaf7e085e8d3c0cad/artifacts/restored_model_checkpoint/epoch=999-step=19999.ckpt'

In [61]:
def overwrite_ref_config(ref_config, run_info):
    
    '''
    This is a workaround for adjusting the configuration of those runs that didn't have a YAML configuration file logged as an artifact.
    '''
    
    config = deepcopy(ref_config)
    config.network_architecture.latent_dim = int(run_info["params.latent_dim"])
    config.loss.regularization.weight = float(run_info["params.w_kl"])
    config.optimizer.parameters.lr = float(run_info["params.lr"])
    config.sample_sizes = [100, 100, 100]
    
    return config


ref_config = load_yaml_config("config_files/config.yaml")
config = overwrite_ref_config(ref_config, run_info)
pprint(to_dict(config))

{'batch_size': 32,
 'dataset': {'data_type': 'cardiac',
             'filename': 'data/cardio/LV_meshes_at_ED_35k.pkl',
             'preprocessing': {'procrustes': 'data/cardio/procrustes_transforms_35k.pkl'},
             'template': 'data/cardio/faces.pkl'},
 'loss': {'reconstruction': {'type': 'MSE', 'weight': 1},
          'regularization': {'type': 'KL', 'weight': 0.0}},
 'mlflow': {'artifact_location': None,
            'experiment_name': 'Synthetic data',
            'run_name': None,
            'tracking_uri': None},
 'network_architecture': {'activation_function': ['ReLU',
                                                  'ReLU',
                                                  'ReLU',
                                                  'ReLU'],
                          'convolution': {'channels_dec': [16, 32, 64, 128],
                                          'channels_enc': [16, 32, 64, 128],
                                          'parameters': {'polynomial_degree': [6

In [8]:
client = MlflowClient()
local_dir = "/tmp/artifact_downloads"

if not os.path.exists(local_dir):
    os.mkdir(local_dir)

In [9]:
# client._tracking_client.list_artifacts(
#     runs_df.run_id[2], 'restored_model_checkp# oint'
# )

In [20]:
dm = get_datamodule(config, perform_setup=True)

In [11]:
model = get_lightning_module(config, dm)

0.07521533966064453: 5200
0.39846229553222656: 5100
0.40227174758911133: 5000
0.3949596881866455: 4900
0.4393291473388672: 4800
0.4231417179107666: 4700
0.39465951919555664: 4600
0.3945751190185547: 4500
0.6326649188995361: 4400
0.3608686923980713: 4300
0.34975743293762207: 4200
0.35294628143310547: 4100
0.35439634323120117: 4000
0.3509514331817627: 3900
0.35549402236938477: 3800
0.34957265853881836: 3700
0.35103297233581543: 3600
0.3348579406738281: 3500
0.33591556549072266: 3400
0.3592402935028076: 3300
0.33854174613952637: 3200
0.3330986499786377: 3100
0.4913671016693115: 3000
0.34344053268432617: 2900
0.3398416042327881: 2800
0.31565165519714355: 2700
0.019639968872070312: 2600
0.19584870338439941: 2500
0.19938421249389648: 2400
0.19098162651062012: 2300
0.19422602653503418: 2200
0.183305025100708: 2100
0.1861741542816162: 2000
0.17617464065551758: 1900
0.17900395393371582: 1800
0.1668233871459961: 1700
0.17716646194458008: 1600
0.1634671688079834: 1500
0.17041325569152832: 1400
0.

In [19]:
model.model.load_state_dict(_model_pretrained_weights)

encoder.layers.layer_0.graph_conv.lins.0.
encoder.layers.layer_0.graph_conv.lins.1.
encoder.layers.layer_0.graph_conv.lins.2.
encoder.layers.layer_0.graph_conv.lins.3.
encoder.layers.layer_0.graph_conv.lins.4.
encoder.layers.layer_0.graph_conv.lins.5.
encoder.layers.layer_1.graph_conv.lins.0.
encoder.layers.layer_1.graph_conv.lins.1.
encoder.layers.layer_1.graph_conv.lins.2.
encoder.layers.layer_1.graph_conv.lins.3.
encoder.layers.layer_1.graph_conv.lins.4.
encoder.layers.layer_1.graph_conv.lins.5.
encoder.layers.layer_2.graph_conv.lins.0.
encoder.layers.layer_2.graph_conv.lins.1.
encoder.layers.layer_2.graph_conv.lins.2.
encoder.layers.layer_2.graph_conv.lins.3.
encoder.layers.layer_2.graph_conv.lins.4.
encoder.layers.layer_2.graph_conv.lins.5.
encoder.layers.layer_3.graph_conv.lins.0.
encoder.layers.layer_3.graph_conv.lins.1.
encoder.layers.layer_3.graph_conv.lins.2.
encoder.layers.layer_3.graph_conv.lins.3.
encoder.layers.layer_3.graph_conv.lins.4.
encoder.layers.layer_3.graph_conv.

<All keys matched successfully>

{'s': tensor([[ 25.4802,   3.4458,  27.2371],
         [ 11.6291,   2.7515, -53.7377],
         [ 30.5974, -12.0060,  21.6333],
         ...,
         [-25.9343,  -0.9732,   6.9544],
         [ 31.9850,   7.3206,  -1.3706],
         [  9.9849,  -1.8175, -53.0963]])}

In [48]:
dm.dataset[1]['s']

tensor([[ 29.6958,   4.6812,  32.1512],
        [ 10.8751,   1.3605, -57.2060],
        [ 35.7236, -13.2848,  24.5868],
        ...,
        [-30.6115,  -0.9908,   6.7875],
        [ 38.8894,   8.9288,  -1.2236],
        [  9.0511,  -4.0243, -57.1594]])

In [34]:
s_hat = model(s)[0][0]

In [37]:
def mse(s1, s2=None):
    if s2 is None:
        s2 = torch.zeros_like(s1)
    return ((s1-s2)**2).sum(-1).mean(-1)

In [38]:
mse(s, s_hat)

tensor(83.7731, grad_fn=<MeanBackward1>)