In [14]:
import os
from pathlib import Path 
import wandb
wandb.login()

True

In [None]:
os.chdir('..')

In [4]:
def filter_wandb_run(anatomy,project_name='msrepo/2d-3d-benchmark',tags=['model-compare']):
    api = wandb.Api()
    runs = api.runs(project_name,filters={
        'tags':{'$in':tags}
    })
    filtered_runs = [ run for run in runs if run.config['ANATOMY'] == anatomy]
    return filtered_runs

In [20]:
def get_latest_checkpoint(path):
    checkpoints = list(Path(path).glob('epoch=*.ckpt'))
    latest_checkpoint_path = max(checkpoints,key=lambda x: x.lstat().st_ctime)
    return str(latest_checkpoint_path)

In [22]:
def get_experiment_name_from_model_name(model_name):
    expt_dict = {'OneDConcat':'ParallelHeadsExperiment','MultiScale2DPermuteConcat':'ParallelHeadsExperiment','TwoDPermuteConcat':'ParallelHeadsExperiment','AttentionUnet':'VolumeAsInputExperiment',
    'UNet':'VolumeAsInputExperiment'}
    return expt_dict[model_name]

In [27]:
def get_run_from_model_name(model_name,wandb_runs):
    for run in wandb_runs:
        if run.config['MODEL_NAME'] == model_name:
            return run
    raise ValueError(f'{model_name} not found')


In [None]:
train_csv_path = 'configs/paths/totalsegmentator_ribs/TotalSegmentor-ribs-DRR-full_train+val.csv'
test_csv_path = 'configs/paths/totalsegmentator_ribs/TotalSegmentor-ribs-DRR-full_test.csv'
gpu = 0
rib_img_size = 128
rib_resolution = 2.5
rib_runs = filter_wandb_run(anatomy='ribs')
for run in rib_runs:
    model_name = run.config['MODEL_NAME']
    input_type = get_experiment_name_from_model_name(model_name)
    run_id = str(run.id)
    checkpoint_path = get_latest_checkpoint(f'runs/2d-3d-benchmark/{run_id}/checkpoints/') 
    output_dir = f'runs/2d-3d-benchmark/{run_id}/evaluation'
    command = f'python train.py  {train_csv_path} {test_csv_path} --gpu {gpu} --tags   model-compare --size {rib_img_size} --batch_size 8 --accelerator gpu --res {rib_resolution} --precision 16 --model_name {model_name} --experiment_name {input_type} --epochs -1 --anatomy ribs --loss DiceLoss  --lr 0.002 --steps 3000 --evaluate --save_predictions --checkpoint_path {checkpoint_path} --output_dir {output_dir}'
    os.system(command)
clear_output(wait=True)

In [28]:
import pandas as pd
latex_table = ""
latex_table_row_template = " & {model_name} &  & {DSC:.2f}  & {HD95:.2f} & {ASD:.2f}  & {NSD:.2f}\\"
MODEL_NAMES = ['AttentionUnet','UNet','MultiScale2DPermuteConcat','TwoDPermuteConcat','OneDConcat']
for model in MODEL_NAMES:
    run = get_run_from_model_name(model,rib_runs)
    print(run.id,run.config['MODEL_NAME'])
    eval_log_csv_path = f'/mnt/SSD0/mahesh-home/xrayto3D-benchmark/runs/2d-3d-benchmark/{run.id}/evaluation/metric-log.csv'
    df = pd.read_csv(eval_log_csv_path)
    latex_table += latex_table_row_template.format(
        model_name=run.config['MODEL_NAME'],
        DSC = df.mean(numeric_only=True).DSC,
        HD95 = df.mean(numeric_only=True).HD95,
        ASD = df.mean(numeric_only=True).ASD,
        NSD = df.mean(numeric_only=True).NSD)
latex_table

31nmkymw AttentionUnet
r83x5x6c UNet
iaqut884 MultiScale2DPermuteConcat
7lmns8ui TwoDPermuteConcat
4pmso51m OneDConcat


' & AttentionUnet &  & 0.52  & 4.48 & 1.23  & 0.52\\ & UNet &  & 0.46  & 4.76 & 1.35  & 0.47\\ & MultiScale2DPermuteConcat &  & 0.40  & 9.52 & 1.47  & 0.40\\ & TwoDPermuteConcat &  & 0.49  & 5.85 & 0.94  & 0.47\\ & OneDConcat &  & 0.28  & 16.88 & 2.62  & 0.28\\'