In [None]:
import tensorflow as tf
import pathlib,os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import skimage.io
from sklearn.model_selection import StratifiedKFold
import time
import albumentations
import tensorflow_addons as tfa
import tensorflow_hub as hub
from tensorflow.keras import layers,models

AUTO = tf.data.experimental.AUTOTUNE

import random as python_random
np.random.seed(1234)
python_random.seed(1234)
tf.random.set_seed(1234)
os.environ['PYTHONHASHSEED']=str(0)
SEED=1234

from kaggle_datasets import KaggleDatasets
!pip install -q efficientnet
import efficientnet.tfkeras as efn
from tensorflow.keras.models import model_from_json
from sklearn.metrics import cohen_kappa_score



In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)


In [None]:
#GCS_DS_PATH = KaggleDatasets().get_gcs_path('prostate-cancer-grade-assessment') # you can list the bucket with "!gsutil ls $GCS_DS_PATH"
GCS_PATH = 'gs://lvl1-36-256x256'
shard_num = 50
files=[]

for i_shard in range(shard_num):
    file = GCS_PATH + '/lvl1-36-256x256-shard'+str(i_shard)+'.tfrec'
    files.append(file)

In [None]:
kaggle_data= '../input/prostate-cancer-grade-assessment'

label_dir=kaggle_data+'/train.csv'
train_labels = pd.read_csv(label_dir).set_index('image_id')
black_list= ['3790f55cad63053e956fb73027179707','014006841b9807edc0ff277c4ab29b91','00d8a8c04886379e266406fdeff81c45','6f310463d3868e86be87adddeccdde19', '6ef357192c2530ee54e1bd38a3231e00', '511a33b7aeb1153407ed2d55cae001c8', '2fb68d6713a52e322692dbc99ac82444', '8a415b0e974861fa00ae9d88f9e3b980', 'eda92317b583435a810cac1dc7bb8025', 'c9dbdd9c9fc0eab0d235499488b26c53', '99d9afb22b65ea97fcd21ce67e1ddb6c', '0c46c60ae2ef49657bc843707162ba6e', '85b7018a9e5287342a1392fb02ce24a1', '4e80b5738f591c7d0d91889c2bdfd39d', 'f13cd8ec6e6fde523a0a065b62086d3c', 'aebe292567e0f8447fdc94994189a80a', '6b512a45bd8ca759e655c2d551dec2d9', 'be403ce415609008605d63396869ed2e', 'e3360926180928287a4d96973e10926a', '3752b697cae9f81a9d5ffe44dac58e7a', 'a0ac3589042f9e99d31b521b5b56ac06', 'f5675cb89a120e225ca8929b64a5af79', '6d569809f11e6e53918dfb1609fb0d83', '7817ae2d392ba4f6cb9104fbe70b6274', 'cd382fdc26516c634b2314aa870bfe80', '1cdd6def1e3099a9763938457cf0b4be', '0186f4811c9d089707d9dc7460160d88', 'e9d628364cf51891028163e0cfca628c', '9effeea56c413b92340b89d1240769c1', '7a0a36bc6119e3d78474e6c8ca875725', '374d5401159d9bf39ce20b395d82c0b4', '7fa4634ab59a7832bc877fef162eacaa', 'ee182a14e532b122f40d561d87eb2136', 'c0a0956a39319920d02c5c4eb30c5e10', '441265c6b4598e9bcd10bc10eb6293cc', '8ae069858aecbad846f4d69d405f9bd6', 'b13961504ea859ff34a150bc19fed335', '476f0dfb144aee7d5422dcc3b2b97a9f', 'bc93d165d96e4fa4883f130b3f7b9885', 'ac9d05fa3f4fafb474fb96f9f8ab71ac', '1438f19e07c389b47fd5219ca62f9f0a', '479200a381febadfd767615fbe77c3ea', 'fc6a695ba44f4b64425c522f590bac48', '046bac77a58c1be84a6418904e755280', '5c083ab21fc57c0954468ab46aa7fb16', 'cca735c397880e88192e97d68b97754e', 'a579110fe1e670847d9d146404597750', '8dbedd97ed2b7b01525d6800d52ae073', '004dd32d9cd167d9cc31c13b704498af', 'cdf40333dfe2afec1a4c54d9eeb1ec7a', '09d4be69a2330cd49298bf30d29cc4e5', 'f73951fddf77034c9fd44cb19f5fe6b5', 'aef75d4c390d838aabe56e2d601b6a13', 'b3a2dc7547bc580c6f3923c61db42051', '774c9b631a29f191836b1078a6c3a67c', '836ca5d73c88ad94fb980ca3e5e65da7', 'bfdbe56fb7fc4d7b3d151370f897d503', '06ef49a7b77e883f089cfdd80642d6f0', '8e25584bd03155d24a2adc00517a38e8', 'dc2ec851fcbf594f11b023387ac15003', 'a0150f4d6d9f6f3b2b5a240b099df000']
train_labels=train_labels.drop(black_list)

In [None]:
def create_folds(train_labels,n_fold, debug=0):
    input_DF= train_labels.copy().reset_index(drop=False) # drop the image_ID as the index so that the you can index with skf's results

    skf=StratifiedKFold(n_splits=n_fold, shuffle=True, random_state=42)
    for f, (train_idx,test_idx) in enumerate(skf.split(input_DF,input_DF['isup_grade'])):
        input_DF.loc[test_idx,'test_fold']= f  #setting wrt test data indexes as they don't overlap

    if(debug): display(input_DF)    
    
    return input_DF



def get_train_val_list(fold, labels_DF,debug=0):

    train_list=[]
    im_ID_train=labels_DF.loc[labels_DF['test_fold']!=fold,'image_id']
    for ID in im_ID_train:   train_list.append(ID) # changed from data_simple to data dir
    if(debug): 
        print('train')
        for i in train_list[0:2]: print(i)
        print("Num of samples: ",len(train_list),"\n")

    val_list=[]
    im_ID_test=labels_DF.loc[labels_DF['test_fold']==fold,'image_id']
    for ID in im_ID_test:   val_list.append(ID)  # changed from data_simple to data dir
    if(debug): 
        print('test')
        for i in test_list[0:2]: print(i)
        print("Num of samples: ",len(val_list))

    return train_list, val_list


In [None]:
def read_tfrecord(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),  # tf.string = bytestring (not text string)
        "label": tf.io.FixedLenFeature([], tf.int64),   # shape [] means scalar
        "im_ID": tf.io.FixedLenFeature([], tf.string),  # tf.string = bytestring (not text string)
    }
    # decode the TFRecord
    example = tf.io.parse_single_example(example, features)
    
    image = tf.image.decode_png(example['image'], channels=0)  
    image = tf.reverse(image, axis=[-1])

    image = tf.cast(image, tf.float32) / 255.0                # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*[IMAGE_DIM,IMAGE_DIM], 3])               # explicit size needed for TPU
    
    
    label = example['label']
    im_ID = example['im_ID']
    return image, label, im_ID

def load_dataset(filenames, ordered = False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # Diregarding data order. Order does not matter since we will be shuffling the data anyway
    
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
        
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # use data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_tfrecord, num_parallel_calls = AUTO) # returns a dataset of (image, label) pairs if labeled = True or (image, id) pair if labeld = False
    return dataset

def data_augment(image, label):
    image = tf.image.random_flip_left_right(image, seed= SEED)
    tf.image.random_hue(image, max_delta=0.2, seed = SEED)
    return image, label   


def get_training_dataset(dataset,im_ID_list,do_aug=True):
    
    dataset=dataset.filter(lambda image,label,im_ID: tf.reduce_any(tf.equal(im_ID,im_ID_list)) )
    dataset=dataset.map(lambda image,label,im_ID: [image,label])
    
    if do_aug: dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(512)
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset


def get_validation_dataset(val_dataset,im_ID_list):

    val_dataset=val_dataset.filter(lambda image,label,im_ID: tf.reduce_any(tf.equal(im_ID,im_ID_list)) )
    val_dataset=val_dataset.map(lambda image,label,im_ID: [image,label])

    val_dataset = val_dataset.batch(BATCH_SIZE,drop_remainder=True)
    val_dataset = val_dataset.cache()
    val_dataset = val_dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return val_dataset

def display_9_images_from_dataset(dataset):
  plt.figure(figsize=(13,13))
  subplot=331
  for i, (image, label,im_ID) in enumerate(dataset):
    plt.subplot(subplot)
    plt.axis('off')
    plt.imshow(image.numpy())
    plt.title(im_ID.numpy().decode("utf-8"), fontsize=14,color='w')
    subplot += 1
    if i==8:
      break
  plt.tight_layout()
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()   

In [None]:
LR_START = 0.0005
LR_MAX = 0.00025 * strategy.num_replicas_in_sync
LR_MIN = 0.00001
LR_RAMPUP_EPOCHS = 5
LR_SUSTAIN_EPOCHS = 0
LR_EXP_DECAY = .8

def warm_up(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr
    

def define_model(IMAGE_DIM):
    enet = efn.EfficientNetB0(
            input_shape=(IMAGE_DIM,IMAGE_DIM,3),
            weights='imagenet',
            include_top=False,
            pooling='max'
        )
    enet.trainable = True
    
    model = tf.keras.Sequential([
        enet,
        tf.keras.layers.Dense(6, kernel_regularizer=tf.keras.regularizers.l2(0.01))
    ])
    

    model.compile(
      optimizer=tf.keras.optimizers.Adam(),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=['sparse_categorical_accuracy']
    )
    
    return model


def run_training(model,train_batches,test_batches,n_epochs,steps_per_epoch):

    tic = time.perf_counter()

    lr_callback = tf.keras.callbacks.LearningRateScheduler(warm_up, verbose=True)
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0.001, patience=3)

    history = model.fit(train_batches, epochs=n_epochs, validation_data=val_batches, steps_per_epoch=steps_per_epoch, callbacks=[lr_callback,early_stop])

    toc = time.perf_counter()
    dt  = toc-tic
    print(f"Operation took {dt//3600:0.0f} hours {(dt-((dt//3600)*3600))//60:0.0f} minutes {dt%60:0.4f} seconds\n")
    
    return model, history

def save_json_and_weights(model_name,fold):
    json_file=model_name+'_fold'+str(fold)+'.json'
    weight_file=model_name+'_fold'+str(fold)+'.h5'

    model_json = efnet_b0.to_json()
    with open(json_file, "w") as json_handle:
        json_handle.write(model_json)

    efnet_b0.save_weights(weight_file)
    print("Saved "+model_name+" model to disk")
    
    
def load_json_and_weights(model_name,fold):
    
    json_file=model_name+'_fold'+str(fold)+'.json'
    weight_file=model_name+'_fold'+str(fold)+'.h5'
    
    
    with open(json_file, "r") as json_handle:
        loaded_model_json= json_handle.read()
    
    loaded_model = model_from_json(loaded_model_json)
    loaded_model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['sparse_categorical_accuracy']
        #metrics=[tfa.metrics.CohenKappa(weightage='quadratic', num_classes=6,sparse_labels=True),'sparse_categorical_accuracy' ]
    )
    
    if (os.path.exists(weight_file)):
        loaded_model.load_weights(weight_file)
    
    print("Loaded "+model_name+" model from disk")
    
    return loaded_model

In [None]:
check_DS=0  # debugging flag for datasets

n_fold=5
fold=0
IMAGE_DIM=1536
BATCH_SIZE = 8 * strategy.num_replicas_in_sync


if(check_DS):

    input_DF=create_folds(train_labels,n_fold)
    train_list, val_list = get_train_val_list(fold, input_DF, debug=0)

    #train_batches = get_training_dataset(load_dataset(files,ordered=False),im_ID_list= train_list, do_aug=aug).unbatch()
    val_batches   = get_validation_dataset(load_dataset(files,ordered=False),im_ID_list= val_list).unbatch()

    #display_9_images_from_dataset(all_images)

    #for i, (image, label) in enumerate(train_batches):   pass #doesn't work if DS is repeated!!
    #print(i)

    for i, (image, label) in enumerate(val_batches):   pass #doesn't work if DS is repeated!!
    print(i)

In [None]:
n_fold=5
debug=0
aug=True
input_dim= 1536
BATCH_SIZE = 8 * strategy.num_replicas_in_sync
IMAGE_DIM = 1536
model_name='efnet_b0'


n_epochs=25



input_DF=create_folds(train_labels,n_fold)

for fold in range(0,1):
    print("\n\nFOLD "+str(fold)+"\n\n")
    tic = time.perf_counter()

    train_list, val_list = get_train_val_list(fold, input_DF, debug)
    steps_per_epoch = len(train_list)//BATCH_SIZE
    
    train_batches = get_training_dataset(load_dataset(files,ordered=False),im_ID_list= train_list, do_aug=aug)
    val_batches   = get_validation_dataset(load_dataset(files,ordered=False),im_ID_list= val_list)
    
    print("Starting training")
    with strategy.scope():
        efnet_b0 = define_model(IMAGE_DIM)
    
    efnet_b0,history= run_training(efnet_b0,train_batches,val_batches,n_epochs,steps_per_epoch)

    toc = time.perf_counter()
    dt  = toc-tic
    print(f"Training took {dt//3600:0.0f} hours {(dt-((dt//3600)*3600))//60:0.0f} minutes {dt%60:0.4f} seconds\n")
    
    save_json_and_weights(model_name,fold)


In [None]:
#check whether model is saved

loaded_model= load_json_and_weights(model_name,fold)

results = loaded_model.evaluate(val_batches,batch_size= BATCH_SIZE)
print("test loss, test acc:", results)

In [None]:
tic = time.perf_counter()

val_batches_ordered = get_validation_dataset(load_dataset(files,ordered=True),im_ID_list= val_list)

pred = loaded_model.predict(val_batches_ordered)
pred_probs= np.argmax(pred, axis = 1)

toc = time.perf_counter()
dt  = toc-tic
print(f"Prediction took {dt//3600:0.0f} hours {(dt-((dt//3600)*3600))//60:0.0f} minutes {dt%60:0.4f} seconds\n")
    

In [None]:
labels=[]
for image,label in val_batches_ordered.unbatch():
    labels.append(label)

actuals=np.array(labels)

In [None]:
kappa = cohen_kappa_score(actuals, pred_probs, labels=None, weights= 'quadratic', sample_weight=None)
print('\nValid Cohen\'s Kappa : {}'.format(kappa))
print(tf.math.confusion_matrix(actuals, pred_probs, num_classes=6, weights=None, dtype=tf.dtypes.int32,name=None).numpy())