In [None]:
import tensorflow as tf

In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
pip install patchify

In [None]:
# Import standard dependencies
import random
import math
import os
import math
import numpy as np
import cv2
from collections import Counter
import re
from patchify import patchify
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.model_selection import StratifiedKFold, KFold, train_test_split
import warnings
warnings.filterwarnings('ignore')


# image preprocessing
from tensorflow.keras.preprocessing.image import img_to_array, ImageDataGenerator, load_img

# to build model
from keras.models import Model, Sequential
from tensorflow.keras.layers import Resizing, Rescaling, Input, Conv2D, Dense, GlobalAveragePooling2D, Dropout, BatchNormalization, Lambda, Layer, Flatten
from tensorflow.keras.activations import softmax

# cost function / optimizer
from tensorflow.keras.optimizers import SGD, Adamax
from sklearn.metrics import confusion_matrix

In [None]:
SEED = 178
tf.keras.utils.set_random_seed(SEED)
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

In [None]:
MAIN_DIR = "/kaggle/input/bach-breast-cancer-histology-images/ICIAR2018_BACH_Challenge/ICIAR2018_BACH_Challenge/Photos"
OUTPUT_DIR = "/kaggle/working/processed_images_train/"  # Directory for saving processed images
os.makedirs(OUTPUT_DIR, exist_ok=True)

num_classes = 4
one_hot = True
img_size = (224, 224)
# Define the number of images per category for test
images_per_category = 10

Categories = ["Benign", "InSitu", "Invasive", "Normal"]
if one_hot:
    encoding = {'Normal': [1,0,0,0], 'Benign': [0,1,0,0], 'InSitu': [0,0,1,0], 'Invasive': [0,0,0,1]}
else:    
    encoding = {'Normal': 0, 'Benign': 1, 'InSitu': 2, 'Invasive': 3}


# Lists to store paths of original images and labels
original_images_paths = []
original_images_labels = []

# Lists to store paths of eval images and labels
eval_images_paths = []
eval_images_labels = []

# Lists to store paths of processed images and labelsd
processed_images_paths = []
processed_images_labels = []

# Iterate over each category and image
for category in Categories:
    category_dir = os.path.join(MAIN_DIR, category)
    number_of_test = images_per_category
    for image_name in os.listdir(category_dir):
        if image_name.endswith(".tif"):
            image_path = os.path.join(category_dir, image_name)
            label = encoding[category]
            
            original_images_paths.append(image_path)
            original_images_labels.append(label)
            
            if number_of_test:
                number_of_test = number_of_test - 1
                eval_images_paths.append(image_path)
                eval_images_labels.append(label)
            else:
                # Open and process the image using OpenCV
                img = cv2.imread(image_path)
                patches = patchify(img,patch_size=(1400,1400,3), step=92)
                for i in range(patches.shape[0]):
                    for j in range(patches.shape[1]):
                        img_resized = cv2.resize(patches[i, j, 0, :, :, :], img_size, interpolation=cv2.INTER_LANCZOS4)
                        filenames = [f'{OUTPUT_DIR}{category}_{image_name[:-4]}_{i:02d}_{j:02d}_{k}.tif' for k in range(6)]
                        processed_images_paths=processed_images_paths+filenames
                        cv2.imwrite(filenames[0], img_resized, [cv2.IMWRITE_TIFF_COMPRESSION, 1]) # orignal
                        cv2.imwrite(filenames[1], cv2.flip(img_resized,0), [cv2.IMWRITE_TIFF_COMPRESSION, 1]) # flip up down
                        cv2.imwrite(filenames[2], cv2.flip(img_resized,1), [cv2.IMWRITE_TIFF_COMPRESSION, 1]) # flip left right
                        cv2.imwrite(filenames[3], cv2.rotate(img_resized,0), [cv2.IMWRITE_TIFF_COMPRESSION, 1]) # rotate +90
                        cv2.imwrite(filenames[4], cv2.rotate(img_resized,1), [cv2.IMWRITE_TIFF_COMPRESSION, 1]) # rotate +180
                        cv2.imwrite(filenames[5], cv2.rotate(img_resized,2), [cv2.IMWRITE_TIFF_COMPRESSION, 1]) # rotate +270
                        processed_images_labels=processed_images_labels+[label]*6

# Converting to NumPy arrays for cross-validation
processed_images_paths = np.array(processed_images_paths)
processed_images_labels = np.array(processed_images_labels)
original_images_paths = np.array(original_images_paths)
original_images_labels = np.array(original_images_labels)
eval_images_paths = np.array(eval_images_paths)
eval_images_labels = np.array(eval_images_labels)
# Display the count of processed and original images and labels
print("Number of processed images:", len(processed_images_paths))
print("Number of processed labels:", len(processed_images_labels))
print("Number of original images:", len(original_images_paths))
print("Number of original labels:", len(original_images_labels))
print("Number of test images:", len(eval_images_paths))
print("Number of test labels:", len(eval_images_labels))

In [None]:
TEST_DIR = "/kaggle/input/bach-breast-cancer-histology-images/ICIAR2018_BACH_Challenge_TestDataset/ICIAR2018_BACH_Challenge_TestDataset/Photos"
OUTPUT_DIR = "/kaggle/working/processed_images_test/"  # Directory for saving processed images
os.makedirs(OUTPUT_DIR, exist_ok=True)

img_size_test = (224, 224)

# Lists to store paths of processed images and labels
test_images_paths = []
test_images_numbers = []

test_patch_images_paths = []
test_patch_images_numbers = []
for image_name in os.listdir(TEST_DIR):
    if image_name.endswith(".tif"):
        image_path = os.path.join(TEST_DIR, image_name)
        test_images_paths.append(image_path)  
        match = re.search(r'test(\d+)\.tif', image_name)
        image_number = int(match.group(1))  # Extracted number as integer
        test_images_numbers.append(image_number)
        
        img = cv2.imread(image_path)
        patches = patchify(img,patch_size=(1400,1400,3), step=92)
        for i in range(patches.shape[0]):
            for j in range(patches.shape[1]):
                img_resized = cv2.resize(patches[i, j, 0, :, :, :], img_size_test, interpolation=cv2.INTER_LANCZOS4)
                img_name = f'{OUTPUT_DIR}{image_name[:-4]}_{i:02d}_{j:02d}.tif'
                test_patch_images_paths.append(img_name)
                cv2.imwrite(img_name, img_resized, [cv2.IMWRITE_TIFF_COMPRESSION, 1]) # orignal
                test_patch_images_numbers.append(image_number)

print("Number of test images:", len(test_images_paths))
print("Number of test ids:", len(test_images_numbers))

print("Number of processed test images:", len(test_patch_images_paths))
print("Number of processed test ids:", len(test_patch_images_numbers))

In [None]:
imgShape = (224, 224, 3)

In [None]:
def preprocess_image(img_path, label=None, resize=False, normalize=False):
    def load_and_preprocess_image(img_path):
        # Convert the img_path from tensor to string
        img_path = img_path.numpy().decode('UTF-8')
        # Read the image file using OpenCV
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        # Convert from BGR to RGB (OpenCV default is BGR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # Optionally resize the image
        if resize:
            img = cv2.resize(img, (imgShape[1], imgShape[0]))

        # Optionally normalize the image to the range [-1, 1]
        if normalize:
            img = img / 127.5 - 1.0

        return img.astype(np.float32)

    # Use tf.py_function to wrap the Python function
    img = tf.py_function(load_and_preprocess_image, [img_path], tf.float32)
    # Set the shape of the image after modification by the tf.py_function
    img.set_shape(imgShape)

    # Return img and label if label is provided
    return (img, tf.cast(label,tf.float32)) if label is not None else img

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

In [None]:
orig_ds = tf.data.Dataset.from_tensor_slices((original_images_paths, original_images_labels))

eval_ds = None 
if images_per_category:
    eval_ds = tf.data.Dataset.from_tensor_slices((eval_images_paths, eval_images_labels))
    eval_ds = eval_ds.map(lambda img, label: preprocess_image(img, label, resize=True), num_parallel_calls=AUTOTUNE)

test_ds = tf.data.Dataset.from_tensor_slices((test_images_paths, test_images_numbers))
patch_test_ds = tf.data.Dataset.from_tensor_slices((test_patch_images_paths, test_patch_images_numbers))

# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
orig_ds = orig_ds.map(lambda img, label: preprocess_image(img, label, resize=True), num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(lambda img, number: preprocess_image(img, number, resize=True), num_parallel_calls=AUTOTUNE)
patch_test_ds = patch_test_ds.map(lambda img, number: preprocess_image(img, number), num_parallel_calls=AUTOTUNE)

In [None]:
class TrivialAugmentWide(tf.keras.layers.Layer):
    def __init__(self, num_magnitude_bins=31, interpolation='nearest', fill=None, exclude_ops=None, **kwargs):
        super().__init__(**kwargs)
        self.num_magnitude_bins = num_magnitude_bins
        self.interpolation = interpolation
        self.fill = fill
        self.exclude_ops = exclude_ops if exclude_ops else []
        self.ops_space = self._augmentation_space(num_magnitude_bins)
        self.AFFINE_TRANSFORM_INTERPOLATIONS = ("nearest","bilinear")
        self.AFFINE_TRANSFORM_FILL_MODES = ("constant","nearest","wrap","reflect")

        # Filter out excluded operations and create TensorFlow-compatible structures
        self.op_names = tf.constant([name for name in self.ops_space.keys() if name not in self.exclude_ops])
        
        # Ensure all magnitudes are cast to float32 and padded to match the shape [num_magnitude_bins]
        def pad_magnitude(magnitude):
            if magnitude.shape.rank == 0:  # Scalar tensor
                return tf.fill([self.num_magnitude_bins], magnitude)
            return magnitude

        self.op_magnitudes = tf.stack([
            pad_magnitude(tf.cast(params[0], tf.float32))
            for name, params in self.ops_space.items() if name not in self.exclude_ops
        ])
        
        # Signed flags remain as a boolean tensor
        self.op_signed = tf.constant([params[1] for name, params in self.ops_space.items() if name not in self.exclude_ops])

    def _augmentation_space(self, num_bins):
        return {
            "Identity": (tf.constant(0.0), False),
            "ShearX": (tf.linspace(0.0, 0.99, num_bins), True),
            "ShearY": (tf.linspace(0.0, 0.99, num_bins), True),
            "TranslateX": (tf.linspace(0.0, 32.0, num_bins), True),
            "TranslateY": (tf.linspace(0.0, 32.0, num_bins), True),
            "Rotate": (tf.linspace(0.0, 135.0, num_bins), True),
            "Brightness": (tf.linspace(0.0, 0.99, num_bins), True),
            "Color": (tf.linspace(0.0, 0.99, num_bins), True),
            "Contrast": (tf.linspace(0.0, 0.99, num_bins), True),
            "Sharpness": (tf.linspace(0.0, 0.99, num_bins), True),
            "Posterize": ( tf.cast(8 - tf.round(tf.cast(tf.range(num_bins), tf.float32) / ((num_bins - 1) / 6)), tf.uint8), False, ),
            "Solarize": (tf.linspace(255.0, 0.0, num_bins), False),
            "AutoContrast": (tf.constant(0.0), False),
            "Equalize": (tf.constant(0.0), False),
        }

    def _augment_one(self, img):
        # Ensure input is in [0, 255] range and uint8 type
        input_image_type = img.dtype
        if img.dtype != tf.uint8:
            img = tf.cast(tf.clip_by_value(img, 0, 255), tf.uint8)
        
        # Use a default fill value if self.fill is None
        fill = self.fill if self.fill is not None else 0.0
        
        # Randomly select an operation
        op_index = tf.random.uniform(shape=(), maxval=tf.shape(self.op_names)[0], dtype=tf.int32)
        op_name = tf.gather(self.op_names, op_index)
        magnitudes = tf.gather(self.op_magnitudes, op_index)
        signed = tf.gather(self.op_signed, op_index)
        
        # Randomly select a magnitude using tf.cond to handle symbolic rank
        magnitude = tf.cond(
            tf.greater(tf.rank(magnitudes), 0),
            lambda: tf.gather(magnitudes, tf.random.uniform(shape=(), maxval=tf.shape(magnitudes)[0], dtype=tf.int32)),
            lambda: tf.constant(0.0, dtype=tf.float32)
        )
        
        # Apply random sign if required
        magnitude = tf.cond(
            signed,
            lambda: magnitude * tf.cond(tf.random.uniform(shape=()) > 0.5, lambda: 1.0, lambda: -1.0),
            lambda: magnitude
        )
        
        # Apply the chosen operation
        img = self._apply_op(img, op_name, magnitude, fill=fill)
        
        # Final clip and cast back to uint8
        return tf.cast(img, dtype=input_image_type)

    @tf.function
    def call(self, img):
        # Check if the rank is statically known
        if img.shape.rank is not None:
            if img.shape.rank == 4:
                return tf.map_fn(self._augment_one, img)
            else:
                return self._augment_one(img)
        else:
            # When rank is unknown, use tf.cond with a symbolic comparison.
            return tf.cond(
                tf.equal(tf.rank(img), 4),
                lambda: tf.map_fn(self._augment_one, img),
                lambda: self._augment_one(img)
            )

    def _apply_op(self, img, op_name, magnitude, fill):
        # Define a mapping from operation names to their corresponding functions
        def apply_shear_x(): return self.shear_x(img, magnitude)
        def apply_shear_y(): return self.shear_y(img, magnitude)
        def apply_translate_x(): return self.translate_x(img, magnitude)
        def apply_translate_y(): return self.translate_y(img, magnitude)
        def apply_rotate(): return self.rotate(img, magnitude)
        def apply_brightness(): return self.adjust_brightness(img, magnitude + 1.0)
        def apply_color(): return self.adjust_saturation(img, magnitude + 1.0)
        def apply_contrast(): return self.adjust_contrast(img, magnitude + 1.0)
        def apply_sharpness(): return self.adjust_sharpness(img, magnitude + 1.0)
        def apply_posterize(): return self.posterize(img, magnitude)
        def apply_solarize(): return self.solarize(img, magnitude)
        def apply_autocontrast(): return self.autocontrast(img)
        def apply_equalize(): return self.equalize(img)
        def apply_identity(): return img

        # Use tf.case to select the appropriate operation based on op_name
        img = tf.switch_case(
            branch_index=tf.cast(tf.argmax(tf.equal(op_name, tf.constant([
                "ShearX", "ShearY", "TranslateX", "TranslateY", "Rotate",
                "Brightness", "Color", "Contrast", "Sharpness", "Posterize",
                "Solarize", "AutoContrast", "Equalize", "Identity"
            ]))), tf.int32),
            branch_fns={
                0: apply_shear_x,
                1: apply_shear_y,
                2: apply_translate_x,
                3: apply_translate_y,
                4: apply_rotate,
                5: apply_brightness,
                6: apply_color,
                7: apply_contrast,
                8: apply_sharpness,
                9: apply_posterize,
                10: apply_solarize,
                11: apply_autocontrast,
                12: apply_equalize,
                13: apply_identity
            },
            default=apply_identity
        )
        return img

    def blend(self, image1, image2, factor):
        image1 = tf.cast(image1, tf.float32)
        image2 = tf.cast(image2, tf.float32)
        factor = tf.cast(factor, tf.float32)

        difference = image2 - image1
        scaled = factor * difference

        # Do addition in float.
        temp = image1 + scaled
        return tf.cast(tf.clip_by_value(temp, 0.0, 255.0), tf.uint8)

    # Affine Transformation Helper
    def _affine_transform(self, img, transform_matrix):
        a0, a1, a2 = transform_matrix[0], transform_matrix[1], transform_matrix[2]
        b0, b1, b2 = transform_matrix[3], transform_matrix[4], transform_matrix[5]

        # Create the 8-parameter transform tensor
        transforms = tf.stack([a0, a1, a2, b0, b1, b2, 0.0, 0.0])
        transforms = tf.reshape(transforms, [1, 8])  # Shape [1, 8]

        # Convert image to float32 and add batch dimension
        img_float = tf.cast(img, tf.float32)
        img_batched = tf.expand_dims(img_float, axis=0)

        # Apply affine transform using Keras ops
        transformed = tf.keras.ops.image.affine_transform(
            img_batched,
            transforms,
            interpolation='bilinear',
            fill_mode='constant',
            fill_value=self.fill if self.fill is not None else 0.0
        )
        # Remove batch dimension and convert back to uint8
        transformed = tf.squeeze(transformed, axis=0)
        return tf.cast(tf.clip_by_value(transformed, 0, 255), tf.uint8)

    def shear_x(self, img, magnitude):
        transform_matrix = tf.stack([
            1.0, magnitude, 0.0,  # a0, a1, a2
            0.0, 1.0, 0.0        # b0, b1, b2
        ])
        return self._affine_transform(img, transform_matrix)

    def shear_y(self, img, magnitude):
        transform_matrix = tf.stack([
            1.0, 0.0, 0.0,       # a0, a1, a2
            magnitude, 1.0, 0.0   # b0, b1, b2
        ])
        return self._affine_transform(img, transform_matrix)

    def translate_x(self, img, magnitude):
        transform_matrix = tf.stack([
            1.0, 0.0, magnitude,  # a0, a1, a2
            0.0, 1.0, 0.0         # b0, b1, b2
        ])
        return self._affine_transform(img, transform_matrix)

    def translate_y(self, img, magnitude):
        transform_matrix = tf.stack([
            1.0, 0.0, 0.0,       # a0, a1, a2
            0.0, 1.0, magnitude  # b0, b1, b2
        ])
        return self._affine_transform(img, transform_matrix)

    def rotate(self, img, angle):
        angle_rad = -math.pi * angle / 180.0
        cos_a = tf.cos(angle_rad)
        sin_a = tf.sin(angle_rad)

        # Center of rotation
        h, w = tf.shape(img)[0], tf.shape(img)[1]
        cx, cy = tf.cast(w, tf.float32) / 2.0, tf.cast(h, tf.float32) / 2.0

        transform_matrix = tf.stack([
            cos_a, -sin_a, (1 - cos_a) * cx + sin_a * cy,  # a0, a1, a2
            sin_a, cos_a, -sin_a * cx + (1 - cos_a) * cy   # b0, b1, b2
        ])
        return self._affine_transform(img, transform_matrix)

    # Adjust Brightness
    def adjust_brightness(self, img, brightness_factor):
        degenerate = tf.zeros_like(img)
        return self.blend(degenerate, img, brightness_factor)

    # Adjust Contrast
    def adjust_contrast(self, img, contrast_factor):
        # degenerate = tf.image.rgb_to_grayscale(img)
        # # Cast before calling tf.histogram.
        # degenerate = tf.cast(degenerate, tf.int32)

        # # Compute the grayscale histogram, then compute the mean pixel value,
        # # and create a constant image size of that value.  Use that as the
        # # blending degenerate target of the original image.
        # hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256)
        # mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0
        # degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean
        # degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
        # degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8))
        degenerate = tf.math.reduce_mean(tf.image.rgb_to_grayscale(img), axis=[-3, -2, -1], keepdims=True)
        return self.blend(degenerate, img, contrast_factor)

    # Adjust Saturation
    def adjust_saturation(self, img, saturation_factor):
        degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(img))
        return self.blend(degenerate, img, saturation_factor)

    # Posterize
    def posterize(self, img, bits):
        bits = tf.cast(bits, img.dtype)
        shift = 8 - bits
        return tf.bitwise.left_shift(tf.bitwise.right_shift(img, shift), shift)

    # Solarize
    def solarize(self, img, threshold):
        threshold = tf.cast(threshold, img.dtype)
        return tf.where(img >= threshold, 255 - img, img)

    # Autocontrast
    def autocontrast(self, img):
        def scale_channel(image: tf.Tensor) -> tf.Tensor:
          """Scale the 2D image using the autocontrast rule."""
          # A possibly cheaper version can be done using cumsum/unique_with_counts
          # over the histogram values, rather than iterating over the entire image.
          # to compute mins and maxes.
          lo = tf.cast(tf.reduce_min(image), tf.float32)
          hi = tf.cast(tf.reduce_max(image), tf.float32)

          # Scale the image, making the lowest value 0 and the highest value 255.
          def scale_values(im):
            scale = 255.0 / (hi - lo)
            offset = -lo * scale
            im = tf.cast(im, tf.float32) * scale + offset
            im = tf.clip_by_value(im, 0.0, 255.0)
            return tf.cast(im, tf.uint8)

          result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image)
          return result

        # Assumes RGB for now.  Scales each channel independently
        # and then stacks the result.
        s1 = scale_channel(img[..., 0])
        s2 = scale_channel(img[..., 1])
        s3 = scale_channel(img[..., 2])
        img = tf.stack([s1, s2, s3], -1)

        return img

    # Equalize
    def equalize(self, image):
        """Implements Equalize function from PIL using TF ops."""

        def scale_channel(im, c):
          """Scale the data in the channel to implement equalize."""
          im = tf.cast(im[..., c], tf.int32)
          # Compute the histogram of the image channel.
          histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)

          # For the purposes of computing the step, filter out the nonzeros.
          nonzero = tf.where(tf.not_equal(histo, 0))
          nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
          step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255

          def build_lut(histo, step):
            # Compute the cumulative sum, shifting by step // 2
            # and then normalization by step.
            lut = (tf.cumsum(histo) + (step // 2)) // step
            # Shift lut, prepending with 0.
            lut = tf.concat([[0], lut[:-1]], 0)
            # Clip the counts to be in range.  This is done
            # in the C code for image.point.
            return tf.clip_by_value(lut, 0, 255)

          # If step is zero, return the original image.  Otherwise, build
          # lut from the full histogram and step and then index from it.
          result = tf.cond(
              tf.equal(step, 0), lambda: im,
              lambda: tf.gather(build_lut(histo, step), im))

          return tf.cast(result, tf.uint8)

        # Assumes RGB for now.  Scales each channel independently
        # and then stacks the result.
        s1 = scale_channel(image, 0)
        s2 = scale_channel(image, 1)
        s3 = scale_channel(image, 2)
        image = tf.stack([s1, s2, s3], -1)
        return image

    def adjust_sharpness(self, image, factor):
        orig_image = image
        image = tf.cast(image, tf.float32)
        # Make image 4D for conv operation.
        image = tf.expand_dims(image, 0)
        # SMOOTH PIL Kernel.
        kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
                              dtype=tf.float32,
                              shape=[3, 3, 1, 1]) / 13.
        # Tile across channel dimension.
        kernel = tf.tile(kernel, [1, 1, 3, 1])
        strides = [1, 1, 1, 1]
        degenerate = tf.nn.depthwise_conv2d(
            image, kernel, strides, padding='VALID', dilations=[1, 1])
        degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
        degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])

        # For the borders of the resulting image, fill in the values of the
        # original image.
        mask = tf.ones_like(degenerate)
        paddings = [[0, 0]] * (orig_image.shape.rank - 3)
        padded_mask = tf.pad(mask, paddings + [[1, 1], [1, 1], [0, 0]])
        padded_degenerate = tf.pad(degenerate, paddings + [[1, 1], [1, 1], [0, 0]])
        result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)

        # Blend the final result.
        return self.blend(result, orig_image, factor)

In [None]:
"""
To train a model with this dataset you will want the data:
  To be well shuffled.
  To be batched.
  Batches to be available as soon as possible.
"""
if tpu is not None:
    batch_size_per_replica = 64
    batch_size = batch_size_per_replica * strategy.num_replicas_in_sync
else: 
    batch_size = 8

TAW = TrivialAugmentWide(exclude_ops=['Rotate'])

def configure_for_performance(ds, shuffle=False, augment=False, drop_remainder=True):
    if shuffle:
        ds=ds.shuffle(10000)
    ds = ds.cache()
    ds = ds.batch(batch_size, drop_remainder=drop_remainder)
    
    if augment:
        @tf.function
        def augment_fn(imgs, labels):
            # Augment images
            aug_imgs = TAW(imgs)
            return aug_imgs, labels
        ds = ds.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE)
    
    ds = ds.prefetch(buffer_size=tf.data.AUTOTUNE)  # Prefetch next batch while training
    return ds

In [None]:
orig_ds = orig_ds.batch(batch_size, drop_remainder=False)

if images_per_category:
    eval_ds = eval_ds.batch(batch_size, drop_remainder=False)

test_ds = test_ds.batch(batch_size, drop_remainder=False)
patch_test_ds = patch_test_ds.batch(batch_size, drop_remainder=False)

In [None]:
from tensorflow.keras import ops

EPOCHS = 70
steps_per_epoch = int((len(processed_images_paths)*0.8) // batch_size)
validation_steps = int((len(processed_images_paths)*0.2) / batch_size)
total_steps = steps_per_epoch * EPOCHS
warmup_steps = int(0.1 * total_steps)
hold_steps = int(0.45 * total_steps)

def lr_warmup_cosine_decay(global_step, warmup_steps, hold=0, total_steps=0, start_lr=0.0, target_lr=1e-2):
    # Cosine decay
    learning_rate = (0.5*target_lr*(1+ops.cos(math.pi*ops.convert_to_tensor(global_step-warmup_steps-hold, dtype="float32")/ops.convert_to_tensor(total_steps-warmup_steps-hold, dtype="float32"))))

    warmup_lr = target_lr * (global_step / warmup_steps)

    if hold > 0:
        learning_rate = ops.where(global_step > warmup_steps + hold, learning_rate, target_lr)

    learning_rate = ops.where(global_step < warmup_steps, warmup_lr, learning_rate)
    return learning_rate

class WarmUpCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, warmup_steps, total_steps, hold, start_lr=0.0, target_lr=1e-2):
        super().__init__()
        self.start_lr = start_lr
        self.target_lr = target_lr
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.hold = hold

    def __call__(self, step):
        lr = lr_warmup_cosine_decay(
            global_step=step,
            total_steps=self.total_steps,
            warmup_steps=self.warmup_steps,
            start_lr=self.start_lr,
            target_lr=self.target_lr,
            hold=self.hold,
        )
        return ops.where(step > self.total_steps, 0.0, lr)


schedule = WarmUpCosineDecay(
    start_lr=0.001,
    target_lr=0.0001,
    warmup_steps=warmup_steps,
    total_steps=total_steps,
    hold=hold_steps,
)

In [None]:
def new_tta(images):
    """
    Apply Test Time Augmentation (TTA) on a batch of images.
    
    Args:
        images (tf.Tensor): Input images in channels-last format, shape [B, H, W, C].
        
    Returns:
        list of tf.Tensor: A list of augmented image batches, each with the same shape as the input.
                          The list contains the original, flipped left-right, and rotated (90°, 180°, 270°) versions.
    """
    ret = []
    
    # Add the original images
    ret.append(images)
    ret.append(images)
    
    # Flip left-right
    flipped_lr = tf.image.flip_left_right(images)
    ret.append(flipped_lr)

    # Flip up_down
    flipped_ud = tf.image.flip_up_down(images)
    ret.append(flipped_ud)
    
    # Rotate 90°
    rotated_90 = tf.image.rot90(images, k=1)
    ret.append(rotated_90)
    
    # Rotate 180°
    rotated_180 = tf.image.rot90(images, k=2)
    ret.append(rotated_180)
    
    # Rotate 270°
    rotated_270 = tf.image.rot90(images, k=3)
    ret.append(rotated_270)
    
    return ret

In [None]:
class GeM(tf.keras.layers.Layer):
    """
    Generalized Mean (GeM) pooling.
    """
    def __init__(self, p=3.0, eps=1e-6, **kwargs):
        """
        Args:
            p (float): The power to raise the inputs.
            eps (float): Small value to avoid numerical issues.
        """
        super(GeM, self).__init__(**kwargs)
        self.p = p
        self.eps = eps

    def call(self, inputs):
        # Clip values for numerical stability.
        x = tf.clip_by_value(inputs, self.eps, tf.reduce_max(inputs))
        x = tf.pow(x, self.p)
        # Global average pooling over the spatial dimensions.
        # Assumes inputs are in channels-last format: (batch, height, width, channels)
        x = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        x = tf.pow(x, 1.0 / self.p)
        return x

In [None]:
class ArcMarginProduct(Layer):
    def __init__(self, units=1000, **kwargs):
        super(ArcMarginProduct, self).__init__(**kwargs)
        self.units = units
        # self.kernel_regularizer
    
    def build(self, input_shape):
            self.w = self.add_weight(
                name="norm_dense_w",
                shape=(input_shape[-1], self.units),
                initializer=tf.keras.initializers.GlorotUniform(), # GlorotNormal()
                trainable=True
            )
            # super(ArcMarginProduct, self).build(input_shape)

    def call(self, inputs, **kwargs):
        norm_w = tf.nn.l2_normalize(self.w, axis=0, epsilon=1e-5) # each column is a weight vector
        norm_embedding = tf.nn.l2_normalize(inputs, axis=1, epsilon=1e-5)
        cos_theta = tf.linalg.matmul(norm_embedding, norm_w, name='cos_theta')
        return cos_theta

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.units)

In [None]:
class ArcfaceLoss(tf.keras.losses.Loss):
    def __init__(self, margin1=1.0, margin2=0.5, margin3=0.0, scale=64.0, smoothing=0.0, one_hot=True, **kwargs):
        super(ArcfaceLoss, self).__init__(**kwargs)
        self.margin1, self.margin2, self.margin3, self.scale = margin1, margin2, margin3, scale
        # self.threshold = np.cos((np.pi - margin2) / margin1)  # grad(theta) == 0
        # self.theta_margin_min = (-1 - margin3) * 2
        self.gamma = 0.2
        if one_hot:
            self.loss = tf.keras.losses.CategoricalCrossentropy(reduction="sum_over_batch_size", from_logits = True, label_smoothing=smoothing) # Linear activation input
        else:
            self.loss = tf.keras.losses.SparseCategoricalCrossentropy(reduction="sum_over_batch_size", from_logits = True) # Linear activation input
        
    def call(self, y_true, cos_theta):
        # y_true = tf.cast(y_true,tf.float32)
        theta = tf.math.acos(tf.keras.backend.clip(cos_theta, -1.0 + tf.keras.backend.epsilon(), 1.0 - tf.keras.backend.epsilon()))
        target_logits = tf.cos(theta * self.margin1 + self.margin2) - self.margin3
        logits = (cos_theta * (1.0 - y_true) + target_logits * y_true) * self.scale
        loss1 = self.loss(y_true, logits)
        # loss2 = self.loss(y_true, cos_theta)
        # loss = (loss1+self.gamma*loss2)/(1+self.gamma)
        return loss1

In [None]:
# Model Definition
class Model(tf.keras.Model):
    def __init__(self, backbone_name, num_classes, embedding_size=512):
        super(Model, self).__init__()
        self.backbone_name = backbone_name
        self.num_classes = num_classes
        self.embedding_size = embedding_size
        
        self.preprocess_fn = None
        self.resize_layer = None  # tf.keras.layers.Resizing(*(224, 224))
        # self.cust_norm = CustomNormalization(mean=[0.7165332113314071, 0.6166415441176472, 0.8421312542595161], std_dev=[0.16585887053987372, 0.18836123162530533, 0.11803739988823378])
        self.backbone = self._initialize_extractor() # preprocess_fn now has value

        self.pool = GeM()  # Default: p=3, eps=1e-6
        self.flatten = Flatten() # flatten for GeM
        self.fc = Dense(embedding_size, use_bias=False)
        self.bn = BatchNormalization()
        self.arcMargin = ArcMarginProduct(self.num_classes)

    def _initialize_extractor(self):
        if self.backbone_name == 'densenet':
            self.preprocess_fn = tf.keras.applications.densenet.preprocess_input
            model = tf.keras.applications.DenseNet169(include_top=False, weights="imagenet")
        elif self.backbone_name == 'resnet':
            self.preprocess_fn = tf.keras.applications.resnet50.preprocess_input
            model = tf.keras.applications.ResNet50V2(include_top=False, weights="imagenet")
        elif self.backbone_name == 'efficientnet':
            self.preprocess_fn = tf.keras.applications.efficientnet.preprocess_input
            model = tf.keras.applications.EfficientNetB2(include_top=False, weights="imagenet")
        elif self.backbone_name == 'efficientnetv2':
            self.preprocess_fn = tf.keras.applications.efficientnet_v2.preprocess_input
            model = tf.keras.applications.EfficientNetV2B1(include_top=False, weights="imagenet")
        elif self.backbone_name == 'xception':
            self.preprocess_fn = tf.keras.applications.xception.preprocess_input
            model = tf.keras.applications.Xception(include_top=False, weights="imagenet")
        elif self.backbone_name == 'inc_res':
            self.preprocess_fn = tf.keras.applications.inception_resnet_v2.preprocess_input
            model = tf.keras.applications.InceptionResNetV2(include_top=False, weights="imagenet")
        else:
            raise ValueError('Invalid backbone specified')

        # UnFreeze the base model layers
        for layer in model.layers:
            layer.trainable = True
            
        return model
    

    def call(self, inputs, training=False):
        """
        Forward pass through the model.
        
        Args:
            inputs: Input images.
            training (bool): Whether the model is in training mode.
        """
        # Optionally resize inputs if needed.
        if self.resize_layer is not None:
            inputs = self.resize_layer(inputs)
        # Apply the appropriate normalization
        inputs = self.preprocess_fn(inputs)
        # inputs = self.cust_norm(inputs)
        # Pass inputs through the backbone.
        x = self.backbone(inputs, training=training)
        # Apply GeM pooling.
        x = self.pool(x)
        # Flatten the pooled output.
        x = self.flatten(x)
        # Pass through the fully-connected layer and batch normalization.
        x = self.fc(x)
        x = self.bn(x, training=training)
        # Apply the ArcMarginProduct layer.
        x = self.arcMargin(x)
        return x

    # @tf.function
    # def train_step(self, data):
    #     x, y = data
        
    #     with tf.GradientTape() as tape:
    #         y_pred = self(x, training=True)
    #         loss = self.compute_loss(y=y, y_pred=y_pred)

    #     # Compute gradients
    #     if self.trainable_weights:
    #         trainable_weights = self.trainable_weights
    #         gradients = tape.gradient(loss, trainable_weights)

    #         # Update weights
    #         self.optimizer.apply_gradients(zip(gradients, trainable_weights))
    #     else:
    #         warnings.warn("The model does not have any trainable weights.")

    #     # Update the metrics.
    #     # Metrics are configured in `compile()`.
    #     for metric in self.metrics:
    #         if metric.name == "loss":
    #             metric.update_state(loss)
    #         else:
    #             metric.update_state(y, y_pred)
        
    #     # Return a dict mapping metric names to current value.
    #     # Note that it will include the loss (tracked in self.metrics).
    #     return {m.name: m.result() for m in self.metrics}
        
    # @tf.function
    # def test_step(self, data):
    #     x, y = data
    #     y_pred = self(x, training=False)
    #     # Updates the metrics tracking the loss
    #     loss = self.compute_loss(y=y, y_pred=y_pred)
    #     # Update all the metrics.
    #     for metric in self.metrics:
    #         if metric.name == "loss":
    #             metric.update_state(loss)
    #         else:
    #             metric.update_state(y, y_pred)
    #     # Return a dict mapping metric names to current value.
    #     # Note that it will include the loss (tracked in self.metrics).
    #     return {m.name: m.result() for m in self.metrics}

In [None]:
# Set up KFold cross-validation
n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)

In [None]:
for fold_no, (train_index, val_index) in enumerate(kf.split(processed_images_paths, np.argmax(processed_images_labels, axis=1) if one_hot else processed_images_labels)): # Convert one-hot encoded labels to integer class labels for SKfold
    print(f'Training on fold {fold_no}...')
    
    # Split the data into training and validation
    train_paths, val_paths = processed_images_paths[train_index], processed_images_paths[val_index]
    train_labels, val_labels = processed_images_labels[train_index], processed_images_labels[val_index]

    # Create tf.data.Dataset objects
    train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
    val_ds = tf.data.Dataset.from_tensor_slices((val_paths, val_labels))

    # Map the load_image function to load and preprocess the images
    train_ds = train_ds.map(lambda img, label: preprocess_image(img, label, resize=True), num_parallel_calls=AUTOTUNE)
    val_ds = val_ds.map(lambda img, label: preprocess_image(img, label, resize=True), num_parallel_calls=AUTOTUNE)

    train_ds = configure_for_performance(train_ds, augment=True, shuffle=True, drop_remainder=True)
    val_ds = configure_for_performance(val_ds, drop_remainder=False)
    # Create a new model for each fold
    with strategy.scope():
        model = Model(backbone_name='xception', num_classes=num_classes, embedding_size=512)
        model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=schedule, weight_decay=5e-4),
                      loss=ArcfaceLoss(one_hot=one_hot, smoothing=0.2),
                      metrics=[tf.keras.metrics.CategoricalAccuracy('acc') if one_hot else tf.keras.metrics.SparseCategoricalAccuracy('acc')])
    
    history = model.fit(train_ds, validation_data=val_ds, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, epochs=EPOCHS, verbose=1)
    
    scores_orig = model.evaluate(orig_ds, verbose=0, return_dict=True)
    print(f'Score for fold {fold_no} for all data: {scores_orig}')

    if images_per_category:
        scores_eval = model.evaluate(eval_ds, verbose=0, return_dict=True)
        print(f'Score for fold {fold_no} for evaluation data: {scores_eval}')

    # Test phase
    predictions = []
    predictions_tta = []
    indices = []
    
    for batch_images, batch_indices in test_ds:
        # Without TTA
        batch_preds = model.predict(batch_images)
        predictions = predictions + np.argmax(batch_preds, axis=1).tolist()
        
        tta_images_list = new_tta(batch_images)  # List of tensors with shape [B, H, W, C]
        tta_preds = []
        # Loop over each augmented image batch
        for aug_images in tta_images_list:
            preds = model.predict(aug_images, verbose=0)  # Each `preds` is of shape [B, num_classes]
            tta_preds.append(preds)
        tta_preds = np.mean(tta_preds, axis=0)
        predictions_tta = predictions_tta + np.argmax(tta_preds, axis=1).tolist()
        
        indices.extend(batch_indices.numpy().astype('int32'))  # Collect indices to restore order later
    predictions = [ x+1 for x in predictions ]
    predictions_tta = [ x+1 for x in predictions_tta ]
    
    predictions_patches, predictions_patches_tta, indices_patches = patches_test(model, patch_test_ds)
    # Convert predictions to DataFrame and save as CSV
    df = pd.DataFrame({'case': indices, 'class': predictions})
    df.to_csv(f'predictions{fold_no}.csv', index=False, sep=',')
    # Convert predictions to DataFrame and save as CSV with TTA
    df = pd.DataFrame({'case': indices, 'class': predictions_tta})
    df.to_csv(f'predictions{fold_no}_TTA.csv', index=False, sep=',')