In [204]:
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
import numpy as np
import os
import cv2
from matplotlib.patches import Rectangle
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from datetime import datetime
import json,itertools
from typing import Optional
from glob import glob
from tensorflow import keras
import tensorflow as tf
import keras
from sklearn.model_selection import StratifiedKFold
import matplotlib.gridspec as gridspec
import matplotlib.patches as mpatches
import matplotlib as mpl


In [205]:
train_image_path = '../input/uw-madison-gi-tract-image-segmentation/train/'
train_image_path = '../input/uw-madison-gi-tract-image-segmentation/test/'
df = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/train.csv')
sample_submission  = pd.read_csv('../input/uw-madison-gi-tract-image-segmentation/sample_submission.csv')
def get_case(row):
    return row.split("_")[0][5:]
def get_day(row):
     return row.split("_")[1][3:]
def get_slice(row):
     temp_ = row.split("_")[2] + row.split("_")[3]
     return temp_[5:]

path = '../input/uw-madison-gi-tract-image-segmentation/train'
train_image_paths = glob(os.path.join( path, '**','**' , '*' , '*.png'))
df['case'] = df['id'].apply(lambda x: get_case(x))
df['day'] = df['id'].apply(lambda x: get_day(x))
df['slice'] = df['id'].apply(lambda x: get_slice(x))
df_without = df[df['segmentation'].isna()]
df_with = df[~df['segmentation'].isna()]


In [206]:
def get_temp_path(temp):
    case = temp.split("_")[0]
    case_day = case + '_' + temp.split("_")[1]
    slice_ = temp.split("_")[-1]
    temp= path + '/' + case + '/'+ case_day + '/scans/slice_' + slice_ 
    return temp
df['temp_path'] = df['id'].apply(lambda x: get_temp_path(x))

def get_path_list(temp):
    one = '/'.join(map(str,temp.split("/")[0:-1])) + '/' + 'slice_'
    last  = temp.split("/")[-1].split("_")[1]
    final = one +last
    return final

temp_orginal_path = [get_path_list(i) for i in train_image_paths]
original_path_df = pd.DataFrame(train_image_paths , columns = ['original_path']) 
original_path_df['temp_path'] = temp_orginal_path
df = pd.merge(df, original_path_df, how='left', on ='temp_path')

In [207]:
def get_width(path):
    return path.split("/")[-1].split("_")[2]
def get_height(path):
    return path.split("/")[-1].split("_")[3]

df['width']= df['original_path'].apply(lambda x: get_width(x))
df['height']= df['original_path'].apply(lambda x: get_height(x))

In [217]:
final_df = pd.DataFrame({'id' : df['id'][::3]})
final_df['lb_seg'] = df['segmentation'][::3].values
final_df['sb_seg'] = df['segmentation'][1::3].values
final_df['st_seg'] =  df['segmentation'][2::3].values
final_df['case'] =  df['case'][::3].values
final_df['day'] =  df['day'][::3].values
final_df['slice'] =  df['slice'][::3].values
final_df['original_path'] =  df['original_path'][::3].values
final_df['width'] =  df['width'][::3].values
final_df['height'] =  df['height'][::3].values

train_indexes = []
#df_train = final_df[~ (final_df['lb_seg'].isna() & final_df['sb_seg'].isna()  & final_df['st_seg'].isna() )]
df_train = final_df.copy()
df_train=df_train.reset_index(drop=True)

df_train['empty'] = df['segmentation'].isna().astype(int)


df_train = df_train.fillna('')

In [218]:
def rle_decode(temp, shape):
    array = np.zeros(shape=shape)
    temp = temp.split(" ")
    array = array.reshape(-1)
    pixels = list(map(int, temp[0:-1:2]))
    distance = list(map(int,temp[1::2]))
    end = [ i+j for i, j in zip(pixels, distance)]
    for start,end_ in zip(pixels, end):
        array[start:end_]=1
    array = array.reshape(shape)
    return array

def rle_encode(arr):
    arr = arr.reshape(-1)
    indexes = (np.where(arr[1:] != arr[:-1])[0])+1
    final = []
    one = indexes[0:-1:2]
    two = indexes[1::2]
    for start,end in zip(one, two):
        final.append(start)
        final.append(end-start)
    return ' '.join(map(str,final))


def show_img(img, mask=None):
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    plt.imshow(img, cmap='bone')
    
    if mask is not None:
        plt.imshow(mask, alpha=0.5)
        handles = [Rectangle((0,0),1,1, color=_c) for _c in [(0.667,0.0,0.0), (0.0,0.667,0.0), (0.0,0.0,0.667)]]
        labels = ["Large Bowel", "Small Bowel", "Stomach"]
        plt.legend(handles,labels)
    plt.axis('off')

In [219]:
class CFG:
    img_size = [224, 224]

In [220]:
data_transforms = {
    "train": A.Compose([
        A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST),
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
        A.OneOf([
            A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
        ], p=0.25),
        A.CoarseDropout(max_holes=8, max_height=CFG.img_size[0]//20, max_width=CFG.img_size[1]//20,
                         min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
        ], p=1.0),
    
    "valid": A.Compose([
        A.Resize(*CFG.img_size, interpolation=cv2.INTER_NEAREST),
        ], p=1.0)
}

In [275]:
class Segmentation_data(torch.utils.data.Dataset):
    def __init__(self, df, flag, transforms=None):
        self.flag=flag
        self.df = df
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, index):
        masks = np.zeros(shape=(224,224,3))
        img_path = self.df['original_path'].iloc[index]
        img = self.__load_image(img_path)
        h = int(self.df['height'].iloc[index])
        w = int(self.df['width'].iloc[index])
        
        if(self.flag=='train'):
            for i, label in enumerate(['lb_seg', 'sb_seg', 'st_seg']):
                mask_ = rle_decode(self.df[label].iloc[index] , (h,w,1))
                mask_ = cv2.resize(mask_ , (224, 224))
                masks[:,:,i]  = mask_
        
        img = np.transpose(img, (2,0,1)) ##pytoch thing
        masks = np.transpose(masks , (2,0,1)) ##pytoch thing
        
        if self.transforms:
            data = self.transforms(image=img, mask=masks)
            img  = data['image']
            msk  = data['mask']
            
        if(self.flag == 'train'):
            return torch.tensor(img), torch.tensor(masks)
        else:
            return torch.tensor(img)
        
    def __load_image(self, image_path):
        img = cv2.imread(image_path,  cv2.IMREAD_UNCHANGED)
        img = (img - img.min())/(img.max() - img.min())*255.0 
        img = cv2.resize(img, CFG.img_size)
        img = np.tile(img[...,None], [1, 1, 3]) 
        img = img.astype('float32') /255.
        return img

In [276]:
skf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(skf.split(df_train, df_train['empty'], groups = df_train["case"])):
    df_train.loc[val_idx, 'fold'] = fold
    
display(df_train.groupby(['fold','empty'])['id'].count())

In [277]:
dataset = Segmentation_data(df_train, 'train' , None )
BATCH_SIZE = 64

In [283]:
def prepare_loaders(fold, debug=False):
    train_df = df_train.query("fold!=@fold").reset_index(drop=True)
    valid_df = df_train.query("fold==@fold").reset_index(drop=True)
        
    train_dataset = Segmentation_data(train_df,'train')#, transforms=data_transforms['train'])
    valid_dataset = Segmentation_data(valid_df,'train')#, transforms=data_transforms['valid'])

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE ,num_workers=4, shuffle=True, pin_memory=True, drop_last=False)
    valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, num_workers=4, shuffle=False, pin_memory=True)
    
    return train_loader, valid_loader

In [284]:
train_loader, valid_loader = prepare_loaders(fold=0, debug=True)

In [285]:
imgs, msks = next(iter(train_loader))
imgs.size(), msks.size()

def plot_batch(imgs, msks, size=3):
    plt.figure(figsize=(5*5, 5))
    for idx in range(size):
        plt.subplot(1, 5, idx+1)
        img = imgs[idx,].permute((1, 2, 0)).numpy()
        msk = msks[idx,].permute((1, 2, 0)).numpy()
        show_img(img, msk)
    plt.tight_layout()
    plt.show()

plot_batch(imgs, msks, size=5)