In [None]:
!pip install -q -U wandb
!pip install -q natsort tensorflow_addons

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import matplotlib.pyplot as plt
from mpl_toolkits import axes_grid1
# import tensorflow_models as tfm 
import warnings
warnings.simplefilter('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers 
from tensorflow.keras.callbacks import Callback
from google.cloud import storage
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard, CSVLogger
import tensorflow_datasets as tfds  

from albumentations import (
    Compose, RandomBrightness, JpegCompression, HueSaturationValue, RandomContrast, HorizontalFlip,
    Rotate, Normalize
)

import tensorflow_addons as tfa

In [None]:
MIXED_PRECISION = True
XLA_ACCELERATE = True

try:  # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()  
    strategy = tf.distribute.TPUStrategy(tpu)
    DEVICE = 'TPU'
except ValueError:  # detect GPUs
    strategy = tf.distribute.get_strategy() 
    DEVICE = 'GPU'
    
if DEVICE == "GPU":
    physical_devices = tf.config.list_physical_devices('GPU')
    print("Num GPUs Available: ", len(physical_devices))
    try: 
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        assert tf.config.experimental.get_memory_growth(physical_devices[0])
    except: # Invalid device or cannot modify virtual devices once initialized.
        pass 
    
if MIXED_PRECISION:
    dtype = 'mixed_bfloat16' if DEVICE == "TPU" else 'mixed_float16'
    tf.keras.mixed_precision.set_global_policy(dtype)
    dtype_model = tf.bfloat16
    print('Mixed precision enabled')
else:
    dtype_model = tf.float32


if XLA_ACCELERATE:
    tf.config.optimizer.set_jit(True)
    print('Accelerated Linear Algebra enabled')
    
AUTO  = tf.data.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync

print('REPLICAS           : ', REPLICAS)
print('TensorFlow Version : ', tf.__version__)
print('Eager Mode Status  : ', tf.executing_eagerly())
print('TF Cuda Built Test : ', tf.test.is_built_with_cuda)
print(
    'TF Device Detected : ', 
    'Running on TPU' if DEVICE == "TPU" else tf.test.gpu_device_name()
)

try:
    print('TF System Cuda V.  : ', tf.sysconfig.get_build_info()["cuda_version"])
    print('TF System CudNN V. : ', tf.sysconfig.get_build_info()["cudnn_version"])
except:
    pass

In [None]:
DATA_DIR_PATH = 'gs://kds-0fe1708952e30744d498ab5049717990e1a2816a0f9ba56acf700c28'

In [None]:
train_shard_suffix = 'data_*-9.tfrec'
train_set_path = sorted(tf.io.gfile.glob(DATA_DIR_PATH + f'/{train_shard_suffix}'))

batch_size = 128 * REPLICAS
batch_size_fer = 32 * REPLICAS
image_shape = (256, 256)

train_set_len = 202599 # for part 0 and for part 1: 655167
train_step_epoch = -(-train_set_len // batch_size)

val_set_len = 28662
val_step_epoch = -(-val_set_len // batch_size)

In [None]:
GCS_DS_PATH_fer = "gs://kds-e0b005301382bc95f485d3d97efd910f02aa1f607d7e1f74cf878f2f"
train_shard_suffix = 'train_*-3.tfrec'
train_set_path_fer = sorted(tf.io.gfile.glob(GCS_DS_PATH_fer + f'/train/{train_shard_suffix}'))

In [None]:
class Augmentation(keras.layers.Layer):

    def __init__(self):
        super(Augmentation, self).__init__()
    @tf.function
    def random_execute(self, prob: float) -> bool:
        return tf.random.uniform([], minval=0, maxval=1) < prob

class RandomToGrayscale(Augmentation):
    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        if self.random_execute(0.2):
            x = tf.image.rgb_to_grayscale(x)
            x = tf.tile(x, [1, 1, 3])
        return x

class RandomColorJitter(Augmentation):
    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        if self.random_execute(0.8):
            x = tf.image.random_brightness(x, 0.8)
            x = tf.image.random_contrast(x, 0.2, 0.8)
            x = tf.image.random_saturation(x, 0.4, 1.6)
            x = tf.image.random_hue(x, 0.2)
        return x

class RandomFlip(Augmentation):
    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        if self.random_execute(0.8):
            x = tf.image.random_flip_left_right(x)
        return x

class RandomResizedCrop(Augmentation):
    def __init__(self, image_size):
        super(Augmentation, self).__init__()
        self.image_size = image_size
    def call(self, x: tf.Tensor) -> tf.Tensor:
        rand_size = tf.random.uniform(
            shape=[],
            minval=int(0.75 * self.image_size),
            maxval=1 * self.image_size,
            dtype=tf.int32,
        )
        crop = tf.image.random_crop(x, (rand_size, rand_size, 3))
        crop_resize = tf.image.resize(crop, (self.image_size, self.image_size))
        return crop_resize

class RandomSolarize(Augmentation):
    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        if self.random_execute(0.2):
            # flips abnormally low pixels to abnormally high pixels
            x = tf.where(x < 10, x, 255 - x)
        return x

class RandomBlur(Augmentation):
    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        if self.random_execute(0.2):
            s = np.random.random()
            return tfa.image.gaussian_filter2d(image=x, sigma=s)
        return x


import tensorflow.keras.backend as K
import math
IMAGE_SIZE = [256,256]

def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    

    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

def transform(image):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    rot = 15. * tf.random.normal([1],dtype='float32')
    shr = 5. * tf.random.normal([1],dtype='float32') 
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    h_shift = 16. * tf.random.normal([1],dtype='float32') 
    w_shift = 16. * tf.random.normal([1],dtype='float32') 

    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

class RandomRotage(Augmentation):
    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        if self.random_execute(0.7):
            # flips abnormally low pixels to abnormally high pixels
            # x = tf.where(x < 10, x, 255 - x)
            x = transform(x)
        return x

class RandomAugmentor(keras.Model):
    def __init__(self, image_size: int):
        super(RandomAugmentor, self).__init__()
        self.image_size = image_size
        self.random_resized_crop = RandomResizedCrop(image_size)
        self.random_flip = RandomFlip()
        self.random_color_jitter = RandomColorJitter()
        self.random_blur = RandomBlur()
        self.random_to_grayscale = RandomToGrayscale()
        self.random_solarize = RandomSolarize()
        self.random_rotage = RandomRotage()
    def call(self, x: tf.Tensor) -> tf.Tensor:
#         x = self.random_resized_crop(x)
        x = self.random_rotage(x)
        x = self.random_flip(x)
        x = self.random_color_jitter(x)
        x = self.random_blur(x)
        x = self.random_to_grayscale(x)
        # x = self.random_solarize(x)
#         x = tf.clip_by_value(x, 0, 1)
        return x

bt_augmentor = RandomAugmentor(image_shape[0])

In [None]:

def deserialization_fn(serialized_example):
    features = tf.io.parse_single_example(
        serialized_example,
        features={
            'image': tf.io.FixedLenFeature([], tf.string),
        })
    
    image = tf.image.decode_jpeg(features['image'], channels=3)
    image = tf.image.resize(image, image_shape)
    image_1 = bt_augmentor(image)
    image_2 = bt_augmentor(image)

    image_1 = tf.keras.applications.mobilenet_v2.preprocess_input(image_1)
    image_2 = tf.keras.applications.mobilenet_v2.preprocess_input(image_2)
    return image_1, image_2

In [None]:
def tfrecords_loader(files_path, shuffle=False):
    datasets = tf.data.Dataset.from_tensor_slices(files_path)
    datasets = datasets.shuffle(len(files_path)) if shuffle else datasets
    datasets = datasets.flat_map(tf.data.TFRecordDataset)
    datasets = datasets.map(deserialization_fn, num_parallel_calls=AUTO)
    return datasets

train_datasets = tfrecords_loader(train_set_path)
train_datasets,

(<ParallelMapDataset element_spec=(TensorSpec(shape=(256, 256, 3), dtype=tf.bfloat16, name=None), TensorSpec(shape=(256, 256, 3), dtype=tf.bfloat16, name=None))>,)

In [None]:
def tfrecords_loader_fer(files_path, shuffle=False):
    datasets = tf.data.Dataset.from_tensor_slices(files_path)
    datasets = datasets.shuffle(len(files_path)) if shuffle else datasets
    datasets = datasets.flat_map(tf.data.TFRecordDataset)
    
    def deserialization_fn_fer(serialized_example):
        features = tf.io.parse_single_example(
            serialized_example,
            features={
                'image': tf.io.FixedLenFeature([], tf.string),
                'label': tf.io.FixedLenFeature([], tf.int64),
            })

        image = tf.image.decode_jpeg(features['image'], channels=3)
        image = tf.image.resize(image, image_shape)
        
        image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
        label = tf.cast(features['label'], tf.int64)  # [0-999]

        return image, label
    datasets = datasets.map(deserialization_fn_fer, num_parallel_calls=AUTO)
    return datasets

In [None]:
train_datasets_fer = tfrecords_loader_fer(train_set_path_fer)
train_datasets_fer

In [None]:
# train_dataset_encode = train_datasets.repeat().batch(batch_size).prefetch(AUTO)
# train_dataset_fer_encode = train_datasets_fer.repeat().batch(batch_size).prefetch(AUTO)
def get_dataset(dataset, batch_size):
    return dataset.repeat().batch(batch_size).prefetch(AUTO)

In [None]:
cp_callback = ModelCheckpoint(filepath='/tmp/checkpoints/model.{epoch:02d}-{loss:.2f}.hdf5',
                             monitor='loss',
                             save_freq='epoch',
                             verbose=1,
                             period=20,
                             save_best_only=True,
                             save_weights_only=True)


In [None]:
start_lr = 0.001
min_lr = 0.0005
max_lr = 0.005
rampup_epochs = 10
sustain_epochs = 0
exp_decay = .8
EPOCHS = 200

def lrfn(epoch):
  if epoch < rampup_epochs:
    return (max_lr - start_lr)/rampup_epochs * epoch + start_lr
  elif epoch < rampup_epochs + sustain_epochs:
    return max_lr
  else:
    return (max_lr - min_lr) * exp_decay**(epoch-rampup_epochs-sustain_epochs) + min_lr

lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lrfn(epoch), verbose=True)

rang = np.arange(EPOCHS)
y = [lrfn(x) for x in rang]
plt.plot(rang, y)
print('Learning rate per epoch:')

In [None]:
!rm -rf /tmp/tensorboard
!rm -rf /tmp/checkpoints

!mkdir /tmp/tensorboard
!mkdir /tmp/checkpoints

In [None]:
from natsort import natsorted

class WandbCustom:
    def __init__(self, path_ckpt, fred=1):
        self.path_ckpt = path_ckpt
        # self.old_name = ""
        self.counter = 0
        self.fred = fred
    def upload_file_to_gcs(self, src_path: str):
        
        artifact = wandb.Artifact('model', type='model')
        artifact.add_dir(src_path)
        wandb.log_artifact(artifact)
    
    def __call__(self):
        
        print(self.counter % self.fred)
        if self.counter % self.fred == 1:
            
            self.upload_file_to_gcs(self.path_ckpt)
            print(f"Uploaded {self.path_ckpt}\n")
        self.counter += 1
        
        

In [None]:
import wandb
from wandb.keras import WandbMetricsLogger
from wandb.keras import WandbCallback
run = wandb.init(project="...", name="...")

In [None]:
def off_diagonal(x):
    n = tf.shape(x)[0]
    flattened = tf.reshape(x, [-1])[:-1]
    off_diagonals = tf.reshape(flattened, (n-1, n+1))[:, 1:]
    return tf.reshape(off_diagonals, [-1])

def normalize_repr(z):
    z_norm = (z - tf.reduce_mean(z, axis=0)) / tf.math.reduce_std(z, axis=0)
    return z_norm

def compute_loss(z_a, z_b, lambd=5e-3):
    # Get batch size and representation dimension.
    batch_size = tf.cast(tf.shape(z_a)[0], z_a.dtype)
    repr_dim = tf.shape(z_a)[1]

    # Normalize the representations along the batch dimension.
    z_a_norm = normalize_repr(z_a)
    z_b_norm = normalize_repr(z_b)

    # Cross-correlation matrix.
    c = tf.matmul(z_a_norm, z_b_norm, transpose_a=True) / batch_size

    # Loss.
    on_diag = tf.linalg.diag_part(c) + (-1)
    on_diag = tf.reduce_sum(tf.pow(on_diag, 2))
    off_diag = off_diagonal(c)
    off_diag = tf.reduce_sum(tf.pow(off_diag, 2))
    loss = on_diag + (lambd * off_diag)
    return loss    

In [None]:
patch_size = 4  # 2x2, for the Transformer blocks.
image_size = 256
expansion_factor = 4  # expansion factor for the MobileNetV2 blocks.




def conv_block(x, filters=16, kernel_size=3, strides=2):
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding="same"
    )(x)
    x = layers.BatchNormalization(momentum=0.1)(x)
    x = tf.nn.swish(x)
    return x

def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
    m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
    m = layers.BatchNormalization(momentum=0.1)(m)
    m = tf.nn.swish(m)

    if strides == 2:
        m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
    m = layers.DepthwiseConv2D(
        3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
    )(m)
    m = layers.BatchNormalization(momentum=0.1)(m)
    m = tf.nn.swish(m)

    m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
    m = layers.BatchNormalization(momentum=0.1)(m)

    if tf.math.equal(x.shape[-1], output_channels) and strides == 1:
        return layers.Add()([m, x])
    return m

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.swish)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, x])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=[x.shape[-1] * 2, x.shape[-1]], dropout_rate=0.1,)
        # Skip connection 2.
        x = layers.Add()([x3, x2])

    return x


def mobilevit_block(x, num_blocks, projection_dim, strides=1):
    # Local projection with convolutions.
    local_features = conv_block(x, filters=projection_dim, strides=strides)
    se_features = SqueezeExcitation(projection_dim, projection_dim, 0.25)(local_features)
    local_features = conv_block(
        local_features, filters=projection_dim, kernel_size=1, strides=strides
    )

    # Unfold into patches and then pass through Transformers.
    num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
    non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
        local_features
    )
    global_features = transformer_block(
        non_overlapping_patches, num_blocks, projection_dim
    )

    # Fold into conv-like feature-maps.
    folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
        global_features
    )

    # Apply point-wise conv -> concatenate with the input features.
    folded_feature_map = conv_block(
        folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
    )
    local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map, se_features])

    # Fuse the local and global features using a convoluion layer.
    local_global_features = conv_block(
        local_global_features, filters=projection_dim, strides=strides
    )

    return local_global_features

class SqueezeExcitation(tf.keras.layers.Layer):
  """Creates a squeeze and excitation layer."""

  def __init__(self,
               in_filters,
               out_filters,
               se_ratio,
               **kwargs):

    super(SqueezeExcitation, self).__init__(**kwargs)

    self._in_filters = in_filters
    self._out_filters = out_filters
    self._se_ratio = se_ratio

  def build(self, input_shape):
    num_reduced_filters = max(1, int(self._in_filters * self._se_ratio))
    self._se_reduce = tf.keras.layers.Conv2D(
        filters=num_reduced_filters,
        kernel_size=1,
        strides=1,
        padding='same',
        use_bias=True,
        )

    self._se_expand = tf.keras.layers.Conv2D(
        filters=self._out_filters,
        kernel_size=1,
        strides=1,
        padding='same',
        use_bias=True,
        )

    super(SqueezeExcitation, self).build(input_shape)

  def get_config(self):
    config = {
        'in_filters': self._in_filters,
        'out_filters': self._out_filters,
        'se_ratio': self._se_ratio,
    }
    base_config = super(SqueezeExcitation, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs):
    x = tf.reduce_mean(inputs, [1, 2], keepdims=True)
    x = tf.nn.swish(self._se_reduce(x))
    x = tf.nn.sigmoid(self._se_expand(x))
    return x * inputs

def correct_pad(inputs, kernel_size):
    """Returns a tuple for zero-padding for 2D convolution with downsampling.
    Args:
      inputs: Input tensor.
      kernel_size: An integer or tuple/list of 2 integers.
    Returns:
      A tuple.
    """
    img_dim = 2 if tf.keras.backend.image_data_format() == "channels_first" else 1
    input_size = tf.keras.backend.int_shape(inputs)[img_dim : (img_dim + 2)]
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
    if input_size[0] is None:
        adjust = (1, 1)
    else:
        adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
    correct = (kernel_size[0] // 2, kernel_size[1] // 2)
    return (
        (correct[0] - adjust[0], correct[0]),
        (correct[1] - adjust[1], correct[1]),
    )
    

def create_mobilevit(num_classes=1000):
    inputs = keras.Input((image_size, image_size, 3))
    # x = layers.Rescaling(scale=1.0 / 255)(inputs)

    # Initial conv-stem -> MV2 block.
    x = conv_block(inputs, filters=16)
    x = inverted_residual_block(
        x, expanded_channels=16 * expansion_factor, output_channels=32
    )

    # Downsampling with MV2 block.
    x = inverted_residual_block(
        x, expanded_channels=32 * expansion_factor, output_channels=48, strides=2
    )
    x = inverted_residual_block(
        x, expanded_channels=48 * expansion_factor, output_channels=48
    )
    x = inverted_residual_block(
        x, expanded_channels=48 * expansion_factor, output_channels=48
    )

    # First MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=48 * expansion_factor, output_channels=64, strides=2
    )
    # x = SqueezeExcitation(48, 48, 0.25)(x)
    x = mobilevit_block(x, num_blocks=2, projection_dim=96)
    

    # Second MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
    )
    # x = SqueezeExcitation(64, 64, 0.25)(x)
    x = mobilevit_block(x, num_blocks=4, projection_dim=120)
    

    # Third MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=96 * expansion_factor, output_channels=96, strides=2
    )
    # x = SqueezeExcitation(80, 80, 0.25)(x)
    x = mobilevit_block(x, num_blocks=3, projection_dim=144)
    x = conv_block(x, filters=384, kernel_size=1, strides=1)

    # Classification head.
    x = layers.GlobalAvgPool2D()(x)
    outputs = layers.Dense(num_classes,activation='softmax', dtype='float32')(x)

    return keras.Model(inputs, outputs)


In [None]:
import tensorflow_hub as hub
def get_encoder():
    mobile = create_mobilevit()


    # dir_work = "gs://cuong_tpu/logs/mobilevit_SE/"
    # ckpt = tf.train.Checkpoint(model=mobile)
    # ckpt.restore(tf.train.latest_checkpoint(dir_work+'ckpts'))

    last_layer = mobile.layers[-2].output

    model = keras.Model(mobile.input, last_layer)
    return model


In [None]:
def get_projection_head(dims=[1280, 1024*2, 1024*5, 1024*5]):
    return keras.Sequential(
        [
            keras.Input(shape=(dims[0],)),
            keras.layers.Dense(dims[1]),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            keras.layers.Dense(dims[2]),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            keras.layers.Dense(dims[3]),
        ],
        name="projection_head",
    )

In [None]:
def get_linear_probe(z_dim):
    return keras.Sequential(
        [
            keras.layers.Input(shape=(z_dim,)), 
            keras.layers.Dense(7)],
        name="linear_probe"
    )

In [None]:
z_dim = 384
with strategy.scope():
    encoder = get_encoder()
    projection_head = get_projection_head([z_dim, 1024*2, 1024*2, 1024*5])
    linear_probe = get_linear_probe(z_dim)

    main_optimizer=tfa.optimizers.LAMB()
    probe_optimizer=keras.optimizers.Adam()

    main_loss_tracker = keras.metrics.Mean(name="c_loss")
    probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
    probe_accuracy = keras.metrics.SparseCategoricalAccuracy(name="p_acc")

per_replica_batch_size = batch_size // strategy.num_replicas_in_sync
per_replica_batch_size_fer = batch_size_fer // strategy.num_replicas_in_sync

train_dataset = strategy.experimental_distribute_datasets_from_function(
    lambda _: get_dataset(train_datasets, per_replica_batch_size))

valid_dataset = strategy.experimental_distribute_datasets_from_function(
    lambda _: get_dataset(train_datasets_fer, per_replica_batch_size_fer))


@tf.function
def train_step(iterator):
    def step_fn(batch):
        y_a, y_b = batch
        with tf.GradientTape() as tape:
            z_a_1, z_b_1 = encoder(y_a, training=True), encoder(y_b, training=True)
            z_a_2, z_b_2 = projection_head(z_a_1, training=True), projection_head(z_b_1, training=True)
            main_loss = compute_loss(z_a_2, z_b_2)
            # main_loss = tf.nn.compute_average_loss(main_loss, global_batch_size=batch_size)

        gradients = tape.gradient(main_loss, 
                encoder.trainable_weights + projection_head.trainable_weights)
        main_optimizer.apply_gradients(
            zip(
                gradients,
                encoder.trainable_weights + projection_head.trainable_weights,
            )
        )
        main_loss_tracker.update_state(main_loss)

    strategy.run(step_fn, args=(next(iterator),))

@tf.function
def valid_step(iterator):
    def step_fn(batch):
        imgs, labels = batch
        with tf.GradientTape() as tape:
            features = encoder(imgs, training=False)
            class_logits = linear_probe(features, training=True)
            loss = probe_loss(labels, class_logits)
        gradients = tape.gradient(loss, linear_probe.trainable_weights)
        probe_optimizer.apply_gradients(
            zip(gradients, linear_probe.trainable_weights)
        )
        probe_accuracy.update_state(labels, class_logits)

    strategy.run(step_fn, args=(next(iterator),))

train_iterator = iter(train_dataset)
valid_iterator = iter(valid_dataset)
import time 
time1 = time.time()
local_device_option = tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
path_checkpoint = "/content/log/mobilevit_SE_no"
ckpt = tf.train.Checkpoint(step=tf.Variable(1), encoder=encoder)
manager = tf.train.CheckpointManager(ckpt, path_checkpoint, max_to_keep=1)
wandb_custom = WandbCustom(path_checkpoint, 5)

In [None]:
EPOCH = 500
for epoch in range(EPOCH):
    ckpt.restore(manager.latest_checkpoint, options=local_device_option)
    if manager.latest_checkpoint:
        print("Restored from {}".format(manager.latest_checkpoint))
    else:
        print("Initializing from scratch.")

    print('Epoch: {}/{}'.format(epoch, EPOCH))

    for step in range(train_step_epoch):
        train_step(train_iterator)
    for step in range(val_step_epoch):
        valid_step(valid_iterator)

    step_running = main_optimizer.iterations.numpy()
    loss_running = round(float(main_loss_tracker.result()), 4)
    acc_running = round(float(probe_accuracy.result()), 2)
    time_running = time.time()-time1

    wandb.log(
        {
            "epoch": epoch,
            "step_running": step_running,
            "loss_running": loss_running,
            "acc_running": acc_running,
            "time_running": time_running,   
        }
    )
    main_loss_tracker.reset_states()
    probe_accuracy.reset_states()
    time1 = time.time()
    manager.save(options=local_device_option)
    wandb_custom()

    
    print('Current step: {}, training loss: {}, accuracy: {}, time: {}'.format(
            epoch,
            loss_running, 
            acc_running, 
            time_running))
    