In [None]:
!pip install -q -U wandb
!pip install -q natsort tensorflow_addons

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
import matplotlib.pyplot as plt
from mpl_toolkits import axes_grid1

import warnings
warnings.simplefilter('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
from tensorflow import keras
from google.cloud import storage
import tensorflow_datasets as tfds  

from albumentations import (
    Compose, RandomBrightness, JpegCompression, HueSaturationValue, RandomContrast, HorizontalFlip,
    Rotate, Normalize
)

import tensorflow_addons as tfa

In [None]:
MIXED_PRECISION = False
XLA_ACCELERATE = False

try:  # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()  
    strategy = tf.distribute.TPUStrategy(tpu)
    DEVICE = 'TPU'
except ValueError:  # detect GPUs
    strategy = tf.distribute.get_strategy() 
    DEVICE = 'GPU'
    
if DEVICE == "GPU":
    physical_devices = tf.config.list_physical_devices('GPU')
    print("Num GPUs Available: ", len(physical_devices))
    try: 
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        assert tf.config.experimental.get_memory_growth(physical_devices[0])
    except: # Invalid device or cannot modify virtual devices once initialized.
        pass 
    
if MIXED_PRECISION:
    dtype = 'mixed_bfloat16' if DEVICE == "TPU" else 'mixed_float16'
    tf.keras.mixed_precision.set_global_policy(dtype)
    dtype_model = tf.bfloat16
    print('Mixed precision enabled')
else:
    dtype_model = tf.float32


if XLA_ACCELERATE:
    tf.config.optimizer.set_jit(True)
    print('Accelerated Linear Algebra enabled')
    
AUTO  = tf.data.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync

print('REPLICAS           : ', REPLICAS)
print('TensorFlow Version : ', tf.__version__)
print('Eager Mode Status  : ', tf.executing_eagerly())
print('TF Cuda Built Test : ', tf.test.is_built_with_cuda)
print(
    'TF Device Detected : ', 
    'Running on TPU' if DEVICE == "TPU" else tf.test.gpu_device_name()
)

try:
    print('TF System Cuda V.  : ', tf.sysconfig.get_build_info()["cuda_version"])
    print('TF System CudNN V. : ', tf.sysconfig.get_build_info()["cudnn_version"])
except:
    pass

In [None]:
batch_size = 32 * REPLICAS
image_shape = (256, 256)

train_set_len = 28662 # for part 0 and for part 1: 655167
valid_set_len = 7171
train_step = -(-train_set_len // batch_size)
val_step = valid_set_len // batch_size

In [None]:
GCS_DS_PATH_fer = "gs://kds-39f11653dcf6e5ceb7705acf4ad46d05a1c259959db3a920906fe2f0"
train_shard_suffix = 'train_*-3.tfrec'
test_shard_suffix = 'test_*-1.tfrec'
train_set_path_fer = sorted(tf.io.gfile.glob(GCS_DS_PATH_fer + f'/train/{train_shard_suffix}'))
test_set_path_fer = sorted(tf.io.gfile.glob(GCS_DS_PATH_fer + f'/test/{test_shard_suffix}'))

In [None]:
class Augmentation(keras.layers.Layer):

    def __init__(self):
        super(Augmentation, self).__init__()

    @tf.function
    def random_execute(self, prob: float) -> bool:

        return tf.random.uniform([], minval=0, maxval=1) < prob


class RandomToGrayscale(Augmentation):

    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:

        if self.random_execute(0.2):
            x = tf.image.rgb_to_grayscale(x)
            x = tf.tile(x, [1, 1, 3])
        return x


class RandomColorJitter(Augmentation):

    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:

        if self.random_execute(0.8):
            x = tf.image.random_brightness(x, 0.8)
            # x = tf.image.random_contrast(x, 0.4, 1.6)
            # x = tf.image.random_saturation(x, 0.4, 1.6)
            # x = tf.image.random_hue(x, 0.2)
        return x

        


class RandomFlip(Augmentation):

    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:

        if self.random_execute(0.5):
            x = tf.image.random_flip_left_right(x)

        return x



class RandomResizedCrop(Augmentation):

    def __init__(self, image_size):
        super(Augmentation, self).__init__()
        self.image_size = image_size

    def call(self, x: tf.Tensor) -> tf.Tensor:
        rand_size = tf.random.uniform(
            shape=[],
            minval=int(0.8 * self.image_size),
            maxval=1 * self.image_size,
            dtype=tf.int32,
        )

        crop = tf.image.random_crop(x, (rand_size, rand_size, 3))
        crop_resize = tf.image.resize(crop, (self.image_size, self.image_size))
        return crop_resize


class RandomSolarize(Augmentation):
    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        if self.random_execute(0.2):
            # flips abnormally low pixels to abnormally high pixels
            x = tf.where(x < 10, x, 255 - x)
        return x


class RandomBlur(Augmentation):
    @tf.function
    def call(self, x: tf.Tensor) -> tf.Tensor:
        if self.random_execute(0.2):
            s = np.random.random()
            return tfa.image.gaussian_filter2d(image=x, sigma=s)
        return x


class RandomAugmentor(keras.Model):
    def __init__(self, image_size: int):
        super(RandomAugmentor, self).__init__()
        self.image_size = image_size
        self.random_resized_crop = RandomResizedCrop(image_size)
        self.random_flip = RandomFlip()
        self.random_color_jitter = RandomColorJitter()
        self.random_blur = RandomBlur()
        self.random_to_grayscale = RandomToGrayscale()
        self.random_solarize = RandomSolarize()
        # self.random_rotage = RandomRotage()
        # self.random_shear_x = RandomShear_x()
        # self.random_cutout = RandomCutout()

    def call(self, x: tf.Tensor) -> tf.Tensor:
        x = self.random_resized_crop(x)
        x = self.random_flip(x)
        x = self.random_color_jitter(x)
        x = self.random_blur(x)
        # x = tf.image.random_brightness(x, 0.2)
        # x = tf.image.random_contrast(x, 0.5, 0.9)
        # x = self.random_cutout(x)
        x = self.random_rotage(x)
        # x = tfa.image.random_cutout(x, 5)
        # x = self.random_shear_x(x)
        
        # x = tf.image.ran
        # x = self.random_to_grayscale(x)
        # x = self.random_solarize(x)

#         x = tf.clip_by_value(x, 0, 1)
        return x


bt_augmentor = RandomAugmentor(image_shape[0])

In [None]:
import tensorflow.keras.backend as K
import math
IMAGE_SIZE = image_shape

def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    

    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

def transform(image):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    rot = 15. * tf.random.normal([1],dtype='float32')
    shr = 5. * tf.random.normal([1],dtype='float32') 
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    h_shift = 16. * tf.random.normal([1],dtype='float32') 
    w_shift = 16. * tf.random.normal([1],dtype='float32') 

    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

In [None]:
WEIGHT = tf.constant([1.02660468, 9.40661861, 1.00104606, 0.56843877, 0.84912748,
       1.29337298, 0.82603942])

def tfrecords_loader_fer(files_path, shuffle=False, is_train=False):
    datasets = tf.data.Dataset.from_tensor_slices(files_path)
    datasets = datasets.shuffle(len(files_path)) if shuffle else datasets
    datasets = datasets.flat_map(tf.data.TFRecordDataset)
    
    def deserialization_fn_fer(serialized_example):
        features = tf.io.parse_single_example(
            serialized_example,
            features={
                'image': tf.io.FixedLenFeature([], tf.string),
                'label': tf.io.FixedLenFeature([], tf.int64),
            })

        image = tf.image.decode_jpeg(features['image'], channels=3)
        image = tf.image.resize(image, (265,265))
        if is_train:
            image = transform(image)
        else:
            image = tf.image.resize(image, image_shape)
            
        image = tf.cast(image, tf.float32)
        image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
        label = tf.cast(features['label'], tf.int64)  # [0-999]
        
        # print(features['label'])
        sample_weight = tf.gather(WEIGHT, indices = label)

        label = tf.one_hot(label, 7)

        

        return image, label, sample_weight
    datasets = datasets.map(deserialization_fn_fer, num_parallel_calls=AUTO)
    return datasets

In [None]:
train_datasets_fer = tfrecords_loader_fer(train_set_path_fer, is_train=True)
test_datasets_fer = tfrecords_loader_fer(test_set_path_fer)

In [None]:
train_datasets_fer.take(10)

In [None]:
train_dataset_encode = train_datasets_fer.repeat().batch(batch_size).prefetch(AUTO)
test_dataset_encode = test_datasets_fer.repeat().batch(batch_size).prefetch(AUTO)

In [None]:
next(iter(test_dataset_encode))[1], next(iter(train_dataset_encode))[1]

In [None]:
import matplotlib.pyplot as plt

# for i in range(10):
iter_data = iter(train_datasets_fer)
x = next(iter_data)[0].numpy()
x = np.asarray(x * 127.5 + 127.5, dtype='uint8')
print(x.min(), x.max())
# print(x)
plt.imshow(x)
plt.show()

In [None]:
patch_size = 4  # 2x2, for the Transformer blocks.
image_size = 256
expansion_factor = 4  # expansion factor for the MobileNetV2 blocks.

from tensorflow.keras import layers 


def conv_block(x, filters=16, kernel_size=3, strides=2):
    x = layers.Conv2D(
        filters, kernel_size, strides=strides, padding="same"
    )(x)
    x = layers.BatchNormalization(momentum=0.1)(x)
    x = tf.nn.swish(x)
    return x


# Reference: https://git.io/JKgtC


def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
    m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
    m = layers.BatchNormalization(momentum=0.1)(m)
    m = tf.nn.swish(m)

    if strides == 2:
        m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
    m = layers.DepthwiseConv2D(
        3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
    )(m)
    m = layers.BatchNormalization(momentum=0.1)(m)
    m = tf.nn.swish(m)

    m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
    m = layers.BatchNormalization(momentum=0.1)(m)

    if tf.math.equal(x.shape[-1], output_channels) and strides == 1:
        return layers.Add()([m, x])
    return m

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.swish)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, x])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=[x.shape[-1] * 2, x.shape[-1]], dropout_rate=0.1,)
        # Skip connection 2.
        x = layers.Add()([x3, x2])

    return x


def mobilevit_block(x, num_blocks, projection_dim, strides=1):
    # Local projection with convolutions.
    local_features = conv_block(x, filters=projection_dim, strides=strides)
    se_features = SqueezeExcitation(projection_dim, projection_dim, 0.25)(local_features)
    local_features = conv_block(
        local_features, filters=projection_dim, kernel_size=1, strides=strides
    )

    # Unfold into patches and then pass through Transformers.
    num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
    non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
        local_features
    )
    global_features = transformer_block(
        non_overlapping_patches, num_blocks, projection_dim
    )

    # Fold into conv-like feature-maps.
    folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
        global_features
    )

    # Apply point-wise conv -> concatenate with the input features.
    folded_feature_map = conv_block(
        folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
    )
    local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map, se_features])

    # Fuse the local and global features using a convoluion layer.
    local_global_features = conv_block(
        local_global_features, filters=projection_dim, strides=strides
    )

    return local_global_features

class SqueezeExcitation(tf.keras.layers.Layer):
  """Creates a squeeze and excitation layer."""

  def __init__(self,
               in_filters,
               out_filters,
               se_ratio,
               **kwargs):

    super(SqueezeExcitation, self).__init__(**kwargs)

    self._in_filters = in_filters
    self._out_filters = out_filters
    self._se_ratio = se_ratio

  def build(self, input_shape):
    num_reduced_filters = max(1, int(self._in_filters * self._se_ratio))
    self._se_reduce = tf.keras.layers.Conv2D(
        filters=num_reduced_filters,
        kernel_size=1,
        strides=1,
        padding='same',
        use_bias=True,
        )

    self._se_expand = tf.keras.layers.Conv2D(
        filters=self._out_filters,
        kernel_size=1,
        strides=1,
        padding='same',
        use_bias=True,
        )

    super(SqueezeExcitation, self).build(input_shape)

  def get_config(self):
    config = {
        'in_filters': self._in_filters,
        'out_filters': self._out_filters,
        'se_ratio': self._se_ratio,
    }
    base_config = super(SqueezeExcitation, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs):
    x = tf.reduce_mean(inputs, [1, 2], keepdims=True)
    x = tf.nn.swish(self._se_reduce(x))
    x = tf.nn.sigmoid(self._se_expand(x))
    return x * inputs

def correct_pad(inputs, kernel_size):
    img_dim = 2 if tf.keras.backend.image_data_format() == "channels_first" else 1
    input_size = tf.keras.backend.int_shape(inputs)[img_dim : (img_dim + 2)]
    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)
    if input_size[0] is None:
        adjust = (1, 1)
    else:
        adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
    correct = (kernel_size[0] // 2, kernel_size[1] // 2)
    return (
        (correct[0] - adjust[0], correct[0]),
        (correct[1] - adjust[1], correct[1]),
    )
    

def create_mobilevit(num_classes=1000):
    inputs = keras.Input((image_size, image_size, 3))

    x = conv_block(inputs, filters=16)
    x = inverted_residual_block(
        x, expanded_channels=16 * expansion_factor, output_channels=32
    )

    # Downsampling with MV2 block.
    x = inverted_residual_block(
        x, expanded_channels=32 * expansion_factor, output_channels=48, strides=2
    )
    x = inverted_residual_block(
        x, expanded_channels=48 * expansion_factor, output_channels=48
    )
    x = inverted_residual_block(
        x, expanded_channels=48 * expansion_factor, output_channels=48
    )

    # First MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=48 * expansion_factor, output_channels=64, strides=2
    )
    # x = SqueezeExcitation(48, 48, 0.25)(x)
    x = mobilevit_block(x, num_blocks=2, projection_dim=96)
    

    # Second MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
    )
    # x = SqueezeExcitation(64, 64, 0.25)(x)
    x = mobilevit_block(x, num_blocks=4, projection_dim=120)
    

    # Third MV2 -> MobileViT block.
    x = inverted_residual_block(
        x, expanded_channels=96 * expansion_factor, output_channels=96, strides=2
    )
    # x = SqueezeExcitation(80, 80, 0.25)(x)
    x = mobilevit_block(x, num_blocks=3, projection_dim=144)
    x = conv_block(x, filters=384, kernel_size=1, strides=1)

    # Classification head.
    x = layers.GlobalAvgPool2D()(x)
    outputs = layers.Dense(num_classes,activation='softmax', dtype='float32')(x)

    return keras.Model(inputs, outputs)


In [None]:
def get_encoder():
    mobile = create_mobilevit()

    dir_work = "gs://cuong_tpu/logs/mobilevit/"
    ckpt = tf.train.Checkpoint(model=mobile)
    ckpt.restore(tf.train.latest_checkpoint(dir_work+'ckpts'))
    # mobile = create_mobilevit(num_classes=1000)
    last_layer = mobile.layers[-2].output

    model = keras.Model(mobile.input, last_layer)
    return model


In [None]:
def get_projection_head(dims=[1280, 1024*2, 1024*5, 1024*5]):
    return keras.Sequential(
        [
            keras.Input(shape=(dims[0],)),
            keras.layers.Dense(dims[1]),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            keras.layers.Dense(dims[2]),
            keras.layers.BatchNormalization(),
            keras.layers.ReLU(),
            keras.layers.Dense(dims[3]),
        ],
        name="projection_head",
    )

In [None]:
def get_linear_probe(z_dim):
    return keras.Sequential(
        [
            keras.layers.Input(shape=(z_dim,)), 
            keras.layers.Dense(7)],
        name="linear_probe"
    )

In [None]:
class BarlowModel(keras.Model):
    def __init__(self):
        super().__init__()

        self.encoder = get_encoder()
        # self.counter = Counter()
        z_dim = 1280
        self.projection_head = get_projection_head([z_dim, 1024*2, 1024*5, 1024*5])
        self.linear_probe = get_linear_probe(z_dim)


    def compile(self, main_optimizer, probe_optimizer, **kwargs):
        super().compile(**kwargs)

        self.main_optimizer = main_optimizer
        self.probe_optimizer = probe_optimizer

        self.main_loss_tracker = keras.metrics.Mean(name="c_loss")
        self.main_loss_tracker_2 = keras.metrics.Mean()
        # self.probe_loss_tracker = keras.metrics.Mean(name="p_loss")
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy(name="p_acc")

    @property
    def metrics(self):
        return [
            self.main_loss_tracker,
            self.main_loss_tracker_2,
            # self.probe_loss_tracker,
            self.probe_accuracy,
        ]

    def train_step(self, batch):
        y_a, y_b = batch

        with tf.GradientTape() as tape:
            z_a_1, z_b_1 = self.encoder(y_a, training=True), self.encoder(y_b, training=True)
            z_a_2, z_b_2 = self.projection_head(z_a_1, training=True), self.projection_head(z_b_1, training=True)
            
            main_loss = compute_loss(z_a_2, z_b_2)

        gradients = tape.gradient(main_loss, 
                self.encoder.trainable_weights + self.projection_head.trainable_weights)
        
        self.main_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )

        self.main_loss_tracker.update_state(main_loss)
        
        self.main_loss_tracker_2.update_state(self.counter())

        return {"loss": self.main_loss_tracker.result(), "assa":self.main_loss_tracker_2.result()}
    
    def test_step(self, batch):
        imgs, labels = batch
        with tf.GradientTape() as tape:
            features = self.encoder(imgs, training=False)
            class_logits = self.linear_probe(features, training=True)
            probe_loss = self.probe_loss(labels, class_logits)
        gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        self.probe_optimizer.apply_gradients(
            zip(gradients, self.linear_probe.trainable_weights)
        )
        self.probe_accuracy.update_state(labels, class_logits)
        
        return {"acc": self.probe_accuracy.result()} 
    
    def call(self, x: tf.Tensor) -> tf.Tensor:
        print("CALL")



In [None]:
class MyModel(keras.Model):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_tracker = keras.metrics.Mean(name="loss")
        self.acc_metric = keras.metrics.CategoricalAccuracy(name="acc")
        self.pre_metric = tf.keras.metrics.Precision(name="pre")
        self.rec_metric = tf.keras.metrics.Recall(name="rec")

    @property
    def metrics(self):
        return [self.loss_tracker, self.acc_metric, self.pre_metric, self.rec_metric]

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y, sample_weight = data
        print(y.shape)

        # print(self.__dict__)

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.loss(y, y_pred, sample_weight=sample_weight)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)

        self.loss_tracker.update_state(loss)
        self.acc_metric.update_state(y, y_pred)
        self.pre_metric.update_state(y, y_pred)
        self.rec_metric.update_state(y, y_pred)
        
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y, sample_weight = data

        y_pred = self(x, training=False)  # Forward pass
        loss = self.loss(y, y_pred, sample_weight=sample_weight)
        
        self.loss_tracker.update_state(loss)
        self.acc_metric.update_state(y, y_pred)
        self.pre_metric.update_state(y, y_pred)
        self.rec_metric.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}


# Contrastive pretraining
with strategy.scope():
    pretraining_model = BarlowModel()
    
    pretraining_model.build(input_shape = (256,256,3))

    encoder_model = get_encoder()
    
    x = encoder_model.output
    x = keras.layers.Activation('relu')(x)
    x = keras.layers.Dropout(0.2)(x)
    # x = keras.layers.
    out = keras.layers.Dense(7)(x)

    model_finetuning = MyModel(encoder_model.input, out)

    model_finetuning.compile(
        loss=keras.losses.CategoricalCrossentropy(from_logits=True,reduction=tf.keras.losses.Reduction.NONE),
        optimizer=keras.optimizers.Adam(),
    )

In [None]:
from tensorflow.keras import backend as K

trainable_count = np.sum(
    [K.count_params(w) for w in model_finetuning.trainable_weights]
)
non_trainable_count = np.sum(
    [K.count_params(w) for w in model_finetuning.non_trainable_weights]
)
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

In [None]:
cb_checkpoint = tf.keras.callbacks.ModelCheckpoint("model.h5", monitor= "val_acc",verbose = 0, save_freq="epoch",
)

history = model_finetuning.fit(
    train_dataset_encode, epochs=50,
    steps_per_epoch=train_step,
    validation_data=test_dataset_encode, 
    validation_steps=val_step,
    # class_weight=class_weights, 
    callbacks=[cb_checkpoint],
)
model_finetuning.load_weights("/content/model.h5")
history = model_finetuning.evaluate(test_dataset_encode, batch_size=128, steps=val_step)
print(f"Acc: {history[1]}, F1: {2*(history[2]*history[3])/(history[2]+history[3])}")