In [None]:
import torch
from fastmri.pl_modules.unet_module import UnetModule
from modules.roi_unet_module import RoiUnetModule
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import yaml
import glob
from fastmri.data.transforms import UnetDataTransform
from fastmri.pl_modules import FastMriDataModule
from fastmri.data.subsample import create_mask_for_mask_type
import pandas as pd

In [None]:
with open("roi_vs_base/logs/alpha1.0_roi100_epoch30/version_1/hparams.yaml", "r") as f:
    config_unet = yaml.safe_load(f)
    print(config_unet)

with open("roi_vs_base/logs/alpha1.25_roi100_epoch30/version_6/hparams.yaml", "r") as f:
    config_roi = yaml.safe_load(f)
    print(config_roi)

In [None]:
# !ls unet_checkpoints/logs/version_3/checkpoints

In [None]:
# Find the checkpoint file (returns list, so grab [0])
# unet_ckpt = glob.glob("roi_vs_base/alpha1.0_roi100_epoch30/checkpoints/epochepoch=*.ckpt")[0]
# roi_ckpt = glob.glob("roi_vs_base/alpha1.25_roi100_epoch30/checkpoints/epochepoch=*.ckpt")[0]
unet_ckpt = glob.glob("roi_vs_base/alpha1.0_roi100_epoch30/checkpoints/last.ckpt")[0]
roi_ckpt = glob.glob("roi_vs_base/alpha1.25_roi100_epoch30/checkpoints/last.ckpt")[0]

In [None]:
glob.glob("roi_vs_base/alpha1.25_roi100_epoch30/checkpoints/last.ckpt")

In [None]:
baseline_model = UnetModule.load_from_checkpoint(
    unet_ckpt,
    in_chans=config_unet["in_chans"],
    out_chans=config_unet["out_chans"],
    chans=config_unet["chans"],
    num_pool_layers=config_unet["num_pool_layers"],
    drop_prob=config_unet["drop_prob"],
    lr=config_unet["lr"],
    lr_step_size=config_unet["lr_step_size"],
    lr_gamma=config_unet["lr_gamma"],
    weight_decay=config_unet["weight_decay"],
)

# roi_model = ROIUnetModule.load_from_checkpoint(
roi_model = ROISobelUnetModule.load_from_checkpoint(
    roi_ckpt,
    alpha=config_roi["alpha"],
    # roi_size=config_roi["roi_size"],
    threshold=70,
    in_chans=config_roi["in_chans"],
    out_chans=config_roi["out_chans"],
    chans=config_roi["chans"],
    num_pool_layers=config_roi["num_pool_layers"],
    drop_prob=config_roi["drop_prob"],
    lr=config_roi["lr"],
    lr_step_size=config_roi["lr_step_size"],
    lr_gamma=config_roi["lr_gamma"],
    weight_decay=config_roi["weight_decay"],
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
baseline_model.to("cuda:0")
roi_model.to("cuda:0")

baseline_model.eval()
roi_model.eval()

In [None]:
mask = create_mask_for_mask_type(
    mask_type_str="random",
    center_fractions=[0.08],
    accelerations=[4],
)

transform = UnetDataTransform(which_challenge="singlecoil", mask_func=mask, use_seed=True)

data_module = FastMriDataModule(
    data_path=Path('/home/hice1/ltupac3/scratch/fastMRI_data'),
    challenge="singlecoil",
    train_transform=transform,
    val_transform=transform,
    test_transform=transform,
    batch_size=1,
    num_workers=0,
    sample_rate=None,
    distributed_sampler=False,
)

data_module.setup("validate")
val_loader = data_module.val_dataloader()

In [None]:
results = []
image_outputs = []

for batch in val_loader:
    image = batch.image
    target = batch.target
    fname = batch.fname
    slice_num = int(batch.slice_num[0])
    image = image.to(device)
    target = target.to(device)

    with torch.no_grad():
        out_base = baseline_model(image)
        out_roi = roi_model(image)

    abs_base = torch.abs(out_base - target)
    abs_roi = torch.abs(out_roi - target)

    # roi_mask = roi_model.sobel_edge_map(target)
    # roi_mask = roi_model.threshold_edges(roi_mask, percentile=70).unsqueeze(0)
    # print(roi_mask.shape)

    roi_mask = roi_mask.unsqueeze(0).expand_as(out_base)

    '''
        fname: The name of the file
        slice: which 2d slice 
        baseline_l1: L1 error of the baseline model for the full image 
        baseline_roi_l1: L1 error of the baseline model in the ROI
        roi_l1: L1 error of the ROI model for the full image 
        roi_roi_l1: L1 error of the ROI model
    '''

    results.append({
        "fname": fname[0],
        "slice": slice_num,
        "baseline_l1": abs_base.mean().item(),
        "baseline_roi_l1": abs_base[roi_mask.bool()].mean().item(),
        "roi_l1": abs_roi.mean().item(),
        "roi_roi_l1": abs_roi[roi_mask.bool()].mean().item()
    })

    image_outputs.append({
        "target": target.detach().cpu().squeeze().numpy(),
        "baseline": out_base.detach().cpu().squeeze().numpy(),
        "roi": out_roi.detach().cpu().squeeze().numpy(),
        "fname": fname[0],
        "slice": slice_num
    })



In [None]:
df = pd.DataFrame(results)

In [None]:
df[["baseline_l1", "roi_l1"]].mean()

In [None]:
df[["baseline_roi_l1", "roi_roi_l1"]].mean()

In [None]:
plt.hist(df["baseline_roi_l1"], bins=30, alpha=0.5, label="Baseline ROI L1")
plt.hist(df["roi_roi_l1"], bins=30, alpha=0.5, label="ROI-weighted ROI L1")
plt.legend()
plt.title("ROI L1 Loss Distribution")
plt.show()

In [None]:
# plt.plot(df["roi_roi_l1"], label="ROI Unet ROI L1")
# plt.plot(df["baseline_roi_l1"], label="Baseline ROI L1")
# plt.legend()
# plt.title("Per-slice ROI L1 Loss")
# plt.xlabel("Sample index")
# plt.ylabel("ROI L1")
# plt.show()

In [None]:
delta = df["baseline_roi_l1"] - df["roi_roi_l1"]
biggest_gain = df.iloc[delta.idxmax()]
print("Most improved slice:", biggest_gain["fname"], "Slice", biggest_gain["slice"])

In [None]:
target_np = target.squeeze().cpu().numpy()
base_np = out_base.squeeze().cpu().numpy()
roi_np = out_roi.squeeze().cpu().numpy()

In [None]:
def overlay_roi_box(ax, roi_size=(128, 128), img_shape=(320, 320), color='r'):
    rh, rw = roi_size
    h, w = img_shape
    start_h = h // 2 - rh // 2
    start_w = w // 2 - rw // 2
    rect = plt.Rectangle((start_w, start_h), rw, rh, linewidth=2, edgecolor=color, facecolor='none')
    ax.add_patch(rect)

In [None]:
# fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# axs[0].imshow(target_np, cmap="gray")
# axs[0].set_title("Ground Truth")

# axs[1].imshow(base_np, cmap="gray")
# axs[1].set_title("Baseline Output")
# overlay_roi_box(axs[1], roi_size=roi_model.roi_size, img_shape=base_np.shape)

# axs[2].imshow(roi_np, cmap="gray")
# axs[2].set_title("ROI-Weighted Output")
# overlay_roi_box(axs[2], roi_size=roi_model.roi_size, img_shape=roi_np.shape)

# for ax in axs:
#     ax.axis("off")

# plt.tight_layout()
# plt.show()


## Display any image index

In [None]:
idx = 1000
sample = image_outputs[idx]

fig, axs = plt.subplots(1, 3, figsize=(15, 5))

axs[0].imshow(sample["target"], cmap="gray")
axs[0].set_title(f"Ground Truth\n{sample['fname']} Slice {sample['slice']}")

axs[1].imshow(sample["baseline"], cmap="gray")
axs[1].set_title("Baseline Output")
overlay_roi_box(axs[1], roi_size=roi_model.roi_size, img_shape=sample["baseline"].shape)

axs[2].imshow(sample["roi"], cmap="gray")
axs[2].set_title("ROI-weighted Output")
overlay_roi_box(axs[2], roi_size=roi_model.roi_size, img_shape=sample["roi"].shape)

for ax in axs:
    ax.axis("off")

plt.tight_layout()
plt.show()


## Most improved slice

In [None]:
slice_match = next(
    s for s in image_outputs
    if s["fname"] == biggest_gain["fname"] and s["slice"] == biggest_gain["slice"]
)


In [None]:
slice_match["fname"] #, slice_match["target"]

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

axs[0].imshow(slice_match["target"], cmap="gray")
axs[0].set_title("Ground Truth")

axs[1].imshow(slice_match["baseline"], cmap="gray")
axs[1].set_title("Baseline Output")

axs[2].imshow(slice_match["roi"], cmap="gray")
axs[2].set_title("ROI-Weighted Output")

for ax in axs:
    ax.axis("off")

plt.tight_layout()
plt.show()


In [None]:
# find filename file1000022.h5 in fname column of df 