In [1]:
import numpy as np
import pandas as pd
# pd.options.plotting.backend = "plotly"
import random
from glob import glob
import os, shutil
from tqdm import tqdm
tqdm.pandas()
import time
import copy
# import joblib
from collections import defaultdict
import gc
from IPython import display as ipd

# visualization
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# Sklearn
# from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# class indices

LARGE_BOWEL = 0
SMALL_BOWEL = 1
STOMACH = 2
MASK_INDICES = {'large_bowel': LARGE_BOWEL, 'small_bowel':SMALL_BOWEL, 'stomach':STOMACH}

In [3]:
def load_scans(scan_path):
    image_files = [f for f in os.listdir(scan_path) if os.path.isfile(os.path.join(scan_path, f))]
    
    image_files.sort(key=lambda x: int(x.split('_')[1]))
    
    scan_slices = []
    for img_file in image_files:
        img_file_path = os.path.join(scan_path, img_file)
        scan_slice = cv2.imread(img_file_path, cv2.IMREAD_UNCHANGED)
        scan_slice = scan_slice.astype('float32')
        scan_slices.append(scan_slice)
    
    img = np.stack(scan_slices)
    max_val = np.max(img)
    if max_val:
        img /= max_val

    return img

def load_msk(path):
    msk = np.load(path)
    msk = msk.astype('float32')
    return msk

## Dataset

In [4]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, df, label=True, transforms=None):
        self.df         = df
        self.label      = label
        self.img_paths  = df['image_path'].tolist()
        self.msk_paths  = df['mask_path'].tolist()
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        print('get item')
        img_path  = self.img_paths[index]
        img = []
        img = load_scans(img_path)
        print('scan loaded')
        
        if self.label:
            msk_path = self.msk_paths[index]
            msk = load_msk(msk_path)
            if self.transforms:
                data = self.transforms(image=img, mask=msk)
                img  = data['image']
                msk  = data['mask']
            return torch.tensor(img), torch.tensor(msk)
        else:
            if self.transforms:
                data = self.transforms(image=img)
                img  = data['image']
            return torch.tensor(img)

In [5]:
# Taking a look at the dataset
df_train = pd.read_csv('./input/uw-madison-gi-tract-image-segmentation/train.csv')
train_path = './input/uw-madison-gi-tract-image-segmentation/train'
df_train['case_day'] = df_train['id'].map(lambda x: x.split('_slice')[0])
df_train.drop(columns=['class', 'segmentation', 'id'], inplace=True)
df_train.drop_duplicates(inplace=True)
df_train['image_path'] = df_train['case_day'].map(lambda x: f'{train_path}/' + x.split('_')[0] + f'/{x}/scans')
df_train['mask_path'] =  df_train['case_day'].map(lambda x: f'{train_path}/' + x.split('_')[0] + f'/{x}/masks3D/{x}.npy')
df_train.head(4)

Unnamed: 0,case_day,image_path,mask_path
0,case123_day20,./input/uw-madison-gi-tract-image-segmentation...,./input/uw-madison-gi-tract-image-segmentation...
432,case123_day22,./input/uw-madison-gi-tract-image-segmentation...,./input/uw-madison-gi-tract-image-segmentation...
864,case123_day0,./input/uw-madison-gi-tract-image-segmentation...,./input/uw-madison-gi-tract-image-segmentation...
1296,case77_day20,./input/uw-madison-gi-tract-image-segmentation...,./input/uw-madison-gi-tract-image-segmentation...


In [6]:
train_dataset = ImageDataset(df_train)

In [7]:
img, msk = train_dataset[20]

get item
scan loaded


In [8]:
img.shape

torch.Size([144, 266, 266])

In [9]:
msk.shape

torch.Size([266, 266, 144, 3])

In [10]:
gc.collect()

11

In [11]:
if __name__ == '__main__':
    train_loader = DataLoader(train_dataset, batch_size=2, num_workers=1, drop_last=True)
    imgs, msks = next(iter(train_loader))
    imgs.size(), msks.size()

RuntimeError: DataLoader worker (pid(s) 10920) exited unexpectedly