In [None]:
import tensorflow as tf
import pathlib,os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import skimage.io
from sklearn.model_selection import StratifiedKFold
import time
import albumentations
import tensorflow_addons as tfa
import tensorflow_hub as hub
from tensorflow.keras import layers,models

AUTO = tf.data.experimental.AUTOTUNE

import random as python_random
np.random.seed(123)
python_random.seed(123)
tf.random.set_seed(1234)
os.environ['PYTHONHASHSEED']=str(0)

import cv2
import IPython.display as display
from google.cloud import storage
from tqdm import tqdm

In [None]:
kaggle_data= '../input/prostate-cancer-grade-assessment'

data_dir=kaggle_data+'/train_images'
ds_dir = pathlib.Path(data_dir)

label_dir=kaggle_data+'/train.csv'
train_labels = pd.read_csv(label_dir).set_index('image_id')

test_path=kaggle_data+'/test_images'

In [None]:
black_list= ['3790f55cad63053e956fb73027179707','014006841b9807edc0ff277c4ab29b91','00d8a8c04886379e266406fdeff81c45','6f310463d3868e86be87adddeccdde19', '6ef357192c2530ee54e1bd38a3231e00', '511a33b7aeb1153407ed2d55cae001c8', '2fb68d6713a52e322692dbc99ac82444', '8a415b0e974861fa00ae9d88f9e3b980', 'eda92317b583435a810cac1dc7bb8025', 'c9dbdd9c9fc0eab0d235499488b26c53', '99d9afb22b65ea97fcd21ce67e1ddb6c', '0c46c60ae2ef49657bc843707162ba6e', '85b7018a9e5287342a1392fb02ce24a1', '4e80b5738f591c7d0d91889c2bdfd39d', 'f13cd8ec6e6fde523a0a065b62086d3c', 'aebe292567e0f8447fdc94994189a80a', '6b512a45bd8ca759e655c2d551dec2d9', 'be403ce415609008605d63396869ed2e', 'e3360926180928287a4d96973e10926a', '3752b697cae9f81a9d5ffe44dac58e7a', 'a0ac3589042f9e99d31b521b5b56ac06', 'f5675cb89a120e225ca8929b64a5af79', '6d569809f11e6e53918dfb1609fb0d83', '7817ae2d392ba4f6cb9104fbe70b6274', 'cd382fdc26516c634b2314aa870bfe80', '1cdd6def1e3099a9763938457cf0b4be', '0186f4811c9d089707d9dc7460160d88', 'e9d628364cf51891028163e0cfca628c', '9effeea56c413b92340b89d1240769c1', '7a0a36bc6119e3d78474e6c8ca875725', '374d5401159d9bf39ce20b395d82c0b4', '7fa4634ab59a7832bc877fef162eacaa', 'ee182a14e532b122f40d561d87eb2136', 'c0a0956a39319920d02c5c4eb30c5e10', '441265c6b4598e9bcd10bc10eb6293cc', '8ae069858aecbad846f4d69d405f9bd6', 'b13961504ea859ff34a150bc19fed335', '476f0dfb144aee7d5422dcc3b2b97a9f', 'bc93d165d96e4fa4883f130b3f7b9885', 'ac9d05fa3f4fafb474fb96f9f8ab71ac', '1438f19e07c389b47fd5219ca62f9f0a', '479200a381febadfd767615fbe77c3ea', 'fc6a695ba44f4b64425c522f590bac48', '046bac77a58c1be84a6418904e755280', '5c083ab21fc57c0954468ab46aa7fb16', 'cca735c397880e88192e97d68b97754e', 'a579110fe1e670847d9d146404597750', '8dbedd97ed2b7b01525d6800d52ae073', '004dd32d9cd167d9cc31c13b704498af', 'cdf40333dfe2afec1a4c54d9eeb1ec7a', '09d4be69a2330cd49298bf30d29cc4e5', 'f73951fddf77034c9fd44cb19f5fe6b5', 'aef75d4c390d838aabe56e2d601b6a13', 'b3a2dc7547bc580c6f3923c61db42051', '774c9b631a29f191836b1078a6c3a67c', '836ca5d73c88ad94fb980ca3e5e65da7', 'bfdbe56fb7fc4d7b3d151370f897d503', '06ef49a7b77e883f089cfdd80642d6f0', '8e25584bd03155d24a2adc00517a38e8', 'dc2ec851fcbf594f11b023387ac15003', 'a0150f4d6d9f6f3b2b5a240b099df000']
train_labels=train_labels.drop(black_list)

In [None]:
def read_tiff(img_path,level):      
    
    #read the image and get the label
    img      = skimage.io.MultiImage(str(img_path))[level]
    im_ID    = img_path.split(os.sep)[-1].split('.')[0]
    label    = train_labels.loc[im_ID,'isup_grade']        # get the label from the csv file
                                                           # make sure the output type is updated in open_crop_tiff                                                           
    return img, label, im_ID

def augment(image_array):
    
    # Augmentations    
    op_train = albumentations.Compose([
        albumentations.VerticalFlip(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.Rotate(limit=90,border_mode=4,p=0.5),    #mode 1 wraps around
        #albumentations.ElasticTransform(alpha=1, sigma=50, alpha_affine=50,p=0.5),
        albumentations.HueSaturationValue(hue_shift_limit=(0,20), sat_shift_limit=0, val_shift_limit=0,p=0.5),
        #albumentations.RGBShift(r_shift_limit=20, g_shift_limit=20, b_shift_limit=20,p=0.5),
        albumentations.GaussianBlur(blur_limit=3,p=0.25)
    ])  
    
    return op_train(image=image_array)['image']        


def tile_tiff(img,level,n_tiles):
    # get the patches with tissue    
    
    if (level==1): tile_size=256                          # tile size depends on the downsampling of the level
    elif (level == 2): tile_size =128
    else: raise Exception("level is not 1 or 2")
    
    mode=0
    sub_imgs=False
    is_rand=False
    
    tiles = []
    h, w, c = img.shape
    pad_h = (tile_size - h % tile_size) % tile_size + ((tile_size * mode) // 2)
    pad_w = (tile_size - w % tile_size) % tile_size + ((tile_size * mode) // 2)

    img2 = np.pad(img,[[pad_h // 2, pad_h - pad_h // 2], [pad_w // 2,pad_w - pad_w//2], [0,0]], constant_values=255)
    img3 = img2.reshape(img2.shape[0] // tile_size,tile_size,img2.shape[1] // tile_size,tile_size, 3)
    
    img3 = img3.transpose(0,2,1,3,4).reshape(-1, tile_size, tile_size,3)
    n_tiles_with_info = (img3.reshape(img3.shape[0],-1).sum(1) < tile_size ** 2 * 3 * 255).sum()
    if len(img) < n_tiles:
        img3 = np.pad(img3,[[0,N-len(img3)],[0,0],[0,0],[0,0]], constant_values=255)
    idxs = np.argsort(img3.reshape(img3.shape[0],-1).sum(-1))[:n_tiles]
    img3 = img3[idxs]
    for i in range(len(img3)):
        tiles.append({'img':img3[i], 'idx':i})
    
    # create the patchwork  
    if is_rand:
        idxes = np.random.choice(list(range(n_tiles)), n_tiles, replace=False)
    else:
        idxes = list(range(n_tiles))
    idxes = np.asarray(idxes) + n_tiles if sub_imgs else idxes

    n_row_tiles = int(np.sqrt(n_tiles))
    images = np.zeros((tile_size * n_row_tiles, tile_size * n_row_tiles, 3))
    for h in range(n_row_tiles):
        for w in range(n_row_tiles):
            i = h * n_row_tiles + w
    
            if len(tiles) > idxes[i]:
               this_img = tiles[idxes[i]]['img']
            else:
                this_img = np.ones((tile_size, tile_size, 3)).astype(np.uint8) * 255
            this_img = 255 - this_img
            h1 = h * tile_size
            w1 = w * tile_size
            images[h1:h1+tile_size, w1:w1+tile_size] = this_img

    images = 255 - images
    #images = images.astype(np.float32)
    #images /= 255                                              # can't send as string  if I convert to float
    
    return images
    
def tile_and_aug_tiff(img_path_tensor,level=1,aug=1,n_tiles=36):      # combining three python functions to be wrapped
    
    #read and get the label
    img,label,im_ID = read_tiff(img_path_tensor,level)
    
    # Augment the image
    if(aug): img= augment(img)

    # get the patches with tissue    
    images=tile_tiff(img,level,n_tiles)
    
    return images,label,im_ID

In [None]:
def create_folds(train_labels,n_fold,debug=0):
    input_DF= train_labels.copy().reset_index(drop=False) # drop the image_ID as the index so that the you can index with skf's results

    skf=StratifiedKFold(n_splits=n_fold, shuffle=True, random_state=42)
    for f, (train_idx,test_idx) in enumerate(skf.split(input_DF,input_DF['isup_grade'])):
        input_DF.loc[test_idx,'test_fold']= f  #setting wrt test data indexes as they don't overlap

    if debug: display(input_DF)    
    
    return input_DF

In [None]:
def get_list(fold, labels_DF,debug=0):

    train_list=[]
    im_ID_train=labels_DF.loc[:,'image_id']
    for ID in im_ID_train:   train_list.append(os.path.join(data_dir,ID+'.tiff')) # changed from data_simple to data dir
    if(debug): 
        print('train')
        for i in train_list[0:2]: print(i)
        print("Num of samples: ",len(train_list),"\n")

    return train_list

In [None]:
# The following functions can be used to convert a value to a type compatible
# with tf.Example.
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _float_feature(list_of_floats): # float32
  return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats))

#=========================================

def image_example(image_string, label,im_ID_bytes):
  """
  Creates a tf.Example message ready to be written to a file.
  """

  image_feature_description = {
      "image": _bytes_feature(image_string),
      "label": _int64_feature(label),
      "im_ID": _bytes_feature(im_ID_bytes)
      }

  return tf.train.Example(features=tf.train.Features(feature=image_feature_description))

#=========================================

def to_tfrecord(rec_file,path_list,level,aug,n_tiles):
    
    with tf.io.TFRecordWriter(rec_file) as writer:

      for img_path in tqdm(path_list):
        
        img,label,im_ID= tile_and_aug_tiff(img_path,level,aug,n_tiles)    
        image_string = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION,1])[1].tostring()   #0-9 compression 0 =no compression, 1 optimum speed/size
        #image_string = cv2.imencode('.tiff', img,[cv2.IMWRITE_TIFF_COMPRESSION,1])[1].tostring()   # compression code 1 stands for LZW
        im_ID_bytes=tf.compat.as_bytes(im_ID)
        
        tf_example = image_example(image_string, label,im_ID_bytes)         # storing all the features in the tf.Example message.
        writer.write(tf_example.SerializeToString())            # write the example messages to a file named images.tfrecords


In [None]:
def read_tfrecord(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),  # tf.string = bytestring (not text string)
        "label": tf.io.FixedLenFeature([], tf.int64),   # shape [] means scalar
        "im_ID": tf.io.FixedLenFeature([], tf.string),  # tf.string = bytestring (not text string)
    }
    # decode the TFRecord
    example = tf.io.parse_single_example(example, features)
    image = tf.image.decode_png(example['image'], channels=0)  
    image = tf.reverse(image, axis=[-1])
    image = tf.image.convert_image_dtype(image, tf.float32)           #augmentation can't handle the normalization here

    label = example['label']
    im_ID = example['im_ID']
    return image, label, im_ID

def display_9_images_from_dataset(dataset):
  plt.figure(figsize=(13,13))
  subplot=331
  for i, (image, label,im_ID) in enumerate(dataset):
    plt.subplot(subplot)
    plt.axis('off')
    plt.imshow(image.numpy())
    plt.title(im_ID.numpy().decode("utf-8"), fontsize=14,color='w')
    subplot += 1
    if i==8:
      break
  plt.tight_layout()
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()   
        


In [None]:
input_DF=create_folds(train_labels,n_fold=5)
train_list =get_list(fold=0, labels_DF=input_DF,debug=0)
train_list_small=train_list[0:9]
print(train_list_small)

In [None]:
# define a filename to store preprocessed image data:
record_file = 'images.tfrecords'
level=1
file_list=train_list_small
list_DF=train_labels
aug=1
n_tiles=36

to_tfrecord(record_file,file_list,level,aug,n_tiles)
!du -sh {record_file}

In [None]:
# to read TFRecord file use TFRecordDataset
image_dataset = tf.data.TFRecordDataset(record_file)
debug=1

if debug:
    parsed_dataset=image_dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    display_9_images_from_dataset(parsed_dataset) 

In [None]:
import http
http.client.HTTPConnection.debuglevel=0                   #set it to 5 for debugging
STORAGE_CLIENT = storage.Client(project='panda-285106')   # For uploading to GCS buckets
WORKING_DIRECTORY = "./"                                              
FILE_PATTERN = data_dir + '/*.tiff'

def create_bucket(dataset_name):
    """Creates a new bucket. https://cloud.google.com/storage/docs/ """
    bucket = STORAGE_CLIENT.create_bucket(dataset_name)
    print('Bucket {} created'.format(bucket.name))

def upload_blob(bucket_name, source_file_name, destination_blob_name):
    """Uploads a file to the bucket. https://cloud.google.com/storage/docs/ """
    bucket = STORAGE_CLIENT.get_bucket(bucket_name)
    blob = bucket.blob(destination_blob_name, chunk_size=1*1024*1024)    #streaming in 1MB chunks
    blob.upload_from_filename(source_file_name)
    print('File {} uploaded to {}.'.format(
        source_file_name,
        destination_blob_name))
    
def list_blobs(bucket_name):
    """Lists all the blobs in the bucket. https://cloud.google.com/storage/docs/"""
    blobs = STORAGE_CLIENT.list_blobs(bucket_name)
    for blob in blobs:
        print(blob.name)
        
def download_to_kaggle(bucket_name,destination_directory,file_name):
    """Takes the data from your GCS Bucket and puts it into the working directory of your Kaggle notebook"""
    os.makedirs(destination_directory, exist_ok = True)
    full_file_path = os.path.join(destination_directory, file_name)
    blobs = STORAGE_CLIENT.list_blobs(bucket_name)
    for blob in blobs:
        blob.download_to_filename(full_file_path)

In [None]:
bucket_name = 'lvl1-36-256x256'         
try:
    create_bucket(bucket_name)   
except:
    print('Couldnt create a bucket')
    

In [None]:
level=1
list_DF=train_labels
aug=0                           #can't augment the validation set
n_tiles=36

shard_count=50
shard_size=len(train_list)//shard_count
print('',len(train_list),shard_size)
start=0
end=shard_size

for shard in range(shard_count):
    tfr_file="{}-shard{}.tfrec".format(bucket_name,shard)
    record_file = WORKING_DIRECTORY + tfr_file
    file_list=train_list[start:end]
    start+= shard_size
    end += shard_size
    print(record_file,'n_img: ',len(file_list))

    to_tfrecord(record_file,file_list,level,aug,n_tiles)
    !du -sh {record_file}

    tic = time.perf_counter()
    upload_blob(bucket_name, record_file, tfr_file)
    toc = time.perf_counter()
    dt  = toc-tic
    print(f"Upload took {dt//3600:0.0f} hours {(dt-((dt//3600)*3600))//60:0.0f} minutes {dt%60:0.4f} seconds\n")
    !rm $tfr_file


In [None]:
debug=1

if(debug):
    destination_directory = '/kaggle/working/download/'       
    download_to_kaggle(bucket_name,destination_directory,tfr_file)

    image_dataset = tf.data.TFRecordDataset(os.path.join(destination_directory,tfr_file))
    parsed_dataset=image_dataset.map(read_tfrecord, num_parallel_calls=AUTO)
    display_9_images_from_dataset(parsed_dataset) 