In [None]:
import os
import argparse
import json
import pathlib

import numpy as np
import torch
import torch.nn as nn
import torch.optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from skimage import img_as_ubyte
from matplotlib import pyplot as plt

from monai.utils import set_determinism
from monai.data import DataLoader, Dataset, CacheDataset
from monai.utils import first
from monai.networks.nets import UNet
from monai.losses import DiceLoss, DiceFocalLoss, FocalLoss
from monai.networks.layers import Norm
from monai.inferers import sliding_window_inference
from monai.metrics import compute_meandice
from monai.transforms import AsDiscrete
from monai.visualize.img2tensorboard import plot_2d_or_3d_image

In [None]:
import sys
libdir = "../"
sys.path.insert(0, libdir)
import utils.data
from utils.data import get_surf_srep_split, get_srep_data_transform
import utils.misc as workspace

In [None]:
exp_name = "run4_1000_mod"
checkpoint = "latest"
experiment_dir = os.path.join(libdir, "experiments", exp_name)

# Setup the checkpoint and model eval dirs in exp_dir
checkpt_dir = os.path.join(experiment_dir, workspace.checkpoint_subdir)
eval_dir = os.path.join(experiment_dir, workspace.evaluation_subdir)
if not os.path.isdir(checkpt_dir):
    os.makedirs(checkpt_dir)
if not os.path.isdir(eval_dir):
    os.makedirs(eval_dir)

with open(os.path.join(experiment_dir, "specs.json"), "r") as f:
    specs = json.load(f)
train_data_dir = specs["DataSource"]
learning_rate = specs["LearningRate"]
num_epochs = specs["Epochs"]
save_epoch = specs["SaveEvery"]
batch_size = specs["BatchSize"]
if_debug = specs["Debug"]
resize_shape = specs["ResizeShape"]
print(
    f'Learning Rate:{learning_rate} | Epochs:{num_epochs} | BatchSize:{batch_size}')
print(f"Training data dir: {train_data_dir}")


In [None]:
train_data_dir = os.path.expanduser(train_data_dir)
print(train_data_dir)

In [None]:
# data_transforms = get_srep_data_transform((resize_shape, resize_shape, resize_shape))
# trn_files, val_files, tst_files = get_surf_srep_split(train_data_dir, random_shuffle=False, debug=if_debug)
# all_files = trn_files + val_files
h_data_dir = "../data/hippocampi/"
val_files = utils.data.get_hippocampi_files(h_data_dir)
data_transforms = utils.data.get_hippocampi_transform((resize_shape, resize_shape, resize_shape))
val_ds = CacheDataset(data=val_files, transform=data_transforms, cache_rate=0.8, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)

device = torch.device("cuda")
model = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
)

criterion1 = DiceLoss(sigmoid=True)
criterion2 = FocalLoss()

optimizer = torch.optim.Adam(model.parameters(), 1e-4)

In [None]:
val_files[0]['fname']

In [None]:
model = nn.DataParallel(model).cuda()
model.eval()
saved_epoch = workspace.load_model_checkpoint(experiment_dir, checkpoint, model)
print(saved_epoch)

In [None]:
# val_ds[0]['image_meta_dict']
eval_saving_dir = os.path.join(experiment_dir, workspace.evaluation_subdir)
print(eval_saving_dir)


In [None]:
with torch.no_grad():
    for i, val_data in tqdm(enumerate(val_loader)):
        val_inp = val_data["image"].to(device)
        eval_file = val_data["fname"][0] + "_eval.npy"
        out_logits = model(val_inp)
        out_img = torch.sigmoid(out_logits).detach().cpu()
        np.save(os.path.join(eval_saving_dir, eval_file), out_img[0,0].numpy())

In [None]:
sw = SummaryWriter(eval_dir)
with torch.no_grad():
    for i, val_data in tqdm(enumerate(val_loader)):
            val_inp = val_data["image"].to(device)
            val_lab = val_data["label"].to(device)
            out_logits = model(val_inp)
            out_img = torch.sigmoid(out_logits).detach().cpu()
            plot_2d_or_3d_image(data=val_lab, step=i, writer=sw, frame_dim=-1, tag='label')
            plot_2d_or_3d_image(data=out_img, step=i, writer=sw, frame_dim=-1, tag='image')

In [None]:
def visualize_model(model, val_loader, device=torch.device("cuda:0"), slice=80):
    with torch.no_grad():
        for i, val_data in enumerate(val_loader):
            roi_size = (128, 128, 128)
            sw_batch_size = 4
            # val_outputs = sliding_window_inference(
            #     val_data["image"].to(device), roi_size, sw_batch_size, model
            # )
            # plot the slice [:, :, 10]
            val_inp = val_data["image"].to(device)
            eval_file = val_data["fname"][0] + ".npy"
            print(eval_file, val_data["fname"][0])
            # val_lab = val_data["label"].to(device)
            out_logits = model(val_inp)
            out_img = torch.sigmoid(out_logits).detach().cpu()
            # print(val_inp.shape, val_lab.shape, out_img.shape)
            print(val_inp.shape, out_img.shape)
            fig = plt.figure("check", (18, 6))
            plt.subplot(1, 3, 1)
            plt.title(f"image {i}")
            plt.imshow(val_data["image"][0, 0, :, :, slice], cmap="gray")
            plt.subplot(1, 3, 2)
            # plt.title(f"label {i}")
            # plt.imshow(val_data["label"][0, 0, :, :, slice])
            plt.subplot(1, 3, 3)
            plt.title(f"output {i}")
            # plt.imshow(torch.argmax(val_outputs, dim=1).detach().cpu()[0, :, :, slice])
            plt.imshow(img_as_ubyte(out_img[0, 0, :, :, slice]))
            plt.show()
            fig.savefig(os.path.join(eval_saving_dir, f"hipp_slice{slice}_{i}.png"))
            # Save the numpy array as well
            np.save(os.path.join(eval_saving_dir, eval_file), out_img[0,0].numpy())

In [None]:
visualize_model(model, val_loader, device)

In [None]:
visualize_model(model, val_loader, device, slice=55)

In [None]:
visualize_model(model, val_loader, device, slice=102)

In [None]:
visualize_model(model, val_loader, device, slice=67)