# Preprocessing

In [None]:
import tensorflow as tf
import math, re, os, random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
from tensorflow import keras
from functools import partial
from sklearn.model_selection import train_test_split
import tensorflow.keras.backend as K
from sklearn.model_selection import KFold


from kaggle_datasets import KaggleDatasets
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix

print("Tensorflow version " + tf.__version__)

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
GCS_PATH = "../input/cassava-leaf-disease-classification"
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMAGE_SIZE = [512, 512]
CLASSES = ['0', '1', '2', '3', '4']
EPOCHS = 25


SEED = 752
SKIP_VALIDATION = False
TTA_NUM = 5

random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)


In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example, labeled):
    tfrecord_format = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64)
    } if labeled else {
        "image": tf.io.FixedLenFeature([], tf.string),
        "image_name": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    if labeled:
        label = tf.cast(example['target'], tf.int32)
        return image, label
    idnum = example['image_name']
    return image, idnum

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE)
    return dataset



In [None]:
TRAINING_FILENAMES, VALID_FILENAMES = train_test_split(
    tf.io.gfile.glob(GCS_PATH + '/train_tfrecords/ld_train*.tfrec'),
    test_size=0.35, random_state=5
)

TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test_tfrecords/ld_test*.tfrec')

In [None]:
def random_blockout(img, sl=0.1, sh=0.2, rl=0.4):
    p=random.random()
    if p>=0.25:
        w, h, c = IMAGE_SIZE[0], IMAGE_SIZE[1], 3
        origin_area = tf.cast(h*w, tf.float32)

        e_size_l = tf.cast(tf.round(tf.sqrt(origin_area * sl * rl)), tf.int32)
        e_size_h = tf.cast(tf.round(tf.sqrt(origin_area * sh / rl)), tf.int32)

        e_height_h = tf.minimum(e_size_h, h)
        e_width_h = tf.minimum(e_size_h, w)

        erase_height = tf.random.uniform(shape=[], minval=e_size_l, maxval=e_height_h, dtype=tf.int32)
        erase_width = tf.random.uniform(shape=[], minval=e_size_l, maxval=e_width_h, dtype=tf.int32)

        erase_area = tf.zeros(shape=[erase_height, erase_width, c])
        erase_area = tf.cast(erase_area, tf.uint8)

        pad_h = h - erase_height
        pad_top = tf.random.uniform(shape=[], minval=0, maxval=pad_h, dtype=tf.int32)
        pad_bottom = pad_h - pad_top

        pad_w = w - erase_width
        pad_left = tf.random.uniform(shape=[], minval=0, maxval=pad_w, dtype=tf.int32)
        pad_right = pad_w - pad_left

        erase_mask = tf.pad([erase_area], [[0,0],[pad_top, pad_bottom], [pad_left, pad_right], [0,0]], constant_values=1)
        erase_mask = tf.squeeze(erase_mask, axis=0)
        erased_img = tf.multiply(tf.cast(img,tf.float32), tf.cast(erase_mask, tf.float32))

        return tf.cast(erased_img, img.dtype)
    else:
        return tf.cast(img, img.dtype)




def data_augment(image, label):
    # Thanks to the dataset.prefetch(AUTO) statement in the following function this happens essentially for free on TPU. 
    # Data pipeline code is executed on the "CPU" part of the TPU while the TPU itself is computing gradients.
    image = tf.image.random_flip_left_right(image)
    image = random_blockout(image)
    return image, label
def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    
    
    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))
def transform(image,label):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    rot = 15. * tf.random.normal([1],dtype='float32')
    shr = 5. * tf.random.normal([1],dtype='float32') 
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    h_shift = 16. * tf.random.normal([1],dtype='float32') 
    w_shift = 16. * tf.random.normal([1],dtype='float32') 
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3]),label


In [None]:
def get_training_dataset(TRAINING_FILENAMES):
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)  
    dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)  
    dataset = dataset.map(transform, num_parallel_calls=AUTOTUNE)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset
def get_validation_dataset(VALID_FILENAMES,ordered=False):
    dataset = load_dataset(VALID_FILENAMES, labeled=True, ordered=ordered) 
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset
def get_test_dataset(TEST_FILENAMES,ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)
NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALID_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)

print('Dataset: {} training images, {} validation images, {} (unlabeled) test images'.format(
    NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))


# Efficient Net

In [None]:
import sys
package_path = '../input/efficientnet/'
sys.path.append(package_path)

package_path = '../input/kerasapplications'
sys.path.append(package_path)

In [None]:
import efficientnet.tfkeras
# from tensorflow.keras.models import load_model

efficientnet_model = tf.keras.models.load_model('../input/cassava-leaf-disease-training/effcient_net.h5')

# Dense Net

In [None]:
densenet_model = tf.keras.models.load_model('../input/cassava-leaf-disease-training/dense_net.h5')

# Resnet50

In [None]:
resnet50_models = [tf.keras.models.load_model(f'../input/cassava-leaf-disease-resnet50/resnet50/fold-{i}.h5') for i in range(4,5)]

In [None]:
for model_index,model in enumerate(resnet50_models):
    for i, layer in enumerate(model.layers):
        model.layers[i]._name  = 'resnet50_'+str(model_index)+"_" + str(i)

# Resnet 101

In [None]:
resnet101_models = [tf.keras.models.load_model(f'../input/cassava-leaf-disease-resnet101/resnet101/fold-{i}.h5') for i in range(4,5)]

In [None]:
for model_index,model in enumerate(resnet101_models):
    for i, layer in enumerate(model.layers):
        model.layers[i]._name  = 'resnet101_'+str(model_index)+"_" + str(i)

# ResNext 101

In [None]:
resnext101_models = [tf.keras.models.load_model(f'../input/cassava-leaf-disease-resnext101/resnext101/fold-{i}.h5') for i in range(4,5)]

In [None]:
for model_index,model in enumerate(resnext101_models):
    for i, layer in enumerate(model.layers):
        model.layers[i]._name  = 'resnext101_'+str(model_index)+"_" + str(i)

# Ensemble

In [None]:
all_models = [efficientnet_model,densenet_model]
# all_models.extend(resnet50_models)
# all_models.extend(resnet101_models)
all_models.extend(resnext101_models)

# Freeze layers

In [None]:
for i, model in enumerate(all_models):
    for layer in model.layers:
        layer.trainable = False

# Ensembled Model

In [None]:
with strategy.scope():
    input_ = tf.keras.layers.Input(shape=(512, 512, 3))
    
    ensemble_outputs = []
    
    for model in all_models:

        ensemble_output = model(input_) 
        ensemble_outputs.append(ensemble_output)
        
    merge = tf.keras.layers.concatenate(ensemble_outputs)
    
    merge = tf.keras.layers.Dense(100, activation='relu',name="layer_dense_prelim")(merge)
    output = tf.keras.layers.Dense(len(CLASSES), activation='softmax',name="ensembled_output")(merge)

    ensembled_model = tf.keras.models.Model(inputs=input_, outputs=output)

# Bi Tempered Loss

In [None]:
def for_loop(num_iters, body, initial_args):
  """Runs a simple for-loop with given body and initial_args.
  Args:
    num_iters: Maximum number of iterations.
    body: Body of the for-loop.
    initial_args: Args to the body for the first iteration.
  Returns:
    Output of the final iteration.
  """
  for i in range(num_iters):
    if i == 0:
      outputs = body(*initial_args)
    else:
      outputs = body(*outputs)
  return outputs


def log_t(u, t):
  """Compute log_t for `u`."""

  def _internal_log_t(u, t):
    return (u**(1.0 - t) - 1.0) / (1.0 - t)

  return tf.cond(
      tf.equal(t, 1.0), lambda: tf.log(u),
      functools.partial(_internal_log_t, u, t))


def exp_t(u, t):
  """Compute exp_t for `u`."""

  def _internal_exp_t(u, t):
    return tf.nn.relu(1.0 + (1.0 - t) * u)**(1.0 / (1.0 - t))

  return tf.cond(
      tf.equal(t, 1.0), lambda: tf.exp(u),
      functools.partial(_internal_exp_t, u, t))


def compute_normalization_fixed_point(activations, t, num_iters=5):
  """Returns the normalization value for each example (t > 1.0).
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    t: Temperature 2 (> 1.0 for tail heaviness).
    num_iters: Number of iterations to run the method.
  Return: A tensor of same rank as activation with the last dimension being 1.
  """

  mu = tf.reduce_max(activations, -1, keep_dims=True)
  normalized_activations_step_0 = activations - mu
  shape_normalized_activations = tf.shape(normalized_activations_step_0)

  def iter_body(i, normalized_activations):
    logt_partition = tf.reduce_sum(
        exp_t(normalized_activations, t), -1, keep_dims=True)
    normalized_activations_t = tf.reshape(
        normalized_activations_step_0 * tf.pow(logt_partition, 1.0 - t),
        shape_normalized_activations)
    return [i + 1, normalized_activations_t]

  _, normalized_activations_t = for_loop(num_iters, iter_body,
                                         [0, normalized_activations_step_0])
  logt_partition = tf.reduce_sum(
      exp_t(normalized_activations_t, t), -1, keep_dims=True)
  return -log_t(1.0 / logt_partition, t) + mu


def compute_normalization_binary_search(activations, t, num_iters=10):
  """Returns the normalization value for each example (t < 1.0).
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    t: Temperature 2 (< 1.0 for finite support).
    num_iters: Number of iterations to run the method.
  Return: A tensor of same rank as activation with the last dimension being 1.
  """
  mu = tf.reduce_max(activations, -1, keep_dims=True)
  normalized_activations = activations - mu
  shape_activations = tf.shape(activations)
  effective_dim = tf.cast(
      tf.reduce_sum(
          tf.cast(
              tf.greater(normalized_activations, -1.0 / (1.0 - t)), tf.int32),
          -1,
          keep_dims=True), tf.float32)
  shape_partition = tf.concat([shape_activations[:-1], [1]], 0)
  lower = tf.zeros(shape_partition)
  upper = -log_t(1.0 / effective_dim, t) * tf.ones(shape_partition)

  def iter_body(i, lower, upper):
    logt_partition = (upper + lower)/2.0
    sum_probs = tf.reduce_sum(exp_t(
        normalized_activations - logt_partition, t), -1, keep_dims=True)
    update = tf.cast(tf.less(sum_probs, 1.0), tf.float32)
    lower = tf.reshape(lower * update + (1.0 - update) * logt_partition,
                       shape_partition)
    upper = tf.reshape(upper * (1.0 - update) + update * logt_partition,
                       shape_partition)
    return [i + 1, lower, upper]

  _, lower, upper = for_loop(num_iters, iter_body, [0, lower, upper])
  logt_partition = (upper + lower)/2.0
  return logt_partition + mu


def compute_normalization(activations, t, num_iters=5):
  """Returns the normalization value for each example.
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    t: Temperature 2 (< 1.0 for finite support, > 1.0 for tail heaviness).
    num_iters: Number of iterations to run the method.
  Return: A tensor of same rank as activation with the last dimension being 1.
  """
  return tf.cond(
      tf.less(t, 1.0),
      functools.partial(compute_normalization_binary_search, activations, t,
                        num_iters),
      functools.partial(compute_normalization_fixed_point, activations, t,
                        num_iters))


def _internal_bi_tempered_logistic_loss(activations, labels, t1, t2):
  """Computes the Bi-Tempered logistic loss.
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    labels: batch_size
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail heaviness).
  Returns:
    A loss tensor for robust loss.
  """
  if t2 == 1.0:
    normalization_constants = tf.log(
        tf.reduce_sum(tf.exp(activations), -1, keep_dims=True))
    if t1 == 1.0:
      return normalization_constants + tf.reduce_sum(
          tf.multiply(labels, tf.log(labels + 1e-10) - activations), -1)
    else:
      shifted_activations = tf.exp(activations - normalization_constants)
      one_minus_t1 = (1.0 - t1)
      one_minus_t2 = 1.0
  else:
    one_minus_t1 = (1.0 - t1)
    one_minus_t2 = (1.0 - t2)
    normalization_constants = compute_normalization(
        activations, t2, num_iters=5)
    shifted_activations = tf.nn.relu(1.0 + one_minus_t2 *
                                     (activations - normalization_constants))

  if t1 == 1.0:
    return tf.reduce_sum(
        tf.multiply(
            tf.log(labels + 1e-10) -
            tf.log(tf.pow(shifted_activations, 1.0 / one_minus_t2)), labels),
        -1)
  else:
    beta = 1.0 + one_minus_t1
    logt_probs = (tf.pow(shifted_activations, one_minus_t1 / one_minus_t2) -
                  1.0) / one_minus_t1
    return tf.reduce_sum(
        tf.multiply(log_t(labels, t1) - logt_probs, labels) - 1.0 / beta *
        (tf.pow(labels, beta) -
         tf.pow(shifted_activations, beta / one_minus_t2)), -1)


def tempered_sigmoid(activations, t, num_iters=5):
  """Tempered sigmoid function.
  Args:
    activations: Activations for the positive class for binary classification.
    t: Temperature tensor > 0.0.
    num_iters: Number of iterations to run the method.
  Returns:
    A probabilities tensor.
  """
  t = tf.convert_to_tensor(t)
  input_shape = tf.shape(activations)
  activations_2d = tf.reshape(activations, [-1, 1])
  internal_activations = tf.concat(
      [tf.zeros_like(activations_2d), activations_2d], 1)
  normalization_constants = tf.cond(
      # pylint: disable=g-long-lambda
      tf.equal(t, 1.0),
      lambda: tf.log(
          tf.reduce_sum(tf.exp(internal_activations), -1, keep_dims=True)),
      functools.partial(compute_normalization, internal_activations, t,
                        num_iters))
  internal_probabilities = exp_t(internal_activations - normalization_constants,
                                 t)
  one_class_probabilities = tf.split(internal_probabilities, 2, axis=1)[1]
  return tf.reshape(one_class_probabilities, input_shape)


def tempered_softmax(activations, t, num_iters=5):
  """Tempered softmax function.
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    t: Temperature tensor > 0.0.
    num_iters: Number of iterations to run the method.
  Returns:
    A probabilities tensor.
  """
  t = tf.convert_to_tensor(t)
  normalization_constants = tf.cond(
      tf.equal(t, 1.0),
      lambda: tf.log(tf.reduce_sum(tf.exp(activations), -1, keep_dims=True)),
      functools.partial(compute_normalization, activations, t, num_iters))
  return exp_t(activations - normalization_constants, t)


def bi_tempered_binary_logistic_loss(activations,
                                     labels,
                                     t1,
                                     t2,
                                     label_smoothing=0.0,
                                     num_iters=5):
  """Bi-Tempered binary logistic loss.
  Args:
    activations: A tensor containing activations for class 1.
    labels: A tensor with shape and dtype as activations.
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
    label_smoothing: Label smoothing
    num_iters: Number of iterations to run the method.
  Returns:
    A loss tensor.
  """
  with tf.name_scope('binary_bitempered_logistic'):
    t1 = tf.convert_to_tensor(t1)
    t2 = tf.convert_to_tensor(t2)
    out_shape = tf.shape(labels)
    labels_2d = tf.reshape(labels, [-1, 1])
    activations_2d = tf.reshape(activations, [-1, 1])
    internal_labels = tf.concat([1.0 - labels_2d, labels_2d], 1)
    internal_logits = tf.concat([tf.zeros_like(activations_2d), activations_2d],
                                1)
    losses = bi_tempered_logistic_loss(internal_logits, internal_labels, t1, t2,
                                       label_smoothing, num_iters)
    return tf.reshape(losses, out_shape)


def bi_tempered_logistic_loss(activations,
                              labels,
                              t1,
                              t2,
                              label_smoothing=0.0,
                              num_iters=5):
  """Bi-Tempered Logistic Loss with custom gradient.
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    labels: A tensor with shape and dtype as activations.
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
    label_smoothing: Label smoothing parameter between [0, 1).
    num_iters: Number of iterations to run the method.
  Returns:
    A loss tensor.
  """
  with tf.name_scope('bitempered_logistic'):
    t1 = tf.convert_to_tensor(t1)
    t2 = tf.convert_to_tensor(t2)
    if label_smoothing > 0.0:
      num_classes = tf.cast(tf.shape(labels)[-1], tf.float32)
      labels = (
          1 - num_classes /
          (num_classes - 1) * label_smoothing) * labels + label_smoothing / (
              num_classes - 1)

    @tf.custom_gradient
    def _custom_gradient_bi_tempered_logistic_loss(activations):
      """Bi-Tempered Logistic Loss with custom gradient.
      Args:
        activations: A multi-dimensional tensor with last dim `num_classes`.
      Returns:
        A loss tensor, grad.
      """
      with tf.name_scope('gradient_bitempered_logistic'):
        probabilities = tempered_softmax(activations, t2, num_iters)
        loss_values = tf.multiply(
            labels,
            log_t(labels + 1e-10, t1) -
            log_t(probabilities, t1)) - 1.0 / (2.0 - t1) * (
                tf.pow(labels, 2.0 - t1) - tf.pow(probabilities, 2.0 - t1))

        def grad(d_loss):
          """Explicit gradient calculation.
          Args:
            d_loss: Infinitesimal change in the loss value.
          Returns:
            Loss gradient.
          """
          delta_probs = probabilities - labels
          forget_factor = tf.pow(probabilities, t2 - t1)
          delta_probs_times_forget_factor = tf.multiply(delta_probs,
                                                        forget_factor)
          delta_forget_sum = tf.reduce_sum(
              delta_probs_times_forget_factor, -1, keep_dims=True)
          escorts = tf.pow(probabilities, t2)
          escorts = escorts / tf.reduce_sum(escorts, -1, keep_dims=True)
          derivative = delta_probs_times_forget_factor - tf.multiply(
              escorts, delta_forget_sum)
          return tf.multiply(d_loss, derivative)

        return loss_values, grad

    loss_values = tf.cond(tf.logical_and(tf.equal(t1, 1.0), tf.equal(t2, 1.0)),
                          functools.partial(
                              tf.nn.softmax_cross_entropy_with_logits,
                              labels=labels,
                              logits=activations),
                          functools.partial(
                              _custom_gradient_bi_tempered_logistic_loss,
                              activations))
    reduce_sum_last = lambda x: tf.reduce_sum(x, -1)
    loss_values = tf.cond(tf.logical_and(tf.equal(t1, 1.0), tf.equal(t2, 1.0)),
                          functools.partial(tf.identity, loss_values),
                          functools.partial(reduce_sum_last, loss_values))
    return loss_values


def sparse_bi_tempered_logistic_loss(activations, labels, t1, t2, num_iters=5):
  """Sparse Bi-Tempered Logistic Loss with custom gradient.
  Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    labels: A tensor with dtype of int32.
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
    num_iters: Number of iterations to run the method.
  Returns:
    A loss tensor.
  """
  with tf.name_scope('sparse_bitempered_logistic'):
    t1 = tf.convert_to_tensor(t1)
    t2 = tf.convert_to_tensor(t2)
    num_classes = tf.shape(activations)[-1]

    @tf.custom_gradient
    def _custom_gradient_sparse_bi_tempered_logistic_loss(activations):
      """Sparse Bi-Tempered Logistic Loss with custom gradient.
      Args:
        activations: A multi-dimensional tensor with last dim `num_classes`.
      Returns:
        A loss tensor, grad.
      """
      with tf.name_scope('gradient_sparse_bitempered_logistic'):
        probabilities = tempered_softmax(activations, t2, num_iters)
        # TODO(eamid): Replace one hot with gather.
        loss_values = -log_t(
            tf.reshape(
                tf.gather_nd(probabilities,
                             tf.where(tf.one_hot(labels, num_classes))),
                tf.shape(activations)[:-1]), t1) - 1.0 / (2.0 - t1) * (
                    1.0 - tf.reduce_sum(tf.pow(probabilities, 2.0 - t1), -1))

        def grad(d_loss):
          """Explicit gradient calculation.
          Args:
            d_loss: Infinitesimal change in the loss value.
          Returns:
            Loss gradient.
          """
          delta_probs = probabilities - tf.one_hot(labels, num_classes)
          forget_factor = tf.pow(probabilities, t2 - t1)
          delta_probs_times_forget_factor = tf.multiply(delta_probs,
                                                        forget_factor)
          delta_forget_sum = tf.reduce_sum(
              delta_probs_times_forget_factor, -1, keep_dims=True)
          escorts = tf.pow(probabilities, t2)
          escorts = escorts / tf.reduce_sum(escorts, -1, keep_dims=True)
          derivative = delta_probs_times_forget_factor - tf.multiply(
              escorts, delta_forget_sum)
          return tf.multiply(d_loss, derivative)

        return loss_values, grad

    loss_values = tf.cond(
        tf.logical_and(tf.equal(t1, 1.0), tf.equal(t2, 1.0)),
        functools.partial(tf.nn.sparse_softmax_cross_entropy_with_logits,
                          labels=labels, logits=activations),
        functools.partial(_custom_gradient_sparse_bi_tempered_logistic_loss,
                          activations))
    return loss_values

# Compile Model

In [None]:
ensembled_model.compile(
    optimizer=tf.keras.optimizers.Adam(lr=0.001), 
    loss='sparse_categorical_crossentropy', 
    metrics=["sparse_categorical_accuracy"])

In [None]:
ensembled_model.summary()

# Training

In [None]:
train_dataset = get_training_dataset(TRAINING_FILENAMES)
val_dataset =get_validation_dataset(VALID_FILENAMES)

In [None]:
EPOCHS = 10

STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
VALID_STEPS = NUM_VALIDATION_IMAGES // BATCH_SIZE

earlystopping = tf.keras.callbacks.EarlyStopping( patience=3)

# Training Fit model

In [None]:
history = ensembled_model.fit(train_dataset,
                    steps_per_epoch=STEPS_PER_EPOCH,
                    epochs=EPOCHS, 
                    validation_data=val_dataset,
                    validation_steps=VALID_STEPS,callbacks = [earlystopping])

In [None]:
ensembled_model.save("./ensembled_model.h5")

# Training Plot

In [None]:
history_frame = pd.DataFrame(history.history)
history_frame.loc[:, ['accuracy', 'val_accuracy']].plot();

# Prediction

In [None]:
test_ds = get_test_dataset(TEST_FILENAMES)
test_images_ds = test_ds.map(lambda image, idnum: image)

In [None]:
ensemble_predictions = ensembled_model.predict(test_images_ds)

In [None]:
print('Calculating predictions...')

predictions = np.argmax(ensemble_predictions, axis=-1)

# Submission

In [None]:
print('Generating submission file...')
test_ids_ds = test_ds.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids_ds.batch(NUM_TEST_IMAGES))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='image_id,label', comments='')


In [None]:
!head submission.csv