In [None]:
import pandas as pd
import numpy as np
import os, sys
import random
import pydicom
import sklearn
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.preprocessing import StandardScaler
import re
import matplotlib.pyplot as plt
import matplotlib.cm as cm
seed = 42

import warnings
warnings.filterwarnings("ignore")

import tensorflow as tf
import keras# ; keras.config.set_dtype_policy("mixed_float16")
from keras import layers, Model
from keras import ops, layers, models, losses, optimizers, metrics
import keras_hub
import keras_cv
import keras_nlp

import cv2
from skimage.io import imread
keras.utils.set_random_seed(seed)
import tensorflow_io as tfio
from kaggle_datasets import KaggleDatasets
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

print(f"Tensorflow version : {tf.__version__}")
try:
    print(f"Keras version : {keras.__version__}")
except:
    pass

from keras import Input, Model, ops
from keras.models import load_model

from keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, BatchNormalization, LayerNormalization, MultiHeadAttention, Embedding, Subtract, Add, Multiply, GlobalAveragePooling2D, GlobalAveragePooling1D, LayerNormalization
from keras.utils import load_img, img_to_array
from keras.applications import *
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.model_selection import train_test_split
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from tqdm.notebook import tqdm
import wandb
def wandb_config():
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    try:
        secret_value_0 = user_secrets.get_secret("__gcloud_sdk_auth__")
        secret_value_1 = user_secrets.get_secret("huggingface_key")
        secret_value_2 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_2
    except:
        secret_value_0 = user_secrets.get_secret("huggingface_key")
        secret_value_1 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_1
    
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        tpu = False
        strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return tpu, strategy

tpu, strategy = auto_select_accelerator()

import PIL
from PIL import Image as PILImage
import matplotlib as mpl
import matplotlib.pyplot as plt

import pprint
from pprint import pprint as pp

res = 384
batch_size = 16

In [None]:
def _parse_tfrecord(res = res):
    def parse_tfrecord(tfrecord):
        features = {'image': tf.io.FixedLenFeature([], tf.string),
                    'label': tf.io.FixedLenFeature([], tf.int64),
                    }
        x = tf.io.parse_single_example(tfrecord, features)
        image_train = tf.image.decode_jpeg(x['image'], channels=1)
        image_train = _transform_images(res = res)(image_train)
        label = tf.cast(x["label"], tf.int32)
        return (image_train, label)
    
    return parse_tfrecord


def _transform_images(res = res):
    def transform_images(x_train):
        x_train = tf.image.resize_with_pad(x_train, res, res, antialias = True)
        x_train = tf.cast(x_train, tf.uint8)
        return x_train
    return transform_images

def load_tfrecord_dataset(tfrecord_name, res = res, batch_size = batch_size, shuffle=True, buffer_size=10240):
    """load dataset from tfrecord"""
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name, compression_type = "GZIP")
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.batch(batch_size, drop_remainder = True)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

train_radimagenet_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RagImageNet_Train_GZIP.tfrecord")
val_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RagImageNet_Test_GZIP.tfrecord")

In [None]:
df_train = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RadImgNet_train.csv")
plt.hist(df_train["label"], bins = range(165))
plt.title("Training dataset label-wise distribution")

pp("+="*50)
pp(f"Total Training case : {len(df_train)}")
pp("                                     LABELS")
df_label = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RadImgNet_label_encoding.csv")
pp(df_label)
pp("+="*50)

In [None]:
for img, lab in train_radimagenet_ds.take(1):
    imgs = img
    labs = lab
fig, axes = plt.subplots(4,4, figsize = (15,15))
axes = axes.flatten()
for idx, ax in enumerate(axes):
    ax.imshow(imgs[idx], cmap = "bone")
    lab_ = int(labs[idx])
    name = df_label.loc[df_label.index == lab_, 'name'].values[0]
    ax.set_title(f"{lab_} : {name}")
plt.show()

In [None]:
def get_vit(att_heads = 8, att_depth = 6, embed_dims = 512):
    inputs = Input((res,res,1), name = "RadImgNetInput")
    patches = Conv2D(filters = embed_dims, padding = 'SAME', kernel_size = 16, strides = 16, name = "PatchingConv", activation = "gelu")(inputs)
    _, w, h, d_ = keras.ops.shape(patches)
    patches = ops.reshape(patches, [-1, w*h, embed_dims]) ; seq_len = w*h
    # positional encoding with RoPE
    patches = keras_hub.layers.RotaryEmbedding(name = "RoPE")(patches)
    cls_token = keras.layers.GlobalAveragePooling1D()(patches)
    patches = tf.concat([cls_token, patches],
                       axis = 1)
    att_weights = {}
    for idx in range(att_depth):
        x0 = LayerNormalization(name = f'preLN{idx}')(patches)
        x1, att_score = MultiHeadAttention(att_heads, embed_dims//att_heads, name = f"MHA{idx}")(query = x0, key = x0, value = x0,
                                                                                     return_attention_scores = True)
        att_weights[idx] = att_score
        x2 = x1 + patches
        x3 = LayerNormalization(name = f'postLN{idx}')(x2)
        x4 = Dense(units = embed_dims, name = f"MLP{idx}", activation = "gelu")(x3)
        patches = x2 + x4
        patches = keras.layers.Identity(name = f"EncodedPatches_{idx}")(patches)
    feature_vector = patches[:, 0, :]
    #classifier = Dense(units = 165, activation = 'softmax', name = "label_classifier")(feature_vector)
    model = Model(inputs, [feature_vector, att_weights],
                 name = f"SimpleViT_depth{att_depth}_heads{att_heads}_dims{embed_dims}")
    return model


# 실험계획
- 비교군 : original github의 result 및 ViT실험
- Supervised contrastive learning
- N회 SSL 후 1회 classifier까지 learning을 반복 --> called 1 "set"

In [None]:
# dino trainer
resize_fn = keras.layers.Resizing(res,res)
def get_two_views(images):
    global_view_1 = keras.layers.RandomCrop(256,256)(images)
    global_veiw_2 = keras.layers.RandomCrop(180,180)(images)
    global_view_1, global_view_2 = resize_fn(global_view_1), resize_fn(global_view_2)
    return (global_view_1, global_view_2)



class MiniDINO(Model):
    """
    Miniature DINO model using a pre-existing Vision Transformer backbone.
    This implementation follows the principles from the DINO and DINOv2 papers.
    """
    def __init__(
        self,
        student_backbone,
        teacher_backbone,
        projection_dim=1024,
        latent_dim=512,
        teacher_momentum=0.996,
        center_momentum=0.9,
        student_temp=0.1,
        teacher_temp=0.04,
        **kwargs
    ):
        """
        Args:
            student_backbone (Model): The student ViT model.
            teacher_backbone (Model): The teacher ViT model (should have the same architecture).
            projection_dim (int): The dimension of the projection head's hidden layer.
            latent_dim (int): The output dimension of the projection head (and input from backbone).
            teacher_momentum (float): The momentum for the teacher network update.
            center_momentum (float): The momentum for the center update.
            student_temp (float): The temperature for the student's softmax.
            teacher_temp (float): The temperature for the teacher's softmax (sharpening).
        """
        super().__init__(**kwargs)
        self.student_backbone = student_backbone
        self.teacher_backbone = teacher_backbone
        self.teacher_momentum = teacher_momentum
        self.center_momentum = center_momentum
        self.student_temp = student_temp
        self.teacher_temp = teacher_temp
        self.latent_dim = latent_dim
        self.resize_fn = keras.layers.Resizing(res,res)
        self.rc1 = keras.layers.RandomCrop(256,256)
        self.rc2 = keras.layers.RandomCrop(180,180)
        # DINO Projection Head
        self.student_projector = self._build_projector(latent_dim, projection_dim)
        self.teacher_projector = self._build_projector(latent_dim, projection_dim)

        # Initialize teacher weights with student weights
        self.teacher_backbone.set_weights(self.student_backbone.get_weights())
        self.teacher_projector.set_weights(self.student_projector.get_weights())

        # Center vector for teacher output centering [14]
        self.center = self.add_weight(
            shape=(1, latent_dim), initializer="zeros", trainable=False, name="center"
        )
        
        # Trackers for metrics
        self.loss_tracker = keras.metrics.Mean(name="dino_loss")
    def _get_two_views(self, images):
        global_view_1 = self.rc1(images)
        global_view_2 = self.rc2(images)
        global_view_1, global_view_2 = self.resize_fn(global_view_1), self.resize_fn(global_view_2)
        return (global_view_1, global_view_2)
    def _build_projector(self, latent_dim, projection_dim):
        """Builds the MLP projection head as described in DINO."""
        # A simpler 2-layer MLP for this miniature version.
        # Original DINO uses a 3-layer MLP.
        return keras.Sequential(
            [
                layers.Input(shape=(latent_dim,)),
                layers.Dense(projection_dim, activation="gelu"),
                layers.Dense(latent_dim),
            ],
            name="projector",
        )

    def _update_teacher(self):
        """Update teacher network weights using EMA of student weights."""
        for student_w, teacher_w in zip(self.student_backbone.weights, self.teacher_backbone.weights):
            teacher_w.assign(self.teacher_momentum * teacher_w + (1 - self.teacher_momentum) * student_w)
        for student_w, teacher_w in zip(self.student_projector.weights, self.teacher_projector.weights):
            teacher_w.assign(self.teacher_momentum * teacher_w + (1 - self.teacher_momentum) * student_w)

    @tf.function
    def _update_center(self, teacher_output):
        """Update the center vector using EMA of teacher outputs."""
        batch_center = tf.reduce_mean(teacher_output, axis=0, keepdims=True)
        self.center.assign(self.center_momentum * self.center + (1 - self.center_momentum) * batch_center)

    def compile(self, optimizer, **kwargs):
        super().compile(**kwargs)
        self.optimizer = optimizer

    def train_step(self, data):
        # The data should be a tuple of two lists of augmented views: (global_crops, local_crops)
        # For simplicity, this example assumes two global crops.
        images, label = data
        view1, view2 = self._get_two_views(images)

        with tf.GradientTape() as tape:
            # === Teacher Forward Pass (no gradients) ===
            teacher_repr1 = self.teacher_backbone(view1, training=False)
            teacher_repr2 = self.teacher_backbone(view2, training=False)
            
            teacher_proj1 = self.teacher_projector(teacher_repr1, training=False)
            teacher_proj2 = self.teacher_projector(teacher_repr2, training=False)

            # Center and sharpen teacher outputs [14]
            teacher_out1 = tf.nn.softmax((teacher_proj1 - self.center) / self.teacher_temp, axis=-1)
            teacher_out2 = tf.nn.softmax((teacher_proj2 - self.center) / self.teacher_temp, axis=-1)

            # === Student Forward Pass ===
            student_repr1 = self.student_backbone(view1, training=True)
            student_repr2 = self.student_backbone(view2, training=True)

            student_proj1 = self.student_projector(student_repr1, training=True)
            student_proj2 = self.student_projector(student_repr2, training=True)

            student_out1 = tf.nn.log_softmax(student_proj1 / self.student_temp, axis=-1)
            student_out2 = tf.nn.log_softmax(student_proj2 / self.student_temp, axis=-1)

            # === Compute DINO Loss (Cross-Entropy) ===
            # The student predicts the teacher's output for a different view.
            loss1 = -tf.reduce_mean(tf.reduce_sum(teacher_out2 * student_out1, axis=-1))
            loss2 = -tf.reduce_mean(tf.reduce_sum(teacher_out1 * student_out2, axis=-1))
            total_loss = (loss1 + loss2) / 2

        # === Gradient Descent ===
        trainable_vars = self.student_backbone.trainable_variables + self.student_projector.trainable_variables
        grads = tape.gradient(total_loss, trainable_vars)
        self.optimizer.apply_gradients(zip(grads, trainable_vars))

        # === EMA Updates ===
        self._update_teacher()
        self._update_center(tf.concat([teacher_proj1, teacher_proj2], axis=0))

        self.loss_tracker.update_state(total_loss)
        return {"loss": self.loss_tracker.result()}
    
    def call(self, inputs):
        # For inference, only the student backbone is used.
        return self.student_backbone(inputs)

student_vit = get_vit()
teacher_vit = get_vit()

dino_model = MiniDINO(
    student_backbone=student_vit,
    teacher_backbone=teacher_vit,
    #latent_dim=EMBED_DIM
)

dino_model.compile(optimizer=keras.optimizers.SGD(learning_rate=1e-4))
#dino_model.fit(train_radimagenet_ds, epochs = 1, steps_per_epoch = 500)

# DINO with supervision

In [None]:
class MiniSupDINO(Model):
    def __init__(
        self,
        student_backbone,
        teacher_backbone,
        mode = 'supervised', #supervised or dino
        projection_dim=1024,
        latent_dim=512,
        teacher_momentum=0.996,
        center_momentum=0.9,
        student_temp=0.1,
        teacher_temp=0.04,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.mode = mode
        self.student_backbone = student_backbone
        self.teacher_backbone = teacher_backbone
        self.teacher_momentum = teacher_momentum
        self.center_momentum = center_momentum
        self.student_temp = student_temp
        self.teacher_temp = teacher_temp
        self.latent_dim = latent_dim
        self.rc = keras.layers.RandomCrop(int(0.75*res),int(0.75*res))
        # DINO Projection Head
        self.student_projector = self._build_projector(latent_dim, projection_dim)
        self.teacher_projector = self._build_projector(latent_dim, projection_dim)
        # Supervision (Classifier) Head
        self.student_classifier = keras.layers.Dense(units = 165, name = "Student_RadImgNetClassifier")
        self.teacher_classifier = keras.layers.Dense(units = 165, name = "Teacher_RadImgNetClassifier")
        # Initialize teacher weights with student weights
        self.teacher_backbone.set_weights(self.student_backbone.get_weights())
        self.teacher_projector.set_weights(self.student_projector.get_weights())
        self.teacher_classifier.set_weights(self.student_classifier.get_weights())

        # Center vector for teacher output centering [14]
        self.center = self.add_weight(
            shape=(1, latent_dim), initializer="zeros", trainable=False, name="center"
        )
        self.center_cls = self.add_weight(
            shape=(1, 165), initializer="zeros", trainable=False, name="center_class_proba"
        )
        
        # Trackers for metrics
        self.loss_tracker = keras.metrics.Mean(name="Total_loss")
        self.dino_loss_tracker = keras.metrics.Mean(name="DINO_loss")
        self.student_cls_loss_tracker = keras.metrics.Mean(name="Student_class_loss")
        self.cls_distil_loss_tracker = keras.metrics.Mean(name="Class_distil_loss")

        self.student_acc_tracker = keras.metrics.Mean(name = "Student_class_Accuracy")
        self.teacher_acc_tracker = keras.metrics.Mean(name = "Teacher_class_Accuracy")
        self.compute_acc = keras.metrics.SparseCategoricalAccuracy()

    def _get_two_views(self, images):
        global_view_1 = images
        global_view_2 = self.rc(images)
        global_view_2 = keras.layers.Resizing(res,res)(global_view_2)
        return (global_view_1, global_view_2)
    def _build_projector(self, latent_dim, projection_dim):
        """Builds the MLP projection head as described in DINO."""
        # A simpler 2-layer MLP for this miniature version.
        # Original DINO uses a 3-layer MLP.
        return keras.Sequential(
            [
                layers.Input(shape=(latent_dim,)),
                layers.Dense(projection_dim, activation="gelu"),
                layers.Dense(latent_dim),
            ],
            name="projector",
        )

    def _update_teacher(self):
        """Update teacher network weights using EMA of student weights."""
        for student_w, teacher_w in zip(self.student_backbone.weights, self.teacher_backbone.weights):
            teacher_w.assign(self.teacher_momentum * teacher_w + (1 - self.teacher_momentum) * student_w)
        for student_w, teacher_w in zip(self.student_projector.weights, self.teacher_projector.weights):
            teacher_w.assign(self.teacher_momentum * teacher_w + (1 - self.teacher_momentum) * student_w)
        if (self.mode == "sup") or (self.mode == "supervised"):
            for student_w, teacher_w in zip(self.student_classifier.weights, self.teacher_classifier.weights):
                teacher_w.assign(self.teacher_momentum * teacher_w + (1 - self.teacher_momentum) * student_w)
    @tf.function
    def _update_center(self, teacher_output, teacher_proba_output = None):
        """Update the center vector using EMA of teacher outputs."""
        batch_center = tf.reduce_mean(teacher_output, axis=0, keepdims=True)
        self.center.assign(self.center_momentum * self.center + (1 - self.center_momentum) * batch_center)
        if (self.mode == "sup") or (self.mode == "supervised"):
            batch_center_proba = tf.reduce_mean(teacher_proba_output, axis = 0, keepdims = True)
            self.center_cls.assign(self.center_momentum * self.center_cls + (1 - self.center_momentum)*batch_center_proba)

    def compile(self, optimizer, **kwargs):
        super().compile(**kwargs)
        self.optimizer = optimizer

    def train_step(self, data):
        # The data should be a tuple of two lists of augmented views: (global_crops, local_crops)
        # For simplicity, this example assumes two global crops.
        images, label = data
        view1, view2 = self._get_two_views(images)

        with tf.GradientTape() as tape:
            # === Teacher Forward Pass (no gradients) ===
            teacher_repr1,_ = self.teacher_backbone(view1, training=False)
            teacher_repr2,_ = self.teacher_backbone(view2, training=False)
            
            teacher_proj1 = self.teacher_projector(teacher_repr1, training=False)
            teacher_proj2 = self.teacher_projector(teacher_repr2, training=False)

            # Center and sharpen teacher outputs [14]
            teacher_out1 = tf.nn.softmax((teacher_proj1 - self.center) / self.teacher_temp, axis=-1)
            teacher_out2 = tf.nn.softmax((teacher_proj2 - self.center) / self.teacher_temp, axis=-1)

            # === Student Forward Pass ===
            student_repr1,_ = self.student_backbone(view1, training=True)
            student_repr2,_ = self.student_backbone(view2, training=True)

            student_proj1 = self.student_projector(student_repr1, training=True)
            student_proj2 = self.student_projector(student_repr2, training=True)

            student_out1 = tf.nn.log_softmax(student_proj1 / self.student_temp, axis=-1)
            student_out2 = tf.nn.log_softmax(student_proj2 / self.student_temp, axis=-1)

            # === Compute DINO Loss (Cross-Entropy) ===
            # The student predicts the teacher's output for a different view.
            loss1 = -tf.reduce_mean(tf.reduce_sum(teacher_out2 * student_out1, axis=-1))
            loss2 = -tf.reduce_mean(tf.reduce_sum(teacher_out1 * student_out2, axis=-1))
            dino_loss = (loss1 + loss2) / 2
            if (self.mode == "sup") or (self.mode == "supervised"):
                # calculate student class proba
                cls_s_1 = self.student_classifier(student_repr1, training = True)
                cls_s_2 = self.student_classifier(student_repr2, training = True)
                cls_t_1 = self.teacher_classifier(teacher_repr1, training = False)
                cls_t_2 = self.teacher_classifier(teacher_repr2, training = False)
                # supervised loss
                student_cls_loss = 0.5*(keras.losses.SparseCategoricalCrossentropy(from_logits = True, reduction = None)(y_true = label, y_pred = cls_s_1) + keras.losses.SparseCategoricalCrossentropy(from_logits = True, reduction = None)(y_true = label, y_pred = cls_s_2))
                student_accuracy = 0.5*(self.compute_acc(y_true = label, y_pred = cls_s_1) + self.compute_acc(y_true = label, y_pred = cls_s_2))
                student_cls_loss = keras.ops.mean(student_cls_loss)
                
                teacher_accuracy = 0.5*(self.compute_acc(y_true = label, y_pred = cls_t_1) + self.compute_acc(y_true = label, y_pred = cls_t_2))
                
                # classifier dino loss
                cls_t_1 = tf.nn.softmax((cls_t_1 - self.center_cls) / self.teacher_temp, axis=-1)
                cls_t_2 = tf.nn.softmax((cls_t_2 - self.center_cls) / self.teacher_temp, axis=-1)

                cls_s_1 = tf.nn.log_softmax(cls_s_1 / self.student_temp, axis=-1)
                cls_s_2 = tf.nn.log_softmax(cls_s_2 / self.student_temp, axis=-1)
                cls_distil_loss_1 = -tf.reduce_mean(tf.reduce_sum(cls_t_2 * cls_s_1, axis=-1))
                cls_distil_loss_2 = -tf.reduce_mean(tf.reduce_sum(cls_t_1 * cls_s_2, axis=-1))
                cls_distil_loss = (cls_distil_loss_1 + cls_distil_loss_2) / 2

                total_loss = dino_loss + student_cls_loss + cls_distil_loss
            else:
                total_loss = dino_loss
                student_cls_loss = 0.0
                cls_distil_loss = 0.0
                student_accuracy = 0.0
                teacher_accuracy = 0.0
                cls_t_1, cls_t_2 = [0.0], [0.0]
            
        # === Gradient Descent ===
        trainable_vars = self.student_backbone.trainable_variables + self.student_projector.trainable_variables
        if (self.mode == "sup") or (self.mode == "supervised"):
            trainable_vars += self.student_classifier.trainable_variables
        grads = tape.gradient(total_loss, trainable_vars)
        self.optimizer.apply_gradients(zip(grads, trainable_vars))

        # === EMA Updates ===
        self._update_teacher()
        self._update_center(tf.concat([teacher_proj1, teacher_proj2], axis=0),
                           tf.concat([cls_t_1, cls_t_2], axis = 0)
                           )

        self.loss_tracker.update_state(total_loss)
        self.dino_loss_tracker.update_state(dino_loss)
        self.student_cls_loss_tracker.update_state(student_cls_loss)
        self.cls_distil_loss_tracker.update_state(cls_distil_loss)

        self.student_acc_tracker.update_state(student_accuracy)
        self.teacher_acc_tracker.update_state(teacher_accuracy)
        
        return {"total_loss": self.loss_tracker.result(),
               "dino_loss" : self.dino_loss_tracker.result(),
               "student_classification_loss" : self.student_cls_loss_tracker.result(),
               "CLS_distil_loss" : self.cls_distil_loss_tracker.result(),
               
               "Student_Classification_Accuracy" : self.student_acc_tracker.result(),
               "Teacher_Classification_Accuracy" : self.teacher_acc_tracker.result()
               }
    
    def call(self, inputs):
        # For inference, only the student backbone is used.
        return self.student_backbone(inputs)
    def get_teacher_model(self):
        teacher_input = self.teacher_backbone.input
        teacher_output, attention_weights = self.teacher_backbone.output
        teacher_proba = self.teacher_classifier(teacher_output)
        whole_teacher_model = Model(teacher_input, [teacher_proba, attention_weights],
                                   name = f'{self.teacher_backbone.name}_DINO_Teacher')
        result = {'feature_extractor' : self.teacher_backbone,
                 "classifier" : self.teacher_classifier,
                 "projector" : self.teacher_projector,
                 "whole_model" : whole_teacher_model}
        return result

In [None]:
with strategy.scope():
    student_vit = get_vit(embed_dims = 512)
    teacher_vit = get_vit(embed_dims = 512)
    
    dino_model = MiniSupDINO(
        student_backbone=student_vit,
        teacher_backbone=teacher_vit,
        latent_dim=512,
        mode = "sup"
    )
    dino_model.compile(optimizer=keras.optimizers.AdamW(learning_rate=1e-4))
    dino_model.fit(train_radimagenet_ds, epochs = 1, steps_per_epoch = 500)