In [None]:
import numpy as np 
import pandas as pd 
import slidingwindow as sw 
import skimage 
from skimage import io 
from pandas import DataFrame
import matplotlib.pyplot as plt
%matplotlib inline

import os 
import re #2 
import random #3
from time import time #4

from utils import *

import torch # 7
from torch.utils.data  import Dataset # 8

from fastai2.data.core import DataLoaders # 5
from fastai2.vision.all import * #6
from fastai2.vision.widgets import *
from fastai2.data.external import untar_data,URLs
from fastai2.data.transforms import get_image_files # 1


In [2]:
BASE_FOLDER = "/project/data/"   # if on gcloud 

train = pd.read_csv(os.path.join(BASE_FOLDER, 'train.csv'))

image_dir = '/home/abharani/data/train_images/'
path = Path(image_dir)

1. Wrapping all pre-processing (resize, conversion to tensor, dividing by 255 and reordering of the channels) on image into one step using a helper func.
2. See label for image inside the file name (or generate via image_id from fname) 


In [None]:
def pre_process_image(fname):
    WINDOW_SIZE = 256
    STRIDE = 64
    K = 16
    image = skimage.io.MultiImage(str(fname))[-1]
    midres_image, best_coordinates, best_regions_image = generate_patches(image, window_size=WINDOW_SIZE, stride=STRIDE, k=K)
    glued_image = glue_to_one_picture(best_regions_image, window_size=WINDOW_SIZE, k=K)
    t = torch.Tensor(np.array(glued_image))
    return t.permute(2,0,1).float()/255.0


def label_func1(fname):
    return re.match(r'^(.*)_\d+.jpg$', fname.name).groups()[0]

def label_func2(filepath):
    """ input : FilePath
    """
    print(filepath)
    file_path = os.path.splitext(filepath)[0]
    image_id = file_path.split("/")[-1]    
    
    return train.loc[train['image_id']==image_id]['isup_grade'].values[0]



def label_func3(file_path):
    """ input : file_path (path:object)
                ( e.g. /home/abharani/data/train_images/6fc63d2394ebade5d7e09856eab1f726_0.jpg)
        returns : 1 , 2, 3 (int)
    """ 
    image_name = str(file_path).split("/")[-1]

    return int(image_name.replace(".jpg", "").split("_")[-1])


#### Generate list of files at image dir, pick random indexes and perform split for train and validation set.

In [None]:
files = get_image_files(path)[0:200]
idxs = np.random.permutation(range(len(files)))
cut = int(0.8 * len(files))
train_files = files[idxs[:cut]]
valid_files = files[idxs[cut:]]
print("Training set images {}, Validation set images {}".format(len(train_files),len(valid_files)))

#### Let's check unique lables in dataset and distribution of each label 

In [None]:
labels = list(set(files.map(label_func3)))
print("distinct labels {}".format(len(labels)))

## Approach I  - Purely Pytorch 
Following from https://dev.fast.ai/tutorial.siamese


####  We can use above files to create Dataset 

In [None]:
class BiopsyDataset(Dataset):
    def __init__(self, files, is_valid=False):
        self.files = files
        self.is_valid =is_valid
        
    def __getitem__(self, i):
        file_path = self.files[i]
        tic = time.time()
        processed_image = pre_process_image(file_path) 
        toc = time.time()
        print("Time took to pre-process {} secs".format(toc-tic))
        cls = label_func3(file_path)
        y_tensor = torch.tensor(cls, dtype=torch.long)
        return (processed_image, y_tensor)
    
    def __len__(self): 
        return len(self.files)
    
    
train_ds :Dataset = BiopsyDataset(train_files)
valid_ds :Dataset = BiopsyDataset(valid_files, is_valid=True)

# Validate dataset

for i in range(len(train_ds)):
    sample = train_ds[i]

    print(i, sample[0].shape, sample[1])

    if i == 3:
        plt.show()
        break
    

#### Create DataLoaders with the following factory method DataLoaders

We can change batch-size depending upon gpu

In [None]:
dls = DataLoaders.from_dsets(train_ds, valid_ds,bs=5,num_workers=4)

#### to use the GPU and inspect one batch of data

In [None]:
dls = dls.cuda()
b = dls.one_batch()

##### Create cnn_learner using pre-trained resnet50 model

In [None]:
learn = cnn_learner(dls, resnet50, metrics=[accuracy],n_out=6,loss_func=F.cross_entropy)
learn.fine_tune(10)

### End of Approach I
what is a bit annoying is that we have to rewrite everything that is already in fastai if we want to normalize our images, or apply data augmentation.

### Approach II - Fastai
Following from https://dev.fast.ai/tutorial.siamese

A dataset like before, you can easily convert it into a fastai Transform by just changing the __getitem__ function to encodes. 

So three things changed:

1. the __len__ disappeared, we won't need it
2. __getitem___ became encodes
3. we return TensorImage for our images

still wrapping all pre-processing (resize, conversion to tensor, dividing by 255 and reordering of the channels) on image into one step using a helper func. 
and generating label for image inside the file name (or generate via image_id from fname)

In [None]:
class BiopsyTransform(Transform):
    def __init__(self, files, is_valid=False):
        self.files = files
        self.is_valid = is_valid
        
    def encodes(self, i):
        file_path = self.files[i]
#         tic = time.time()
        processed_image = pre_process_image(file_path) 
#         toc = time.time()
#         print("Time took to pre-process {} secs".format(toc-tic)) 
        cls = label_func3(file_path)
        y_tensor = torch.tensor(cls, dtype=torch.long)
        return (TensorImage(processed_image), y_tensor)
    

##### How do we build a dataset with this? We will use TfmdLists. It's just an object that lazily applies a collection of Transforms on a list. Here since our transform takes integers, we will pass simple ranges for this list. 

In [None]:
train_tl= TfmdLists(range(len(train_files)), BiopsyTransform(train_files))
valid_tl= TfmdLists(range(len(valid_files)), BiopsyTransform(valid_files, is_valid=True))

##### Then, when we create a DataLoader, we can add any transform we like.


In [None]:
dls = DataLoaders.from_dsets(train_tl, valid_tl, bs=5,num_workers=4,after_item=[Resize(224), ToTensor],
                             after_batch=[Resize(224),Normalize.from_stats(*imagenet_stats), *aug_transforms()])
dls = dls.cuda()
b = dls.one_batch()
print(b[0].shape,b[1])

In [None]:
# for i, sample in enumerate(dls):
#     print(sample)
# dls.show_batch()

##### Create cnn_learner using pre-trained resnet50 model

In [None]:
learn = cnn_learner(dls, resnet50, metrics=[accuracy],n_out=6,loss_func=F.cross_entropy)
learn.fine_tune(10)

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

In [None]:
learn.export()

In [None]:
path = Path()
path.ls(file_exts='.pkl')

In [None]:
learn_inf = load_learner(path/'export_resnet50.pkl')

In [None]:
def generate_patches(image, window_size=200, stride=128, k=20):
    
#     image = skimage.io.MultiImage(slide_path)[-2]
#     image = np.array(image)
    
    max_width, max_height = image.shape[0], image.shape[1]
    regions_container = []
    i = 0
    
    while window_size + stride*i <= max_height:
        j = 0
        
        while window_size + stride*j <= max_width:            
            x_top_left_pixel = j * stride
            y_top_left_pixel = i * stride
            
            patch = image[
                x_top_left_pixel : x_top_left_pixel + window_size,
                y_top_left_pixel : y_top_left_pixel + window_size,
                :
            ]
            
            ratio_white_pixels, green_concentration, blue_concentration = compute_statistics(patch)
            
            region_tuple = (x_top_left_pixel, y_top_left_pixel, ratio_white_pixels, green_concentration, blue_concentration)
            regions_container.append(region_tuple)
            
            j += 1
        
        i += 1
    
    k_best_region_coordinates = select_k_best_regions(regions_container, k=k)
    k_best_regions = get_k_best_regions(k_best_region_coordinates, image, window_size)
    
    return image, k_best_region_coordinates, k_best_regions


def compute_statistics(image):
    """
    Args:
        image                  numpy.array   multi-dimensional array of the form WxHxC
    
    Returns:
        ratio_white_pixels     float         ratio of white pixels over total pixels in the image 
    """
    width, height = image.shape[0], image.shape[1]
    num_pixels = width * height
    
    num_white_pixels = 0
    
    summed_matrix = np.sum(image, axis=-1)
    # Note: A 3-channel white pixel has RGB (255, 255, 255)
    num_white_pixels = np.count_nonzero(summed_matrix > 620)
    ratio_white_pixels = num_white_pixels / num_pixels
    
    green_concentration = np.mean(image[1])
    blue_concentration = np.mean(image[2])
    
    return ratio_white_pixels, green_concentration, blue_concentration

def select_k_best_regions(regions, k=20):
    """
    Args:
        regions               list           list of 2-component tuples first component the region, 
                                             second component the ratio of white pixels
                                             
        k                     int            number of regions to select
    """
    regions = [x for x in regions if x[3] > 180 and x[4] > 180]
    k_best_regions = sorted(regions, key=lambda tup: tup[2])[:k]
    return k_best_regions


def display_images(regions, title):
    fig, ax = plt.subplots(5, 4, figsize=(15, 15))
    
    for i, region in regions.items():
        ax[i//4, i%4].imshow(region)
    
    fig.suptitle(title)
    
    
def get_k_best_regions(coordinates, image, window_size=512):
    regions = {}
    for i, tup in enumerate(coordinates):
        x, y = tup[0], tup[1]
        regions[i] = image[x : x+window_size, y : y+window_size, :]
    
    return regions


def glue_to_one_picture(image_patches, window_size=200, k=16):
    side = int(np.sqrt(k))
    image = np.zeros((side*window_size, side*window_size, 3), dtype=np.int16)
        
    for i, patch in image_patches.items():
        x = i // side
        y = i % side
        image[
            x * window_size : (x+1) * window_size,
            y * window_size : (y+1) * window_size,
            :
        ] = patch
    
    return image