In [None]:
import os
from pathlib import Path 
import wandb
wandb.login()

In [None]:
os.chdir('/mnt/SSD0/mahesh-home/xrayto3D-benchmark/')
!pwd

In [None]:
def filter_wandb_run(anatomy,project_name='msrepo/2d-3d-benchmark',tags=['model-compare']):
    api = wandb.Api()
    runs = api.runs(project_name,filters={
        'tags':{'$in':tags}
    })
    filtered_runs = [ run for run in runs if run.config['ANATOMY'] == anatomy]
    return filtered_runs

In [None]:
def get_latest_checkpoint(path):
    checkpoints = list(Path(path).glob('epoch=*.ckpt'))
    latest_checkpoint_path = max(checkpoints,key=lambda x: x.lstat().st_ctime)
    return str(latest_checkpoint_path)

In [None]:
def get_experiment_name_from_model_name(model_name):
    expt_dict = {'OneDConcat':'ParallelHeadsExperiment','MultiScale2DPermuteConcat':'ParallelHeadsExperiment','TwoDPermuteConcatModel':'ParallelHeadsExperiment','AttentionUnet':'VolumeAsInputExperiment',
                 'TwoDPermuteConcat':'ParallelHeadsExperiment',
                 'OneDConcatModel':'ParallelHeadsExperiment',
    'UNet':'VolumeAsInputExperiment','TLPredictor':'TLPredictorExperiment'}
    return expt_dict[model_name]

In [None]:
def get_run_from_model_name(model_name,wandb_runs):
    for run in wandb_runs:
        if run.config['MODEL_NAME'] == model_name:
            return run
    raise ValueError(f'{model_name} not found')


In [None]:
train_csv_path = 'configs/paths/femur/30k/TotalSegmentor-femur-left-DRR-30k_train+val.csv'
test_csv_path = 'configs/paths/femur/30k/TotalSegmentor-femur-left-DRR-30k_test.csv'
gpu = 0
femur_img_size = 128
femur_resolution = 1.0
femur_runs = filter_wandb_run(anatomy='femur')
for run in femur_runs:
    print(run.id,run.config['MODEL_NAME'])

In [None]:
accelerator = 'gpu'
batch_size = 2
for run in femur_runs:
    model_name = run.config['MODEL_NAME']
    input_type = get_experiment_name_from_model_name(model_name)
    run_id = str(run.id)
    checkpoint_path = get_latest_checkpoint(f'runs/2d-3d-benchmark/{run_id}/checkpoints/') 
    output_dir = f'runs/2d-3d-benchmark/{run_id}/evaluation'
    command = f'python train.py  {train_csv_path} {test_csv_path} --gpu {gpu} --tags   model-compare --size {femur_img_size} --batch_size {batch_size} --accelerator {accelerator} --res {femur_resolution} --precision 16 --model_name {model_name} --experiment_name {input_type} --epochs -1 --anatomy femur --loss DiceLoss  --lr 0.002 --steps 3000 --evaluate --save_predictions --checkpoint_path {checkpoint_path} --output_dir {output_dir}'
    if model_name == 'TLPredictor':
        command += ' --load_autoencoder_from runs/2d-3d-benchmark/d2pnyx7v/checkpoints/last.ckpt'
    print(command,'\n')
    os.system(command)
clear_output(wait=True)

In [None]:
!python train.py  configs/paths/femur/30k/TotalSegmentor-femur-left-DRR-30k_train+val.csv configs/paths/femur/30k/TotalSegmentor-femur-left-DRR-30k_test.csv --gpu 0 --tags   model-compare --size 128 --batch_size 2 --accelerator gpu --res 1.0 --precision 16 --model_name MultiScale2DPermuteConcat --experiment_name ParallelHeadsExperiment --epochs -1 --anatomy femur --loss DiceLoss  --lr 0.002 --steps 3000 --evaluate --save_predictions --checkpoint_path runs/2d-3d-benchmark/odvwh69h/checkpoints/epoch=40-step=4000-val_loss=0.07-val_acc=0.93.ckpt --output_dir runs/2d-3d-benchmark/odvwh69h/evaluation 

In [77]:
import pandas as pd
latex_table = ""
latex_table_row_template = " & {model_name} & {model_size} & {DSC:.2f}  & {HD95:.2f} & {ASD:.2f}  & {NSD:.2f}\\"
MODEL_NAMES = ['AttentionUnet','UNet','MultiScale2DPermuteConcat','TwoDPermuteConcatModel','OneDConcatModel','TLPredictor']
model_sizes = {'AttentionUnet':'1.5M','UNet':'1.2M','MultiScale2DPermuteConcat':'3.5M','TwoDPermuteConcatModel':'1.2M','OneDConcatModel':'40.6M','TLPredictor':'6.6M'}
for model in MODEL_NAMES:
    run = get_run_from_model_name(model,femur_runs)
    print(run.id,run.config['MODEL_NAME'])
    eval_log_csv_path = f'/mnt/SSD0/mahesh-home/xrayto3D-benchmark/runs/2d-3d-benchmark/{run.id}/evaluation/metric-log.csv'
    df = pd.read_csv(eval_log_csv_path)
    latex_table += latex_table_row_template.format(
        model_name=run.config['MODEL_NAME'],
        DSC = df.mean(numeric_only=True).DSC,
        HD95 = df.mean(numeric_only=True).HD95,
        ASD = df.mean(numeric_only=True).ASD,
        NSD = df.mean(numeric_only=True).NSD,
        model_size=model_sizes[model])
latex_table

ceio7qj7 AttentionUnet
yp62xxp0 UNet
odvwh69h MultiScale2DPermuteConcat
axwfvhyg TwoDPermuteConcatModel
3x3ap0vm OneDConcatModel
iy85yu4l TLPredictor


' & AttentionUnet & 1.5M & 0.94  & 3.56 & 0.97  & 0.78\\ & UNet & 1.2M & 0.93  & 3.49 & 1.01  & 0.74\\ & MultiScale2DPermuteConcat & 3.5M & 0.93  & 3.35 & 1.06  & 0.77\\ & TwoDPermuteConcatModel & 1.2M & 0.94  & 3.01 & 0.88  & 0.78\\ & OneDConcatModel & 40.6M & 0.90  & 4.44 & 1.41  & 0.61\\ & TLPredictor & 6.6M & 0.88  & 5.03 & 1.95  & 0.52\\'