In [1]:
from data_modules.ocelot import OcelotDataModule
from models.dynunet import DynUNetModel
import monai.transforms as T
from lightning.pytorch import seed_everything
import torch
from system import System
from monai.inferers import SlidingWindowInferer
from monai.metrics import (
    DiceMetric,
    HausdorffDistanceMetric,
)
import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from ocelot_util import normalize_crop_coords_batch
import torch.nn as nn
from torchvision.utils import save_image
from torchvision.transforms.functional import to_pil_image

seed_everything(42)

[rank: 0] Seed set to 42


42

In [2]:
print(torch.cuda.is_available())

False


In [3]:
paths = ['./lightning_logs/4lsj3sdu/checkpoints/model-epoch=98-val_tissue_dice=0.75.ckpt', 
        './lightning_logs/sg9vsaxz/checkpoints/model-epoch=200-val_tissue_dice=0.80.ckpt',
        './lightning_logs/m3ruzf72/checkpoints/model-epoch=149-val_tissue_dice=0.78.ckpt',
        './lightning_logs/7jddlwlw/checkpoints/model-epoch=176-val_tissue_dice=0.85.ckpt',
        './lightning_logs/3rmq2rbn/checkpoints/model-epoch=206-val_tissue_dice=0.85.ckpt']
models = {}
for i in range(5):
    tissue_segmentation = System.load_from_checkpoint(checkpoint_path=paths[i])
    tissue_inferer = SlidingWindowInferer(roi_size=(896, 896),
            overlap=0.5, sw_batch_size= 1, mode= "gaussian")
    models[f'fold_{i}'] = tissue_inferer

Attribute 'net' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['net'])`.


In [4]:
folds = {}

for i in range(5):
    data_module = OcelotDataModule(batch_size=1, num_workers=1, cv_folds=5, fold_index=i)
    data_module.prepare_data()
    data_module.setup()
    folds[f'fold_{i}'] = data_module.val_dataloader()

Train: 320 | Val: 80 | Test: 126
Train: 320 | Val: 80 | Test: 126
Train: 320 | Val: 80 | Test: 126
Train: 320 | Val: 80 | Test: 126
Train: 320 | Val: 80 | Test: 126


In [5]:
post_transform = T.Compose([
    T.Activations(sigmoid=True, softmax=False),
    T.AsDiscrete(threshold=0.5),
])

In [6]:
save_dir = "data/ocelot_tissue/training_cropped_labels"
os.makedirs(save_dir, exist_ok=True)

In [7]:
for i in range(5):
    loader = folds[f'fold_{i}']
    tissue_inferer = models[f'fold_{i}']
    for j, batch in enumerate(tqdm(loader, desc="Evaluating")):
        img = batch["img_tissue"]
        pred = tissue_inferer(img, tissue_segmentation)
        pred_post = post_transform(pred)
        meta_batch = batch["meta"]
        id = meta_batch["sample_id"][0]
   
        cropped_coords = normalize_crop_coords_batch(meta_batch)
        
        x1, y1, x2, y2 = cropped_coords[0]
        tissue_crop = pred_post[0:1, :, y1:y2, x1:x2]
        tissue_crop_resized = nn.functional.interpolate(tissue_crop, size=img.shape[2:], mode="bilinear")
 
        img_to_save = tissue_crop_resized.squeeze(0)  
        img_pil = to_pil_image(img_to_save)

        save_path = os.path.join(save_dir, f"{id}.png")
        img_pil.save(save_path)
        

Evaluating: 100%|██████████| 80/80 [20:13<00:00, 15.17s/it]
Evaluating: 100%|██████████| 80/80 [20:34<00:00, 15.43s/it]
Evaluating: 100%|██████████| 80/80 [20:43<00:00, 15.54s/it]
Evaluating: 100%|██████████| 80/80 [20:52<00:00, 15.65s/it]
Evaluating: 100%|██████████| 80/80 [20:54<00:00, 15.68s/it]
