In [1]:
cd ..

/home/ikboljonsobirov/hecktor/fusion_vit/fusion_vit_project


In [2]:
import os
import sys
import pathlib

import numpy as np
import pandas as pd
import SimpleITK as sitk
from tqdm.notebook import tqdm
import torch
import matplotlib.pyplot as plt
%matplotlib inline
import nrrd
from einops import  rearrange
from monai.networks.nets import UNETR, SwinUNETR, SegResNet

from src.models.components.models import BaselineUNet, FastSmoothSENormDeepUNet_supervision_skip_no_drop
from src.models.sega_module import SegaModule
from src.data.components.sega_dataset import SegaDataset
from src.models.components.metrics import dice
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, Subset
from sklearn.model_selection import KFold


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'src.data.components.sega_dataset'

In [3]:
def read_nifti(path):
    """Read a NIfTI image. Return a SimpleITK Image."""
    nifti = sitk.ReadImage(str(path))
    return nifti


def write_nifti(sitk_img, path):
    """Save a SimpleITK Image to disk in NRRD format."""
    writer = sitk.ImageFileWriter()
    writer.SetImageIO("NrrdImageIO")
    writer.SetFileName(str(path))
    writer.Execute(sitk_img)


In [4]:
data_path = pathlib.Path('/share/sda/nurenzhaksylyk/segaorta_resampled/')

chkpt_path = '/home/ikboljonsobirov/hecktor/fusion_vit/fusion_vit_project/logs/train/runs/2023-09-27_21-13-02/checkpoints/epoch_099.ckpt'

output_path = pathlib.Path('/home/ikboljonsobirov/lightning-hydra-template/data/sega_val_pred/')

if not os.path.exists(output_path):
            os.makedirs(output_path, exist_ok=True)


In [5]:
from torch.utils.data import DataLoader
from src.data.augmentations import *

trans_pred = Compose([
                    NormalizeIntensity(),
                    ToTensor(), 
                    Resizing(z=256,x=256,y=256),
                    ])

trans_orig = Compose([
                    NormalizeIntensity(),
                    ToTensor(), 
                    ]) 

dataset_pred = SegaDataset(data_path, transforms=trans_pred)
dataset_orig = SegaDataset(data_path, transforms=trans_orig)


full_indices = range(len(dataset_pred))

kf = KFold(n_splits=5, shuffle=True, random_state=786)

train_idx = {}
test_idx = {}

key = 1
for i,j in kf.split(full_indices):
    train_idx[key] = i
    test_idx[key] = j

    key += 1

_, val_dataset_pred = Subset(dataset_pred, train_idx[1]), Subset(dataset_pred, test_idx[1])
_, val_dataset_orig = Subset(dataset_orig, train_idx[1]), Subset(dataset_orig, test_idx[1])

print(len(dataset_pred), len(val_dataset_orig), len(val_dataset_pred))



pred_loader = DataLoader(
            dataset=val_dataset_pred,
            batch_size=1,
            num_workers=1,
            pin_memory=True,
            shuffle=False,
            drop_last=True,
        )

orig_loader = DataLoader(
            dataset=val_dataset_orig,
            batch_size=1,
            num_workers=1,
            pin_memory=True,
            shuffle=False,
            drop_last=True,
        )


56 12 12


In [6]:
val_dataset_pred[0]['input'].shape

In [None]:
data_pred = next(iter(pred_loader))

In [12]:
data_orig = next(iter(orig_loader))

In [13]:
print(data_pred['id'])
print(data_orig['id'])

['K7']
['K7']


In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
model = SegResNet(in_channels=1, out_channels=1)

checkpoint = torch.load(chkpt_path)
state_dict = checkpoint['state_dict']
for key in list(state_dict):
    state_dict[key.replace("model.", "")] = state_dict.pop(key)

model.load_state_dict(state_dict)

model = model.to(device)

model.eval()


SegResNet(
  (act_mod): ReLU(inplace=True)
  (convInit): Convolution(
    (conv): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  )
  (down_layers): ModuleList(
    (0): Sequential(
      (0): Identity()
      (1): ResBlock(
        (norm1): GroupNorm(8, 8, eps=1e-05, affine=True)
        (norm2): GroupNorm(8, 8, eps=1e-05, affine=True)
        (act): ReLU(inplace=True)
        (conv1): Convolution(
          (conv): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        )
        (conv2): Convolution(
          (conv): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        )
      )
    )
    (1): Sequential(
      (0): Convolution(
        (conv): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
      )
      (1): ResBlock(
        (norm1): GroupNorm(8, 16, eps=1e-05, affine=True)
        (norm2): GroupNorm(8, 16, eps=1e-05, affine=Tr

In [25]:
metric_dictionary = {}
metric_dictionary['id'] = []
metric_dictionary['resize_metric'] = []


In [28]:
with torch.no_grad():
    for sample, orig in zip(pred_loader, orig_loader):

        input = sample['input']

        output = model(input)

        # y_pred = output.float()
        y_pred = torch.where(output>0.5, 1, 0).float()
        # Upsample back to original size
        y_pred = y_pred.squeeze(0)
        y_orig = tio.Resize(orig['input'].squeeze().shape, image_interpolation='nearest')(y_pred)
        # Save prediction:

        
        metric_value_rs = dice(y_pred.unsqueeze(0).detach(), sample['target'].detach())
        metric_value_or = dice(y_orig.unsqueeze(0).detach(), orig['target'].detach())

        metric_dictionary['id'].append(sample['id'][0])
        metric_dictionary['resize_metric'].append(metric_value_rs.item())
        metric_dictionary['orig_metric'].append(metric_value_or.item())
        
        print(f"id: {sample['id'][0]}, metric_rs: {metric_value_rs}, metric_org: {metric_value_or}")

        if not os.path.exists(output_path / sample['id'][0]):
            os.makedirs(output_path / sample['id'][0], exist_ok=True)

        # write_nifti(sitk.GetImageFromArray(y_pred.squeeze()),  str(output_path / sample['id'][0] / (sample['id'][0] +'_resize.nrrd')))
        # write_nifti(sitk.GetImageFromArray(y_orig.squeeze()),  str(output_path / sample['id'][0] / (sample['id'][0] +'.nrrd')))
        sitk.WriteImage(sitk.GetImageFromArray(y_orig.squeeze()),  str(output_path / sample['id'][0] / (sample['id'][0] +'.nrrd')), useCompression=True)
        break


KeyboardInterrupt: 