<a href="https://colab.research.google.com/github/dakilaledesma/arcface-classifier/blob/main/ArcFace_TF2_LayerImp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%time
! unzip -q /content/drive/MyDrive/UNC/H2022/orchidaceae_train.zip -d /content/

CPU times: user 145 ms, sys: 17 ms, total: 162 ms
Wall time: 20.7 s


In [2]:
! pip install tensorflow-addons

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow-addons
  Downloading tensorflow_addons-0.17.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 4.1 MB/s 
Installing collected packages: tensorflow-addons
Successfully installed tensorflow-addons-0.17.1


In [3]:
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input

from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Layer, Dense, Concatenate, Flatten, GlobalAveragePooling2D, Input
from tensorflow.keras import regularizers
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.preprocessing import image
from tensorflow.keras.utils import to_categorical

import numpy as np

import tensorflow_addons as tfa
import tensorflow as tf

from tqdm.notebook import tqdm
from glob import glob
import os

In [4]:
# From https://github.com/ozora-ogino/asoftmax-tf/blob/main/asoftmax.py
class ASoftmax(tf.keras.layers.Layer):
    def __init__(
        self,
        n_classes=10,
        scale=30.0,
        margin=0.50,
        regularizer=None,
        **kwargs,
    ):
        """[ASoftmax]
        Args:
            n_classes (int, optional): Number of class. Defaults to 10.
            scale (float, optional): Float variable for scaling. Defaults to 30.0.
            margin (float, optional): Float variable of margin. Defaults to 0.50.
            regularizer (function, optional): keras.regularizers. Defaults to None.
        """

        super(ASoftmax, self).__init__(**kwargs)
        self.n_classes = n_classes
        self.scale = scale
        self.margin = margin
        self.regularizer = regularizers.get(regularizer)

    def build(self, input_shape):
        super(ASoftmax, self).build(input_shape[0])
        self.W = self.add_weight(
            name="W",
            shape=(input_shape[0][-1], self.n_classes),
            initializer="glorot_uniform",
            trainable=True,
            regularizer=self.regularizer,
        )

    def _train_op(self, inputs):
        x, y = inputs

        # Normalization
        x = tf.nn.l2_normalize(x, axis=1)
        W = tf.nn.l2_normalize(self.W, axis=0)

        # Dot product
        logits = x @ W

        # Add margin and clip logits to prevent zero division when backward
        theta = tf.acos(K.clip(logits, -1.0 + K.epsilon(), 1.0 - K.epsilon()))
        target_logits = tf.cos(theta + self.margin)
        logits = logits * (1 - y) + target_logits * y

        # Rescale the feature
        logits *= self.scale
        out = tf.nn.softmax(logits)
        return out

    def _predict_op(self, inputs):
        # Normalization
        x = tf.nn.l2_normalize(inputs, axis=1)
        W = tf.nn.l2_normalize(self.W, axis=0)
        logits = x @ W
        out = tf.nn.softmax(logits)
        return out

    def call(self, inputs, training=False):
        if training:
            out = self._train_op(inputs)
        else:
            out = self._predict_op(inputs)
        return out

In [5]:
num_classes = 300

In [6]:
cat_to_int = {}
for i, f in enumerate(sorted(glob("orchidaceae_train/*"))):
  cat = os.path.basename(f)
  cat_to_int[cat] = i

In [7]:
o = []
y = []
orc = sorted(glob("orchidaceae_train/**/*.*", recursive=True))
for fn in tqdm(orc, total=len(list(orc))):
  bn =  os.path.basename(fn)
  cat = cat_to_int[fn.split("/")[-2]]

  img = image.load_img(fn, target_size=(224, 224))
  x = image.img_to_array(img)
  x = preprocess_input(x)
  o.append(x)

  y.append(cat)

  0%|          | 0/9419 [00:00<?, ?it/s]

In [8]:
o = np.array(o)
y = to_categorical(y, 300)

In [None]:
class AFModel(Model):
  def __init__(self, num_classes=300, weight_decay=1e-4):
        super(AFModel, self).__init__()
        self.label_input = Input(shape=(num_classes,))
        self.backbone = ResNet50(input_shape=(224, 224, 3), classes=300, weights='imagenet', include_top=False)
        self.layer_1 = GlobalAveragePooling2D()
        self.layer_2 = Dense(512, activation='relu')

        self.out = ASoftmax(
            n_classes=num_classes,
            regularizer=regularizers.l2(weight_decay),
        )

  def call(self, x, training=False):
      if training:
          x, y = x[0], x[1]
      x = self.backbone(x)
      x = self.layer_1(x)
      x = self.layer_2(x)

      if training:
          # When training, you need to pass label to ASoftmax
          out = self.out([x, y])
      else:
          out = self.out(x)
      return out

model = AFModel()

opt = tfa.optimizers.AdaBelief(learning_rate=1e-3)
model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])

model.fit([o, y],
          y,
          batch_size=64,
          epochs=24,
          verbose=1,
          callbacks=[ModelCheckpoint('model.hdf5',
                     verbose=1, save_best_only=True)])

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
Epoch 1/24
Epoch 2/24
 33/148 [=====>........................] - ETA: 27s - loss: 16.2203 - accuracy: 0.0000e+00