In [None]:
DEBUG = True
import os
import shutil

In [None]:
import gc
import ast
import cv2
import time
from timm0412 import timm as timm # timm0412 means timm v0.4.12
# import timm

import pickle
import random
import pydicom
import argparse
import warnings
import threading
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
from glob import glob

import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp

import torch
import torch.nn as nn
import torch.optim as optim
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from pylab import rcParams

%matplotlib inline




In [None]:
timm.__version__#, timm4smp.__version__

In [None]:
import monai.transforms as transforms
from monai.transforms import Resize

In [None]:
# pip3 install dask-cuda
# conda install -c rapidsai -c conda-forge -c nvidia cuml=21.10
import cudf
from dask_cuda import LocalCUDACluster
from dask.distributed import Client, wait

In [None]:
CUDA_VISIBLE_DEVICES = "0"
os.environ["CUDA_VISIBLE_DEVICES"]=CUDA_VISIBLE_DEVICES

In [None]:
#Uses a RAID folder if it is available (DGX), if not it uses the /tmp folder for DASK workspace
if os.path.exists('/raid'):
    local_directory = '/raid'
else:    
    local_directory = '/tmp'
    
dask_workdir = os.path.join(local_directory, 'dask-workdir')    
print('Dask dir:', dask_workdir)

In [None]:
# Make sure we have a clean worker space for Dask
if os.path.isdir(dask_workdir):
    shutil.rmtree(dask_workdir) # to Delete Non-Empty Directory
os.mkdir(dask_workdir)

<p>Dask is often used in situations where the data are too big to fit in memory. In these cases the data are split into chunks or partitions. Each task is computed on the chunk and then the results are aggregated.</p>

<p>For chunked data, if each worker is able to comfortably hold one data chunk in memory and do some computation on that data, then the number of chunks should be a multiple of the number of workers. This ensures that there is always enough work for a worker to do.</p>

In [None]:
# cluster manager class LocalCUDACluster, this cluster manager is optimised for a single piece of hardware.
# create a Dask-CUDA cluster using all available GPUs
cluster = LocalCUDACluster(#dashboard_address=':8800',
                        CUDA_VISIBLE_DEVICES = CUDA_VISIBLE_DEVICES,
                        rmm_pool_size=None, # no RMM pool is initialized.
                        device_memory_limit=7516192768, # spill to host memory when 7GB of 8GB (GPU) is reached
                        local_directory = dask_workdir) # launches a "scheduler" and workers locally

# https://dask-cuda.readthedocs.io/en/stable/api.html#dask_cuda.LocalCUDACluster
# https://docs.rapids.ai/api/dask-cuda/stable/api.html
# https://docs.dask.org/en/latest/deploying.html


# connect a Dask.distributed Client to Dask-CUDA cluster:
client = Client(cluster) #  client = Client() set up local cluster on your laptop
# when we create a Client object it registers itself as the default Dask scheduler. -
# - all .compute() methods will automatically start using the distributed system.


# https://distributed.dask.org/en/latest/quickstart.html
# 7GB = 7516192768 bytes

In [None]:
# # Initialize RMM pool on ALL workers
# def _rmm_pool():
#     rmm.reinitialize(
#         pool_allocator=True,
#         initial_pool_size=None, # Use default size
#     )

# # https://medium.com/rapids-ai/reading-larger-than-memory-csvs-with-rapids-and-dask-e6e27dfa6c0f
# # https://github.com/rapidsai/rmm
# # https://docs.rapids.ai/api/rmm/stable/basics.html


# client.run(_rmm_pool)
# client

In [None]:
# uses RAM memory when cudf spills over gpu memory
client.run(cudf.set_allocator, "managed")  # uses managed memory instead of "default"
# https://github.com/rapidsai/cudf/blob/4e66281f48c55735edb4b610e0f859ee2de32a75/python/cudf/cudf/utils/utils.py#L193
# https://distributed.dask.org/en/stable/api.html#distributed.Client.run


client

In [None]:
device = torch.device('cuda')
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
data_dir = './'
image_size_seg = (128, 128, 128)
msk_size = image_size_seg[0]
image_size_cls = 224
n_slice_per_c = 15
n_ch = 5

R = Resize([2, 224, 6], mode="trilinear") # monai => Resize => 'trilinear' for 3d output

batch_size_seg = 1
num_workers = 12 # 2
model_dir_seg = './kaggle'

In [None]:
if DEBUG:
    df = pd.read_csv(os.path.join(data_dir, 'train_seg.csv'))# train.csv 
    df = pd.DataFrame({
        'StudyInstanceUID': df['StudyInstanceUID'].unique().tolist()
    })
    df['image_folder'] = df['StudyInstanceUID'].apply(lambda x: os.path.join(data_dir, 'train_images', x))


df.tail()

In [None]:
df.head()

# Dataset

In [None]:
def load_dicom(path):
    dicom = pydicom.read_file(path)
    data = dicom.pixel_array
    data = cv2.resize(data, (image_size_seg[0], image_size_seg[1]), interpolation = cv2.INTER_AREA)
    return data


def load_dicom_line_par(path):

    t_paths = sorted(glob(os.path.join(path, "*")), key=lambda x: int(x.split('/')[-1].split(".")[0]))
    
    n_scans = len(t_paths)
#     print(n_scans)
    indices = np.quantile(list(range(n_scans)), np.linspace(0., 1., image_size_seg[2])).round().astype(int)
    t_paths = [t_paths[i] for i in indices]

    images = []
    for filename in t_paths:
        images.append(load_dicom(filename))
    images = np.stack(images, 0) # -1
    
    images = images - np.min(images)
    images = images / (np.max(images) + 1e-4)
    images = (images * 255).astype(np.uint8)

    return images, indices


class SegTestDataset(Dataset):

    def __init__(self, df):
        self.df = df.reset_index()

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        row = self.df.iloc[index]

        image, index = load_dicom_line_par(row.image_folder)
        if image.ndim < 4:
            image = np.expand_dims(image, axis=0).repeat(3, axis=0) # to 3ch

        image = image / 255.
        return torch.tensor(image).float(), index


In [None]:
# dataset_seg = SegTestDataset(df)

In [None]:
# if DEBUG:
#     rcParams['figure.figsize'] = 20,8
#     for i in range(1):
#         f, axarr = plt.subplots(1,4)
#         for p in range(4):
#             idx = i*4+p
#             img = dataset_seg[idx]
#             # img.shape => torch.Size([3, 128, 128, 128])

#             img = img[:, 60, :, :] # picking 60th dcm image of a patient/folder.
#             # img.shape => torch.Size([3, 128, 128])
            
#             axarr[p].imshow(img.transpose(0, 1).transpose(1,2).squeeze())
#             # img.transpose(0, 1).shape => torch.Size([128, 3, 128])
#             # img.transpose(0, 1).transpose(1,2).shape => torch.Size([128, 128, 3])

# Model

In [None]:
from conv2d_same import Conv2dSame
from conv3d_same import Conv3dSame


def convert_3d(module):

    module_output = module
    if isinstance(module, torch.nn.BatchNorm2d):
        module_output = torch.nn.BatchNorm3d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats,
        )
        
        # module => BatchNorm2d
        # affine – a boolean value that when set to True, this module has learnable affine parameters.
        # parameters weight and bias are only defined if the argument affine is set to True.
        if module.affine:
            # torch.no_grad() temporarily sets all of the requires_grad flags to false.
            # 'requires_grad' flag is set then model will compute gradient w.r.t to parameter.
            with torch.no_grad():
            # with => ensures that resource is "cleaned up" when the code that uses it finishes running, even if exceptions are thrown.
            # with torch.no_grad() => disable gradient calculation in this context.
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
        if hasattr(module, "qconfig"):
            module_output.qconfig = module.qconfig
            
    elif isinstance(module, Conv2dSame):
        module_output = Conv3dSame(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.Conv2d):
        module_output = torch.nn.Conv3d(
            in_channels=module.in_channels,
            out_channels=module.out_channels,
            kernel_size=module.kernel_size[0],
            stride=module.stride[0],
            padding=module.padding[0],
            dilation=module.dilation[0],
            groups=module.groups,
            bias=module.bias is not None,
            padding_mode=module.padding_mode
        )
        module_output.weight = torch.nn.Parameter(module.weight.unsqueeze(-1).repeat(1,1,1,1,module.kernel_size[0]))

    elif isinstance(module, torch.nn.MaxPool2d):
        module_output = torch.nn.MaxPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            dilation=module.dilation,
            ceil_mode=module.ceil_mode,
        )
    elif isinstance(module, torch.nn.AvgPool2d):
        module_output = torch.nn.AvgPool3d(
            kernel_size=module.kernel_size,
            stride=module.stride,
            padding=module.padding,
            ceil_mode=module.ceil_mode,
        )

    for name, child in module.named_children():
        module_output.add_module( name, convert_3d(child) )
    del module

    return module_output



class TimmSegModel(nn.Module):
    def __init__(self, backbone, segtype='unet', pretrained=False):
        super(TimmSegModel, self).__init__()

        self.encoder = timm.create_model( 
            backbone,
            in_chans=3,
            features_only=True,
            pretrained=pretrained
        )
        g = self.encoder(torch.rand(1, 3, 64, 64))
        encoder_channels = [1] + [_.shape[1] for _ in g]
        decoder_channels = [160, 64, 48, 24, 16]
        if segtype == 'unet':
            self.decoder = smp.unet.decoder.UnetDecoder(
                encoder_channels=encoder_channels[:n_blocks+1],
                decoder_channels=decoder_channels[:n_blocks],
                n_blocks=n_blocks,
                attention_type='scse',
            )
        self.segmentation_head = nn.Conv2d(decoder_channels[n_blocks-1], 7, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        enc_features = self.encoder(x)[:n_blocks]
        f, d = enc_features[0].shape[0], enc_features[0].device
        a = [24,48,64,160]
        b = [63,31,15,7]
        enc_features = [torch.cat((feat, torch.zeros((f,a[i],1,b[i],b[i]), device=d).float()), dim=2) for i, feat in enumerate(enc_features)]
        enc_features = [torch.cat((feat, torch.zeros((f,a[i],b[i]+1,1,b[i]), device=d).float()), dim=3) for i, feat in enumerate(enc_features)]      
        enc_features = [torch.cat((feat, torch.zeros((f,a[i],b[i]+1,b[i]+1,1), device=d).float()), dim=4) for i, feat in enumerate(enc_features)]           
        global_features = [0] + enc_features
        seg_features = self.decoder(*global_features)
        seg_features = self.segmentation_head(seg_features)
        return seg_features
  


# Load Models

In [None]:
models_seg = []

kernel_type = 'timm3d_effv2_unet4b_128_128_128_dsv2_flip12_shift333p7_gd1p5_bs4_lr3e4_20x50ep'
backbone = 'tf_efficientnetv2_s_in21ft1k'

n_blocks = 4
for fold in range(5): # 5
    model = TimmSegModel(backbone, pretrained=False)
    model = convert_3d(model)
    model = model.to(device)
    load_model_file = os.path.join(model_dir_seg, f'{kernel_type}_fold{fold}_best.pth')
    sd = torch.load(load_model_file)
    # sd.keys() =>
    # odict_keys(['encoder.conv1.0.weight', 'encoder.conv1.1.weight', 'encoder.conv1.1.bias', 
    # 'encoder.conv1.1.running_mean', 'encoder.conv1.1.running_var', 
    # ..................  'segmentation_head.weight', 'segmentation_head.bias'])
        
    # use this if 'model file .pth' is stored with extra information
    if 'model_state_dict' in sd.keys():
        sd = sd['model_state_dict']
        
    sd = {k[7:] if k.startswith('module.') else k: sd[k] for k in sd.keys()}
    # changed odict to dict
    # sd.keys() => 
    # dict_keys(['encoder.conv1.0.weight', 'encoder.conv1.1.weight', 'encoder.conv1.1.bias', 
    # 'encoder.conv1.1.running_mean', 'encoder.conv1.1.running_var', 
    # ..................  'segmentation_head.weight', 'segmentation_head.bias'])    
    
    model.load_state_dict(sd, strict=True) # strict=True
    model.eval()
    models_seg.append(model)
len(models_seg)

# Predict

In [None]:
# dataset_seg = SegTestDataset(df[6:7])
# loader_seg = torch.utils.data.DataLoader(dataset_seg, batch_size=batch_size_seg, shuffle=False, num_workers=num_workers)

In [None]:
# # plotting some prediction


# plt.rcParams["figure.figsize"] = (20,8)
# bar = tqdm(loader_seg)
# # bar => 0%|                                   | 0/2019 [00:00<?, ?it/s]
# # type(bar) => <class 'tqdm.std.tqdm'>

# with torch.no_grad():
#     f, axarr = plt.subplots(1,4)    
    
#     for batch_id, images in enumerate(bar):
#         # images.shape => torch.Size([1, 3, 128, 128, 128])
        
#         images = images.cuda()

#         # SEG
#         pred_masks = []
#         for model in models_seg:
#             pmask = model(images).sigmoid()
#             pred_masks.append(pmask)
            
#             # pred_masks[0].shape => torch.Size([1, 7, 128, 128, 128])
#             # torch.stack(pred_masks, 0).shape) => torch.Size([5, 1, 7, 128, 128, 128])
# #         pred_masks = torch.stack(pred_masks, 0).mean(0).cpu().numpy()            
#         pred_masks = torch.stack(pred_masks, 0).mean(dim = 0).cpu() #  taking mean of all 5 predictions
#             # pred_masks.shape => (1, 7, 128, 128, 128)
        
#         images = images.cpu()
#         images = images.squeeze() # numpy like squeeze on torch tensor.        
#         # images.shape => torch.Size([3, 128, 128, 128])               
        
#         # converting mask_values to 0.0 and 1.0
#         pred_masks = (pred_masks>0.5).float().squeeze()
#             # pred_masks.shape => torch.Size([7, 128, 128, 128])         
        
#         c = 0
#         for i in range(0,2):
#             if i==0:
#                 img = images[:, 60, :, :].detach().clone() # checking 60th dcm image for a particular patient       
#                 masks = pred_masks[:, 60, :, :].detach().clone() # checking corresponding 60th slice 
#                     # images.shape, pred_masks.shape => (3, 128, 128) (7, 128, 128)        
#             else:
#                 img = images[:, :, :, 60].detach().clone()
#                 masks = pred_masks[:, :, :, 60].detach().clone()          

#             # merging 7 channels in order to reduce to 3
#             masks[0] = masks[0] + masks[3] + masks[6] # merging C1, C4 and C7
#             masks[1] = masks[1] + masks[4] # merging C2, C5
#             masks[2] = masks[2] + masks[5] # merging C3, C6

#             masks = masks[:3] # selecting only 3 sequence / channels out of 7.        

#             axarr[c].imshow(img.transpose(0, 1).transpose(1,2)) # squeeze(); removes axes of length one.
#             axarr[c+1].imshow(masks.transpose(0, 1).transpose(1,2))         
#             c += 2
        
#         del img, masks,# images#, pred_masks
        
#         torch.cuda.empty_cache()
#         _ = gc.collect()
        

#         break          
        


# Predict Mask and Crop Slices on Train Data

In [None]:
def load_bone(msk, cid, t_paths, cropped_images, index, inst_id):
    sema.acquire() #  threading topic    
    n_scans = len(t_paths)
    bone = []
    try:
        msk_b = msk[cid] > 0.2 # as best_threshold in train_1_efficient.ipynb is 0.2.
            # msk_b.shape => (128, 128, 128)
            # msk_b => 
            # [[[False False False ... False False False]        
            #   ...
            #   ...
            #   [False False False ... False False False]
            #   [False False False ... False False False]]]          
          
        msk_c = msk[cid] > 0.02
                
            # msk_b.sum(axis = 1).shape => (128, 128)
            # msk_b.sum(1).sum(1).shape =>  (128,) 
            # msk_b.sum(1).sum(1) => 
            # [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   1
            #   18  75 157 239 307 382 485 592 655 663 688 809 810 601 323 100  36  10
            #    0   0   1   2   0   0   0   0   0   0   0   0   0   0   0   0   0   0
            #    ...
            #    ...
            #    0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
            #    0   0]            
            # np.where(msk_b.sum(1).sum(1) > 0) => 
            # (array([17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
            #        34, 35, 38, 39]),)              

        # finding indices of three coordinate axes of mask_cuboid (128,128,128) where data is present.
        # (128,128,128) => (slices, W, H)
        y = np.where(msk_b.sum(0).sum(0) > 0)[0]
        x = np.where(msk_b.sum(0).sum(1) > 0)[0]        
        z = np.where(msk_b.sum(1).sum(1) > 0)[0]


        if len(x) == 0 or len(y) == 0 or len(z) == 0:
            y = np.where(msk_c.sum(0).sum(0) > 0)[0]            
            x = np.where(msk_c.sum(0).sum(1) > 0)[0]
            z = np.where(msk_c.sum(1).sum(1) > 0)[0]            

        # msk.shape => (7, 128, 128, 128)
        y1, y2 = max(0, y[0]), min(msk.shape[2]-1, y[-1])        
        x1, x2 = max(0, x[0]), min(msk.shape[3]-1, x[-1])
        z1, z2 = max(2, z[0]), min(msk.shape[1]-3, z[-1])
        
#             # z1, z2 => 26 98
        zz1, zz2 = int(z1 / msk_size * n_scans), int(z2 / msk_size * n_scans) # msk_size = 128, defined in config section
#             # from BODMAS rule => zz1, zz2 => 49 186
#             # (z1 / msk_size) => proportion of z1-coordinate over index length 128 (msk_size).
#             # (z1 / msk_size) * n_scans => z1-coordinate over index length 'n_scans'.
            
        ## return 15 (n_slice_per_c) evenly spaced indexes.
        inds_ = np.linspace(z1 ,z2, n_slice_per_c).astype(int) 
            # inds_ = np.linspace(26, 98, 128) => [26 31 36 41 46 51 56 61 66 71 76 81 86 91 97]        
            # np.linspace(start, stop, num_of_samples_to_generate, ...) => Return evenly spaced numbers over a specified interval. 
            
        inds = np.linspace(zz1 ,zz2, n_slice_per_c).astype(int)
            # inds => [ 49  58  68  78  87  97 107 117 126 136 146 155 165 175 185]

        for sid, (ind, ind_) in enumerate(zip(inds, inds_)):        
#         for ind_ in inds_:
#             k = index[0][ind_]
            images = []
#             for i in range(-n_ch//2+1, n_ch//2+1): # n_ch = 5, defined in config 
            for i in [-2,-1,0,1,2]:
                # for i in (-2, 3):
                # 5//2 = 2, -5//2 = -3 (1 extra in -ve)
                try:
#                     dicom = pydicom.read_file(t_paths[k+i])
                        # ind_ = 26 read 24, 25, 26, 27, 28 dicom files 
                    dicom = pydicom.read_file(t_paths[ind+i])                     
                    images.append(dicom.pixel_array)
                except:
                    images.append(np.zeros((512, 512)))
                    

            data = np.stack(images, -1)
                # data.shape => (512, 512, 5)
            
            data = data - np.min(data)
            data = data / (np.max(data) + 1e-4)
            data = (data * 255).astype(np.uint8)  
                # prior to any type of transformation(resize, augmentation etc.) convert data to uint8.
            
            xx1 = int(x1 / msk_size * data.shape[0])
            xx2 = int(x2 / msk_size * data.shape[0])
            yy1 = int(y1 / msk_size * data.shape[1])
            yy2 = int(y2 / msk_size * data.shape[1])
            
            
            data = data[xx1:xx2, yy1:yy2]
                # data.shape => (96, 172, 5)
            data = np.stack([cv2.resize(data[:, :, i], (image_size_cls, image_size_cls), interpolation = cv2.INTER_LINEAR) for i in range(n_ch)], -1)
                # data.shape => (224, 224, 5)                     

            # sticking center mask at lask such that area around it, will be targeted.
            msk_this = msk[cid, ind_, :, :] # ind_ = 26 
            
            msk_this = msk_this[x1:x2, y1:y2]
                # msk_this[16:40, 23:66].shape => (24, 43)                                                
            msk_this = (msk_this * 255).astype(np.uint8)
                # prior to any type of transformation(resize, augmentation etc.) convert data to uint8.
                # np.unique(msk_this).astype(np.uint8)) =>             
                # [  0   1   2   4   6   8  11  14  21  23  48  49  50  61  69  72  73  81
                #  107 109 113 116 121 128 143 144 147 161 181 192 204 207 209 214 232 236
                #  240 241 245 247 250 252 253 254 255]                                
                
            msk_this = cv2.resize(msk_this, (image_size_cls, image_size_cls), interpolation = cv2.INTER_LINEAR)            
                # msk_this.shape => (224, 224)

            # appending mask at as 6th channel of data. 
            data = np.concatenate([data, msk_this[:, :, np.newaxis]], -1)            
                 # data.shape => (224, 224, 6)

            bone.append(torch.tensor(data))

    except:
        for sid in range(n_slice_per_c):
            bone.append(torch.zeros((image_size_cls, image_size_cls, n_ch+1)).to(torch.uint8))
        
    cropped_images[cid] = torch.stack(bone, 0)
        # cropped_images[cid].shape => torch.Size([15, 224, 224, 6])    
        
    ### using local cache
    image_file = os.path.join(data_dir, f'numpy_2/{inst_id}_{cid+1}.npy')
    np.save(image_file, cropped_images[cid])
    
    sema.release() #  threading topic    

def load_cropped_images(msk, image_folder, index, inst_id, n_ch=n_ch):

    t_paths = sorted(glob(os.path.join(image_folder, "*")), key=lambda x: int(x.split('/')[-1].split(".")[0]))
    
    for cid in range(7): # 7
        threads[cid] = threading.Thread(target=load_bone, args=(msk, cid, t_paths, cropped_images, index, inst_id))
        threads[cid].start()
    for cid in range(7):
        threads[cid].join()

        # torch.cat(tensors, dim=0,..) => Concatenates the given sequence of seq tensors in the given dimension. 
        # torch.cat(cropped_images, 0).shape => torch.Size([105, 224, 224, 6])    
    return torch.cat(cropped_images, 0)


In [None]:
sema = threading.Semaphore(value=12) # value => setting number of threads
dataset_seg = SegTestDataset(df) # df[1:2] df[3:4] 
loader_seg = torch.utils.data.DataLoader(dataset_seg, batch_size=batch_size_seg, shuffle=False, num_workers=num_workers)

In [None]:

bar = tqdm(loader_seg)
with torch.no_grad():
    for batch_id, (images, indices) in enumerate(bar):
        indices = indices.numpy()
        images = images.cuda()
        # SEG
        pred_masks = []
        for model in models_seg:            
            pmask = model(images).sigmoid()
                # torch.unique(model(images)) => 
                # tensor([-191.6912, -190.9502, -190.5765,  ...,   36.2534,   36.5492,
                #           36.8366], device='cuda:0')

            pred_masks.append(pmask)
        pred_masks = torch.stack(pred_masks, 0).mean(0).cpu().numpy()
            # pred_masks.shape => (1, 7, 128, 128, 128)
        
        # Build cls input
        cls_inp = []
        threads = [None] * 7
        cropped_images = [None] * 7

        for i in range(pred_masks.shape[0]):
            
             # batch_id, batch_size_seg, i => 0, 1, 0
            row = df.iloc[batch_id*batch_size_seg+i] # df[1:2] df[3:4] 

            cropped_images = load_cropped_images(pred_masks[i], row.image_folder, indices, row.StudyInstanceUID)



In [None]:
# [113 117 122 126 131 135 140 144 149 153 158 162 167 171 176]
# [114, 118, 122, 126, 130, 134, 138, 146, 150, 154, 157, 161, 165, 169, 177]

In [None]:
# print(cropped_images.shape)
# # print(torch.min(cropped_images), torch.max(cropped_images))

In [None]:
# check all Cid (for one patient) one by one in order to verify that images are showing correctly or not -
# - otherwise adjust numeric value in 'msk[cid] > 0.1' in above function.
# plotter Cid = 2
plt.rcParams["figure.figsize"] = (20,8)
for i in range(17,27):
    fx, arr = plt.subplots(1,6)
    
    for j in range(6):
        arr[j].imshow(cropped_images[i][:,:,j])   

In [None]:
# plotter Cid = 1
plt.rcParams["figure.figsize"] = (20,8)
for i in range(1,11):
    fx, arr = plt.subplots(1,6)
    
    for j in range(6):
        arr[j].imshow(cropped_images[i][:,:,j])   

In [None]:
# plotter Cid = 7
plt.rcParams["figure.figsize"] = (20,8)
for i in range(65,75):
    fx, arr = plt.subplots(1,6)
    
    for j in range(6):
        arr[j].imshow(cropped_images[i][:,:,j])   

In [None]:
# plotter Cid = 7
plt.rcParams["figure.figsize"] = (20,8)
for i in range(95,105):
    fx, arr = plt.subplots(1,6)
    
    for j in range(6):
        arr[j].imshow(cropped_images[i][:,:,j])   

In [None]:
# plotter Cid = 7
plt.rcParams["figure.figsize"] = (20,8)
for i in range(80,90):
    fx, arr = plt.subplots(1,6)
    
    for j in range(6):
        arr[j].imshow(cropped_images[i][:,:,j])  