In [None]:
import argparse, os, os.path as osp, json
import numpy as np
from PIL import Image
import torch

from data.aug import val_trfm
from data import ts_spinelspelvic, pengwin, ctpelvic1k
from util import color_seg, compact_image_grid, std_to_rgb
from modules import build_model
from evaluation import pred_volume

In [None]:
device = "cuda:1"

In [None]:
bin_fn_bone = functools.partial(binarise_totalseg_label, coi=TOTALSEG_CLS_SET["bone"])
bin_fn_spine = functools.partial(binarise_totalseg_label, coi=TOTALSEG_CLS_SET["spine"])

In [None]:
@torch.no_grad()
def draw_vol(model, loader, save_dir="paper_fig/quali/stage2"):
    os.makedirs(save_dir, exist_ok=True)
    _, pred_vol = pred_volume(model, loader) # axis order: (LR, AP, IS)
    for i in pred_vol.shape[2]:
        color_seg(pred_vol[:, :, i]).save(osp.join(save_dir, f"{i}.png"))
        print(i, end='\r')

# totalseg-spineLSpelvic-small

In [None]:
volume_id = "s1423"
data_root = "~/sd10t/totalsegmentator"
log_path = "log/totalseg-spineLSpelvic-small"

In [None]:
with open(os.path.join(log_path, "config.json"), 'r') as f:
    args = argparse.Namespace(**json.load(f))

In [None]:
val_trans = val_trfm(args)

In [None]:
ds = ts_spinelspelvic.VolumeDataset(volume_id, "full", val_trans,
    args.window, args.window_level, args.window_width, data_root)
loader = torch.utils.data.DataLoader(ds, batch_size=32, shuffle=False)

In [None]:
model = build_model(args).cuda()
model.load_state_dict(torch.load(os.path.join(log_path, "2nd-re_stu/best_val.pth"))["model"])
model.eval()

In [None]:
draw_vol(model, loader, "paper_fig/quali/totalseg-spineLSpelvic-small/stage2/{}".format(volume_id))

# pengwin

In [None]:
volume_id = "085"
data_root = "~/sd10t/pengwin"
log_path = "log/pengwin"

In [None]:
with open(os.path.join(log_path, "config.json"), 'r') as f:
    args = argparse.Namespace(**json.load(f))

In [None]:
val_trans = val_trfm(args)

In [None]:
ds = pengwin.VolumeDataset(volume_id, "full", val_trans,
    args.window, args.window_level, args.window_width, data_root)
loader = torch.utils.data.DataLoader(ds, batch_size=32, shuffle=False)

In [None]:
model = build_model(args).cuda()
model.load_state_dict(torch.load(os.path.join(log_path, "2nd-re_stu/best_val.pth"))["model"])
model.eval()

In [None]:
draw_vol(model, loader, "paper_fig/quali/pengwin/stage2/{}".format(volume_id))

# ctpelvic1k

In [None]:
volume_id = "d1_0065"
data_root = "~/sd10t/ctpelvic1k"
log_path = "log/ctpelvic1k"

In [None]:
with open(os.path.join(log_path, "config.json"), 'r') as f:
    args = argparse.Namespace(**json.load(f))

In [None]:
val_trans = val_trfm(args)

In [None]:
ds = ctpelvic1k.VolumeDataset(volume_id, "full", val_trans,
    args.window, args.window_level, args.window_width, data_root)
loader = torch.utils.data.DataLoader(ds, batch_size=32, shuffle=False)

In [None]:
model = build_model(args).cuda()
model.load_state_dict(torch.load(os.path.join(log_path, "2nd-re_stu/best_val.pth"))["model"])
model.eval()

In [None]:
draw_vol(model, loader, "paper_fig/quali/ctpelvic1k/stage2/{}".format(volume_id))