In [1]:
cd ..

/home/ikboljonsobirov/hecktor/fusion_vit/fusion_vit_project


In [2]:
import os
import json
import shutil
import tempfile
import time
import random
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib

from monai.losses import DiceLoss, DiceCELoss, FocalLoss, GeneralizedDiceFocalLoss
from monai.inferers import sliding_window_inference
from monai import transforms
from monai.transforms import (
    AsDiscrete,
    Activations,
    Compose,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.utils.enums import MetricReduction
from monai.networks.nets import SwinUNETR, UNet, SegResNet, UNETR
from monai import data
from monai.metrics import DiceMetric
from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)
from functools import partial
from src.data.augmentations import *

import torch
import SimpleITK as sitk
from src.models.components.metrics import dice
from src.models.components.unetr.unetr import CustomUNETR

from einops import rearrange


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = np.where(self.count > 0, self.sum / self.count, self.sum)


def datafold_read(datalist, basedir, fold=0, key="training"):
    with open(datalist) as f:
        json_data = json.load(f)

    json_data = json_data[key]

    for d in json_data:
        for k in d:
            if isinstance(d[k], list):
                d[k] = [os.path.join(basedir, iv) for iv in d[k]]
            elif isinstance(d[k], str):
                d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]

    tr = []
    val = []
    for d in json_data:
        if "fold" in d and d["fold"] == fold:
            val.append(d)
        else:
            tr.append(d)

    return tr, val

In [4]:
roi = (96,96,96)

In [None]:
def get_loader(batch_size, sw_batch_size, data_dir, json_list, fold):
    data_dir = data_dir
    datalist_json = json_list

    train_files, validation_files = datafold_read(datalist=datalist_json, basedir=data_dir, fold=fold)
    train_transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image","image2", "label"], ensure_channel_first = True),
            transforms.SpatialPadd(keys=["image", "label"], spatial_size=(roi[0], roi[1], roi[2]), method='end'),
            transforms.RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=roi,
                pos=1,
                neg=1,
                num_samples=1,
                image_key="image",
                image_threshold=0,
            ),
            transforms.NormalizeIntensityd(keys=["image", "image2"]),
 
        ]
    )

    val_transform = transforms.Compose(
        [
            transforms.LoadImaged(keys=["image","image2", "label"], ensure_channel_first = True),
            transforms.SpatialPadd(keys=["image", "label"], spatial_size=(roi[0], roi[1], roi[2]), method='end'),
        ]            
    )
    train_ds = data.Dataset(data=train_files, transform=train_transform)


    train_loader = data.DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    val_ds = data.Dataset(data=validation_files, transform=val_transform)

    val_loader = data.DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )

    return train_loader, val_loader, train_ds, val_ds

In [None]:
root_dir = ""

data_dir = '/share/nvmedata/ikboljonsobirov/fusion_vit/hecktor2022_cropped/'
datalist_json = '/home/ikboljonsobirov/hecktor/fusion_vit/fusion_vit_project/files/train_json_new.json'
batch_size = 2
sw_batch_size = 8
fold = 0 # 0,1,2,3,4
infer_overlap = 0.5

chkpt_path = '/home/ikboljonsobirov/hecktor/fusion_vit/fusion_vit_project/logs/train/runs/2023-10-27_12-13-58/checkpoints/epoch_076.ckpt'
# chkpt_path = '/home/ikboljonsobirov/hecktor/fusion_vit/fusion_vit_project/logs/train/runs/segresnet_fold0/checkpoints/epoch_088.ckpt'

In [None]:
train_loader, val_loader, train_ds, val_ds = get_loader(batch_size, sw_batch_size, data_dir, datalist_json, fold)