In [None]:
%%capture

import warnings 
warnings.filterwarnings("ignore")
import os
from os.path import join
import time
from tqdm import tqdm

import numpy as np
from numpy.random import choice
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.metrics import cohen_kappa_score

import PIL
from PIL import Image
import cv2


import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models as md



# Data pre-processing

In [None]:
DATA_DIR = '../input/aptos2019-blindness-detection'

train_dir = join(DATA_DIR, 'train_images')
label_df  = pd.read_csv(join(DATA_DIR, 'train.csv'))


def train_validation_split(df, val_fraction=0.1):
    val_ids  = np.random.choice(df.id_code, size=int(len(df) * val_fraction))
    val_df   = df.query('id_code     in @val_ids')
    train_df = df.query('id_code not in @val_ids')
    return train_df, val_df


train_df, val_df = train_validation_split(label_df)
print(train_df.shape, val_df.shape)
train_df.head()

In [None]:
def crop_image_from_gray(img,tol=7):
    """
    This function from:
    https://www.kaggle.com/ratthachat/aptos-updatedv14-preprocessing-ben-s-cropping
    """
    if img.ndim ==2:
        mask = img>tol
        return img[np.ix_(mask.any(1),mask.any(0))]
    elif img.ndim==3:
        gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        mask = gray_img>tol
        
        check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0]
        if (check_shape == 0): # image is too dark so that we crop out everything,
            return img # return original image
        else:
            img1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))]
            img2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))]
            img3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))]
    #         print(img1.shape,img2.shape,img3.shape)
            img = np.stack([img1,img2,img3],axis=-1)
    #         print(img.shape)
        return img


cv_to_pil = transforms.ToPILImage()

    
def center_crop(image: PIL.Image):
    """
    Only gets center square (of rectangular images) - no resizing
    => diffently sized square images
    """
    old_width, old_heigh = image.size
    new_size = min(old_width, old_heigh)
    
    margin_x = (old_width - new_size) // 2
    margin_y = (old_heigh - new_size) // 2
    
    left   = margin_x
    right  = margin_x + new_size
    top    = margin_y
    bottom = margin_y + new_size
    
    return image.crop( (left, top, right, bottom) )


def process_image_ratio_invariant(cv2_image, size=256, do_center_crop=True):
    
    image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)
    image = crop_image_from_gray(image)
    #image = cv2.resize(image, (size, size))  # this would distort eyeball shape
    
    if do_center_crop is False:
        return image
    
    # crop the largest possible square from the center
    pil_img = cv_to_pil(image)
    pil_img = center_crop(pil_img)
    image   = np.array(pil_img).copy()
    
    # now we have quadratic, but differently sized images
    # => resize without altering the shape of the eyeball
    image = cv2.resize(image, (size, size))
    
    # add gaussian blur with sigma proportional to new size:
    image = cv2.addWeighted (image, 4, cv2.GaussianBlur(image, (0, 0) , size/30) , -4 ,128)
        
    return cv_to_pil(image)

In [None]:
%%time

    
class Diabetic_Retionopathy_Data(Dataset):
    
    def __init__(self,
                 image_dir: str,
                 label_df: pd.DataFrame,
                 train=True,
                 transform=transforms.ToTensor(),
                 sample_n=None,
                 in_memory=False,
                 write_images=False):
        """
        @ image_dir:   path to directory with images
        @ label_df:    df with image id (str) and label (0/1) - only for labeled test-set
        @ transforms:  image transformation; by default no transformation
        @ sample_n:    if not None, only use that many observations
        """
        self.image_dir = image_dir
        self.transform = transform
        self.train     = train
        self.in_memory = in_memory
        
        if sample_n:
            label_df  = label_df.sample(n=min(sample_n, len(label_df)))
            
        ids            = set(label_df.id_code)
        self.img_files = [f for f in os.listdir(image_dir) if f.split('.')[0] in ids]
        label_df.index = label_df.id_code
        self.label_df  = label_df.drop('id_code', axis=1)
        
        if in_memory:
            
            self.id2image = {}
            for i, file_name in enumerate(self.img_files):
                
                if i and i % 500 == 0:
                    print(f'{i} / {len(self.img_files)}')
                
                image = self._read_process_image(join(image_dir, file_name))
                id_   = file_name.split('.')[0]
                self.id2image[id_] = image
                
                if write_images:
                    image.save(file_name)
                    
        print(f'Initialized datatset with {len(self.img_files)} images.\n')
        
    @staticmethod
    def _read_process_image(file_path: str, size=256):
        image = cv2.imread(file_path)        
        return process_image_ratio_invariant(image, size=size)        

    def __getitem__(self, idx):

        file_name = self.img_files[idx]
        id_ = file_name.split('.')[0]
        
        if self.in_memory:
            img = self.id2image[id_]
        else:
            img = self._read_process_image(join(self.image_dir, file_name))
        
        X   = self.transform(img)
        
        if self.train:
            y = float(self.label_df.loc[id_].diagnosis)
            return X, y, id_
        else:
            return X, id_
    
    def __len__(self):
        return len(self.img_files)


class RandomCenterCrop(transforms.CenterCrop):
    """
    Crops the PIL Image at the center.
    :param: min_size, max_size: range of crop-size randomly within [min_size, max_size]
    """
    def __init__(self, min_size: int, max_size: int):
        self.min_size = min_size
        self.max_size = max_size
        
    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.
        Returns:
            PIL Image: Cropped image.
        """
        size = np.random.randint(self.min_size, self.max_size + 1)
        crop = transforms.CenterCrop( (size, size) )
        return crop(img)

    def __repr__(self):
        return f'{self.__class__.__name__}: (min-size={self.min_size}, max-size={self.max_size})'


batchsize = 3

# due to the large amount of data, random transformations might not be necessary...
train_transform = transforms.Compose([
    RandomCenterCrop(min_size=200, max_size=256),
    transforms.Resize( (256, 256) ),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation( (-20, 20) ),  
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train = Diabetic_Retionopathy_Data(train_dir,
                                   train_df,
                                   transform=train_transform,
                                   in_memory=True,
                                   write_images=False)
val   = Diabetic_Retionopathy_Data(train_dir,
                                   val_df,
                                   transform=train_transform,
                                   in_memory=True,
                                   write_images=False)

train_loader = DataLoader(train, batch_size=batchsize, num_workers=4, shuffle=True)
val_loader   = DataLoader(val,   batch_size=batchsize, num_workers=3, shuffle=False)

X, y, _ = next(iter(val_loader))
print(f'batch-dimension:\nX = {X.shape},\ny = {y.shape}')
print(f'number of batches:\ntrain: {len(train_loader)}\nvalidation: {len(val_loader)}')

### Check pre-processing: compare raw vs. pre-processed images

In [None]:
def show_processed_images(image_dir, n=5, label_df=None, tf=None):
    
    sample_files = np.random.choice(os.listdir(image_dir), size=n)
    
    for file_name in sample_files:
        
        if label_df is not None:
            id_ = file_name.split('.')[0]
            diagnosis = label_df.query('id_code == @id_').diagnosis.item()
        else:
            diagnosis = 'unknown'
        
        image     = cv2.imread(join(image_dir, file_name))
        raw_image = cv_to_pil(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        
        if tf is not None:
            processed_image = tf(join(image_dir, file_name))
        
        fig, (ax1, ax2) = plt.subplots(1, 2)
        fig.set_size_inches(10, 5)
        
        ax1.imshow(raw_image)
        if tf is not None:
            ax2.imshow(processed_image)
        ax1.set_title('raw')
        ax2.set_title('processed')
        
        fig.suptitle(f'diagnosis = {diagnosis}')            
        plt.show()
        
    
print('TRAINING DATA:')
show_processed_images(join(DATA_DIR, 'train_images'),
                      label_df=pd.concat([train_df, val_df]),
                      tf=train._read_process_image)
print('TEST DATA:')
show_processed_images(join(DATA_DIR, 'test_images'),
                      tf=train._read_process_image)