In [5]:
import tensorflow as tf
import os

from datetime import datetime

import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import *

In [4]:
class NiHClassifier(tf.keras.Model):
    def __init__(self, number_of_output_classes: int,
                 image_shape: Tuple[int, int, int] = (224, 224, 3)):
        super().__init__()
        self.number_of_output_classes: int = number_of_output_classes
        self.image_shape: Tuple[int, int, int] = image_shape

        self.pretrained_resnet50 = tf.keras.applications.resnet50.ResNet50(include_top=False,
                                                                           weights="imagenet",
                                                                           input_shape=image_shape)
        self.pretrained_resnet50.trainable = False

        self.global_average_pooling = tf.keras.layers.GlobalAveragePooling2D()
        self.prediction_layer = tf.keras.Sequential([
            tf.keras.layers.Dense(256, activation='relu'),
            tf.keras.layers.Dense(self.number_of_output_classes, activation=tf.keras.activations.sigmoid)
        ])

        self.build(input_shape=(None, image_shape[0], image_shape[1], image_shape[2]))
        
    def unfreeze_top_layers(self, fine_tune_top_n: int):
        self.pretrained_resnet50.trainable = True

        number_of_layers: int = len(self.pretrained_resnet50.layers)
        layers_to_freeze: int = number_of_layers - fine_tune_top_n

        for i in range(layers_to_freeze):
            self.pretrained_resnet50.layers[i].trainable = False

    def call(self, inputs, training=None, mask=None):
        resnet_features = self.pretrained_resnet50(inputs, training=training)
        avg_pooling_features = self.global_average_pooling(resnet_features)
        predictions = self.prediction_layer(avg_pooling_features)
        return predictions


In [6]:
def read_image_data_augmentation(file_path, label):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
    if tf.random.uniform(shape=[]) > 0.5:
        image = tf.image.flip_left_right(image)
    return image, label

In [8]:
def read_image(file_path, label):
    image = tf.io.read_file(file_path)
    image = tf.image.decode_image(image, channels=3, dtype=tf.float32)
    return image, label

In [9]:
def scheduler(epoch: int, lr: float) -> float:
    if epoch < 10:
        return lr
    else:
        return lr*tf.math.exp(-0.1)