In [1]:
import os
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from keras import layers
from tensorflow.keras.applications import ResNet101V2
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import numpy as np
from tqdm import tqdm
import cv2
from keras.models import load_model
import keras
from keras import backend as K
import matplotlib.pyplot as plt
from skimage.morphology import skeletonize

import pickle
import datetime

2024-06-18 06:19:50.776669: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-18 06:19:50.807554: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-18 06:19:50.807575: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-18 06:19:50.808399: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-18 06:19:50.813455: I tensorflow/core/platform/cpu_feature_guar

In [2]:
WINDOW_SIZE = 20
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512
MAX_VECTORS = 40

# Preprocess data

In [3]:
def iou_score(y_pred, y_true, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    union = K.sum(y_true, -1) + K.sum(y_pred, -1) - intersection
    iou = (intersection + smooth)/(union + smooth)
    return iou

def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)
custom_objects = {"dice_coef_loss": dice_coef_loss, 'iou_score': iou_score}

with keras.saving.custom_object_scope(custom_objects):
    model = load_model("models/r2u_attention_80e.h5")
    
with keras.saving.custom_object_scope(custom_objects):
    catheter_model = load_model("models/catheter_detect.h5")

2024-06-18 06:19:55.479525: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 22502 MB memory:  -> device: 0, name: NVIDIA RTX A5000, pci bus id: 0000:19:00.0, compute capability: 8.6
2024-06-18 06:19:55.479947: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 22502 MB memory:  -> device: 1, name: NVIDIA RTX A5000, pci bus id: 0000:8d:00.0, compute capability: 8.6


In [4]:
def match(img, pred_img):
    match_img = pred_img*img
    return match_img

# function to average color
def average(img, i, j, r):
  sum = 0
  for x in range(i - r, i + r + 1):
    for y in range(j -r, j + r + 1):
      sum += img[x][y]
  return sum/((2*r + 1)**2)

# function to match the segmentation back to the origin image, but we just get a square area
def square_match(img, pred_img, target):
  # get top, right, bottom, left boundary

  top_flat = 0
  top_index = 512
  for i in range(0, 510):
    for j in range(0, 510):
      if average(pred_img, i, j, 2) > 0.7:
        top_flat = 1
        top_index = max(0, i - 10)
        break
    if top_flat == 1:
      break

  right_flat = 0
  right_index = 0
  for j in reversed(range(0, 510)):
    for i in range(0, 510):
      if average(pred_img, i, j, 2) > 0.7:
        right_flat = 1
        right_index = min(j + 10, 512)
        break
    if right_flat == 1:
      break

  bottom_flat = 0
  bottom_index = 0
  for i in reversed(range(0, 510)):
    for j in range(0, 510):
      if average(pred_img, i, j, 2) > 0.7:
        bottom_flat = 1
        bottom_index = min(512, i + 10)
        break
    if bottom_flat == 1:
      break

  left_flat = 0
  left_index = 512
  for j in range(0, 510):
    for i in range(0, 510):
      if average(pred_img, i, j, 2) > 0.7:
        left_flat = 1
        left_index = max(0, j - 10)
        break
    if left_flat == 1:
      break

  if target == "full":
    for i in range(512):
      for j in range(512):
        if i < top_index or i > bottom_index or j < left_index or j > right_index:
          img[i][j] = 0

  if target == "crop":
    if left_index == 512 and top_index == 512 and right_index == 0 and bottom_index == 0:
      return 0
    else:
      return img[top_index:bottom_index, left_index:right_index]

  return img

def predict_one_image(img, model):
  resized_img = cv2.resize(img, (512, 512))
  X = np.reshape(resized_img, (1, resized_img.shape[0], resized_img.shape[1], 1))
  normalized_X = X/255
  normalized_X = np.rollaxis(normalized_X, 3, 1)
  pred_y = model.predict(normalized_X, verbose=0)
  pred_y[pred_y > 0.5] = 1
  pred_y[pred_y != 1] = 0
  pred_img = np.reshape(pred_y[0]*255, (512, 512))
  match_img = pred_img*resized_img
  return pred_img, match_img

def remove_catheter(image):
    vessel_img, _ = predict_one_image(image, model)
    catheter_img, _ = predict_one_image(image, catheter_model)
    subtract_image = vessel_img - catheter_img
    _, binary = cv2.threshold(subtract_image, 50, 255, cv2.THRESH_BINARY)
    binary = binary.astype(np.uint8)
    contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    largest_contour = find_largest_contour(contours)
    mask = np.zeros_like(binary)
    cv2.drawContours(mask, [largest_contour], -1, 255, thickness=cv2.FILLED)
    vessel_img = cv2.bitwise_and(subtract_image, subtract_image, mask=mask)
    resized_img = cv2.resize(img, (512, 512))
    return vessel_img/255., resized_img*vessel_img/255.

def get_vessel(image):
    vessel_img, _ = predict_one_image(image, model)
    subtract_image = vessel_img 
    resized_img = cv2.resize(img, (512, 512))
    return vessel_img/255., resized_img*vessel_img/255.

def find_largest_contour(contours):
    max_contour = None
    max_area = 0
    for contour in contours:
        area = cv2.contourArea(contour)
        if area > max_area:
            max_area = area
            max_contour = contour
    return max_contour

def neighbours(x,y,image):
    img = image
    x_1, y_1, x1, y1 = x-1, y-1, x+1, y+1
    return [ img[x_1][y], img[x_1][y1], img[x][y1], img[x1][y1],     # P2,P3,P4,P5
                img[x1][y], img[x1][y_1], img[x][y_1], img[x_1][y_1] ]    # P6,P7,P8,P9

def transitions(neighbours):
    n = neighbours + neighbours[0:1]      # P2, P3, ... , P8, P9, P2
    return sum( (n1, n2) == (0, 1) for n1, n2 in zip(n, n[1:]) )  # (P2,P3), (P3,P4), ... , (P8,P9), (P9,P2)

def zhangSuen(image):
    Image_Thinned = image.copy()  # deepcopy to protect the original image
    changing1 = changing2 = 1        #  the points to be removed (set as 0)
    while changing1 or changing2:   #  iterates until no further changes occur in the image
        # Step 1
        changing1 = []
        rows, columns = Image_Thinned.shape               # x for rows, y for columns
        for x in range(1, rows - 1):                     # No. of  rows
            for y in range(1, columns - 1):            # No. of columns
                P2,P3,P4,P5,P6,P7,P8,P9 = n = neighbours(x, y, Image_Thinned)
                if (Image_Thinned[x][y] == 1     and    # Condition 0: Point P1 in the object regions
                    2 <= sum(n) <= 6   and    # Condition 1: 2<= N(P1) <= 6
                    transitions(n) == 1 and    # Condition 2: S(P1)=1
                    P2 * P4 * P6 == 0  and    # Condition 3
                    P4 * P6 * P8 == 0):         # Condition 4
                    changing1.append((x,y))
        for x, y in changing1:
            Image_Thinned[x][y] = 0
        # Step 2
        changing2 = []
        for x in range(1, rows - 1):
            for y in range(1, columns - 1):
                P2,P3,P4,P5,P6,P7,P8,P9 = n = neighbours(x, y, Image_Thinned)
                if (Image_Thinned[x][y] == 1   and        # Condition 0
                    2 <= sum(n) <= 6  and       # Condition 1
                    transitions(n) == 1 and      # Condition 2
                    P2 * P4 * P8 == 0 and       # Condition 3
                    P2 * P6 * P8 == 0):            # Condition 4
                    changing2.append((x,y))
        for x, y in changing2:
            Image_Thinned[x][y] = 0
    return Image_Thinned

def distance(x1,y1,x2,y2):
  return np.sqrt((x1-x2)**2 + (y1-y2)**2)

def min_distance(x,y, vector, limit):
  if limit > 0:
    d = [distance(x, y, vector[i][0],vector[i][1]) for i in range(limit)]
    d.sort()
    return d[0]
  return 1000

def vectorize_one_image_using_center_line(img):
  vector = np.zeros((MAX_VECTORS, 3), dtype=np.float32)
  STEP = 5
  resized_img = cv2.resize(img, (512, 512))
  pred_img, match_img = get_vessel(img)

  centerline = zhangSuen(pred_img.astype(int))

  all_zeros = np.all(pred_img == 0)
  if all_zeros:
    return np.zeros((MAX_VECTORS, 3)), None
  img_with_rectangles = cv2.cvtColor(match_img, cv2.COLOR_GRAY2BGR)
  centerline_with_rect = cv2.cvtColor(centerline.astype(np.float32)*255, cv2.COLOR_GRAY2BGR)
  index = 0
  WS12 = WINDOW_SIZE//2

  for y in range(0, IMAGE_HEIGHT, STEP):
      for x in range(0, IMAGE_WIDTH, STEP):
          window = centerline[y:y + STEP, x:x + STEP]
          x_arr, y_arr = np.where(window==1)
          if len(x_arr) > 2:
            # calculate the center point of window
            x_w = int(x_arr.mean()) + x
            y_w = int(y_arr.mean()) + y

            # get closest window from x_w, y_w
            m_distance = min_distance(x_w, y_w, vector, index)
            # check overlap among window
            if m_distance >  WS12 * 1.5:
              window = pred_img[y_w - WS12 : y_w + WS12, x_w - WS12 : x_w + WS12]
              upper_left = (x_w - WS12, y_w - WS12)
              lower_right = (x_w + WS12, y_w + WS12)
              # disable area that we already used
              centerline[upper_left[1]:lower_right[1]+1,upper_left[0]:lower_right[0]+1] = 0

              cv2.rectangle(img_with_rectangles, upper_left, lower_right, (0, 255, 0), 1)
              cv2.rectangle(centerline_with_rect, upper_left, lower_right, (0, 255, 0), 1)

              average_color = np.sum(window)
              vector[index] = np.array([x_w, y_w, average_color])
              index+=1
              if index>=MAX_VECTORS:
                break
      if index>=MAX_VECTORS:
        break
  return vector, img_with_rectangles, centerline_with_rect

In [5]:
def vectorize_one_image_using_center_line(img):
    vector = np.zeros((MAX_VECTORS, 3), dtype=np.float32)
    STEP = 5
    RESIZE_DIM = (512, 512)
    resized_img = cv2.resize(img, RESIZE_DIM)
    pred_img, match_img = get_vessel(img)

    centerline = zhangSuen(pred_img.astype(int))

    if np.all(pred_img == 0):
        return vector, None

    img_with_rectangles = cv2.cvtColor(match_img, cv2.COLOR_GRAY2BGR)
    centerline_with_rect = cv2.cvtColor(centerline.astype(np.float32) * 255, cv2.COLOR_GRAY2BGR)
    index = 0
    WS12 = WINDOW_SIZE // 2
    IMAGE_DIM = (IMAGE_HEIGHT, IMAGE_WIDTH)

    for y in range(0, IMAGE_DIM[0], STEP):
        for x in range(0, IMAGE_DIM[1], STEP):
            window = centerline[y:y + STEP, x:x + STEP]
            if np.count_nonzero(window) > 2:
                y_arr, x_arr = np.where(window == 1)
                x_w = int(x_arr.mean()) + x
                y_w = int(y_arr.mean()) + y

                if min_distance(x_w, y_w, vector, index) > WS12 * 1.5:
                    upper_left = (max(0, x_w - WS12), max(0, y_w - WS12))
                    lower_right = (min(IMAGE_WIDTH, x_w + WS12), min(IMAGE_HEIGHT, y_w + WS12))
                    
                    if (lower_right[0] - upper_left[0]) <= 0 or (lower_right[1] - upper_left[1]) <= 0:
                        continue

                    window = pred_img[upper_left[1]:lower_right[1], upper_left[0]:lower_right[0]]
                    centerline[upper_left[1]:lower_right[1], upper_left[0]:lower_right[0]] = 0

                    cv2.rectangle(img_with_rectangles, upper_left, lower_right, (0, 255, 0), 1)
                    cv2.rectangle(centerline_with_rect, upper_left, lower_right, (0, 255, 0), 1)

                    average_color = window.sum()
                    vector[index] = [x_w, y_w, average_color]
                    index += 1

                    if index >= MAX_VECTORS:
                        return vector, img_with_rectangles

    return vector, img_with_rectangles, centerline_with_rect

In [6]:
def adjust_boxes(row):
    width_scale = 512 / row['width']
    height_scale = 512 / row['height']
    
    row['xmin'] = int(row['xmin'] * width_scale)
    row['ymin'] = int(row['ymin'] * height_scale)
    row['xmax'] = int(row['xmax'] * width_scale)
    row['ymax'] = int(row['ymax'] * height_scale)
    
    return row

In [7]:
def is_window_overlap(x, y, window_size, box):
    xmin = x - window_size / 2
    ymin = y - window_size / 2
    xmax = x + window_size / 2
    ymax = y + window_size / 2
    
    box_xmin, box_ymin, box_xmax, box_ymax = box
    
    if (xmin < box_xmax and xmax > box_xmin and
        ymin < box_ymax and ymax > box_ymin):
        return 1
    else:
        return 0

In [None]:
# filenames = df_train.filename.values
# vectors = np.zeros((len(filenames), 40, 3))
# labels = np.zeros((len(filenames), 40))
# boxes = df_train[['xmin', 'ymin', 'xmax', 'ymax']].values

# for index, filename in tqdm(enumerate(filenames)):
#     img = cv2.imread(os.path.join("data", filename), 0)
#     resized_img = cv2.resize(img, (512, 512))
#     pred_img, match_img = get_vessel(img)
#     skeletion = skeletonize(pred_img.astype(int))
#     vector, img_with_rectangles, centerline_with_rect = vectorize_one_image_using_center_line(img)
#     label = []
#     for v in vector:
#         x, y, color = v
#         label.append(is_window_overlap(x, y, WINDOW_SIZE//2, boxes[index]))
#     vectors[index] = vector
#     label = np.array(label)
#     labels[index] = label
#     break
# #     np.save("label.npy", labels)
# #     np.save("vector.npy", vectors)

In [8]:
labels = np.load("label.npy")
vectors = np.load("vector.npy")
vectors = vectors[:546]
labels = labels[:546]

In [9]:
df_train = pd.read_csv('train_labels.csv')
filenames = df_train.filename.values[:546]
images_list = np.zeros((len(filenames), 40, 20, 20))

In [10]:
for index, filename in tqdm(enumerate(filenames)):
    img = cv2.imread(os.path.join("data", filename), 0)
    resized_img = cv2.resize(img, (512, 512))
    pred_img, match_img = get_vessel(img)
    vector = vectors[index]
    images = np.zeros((40, 20, 20))
    for i, v in enumerate(vector):
        x, y, pixel_count = v
        x = int(x)
        y = int(y)
        small_image = pred_img[y-10:y+10, x-10:x+10]
        if small_image.shape[0]!=20 or small_image.shape[1]!=20:
            continue
        images[i] = small_image
    images_list[index] = images

0it [00:00, ?it/s]2024-06-18 06:20:08.020214: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8902
546it [00:34, 15.80it/s]


In [12]:
# np.save("images_list.npy", images_list)

In [13]:
images_list = np.load("images_list.npy")
images_list.shape

(546, 40, 20, 20)

In [14]:
def multi_label_accuracy(y_true, y_pred):
    y_pred = tf.round(y_pred)
    correct_predictions = tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.float32), axis=-1)
    accuracy = correct_predictions / tf.cast(tf.shape(y_true)[-1], tf.float32)
    return tf.reduce_mean(accuracy)

In [15]:
import tensorflow as tf
from tensorflow.keras import backend as K

nINF = -100.0

class TwoWayLoss(tf.keras.losses.Loss):
    def __init__(self, Tp=4.0, Tn=1.0, name="two_way_loss"):
        super(TwoWayLoss, self).__init__(name=name)
        self.Tp = Tp
        self.Tn = Tn

    def call(self, y_true, y_pred):
        class_mask = tf.reduce_any(y_true > 0, axis=0)
        sample_mask = tf.reduce_any(y_true > 0, axis=1)

        # Calculate hard positive/negative logits
        pmask = tf.where(y_true > 0, 0.0, nINF)
        plogit_class = tf.reduce_logsumexp(-y_pred / self.Tp + pmask, axis=0) * self.Tp
        plogit_sample = tf.reduce_logsumexp(-y_pred / self.Tp + pmask, axis=1) * self.Tp
        
        plogit_class = tf.boolean_mask(plogit_class, class_mask)
        plogit_sample = tf.boolean_mask(plogit_sample, sample_mask)
    
        nmask = tf.where(y_true == 0, 0.0, nINF)
        nlogit_class = tf.reduce_logsumexp(y_pred / self.Tn + nmask, axis=0) * self.Tn
        nlogit_sample = tf.reduce_logsumexp(y_pred / self.Tn + nmask, axis=1) * self.Tn

        nlogit_class = tf.boolean_mask(nlogit_class, class_mask)
        nlogit_sample = tf.boolean_mask(nlogit_sample, sample_mask)

        return tf.reduce_mean(tf.nn.softplus(nlogit_class + plogit_class)) + \
               tf.reduce_mean(tf.nn.softplus(nlogit_sample + plogit_sample))

def get_criterion(Tp, Tn):
    return TwoWayLoss(Tp=Tp, Tn=Tn)

In [16]:
learning_rate = 1e-4
weight_decay = 1e-4
batch_size = 2**7
num_epochs = 20
patch_size = 6
image_size = 32
num_patches = (image_size // patch_size) ** 2
projection_dim = 2**6
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim
] # Size of the transformer layers

transformer_layers = 8
mlp_head_units = [
    2 ** 11,
    2 ** 10,
] # Size of the dense layers of the final classifier

In [17]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In [18]:
import tensorflow as tf
from tensorflow.keras import layers

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patches):
        positions = tf.expand_dims(
            tf.range(start=0, limit=self.num_patches, delta=1), axis=0
        )
        projected_patches = self.projection(patches)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config

In [19]:
def create_vit_classifier(num_classes=40):
    encoded_patches = keras.Input(shape=INPUT_SHAPE)

    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    logits = layers.Dense(num_classes)(features)
    model = keras.Model(inputs=encoded_patches, outputs=logits)
    return model

# Add index to images

In [21]:
N = 546  
num_patches = 40  
height = 20  
width = 20  
channels = 1  
projection_dim = 64  

patches = images_list

patches = tf.reshape(patches, (N, num_patches, height * width * channels))

encoder = PatchEncoder(num_patches=num_patches, projection_dim=projection_dim)
encoded_patches = encoder(patches)
print(encoded_patches.shape) 

(546, 40, 64)


In [22]:
import tensorflow as tf
from tensorflow.keras import layers

nINF = -100.0

class TwoWayLoss(tf.keras.losses.Loss):
    def __init__(self, Tp=4.0, Tn=1.0, name="two_way_loss"):
        super().__init__(name=name)
        self.Tp = Tp
        self.Tn = Tn

    def call(self, y_true, y_pred):
        # Convert y_true and y_pred to tensors
        y_true = tf.convert_to_tensor(y_true)
        y_pred = tf.convert_to_tensor(y_pred)

        # Create masks to identify positive and negative samples
        class_mask = tf.reduce_any(y_true > 0, axis=0)
        sample_mask = tf.reduce_any(y_true > 0, axis=1)

        # Calculate hard positive logits
        pmask = tf.where(y_true > 0, 0.0, nINF)
        plogit_class = tf.reduce_logsumexp(-y_pred / self.Tp + pmask, axis=0) * self.Tp
        plogit_class = tf.boolean_mask(plogit_class, class_mask)
        plogit_sample = tf.reduce_logsumexp(-y_pred / self.Tp + pmask, axis=1) * self.Tp
        plogit_sample = tf.boolean_mask(plogit_sample, sample_mask)
    
        # Calculate hard negative logits
        nmask = tf.where(y_true == 0, 0.0, nINF)
        nlogit_class = tf.reduce_logsumexp(y_pred / self.Tn + nmask, axis=0) * self.Tn
        nlogit_class = tf.boolean_mask(nlogit_class, class_mask)
        nlogit_sample = tf.reduce_logsumexp(y_pred / self.Tn + nmask, axis=1) * self.Tn
        nlogit_sample = tf.boolean_mask(nlogit_sample, sample_mask)

        # Calculate loss
        loss = tf.reduce_mean(tf.nn.softplus(nlogit_class + plogit_class)) + \
               tf.reduce_mean(tf.nn.softplus(nlogit_sample + plogit_sample))
        return loss

def get_criterion():
    return TwoWayLoss(Tp=4.0, Tn=1.0)

In [24]:
INPUT_SHAPE = (40, 64)
train_size = len(images_list)
initial_learning_rate = 0.001
final_learning_rate = 0.00001
learning_rate_decay_factor = (final_learning_rate / initial_learning_rate) ** (1 / num_epochs)
steps_per_epoch = int(train_size/batch_size)
transformer_units = [
    projection_dim * 2,
    projection_dim
]

lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=initial_learning_rate,
                decay_steps=steps_per_epoch,
                decay_rate=learning_rate_decay_factor,
                staircase=True)

optimizer = keras.optimizers.AdamW(
    learning_rate=lr_schedule, weight_decay=weight_decay
)

model = create_vit_classifier()
model.compile(
    optimizer=optimizer,
    loss=get_criterion(),
    metrics=[
        multi_label_accuracy
    ],
)

weight_filename = str(datetime.datetime.now().strftime("%d %b %Y %I:%M%p")) + '__checkpoint.weights.h5'
checkpoint_filepath = "ViT_weights/" + str(weight_filename)
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    checkpoint_filepath,
    monitor="val_f1_score",
    save_best_only=True,
    save_weights_only=True,
)
f1_early_stopping_callback = keras.callbacks.EarlyStopping(
    monitor='val_f1_score',
    patience=10,
    mode='max',
    restore_best_weights=True,
    verbose=1
)
history = model.fit(
    encoded_patches,
    labels,
    batch_size=batch_size,
    epochs=num_epochs,
    callbacks=[checkpoint_callback,
               f1_early_stopping_callback],
)


Epoch 1/20


2024-06-18 06:23:08.584522: I external/local_xla/xla/service/service.cc:168] XLA service 0x7efb340c5b70 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-06-18 06:23:08.584550: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA RTX A5000, Compute Capability 8.6
2024-06-18 06:23:08.584572: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (1): NVIDIA RTX A5000, Compute Capability 8.6
2024-06-18 06:23:08.589289: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1718691788.684198 2718641 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
