In [1]:
import os
import json
import shutil
import tempfile
import time
import random
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib

from monai.losses import DiceLoss, DiceCELoss, FocalLoss, GeneralizedDiceFocalLoss
from monai.inferers import sliding_window_inference
from monai import transforms
from monai.transforms import (
    AsDiscrete,
    Activations,
    Compose,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.utils.enums import MetricReduction
from monai.networks.nets import SwinUNETR, UNet, SegResNet, UNETR
from monai import data
from monai.metrics import DiceMetric
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)
from functools import partial

import torch
import SimpleITK as sitk
from einops import rearrange


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def datafold_read(datalist, basedir, fold=0, key="training"):
    with open(datalist) as f:
        json_data = json.load(f)

    json_data = json_data[key]

    for d in json_data:
        for k in d:
            if isinstance(d[k], list):
                d[k] = [os.path.join(basedir, iv) for iv in d[k]]
            elif isinstance(d[k], str):
                d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]

    tr = []
    val = []
    for d in json_data:
        if "fold" in d and d["fold"] == fold:
            val.append(d)
        else:
            tr.append(d)

    return tr, val

In [3]:
roi = (192,192,192)

In [5]:
class LoadImageTorch:
    def __init__(self):
        pass
    
    def __call__(self, path_to_data_dir):
        if path_to_data_dir is None:
            print('Please provide directory to the data path')
        else:
            img_data = self.read_torch_file(path_to_data_dir['path'])

            return img_data

            
    @staticmethod
    def read_torch_file(path):
        img = torch.load(path)
        return img


In [6]:
def get_loader(batch_size, sw_batch_size, data_dir, json_list, fold):
    data_dir = data_dir
    datalist_json = json_list

    train_files, validation_files = datafold_read(datalist=datalist_json, basedir=data_dir, fold=fold)
    train_transform = transforms.Compose(
        [
            LoadImageTorch(),
            # LoadImagedMonai(keys=["image","image2", "label"], ensure_channel_first = True),
            # transforms.SpatialPadd(keys=["image", "label"], spatial_size=(roi[0], roi[1], roi[2]), method='end'),
            transforms.RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=roi,
                pos=1,
                neg=1,
                num_samples=1,
                image_key="image",
                image_threshold=0,
            ),
            # transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
            # # transforms.ScaleIntensityRanged(keys=["image"], a_min=-1024, a_max=1024, b_min=0.0, b_max=1.0, clip=True),
            # # ClipCT(keys=["image"]),

            # transforms.RandFlipd(
            #     keys=["image", "label"],
            #     spatial_axis=[0],
            #     prob=0.20,
            # ),
            # transforms.RandFlipd(
            #     keys=["image", "label"],
            #     spatial_axis=[1],
            #     prob=0.20,
            # ),
            # transforms.RandFlipd(
            #     keys=["image", "label"],
            #     spatial_axis=[2],
            #     prob=0.20,
            # ),
            # transforms.RandRotate90d(
            #     keys=["image", "label"],
            #     prob=0.20,
            #     max_k=3,
            # ),
            # transforms.RandShiftIntensityd(
            #     keys=["image"],
            #     offsets=0.10,
            #     prob=0.50,
            # ),
            # transforms.RandZoomd(   #added new
            #     keys=["image", "label"],
            #     prob = 0.5,
            #     min_zoom = 0.85,
            #     max_zoom = 1.15,
            #     mode = ['area', 'nearest'],
            # ),
            
    
            # transforms.RandScaleIntensityd(keys="image", factors=0.1, prob=0.5),
            # transforms.RandShiftIntensityd(keys="image", offsets=0.1, prob=0.5),
        ]
    )

    val_transform = transforms.Compose(
        [
            LoadImageTorch(),
            # LoadImagedMonai(keys=["image", "image2", "label"], ensure_channel_first = True),
            # transforms.Orientationd(keys=["image", "label"], axcodes="RAS"),
            # # transforms.CropForegroundd(keys=["image", "label"], source_key="image"),
            # # transforms.ScaleIntensityRanged(keys=["image"], a_min=-1024, a_max=1024, b_min=0.0, b_max=1.0, clip=True),
            # transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        ]
    )
    train_ds = data.Dataset(data=train_files, transform=train_transform)


    train_loader = data.DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    val_ds = data.Dataset(data=validation_files, transform=val_transform)

    val_loader = data.DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )

    return train_loader, val_loader, train_ds, val_ds

In [8]:
root_dir = ""


data_dir = '/share/nvmedata/ikboljonsobirov/fusion_vit/hecktor2022_torch/'
datalist_json = '/home/ikboljonsobirov/hecktor/fusion_vit/fusion_vit_project/files/train_json_torch.json'
batch_size = 1
sw_batch_size = 1
fold = 0 # 0,1,2,3,4
# roi = (64,64,64)
infer_overlap = 0.5

In [9]:
train_loader, val_loader, train_ds, val_ds = get_loader(batch_size, sw_batch_size, data_dir, datalist_json, fold)

In [10]:
a = next(iter(val_loader))

In [12]:
torch.unique(a['seg']), a['ctpt'].shape, a['id']

(metatensor([0., 1.]), torch.Size([1, 2, 200, 200, 310]), ['CHUV-008'])