In [None]:
import pandas as pd
import numpy as np
import os, sys
import random
import pydicom
import sklearn
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.preprocessing import StandardScaler
try:
    import umap
except:
    !pip install umap
    import umap
    
import re
import matplotlib.pyplot as plt
import matplotlib.cm as cm
seed = 42

import warnings
warnings.filterwarnings("ignore")

# ML tools 
sys.path.append("/kaggle/input/kimm-keras-image-model-repository"
               )

import tensorflow as tf
import keras# ; keras.config.set_dtype_policy("mixed_float16")
from keras import ops, layers, models, losses, optimizers, metrics
import keras_hub
import kimm
import keras_cv
import keras_nlp

import cv2
from skimage.io import imread
keras.utils.set_random_seed(seed)
import tensorflow_io as tfio
from kaggle_datasets import KaggleDatasets
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import tensorflow_decision_forests as tfdf

print(f"Tensorflow version : {tf.__version__}")
try:
    print(f"Keras version : {keras.__version__}")
except:
    pass

from keras import Input, Model, ops
from keras.models import load_model

from keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, BatchNormalization, LayerNormalization, MultiHeadAttention, Embedding, Subtract, Add, Multiply, GlobalAveragePooling2D, GlobalAveragePooling1D, LayerNormalization
from keras.utils import load_img, img_to_array
from keras.applications import *
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.model_selection import train_test_split
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from tqdm.notebook import tqdm
import wandb
#from wandb.keras import WandbCallback, WandbModelCheckpoint, WandbMetricsLogger
def wandb_config():
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    try:
        secret_value_0 = user_secrets.get_secret("__gcloud_sdk_auth__")
        secret_value_1 = user_secrets.get_secret("huggingface_key")
        secret_value_2 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_2
    except:
        secret_value_0 = user_secrets.get_secret("huggingface_key")
        secret_value_1 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_1
    
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        tpu = False
        strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return tpu, strategy

tpu, strategy = auto_select_accelerator()
import ssl_module
from ssl_module import feature_visualize, get_masking_fn, get_map_fn, get_gcvit_configs, get_flops, att_visualize, get_full_model, BarlowModel, VICRegModel, Moco, SimSiam, CLIP, SigLIP
import nas_ftp_module
from nas_ftp_module import upload_file, download_file
import PIL
from PIL import Image as PILImage
import matplotlib as mpl
import matplotlib.pyplot as plt

Collecting umap
  Downloading umap-0.1.1.tar.gz (3.2 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: umap
  Building wheel for umap (setup.py) ... [?25l[?25hdone
  Created wheel for umap: filename=umap-0.1.1-py3-none-any.whl size=3542 sha256=815dc3c9dbfe62b224520e0c7a576c9650d3f4af55c0816f05ea6dbb3655d1b3
  Stored in directory: /root/.cache/pip/wheels/15/f1/28/53dcf7a309118ed35d810a5f9cb995217800f3f269ab5771cb
Successfully built umap
Installing collected packages: umap
Successfully installed umap-0.1.1


# 실험 계획
- token mixer : gMLP vs gaMLP vs Attention
- ConvNeXt vs pure-metaformer
     - if convnext, FE 후 token mixer의 갯수에 따른 변화
     - if convnext, ImageNet weight vs randomly initialized

# HybridViT

 - from [here](https://www.kaggle.com/code/khs224025/hybridvit)

In [None]:
class TransformerEncoderLayer(layers.Layer):
    def __init__(self, embed_dims, num_heads, ff_dim, return_attention_scores, 
                 orthogonal_factor = 0.01,
                 **kwargs):
        super().__init__(**kwargs)
        self.attn = keras.layers.MultiHeadAttention(num_heads, embed_dims//num_heads,
                                             kernel_regularizer=keras.regularizers.OrthogonalRegularizer(factor = orthogonal_factor),
                                             bias_regularizer=keras.regularizers.OrthogonalRegularizer(factor = orthogonal_factor)
                                               )
        self.ffn = models.Sequential([
            layers.Dense(ff_dim, use_bias = False), 
            layers.Dense(embed_dims, use_bias = False)
        ])
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.return_attention_scores = return_attention_scores
        
    def call(self, inputs):
        if self.return_attention_scores :
            attn_out, attn_weights = self.attn(inputs, inputs, return_attention_scores=self.return_attention_scores)
        else:
            attn_out = self.attn(inputs, inputs, return_attention_scores=self.return_attention_scores)
        x = self.norm1(inputs + attn_out)
        ffn_out = self.ffn(x)
        output = self.norm2(x + ffn_out)
        if self.return_attention_scores:
            return self.norm2(x+ffn_out), attn_weights
        else:
            return self.norm2(x+ffn_out)
            
class RotaryEmbedding2D(keras.layers.Layer):
    def __init__(
        self,
        max_wavelength=10000,
        scaling_factor=1.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.max_wavelength = max_wavelength
        self.scaling_factor = scaling_factor
        # 가로, 세로 방향에 대한 각각의 RotaryEmbedding 레이어 생성
        self.horizontal_rope = keras_hub.layers.RotaryEmbedding(
            max_wavelength=max_wavelength,
            scaling_factor=scaling_factor,
            sequence_axis=2,  # 가로 방향 축
            feature_axis=-1
        )
        self.vertical_rope = keras_hub.layers.RotaryEmbedding(
            max_wavelength=max_wavelength,
            scaling_factor=scaling_factor,
            sequence_axis=1,  # 세로 방향 축
            feature_axis=-1
        )
        
    def call(self, inputs):
        # 입력 형태: [batch_size, height, width, channels]
        
        # 특성 차원을 두 부분으로 분할
        channels = ops.shape(inputs)[-1]
        half_channels = channels // 2
        
        # 특성 차원 분할
        x_h, x_v = ops.split(inputs, 2, axis=-1)
        
        # 수평 방향(width) rotary embedding 적용
        x_h = self.horizontal_rope(x_h)
        
        # 수직 방향(height) rotary embedding 적용
        x_v = self.vertical_rope(x_v)
        
        # 결과 합치기
        x = tf.concat([x_h, x_v], axis=-1)
        return x
        
class HybridViT(models.Model):
    def __init__(self, image_size, patch_size=16, num_layers=12,num_heads=12, embed_dims=768, 
                 proj_dims = 128, split_layer=4, n_register = 0,
                 q_size = 2**15, t = 0.07,
                 **kwargs):
        super().__init__(**kwargs)
        self.n_reg = n_register
        self.image_size = image_size
        self.res = image_size #for convenient
        self.heads = num_heads
        self.patch_size = patch_size
        self.embed_dims = embed_dims
        self.split_layer = split_layer

        # 1. Learnable CLS token and positional embedding
        self.cls_token = self.add_weight(
            shape=(1, 1, embed_dims),
            initializer=keras.initializers.RandomNormal(),
            name="cls_token"
        )
        if self.n_reg > 0:
            self.reg_tokens = self.add_weight(
                shape=(1, self.n_reg, embed_dims),
                initializer=keras.initializers.RandomNormal(),
                name="cls_token"
            )
        
        # 2. Patch embedding and RoPE
        self.patch_embed = layers.Conv2D(embed_dims//2, patch_size, strides=patch_size, name = "PatchConvolution")
        self.middle_conv = layers.Conv2D(embed_dims, 5, padding = 'same', name = "mid_sep_conv2d")
        self.rope = RotaryEmbedding2D(name = "RoPE_2Dims")

        # 3. Transformer layers
        self.early_layers = keras.Sequential([TransformerEncoderLayer(embed_dims, num_heads, 2*embed_dims, False, orthogonal_factor = 0.01)
                            for _ in range(split_layer)],
                                            name = "EarlyTREncoder")
        self.late_layers = keras.Sequential([keras.layers.MultiHeadAttention(num_heads, embed_dims//num_heads,
                                             kernel_regularizer=keras.regularizers.OrthogonalRegularizer(factor = 0.01),
                                             bias_regularizer=keras.regularizers.OrthogonalRegularizer(factor = 0.01))
                           for _ in range(num_layers - split_layer)], name = "LateTREncoder"
                                           )
        # 4. Projection heads
        self.mim_head = layers.Dense(patch_size**2, activation = "sigmoid", name = "MIM_Regressor")
        self.nnclr_proj = layers.Dense(proj_dims, name = "NCLR_Projector")

        # 5. Feature Q for NNCLR/SNCLR
        self.q_size = q_size
        self.feature_q = keras.Variable(
            keras.utils.normalize(
                keras.random.normal(shape=(self.q_size, proj_dims)), #q_size, embed_dims shape matrix : FIFO로 update해야 함
                axis=1,
                order=2,
            ),
            trainable=False, dtype = "float32"
        )
        self.t = t
        # 6. Image augmenter for Contrastive learning
        # Loss : CL + MIM
        # CL : SNCLR; [original vs whole with downsampling] // + DenseCL [DenseCLView 1 vs DenseCLView 2] <- learn the whole feature
        # MIM : Sim MIM [Original image vs masked image] <- learn the high frequency feature (partial view learning)
        
        self.whole_view_augmenter = keras.Sequential([
                            keras.layers.RandomFlip(),
                            keras.layers.RandomRotation(0.2),
                            keras.layers.RandomCrop(height = int(0.8 * image_size), width = int(0.8 * image_size)),
                            keras.layers.Resizing(288,288)
                        ],
                                         name = "WholeViewDownsampler")

        self.possible_sizes = [(192, 192), (192, 384), (384, 192), 
                              (384, 256),(256,384),
                               (192, 256),(256,192), (256,256)]
        
    #### FOR SNCLR ###
    def get_neighbor(self, projections):  #for NNCLR/SNCLR ; finding positive match
        p = projections
        cos_sim = ops.matmul(p, ops.transpose(self.feature_q)
                            ) # batch, q_size
        w = ops.softmax(cos_sim/self.t) # batch, q size
        support_f = ops.matmul(w, self.feature_q) ; del w #batch, embed_dims
        return p + ops.stop_gradient(support_f - p)

    def compute_nclr_loss(self, projections, training = True): #SimCLR-like loss shape
        # NCLR
        f_original, f_aug = projections[0], projections[1]
        f_original, f_aug = ops.normalize(f_original), ops.normalize(f_aug)
        # 각각 batch, embed_dims shape tensor
        f_original_n, f_aug_n = self.get_neighbor(f_original), self.get_neighbor(f_aug)
        sim_matrix_1_1 = ops.matmul(f_original, ops.transpose(f_aug_n)) / self.t
        sim_matrix_1_2 = ops.matmul(f_original_n, ops.transpose(f_aug)) / self.t
        batch_size = ops.shape(f_original)[0]
        pseudo_label = ops.arange(ops.shape(f_original)[0])
        
        loss_1 = (keras.losses.sparse_categorical_crossentropy(pseudo_label, sim_matrix_1_1, from_logits = True) + 
                 keras.losses.sparse_categorical_crossentropy(pseudo_label, ops.transpose(sim_matrix_1_1), from_logits = True) + 
                 keras.losses.sparse_categorical_crossentropy(pseudo_label, sim_matrix_1_2, from_logits = True) + 
                 keras.losses.sparse_categorical_crossentropy(pseudo_label, ops.transpose(sim_matrix_1_2), from_logits = True))
        
        if training:
            self.feature_q.assign(
                ops.concatenate([f_original, 
                                 self.feature_q[:-batch_size ]
                                ], axis=0)
            )
            
        nclr_loss = ops.mean(loss_1)
        return nclr_loss
    ### for DenseCL ###
    def compute_densecl_loss(self, f1, f2):
        #f1, f2 : feature maps of 2 different views
        f1, f2 = ops.normalize(f1), ops.normalize(f2)
        sim = ops.einsum("bld, bnd -> bln", f1, f2)
        sim = ops.exp(sim)
        max_val1, max_val2 = ops.max(sim, axis = -1), ops.max(sim, axis = 1)
        sum_val1, sum_val2 = ops.sum(sim, axis = -1), ops.sum(sim, axis = 1)
        loss1, loss2 = -ops.log(1e-5 + max_val1/sum_val1), -ops.log(max_val2/sum_val2)
        return ops.mean(loss1)+ops.mean(loss2)
    ### For MIM ###

    def _dynamic_masking(self, images, base_ratio=0.7):
        #중요한 영역을 우선적으로 마스킹하는 함수 (Masked Image Modeling용)
        patch_size = self.patch_size
        images = ops.cast(images, "float32")
        batch_size = ops.shape(images)[0]
        image_size = ops.shape(images)[1]  # 정사각형 이미지 가정
        P = patch_size
        c = ops.shape(images)[-1]
        num_patches_h = image_size // P
        num_patches_w = image_size // P
        
        # 1. 지역적 콘트라스트 계산
        if c == 3:  # RGB 이미지인 경우
            # RGB → 그레이스케일 변환 (가중치 적용)
            gray = ops.sum(images * ops.convert_to_tensor([[[[0.299, 0.587, 0.114]]]]), axis=-1, keepdims=True)
        else:  # 이미 그레이스케일이거나 단일 채널
            gray = images
        
        # Sobel 필터 적용 (TF 사용)
        sobel = tf.image.sobel_edges(gray)
        sobel_magnitude = ops.sqrt(ops.sum(ops.square(ops.convert_to_tensor(sobel)), axis=-1))
        
        # 패치 단위로 콘트라스트 맵 다운샘플링
        contrast_patches = ops.image.extract_patches(
            images=sobel_magnitude,
            size=(P, P),
            strides=(P, P),
            dilation_rate=1,
            padding="valid"
        )
        
        # 패치별 평균 콘트라스트 계산
        contrast_map = ops.mean(contrast_patches, axis=-1)
        contrast_map = ops.reshape(contrast_map, [batch_size, num_patches_h, num_patches_w, 1])
        
        # 2. 중요도 기반 마스크 비율 생성 (높은 콘트라스트 = 중요한 영역)
        # 콘트라스트 맵 정규화 (0-1 범위로)
        contrast_min = ops.min(contrast_map, axis=[1, 2, 3], keepdims=True)
        contrast_max = ops.max(contrast_map, axis=[1, 2, 3], keepdims=True)
        normalized_contrast = (contrast_map - contrast_min) / (contrast_max - contrast_min + 1e-8)
        
        # 중요한 영역(높은 콘트라스트)에 더 높은 마스킹 확률 부여
        # Sigmoid 함수: 1/(1+exp(-x))
        importance_score = normalized_contrast  # 높을수록 중요한 영역
        mask_ratio = base_ratio * (0.5 + importance_score)  # 중요한 영역은 더 높은 마스킹 확률
        
        # 3. 확률적 마스크 생성 (1: 마스킹됨, 0: 유지됨)
        random_values = np.random.uniform(size=ops.shape(mask_ratio))
        mask = ops.cast(random_values < mask_ratio, "float32")
        
        # 최소 마스킹 비율 보장 (전체 패치의 최소 50%는 마스킹)
        current_mask_ratio = ops.mean(mask)
        mask = tf.cond(
                        current_mask_ratio < 0.5,
                        lambda: ops.maximum(mask, 
                                            ops.cast(np.random.uniform(size=ops.shape(mask_ratio)) < (0.5 - current_mask_ratio), "float32")),
                        lambda: mask
                    )
        
        
        # 4. 마스크 업샘플링 및 적용
        mask_up = ops.repeat(mask, P, axis=1)
        mask_up = ops.repeat(mask_up, P, axis=2)
        mask_up = ops.reshape(mask_up, [batch_size, image_size, image_size, 1])
        
        # 마스킹된 이미지 생성 (마스크=1인 위치는 0으로 설정)
        masked_images = images * (1.0 - mask_up)
        
        # 5. 원본 이미지에서 패치 추출
        patches = ops.image.extract_patches(
            images=images,
            size=(P, P),
            strides=(P, P),
            dilation_rate=1,
            padding="valid"
        )
        patches = ops.reshape(patches, [batch_size, num_patches_h * num_patches_w, P*P*c])
        
        # 6. 마스크 평탄화 (loss 계산용)
        mask_flat = ops.reshape(mask, [batch_size, -1])
        
        return masked_images, mask_flat, patches

    def _compute_mim_loss(self, gt, pred, mask):
        mask = ops.expand_dims(ops.cast(mask, "float32"), -1)
        mse = ops.square(gt - pred) * mask
        return ops.sum(mse) / (ops.sum(mask) + 1e-8)
    def compute_headwise_sim(self, attn_weights):
        # attn_weights shape: [batch, heads, 1, n_patches]
        batch = ops.shape(attn_weights)[0]
        num_heads = ops.shape(attn_weights)[1]
        
        # 1. 공간 차원을 평탄화: [B, H, N] (N = n_patches)
        flat_attn = ops.reshape(attn_weights, (batch, num_heads, -1))
        
        # 2. 각 head마다 L2 정규화 (수치 안전성을 위해 epsilon 추가)
        norm_attn = ops.normalize(flat_attn, axis=-1)  + 1e-8
        
        # 3. head간 코사인 유사도 계산: einsum을 통해 [B, H, H] 크기의 유사도 행렬 생성
        cosine_sim = ops.einsum("bhi,bji->bhj", norm_attn, norm_attn)
        
        # 4. 대각선이 아닌 상삼각 행렬만 선택 (중복 비교 방지)
        # num_heads가 static일 경우 이용 가능
        n_heads = norm_attn.shape[1]
        identity = ops.eye(n_heads)
        mask = 1-identity
        # k=1: 대각선 위쪽만 True인 마스크 생성
        
        # 5. boolean_mask로 상삼각 값만 추출 후 제곱하여 평균 계산
        sim_values = mask[tf.newaxis, ...] * cosine_sim
        
        sims = ops.mean(sim_values)
        
        return sims

    def call(self, inputs):
        
        # 1. Patch embedding
        inputs = ops.cast(inputs, 'float32')
        patches = self.patch_embed(inputs)
        patches = self.middle_conv(patches)
        _, w, h, dims_ = patches.shape ; n_patches = w*h
        
        # 2. Apply RoPE to patches
        patches = self.rope(patches)
        patches = ops.reshape(patches, (ops.shape(patches)[0], n_patches, self.embed_dims))
        
        # 3. Add CLS token with positional embedding
        cls_token = ops.repeat(self.cls_token, ops.shape(inputs)[0], axis=0)
        if self.n_reg > 0:
            register_tokens = ops.repeat(self.reg_tokens, ops.shape(inputs)[0], axis=0)
        #x = ops.concatenate([cls_tokens, patches], axis=1)
        #if self.n_reg > 0:
        #    x = ops.concatenate([x, register_tokens], axis = 1)
        
        # 4. Early layers (MIM)
        encoded_patches = self.early_layers(patches)
        print("Encoded patches shape after Early layer encoding : ", ops.shape(encoded_patches))
        #encoded_patches = x[:, 1:(n_patches + 1)]
        #cls_token = x[:, 0:1]
        
        # 5. Late layers (NNCLR)
        for layer in self.late_layers.layers:
            cls_token, attn_weights = layer(query = cls_token, key = encoded_patches, value = encoded_patches, 
                                            return_attention_scores = True)
        
        return cls_token[:,0,:], encoded_patches, attn_weights

    def train_step(self, data):
        images = data
        batch_size = ops.shape(images)[0] ; channels = ops.shape(images)[-1]
        thumbnail = self.whole_view_augmenter(images)

        crop_size = random.choice(self.possible_sizes)
        crop_h, crop_w = crop_size[0], crop_size[1]
        
        # 4. 크롭 적용 (배치 차원 보존)
        dense_view1 = tf.image.random_crop(
            images,
            size=tf.concat(  # 모든 차원에 대해 명시적 크기 지정
                [
                    [batch_size],  # 배치 차원 유지 (있을 경우)
                    [crop_h], 
                    [crop_w], 
                    [channels]
                ], 
                axis=0
            )
        )
        dense_view1 = ops.reshape(dense_view1, [batch_size, crop_h, crop_w, channels])
        ############################
        crop_size = random.choice(self.possible_sizes)
        crop_h, crop_w = crop_size[0], crop_size[1]
        
        # 4. 크롭 적용 (배치 차원 보존)
        dense_view2 = tf.image.random_crop(
            images,
            size=tf.concat(  # 모든 차원에 대해 명시적 크기 지정
                [
                    [batch_size],  # 배치 차원 유지 (있을 경우)
                    [crop_h], 
                    [crop_w], 
                    [channels]
                ], 
                axis=0
            )
        )
        dense_view2 = ops.reshape(dense_view2, [batch_size, crop_h, crop_w, channels])
        
        masked_images, mask_flat, gt_patches = self._dynamic_masking(images)
        gt_patches /= 255

        if np.random.randint(0, 10) <= 5:
            thumbnail = 255 - thumbnail
                
        with tf.GradientTape(watch_accessed_variables=True) as tape:
            #1. SimMIM
            _, encoded_patches, _ = self(masked_images)
            pred_patches = self.mim_head(encoded_patches)
            mim_loss = self._compute_mim_loss(gt_patches, pred_patches, mask_flat)
            
            #2. NCLR b/w thumbnail and original images
            cls_token, _, attn_weights = self(images)
            aug_token, _, _ = self(thumbnail)
            proj_features = self.nnclr_proj(cls_token)
            proj_aug_features = self.nnclr_proj(aug_token)
            nnclr_loss = self.compute_nclr_loss([proj_features, proj_aug_features])

            #3. DenseCL b/w 2 small feature maps
            print(ops.shape(dense_view1))
            _, f1, _ = self(dense_view1)
            _, f2, _ = self(dense_view2)
            dense_cl_loss = self.compute_densecl_loss(f1, f2)
            
            headwise_att_sim = self.compute_headwise_sim(attn_weights)
            loss = 0.2*mim_loss + nnclr_loss + dense_cl_loss
        # 3. 그래디언트 병합 및 적용
        grads = tape.gradient(
            loss, 
            self.trainable_weights
        )
        
        self.optimizer.apply_gradients(zip(
            grads,
            self.trainable_weights
        ))
        
        return {"mim_loss": mim_loss, "nnclr_loss": nnclr_loss, 'DenseCL_loss' : dense_cl_loss,
                "HeadwiseAttn_w_Cos_sim" : headwise_att_sim,
                'total_loss' : loss}

    def mimmax_scale(self, data):
        data = np.array(data)
        return (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-4)
    def get_segmentation_maps(self, feature_map):
        n_clusters = 20
        cluster_cmap = "tab20"
        feature_map = np.array(feature_map)
        if len(ops.shape(feature_map)) == 3:
            batch_size, seq_len, embed_dims = feature_map.shape
            res_ = ops.sqrt(ops.cast(seq_len, "float32"))
            res_ = ops.cast(res_, 'int32')
            w, h = res_, res_
            flatten_map = feature_map.reshape((batch_size*seq_len, embed_dims))
        elif len(ops.shape(feature_map)) == 4:
            batch_size, w, h, embed_dims = feature_map.shape
            flatten_map = feature_map.reshape((batch_size*w*h, embed_dims))
        result_maps = {}
        scaler = StandardScaler()
        scaled_features = scaler.fit_transform(flatten_map)
        
        #1. PCA
        pca = PCA(n_components=3)
        pca_result = pca.fit_transform(scaled_features)
        pca_result_norm = self.mimmax_scale(pca_result)
        result_maps["pca"] = pca_result_norm.reshape((batch_size, w, h, 3))
        
        #2. t-SNE
        tsne = TSNE(n_components=3, random_state=42, perplexity=min(30, (w*h)//5))
        tsne_result = tsne.fit_transform(scaled_features)
        tsne_result_norm = self.mimmax_scale(tsne_result)
        result_maps["tsne"] = tsne_result_norm.reshape((batch_size, w, h, 3))

        #3. UMAP
        #reducer = umap.UMAP(n_components=3, random_state=42)
        #umap_result = reducer.fit_transform(scaled_features)
        #umap_result_norm = self.mimmax_scale(umap_result)
        #result_maps["umap"] = umap_result_norm.reshape((batch_size, w, h, 3))

        #4. Agglomerative clustering
        agg_clustering = AgglomerativeClustering(n_clusters=n_clusters)
        agg_labels = agg_clustering.fit_predict(scaled_features)
        
        # Agglomerative Clustering 결과에 colormap 적용
        cmap = plt.cm.get_cmap(cluster_cmap, n_clusters)
        agg_colors = cmap(agg_labels / (n_clusters - 1 + 1e-5))[..., :3]
        result_maps['agglomerative'] = agg_colors.reshape((batch_size, w, h, 3))

        #5. K-means
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')
        kmeans_labels = kmeans.fit_predict(scaled_features)
        
        cmap = plt.cm.get_cmap(cluster_cmap, n_clusters)
        kmeans_colors = cmap(kmeans_labels / (n_clusters - 1 + 1e-5))[..., :3]
        
        result_maps['kmeans'] = kmeans_colors.reshape((batch_size, w, h, 3))
        return result_maps
        
    def infer(self, images):
        #output : attention map, merged map.
        batch_size, res, _, __ = ops.shape(images)
        print("Heatmap calculation...", '\n')
        feature_vector, encoded_patches, attn_weights = self(images)
        
        _, patch_len_, __ = encoded_patches.shape
        n_patches = patch_len_  ; small_res = res // self.patch_size

        attn_weights = attn_weights[:, :, 0, :]
        attn_weights = ops.reshape(attn_weights, [-1, small_res * small_res])
        attn_weights = (attn_weights - ops.min(attn_weights, axis = -1, keepdims = True)) / (ops.max(attn_weights, axis = -1, keepdims = True) - ops.min(attn_weights, axis = -1, keepdims = True))
        attn_weights *= 255.0
        threshold = ops.median(attn_weights, axis = -1, keepdims = True)
        
        attn_weights = ops.where(attn_weights < threshold, 0.0, attn_weights)
        heatmap = ops.reshape(attn_weights, [batch_size*self.heads, small_res, small_res])
        heatmap = np.array(heatmap).astype("int32")
        cmap = mpl.colormaps["jet"]
        # Use RGB values of the colormap
        _colors = cmap(np.arange(256))[:, :3]
        heatmap = _colors[heatmap]
        heatmap = [keras.utils.array_to_img(h) for h in heatmap]
        heatmap = [h.resize((res,res)) for h in heatmap]
        heatmap = [keras.utils.img_to_array(h) for h in heatmap]
        heatmap = np.array(heatmap)
        heatmap = ops.reshape(heatmap, [batch_size, self.heads, res, res, 3])
        mean_heatmap = ops.mean(heatmap, axis = 1)
        mean_merged = ops.cast(images, 'float32') * 0.5 + ops.cast(mean_heatmap, 'float32') * 0.5
        
        images = images[:, tf.newaxis, ...]
        merged = ops.cast(images, 'float32') * 0.5 + ops.cast(heatmap, 'float32') * 0.5
        
        merged = ops.cast(merged, 'uint8')
        mean_merged = ops.cast(mean_merged, 'uint8')
        heatmap = ops.cast(heatmap, 'uint8')
        mean_heatmap = ops.cast(mean_heatmap, 'uint8')
        print("Segmentation...")
        seg_maps = self.get_segmentation_maps(encoded_patches)
        result = {"attn_map" : heatmap, "merged_original" : merged,
               "attn_map_head_merged" : mean_heatmap,
               "head_merged_original": mean_merged,
                "encoded_patches" : encoded_patches,
               }
        result.update(seg_maps)
        
        return result


# Setting hyperparameters

In [None]:
batch_size = 8
batch_size = strategy.num_replicas_in_sync * batch_size
print('batch size', batch_size)

res = int(3.0*256)
small_res = 64

n_multicrop = 2
randaug =keras_cv.layers.RandAugment(
    value_range=(0, 255), magnitude=0.1, magnitude_stddev=0.1, geometric = False
)

grayscale = False # False if using pretrained model, True if from scratch
patch_size = 12
heads = 8
att_dims = 64
embed_dims = 512

c = 1 if grayscale else 3
if grayscale:
    pretrained_encoder = None
    depth = 2
    registers = 2
    pretrained_note = "gray_metaformer"
else:
    depth = 3
    registers = 0
    pretrained_encoder = keras.applications.ConvNeXtTiny(input_shape = [res,res,3], 
                                                         include_top = False,
                                                        #weights = None,
                                                        ); patch_size = 32
    
    pretrained_vit = kimm.models.VisionTransformerTiny32(input_shape = [res,res,3], include_top = False)
    #pretrained_regnet = kimm.models.RegNetY040(input_shape = [res,res,3], include_top = False); patch_size = 32
    #pretrained_regnet = keras.Model(inputs = pretrained_regnet.input, outputs = pretrained_regnet.get_layer("s4_b1_conv1").output,
    #                    name = f"{pretrained_regnet.name}_upsample")
    #pretrained_vit = kimm.models.VisionTransformerBase16(input_shape = [res,res,3], include_top = False) ; patch_size = 16
    #pretrained_vit = kimm.models.VisionTransformerLarge16(input_shape = [res,res,3], include_top = False) ; patch_size = 16
    
    for layer in pretrained_encoder.layers:
        layer.dtype_policy = keras.mixed_precision.Policy('mixed_float16')
    for layer in pretrained_vit.layers:
        layer.dtype_policy = keras.mixed_precision.Policy('mixed_float16')
    #for layer in pretrained_regnet.layers:
    #    layer.dtype_policy = keras.mixed_precision.Policy('mixed_float16')
    pretrained_note = f"ConvWithMetaEncoder_ImageNet_TM{depth}"
    #depth = 6
    #registers = 0
    #pretrained_encoder = None
    #pretrained_note = f"RGB_metaformer_TM{depth}"
    #patch_size = 24

In [None]:
def get_dual_encoder():
    input_tensor = Input([res,res,3], name = "DualEncoderInputImg")
    # Step 1: Load pretrained lightweight networks
    # Extract feature maps from the pretrained models
    f1 = pretrained_encoder
    f2 = pretrained_vit

    m1 = f1(input_tensor)  # Feature map from MobileNetV2
    m2 = f2(input_tensor)[:, 1:, :]  # Feature map from ViT
    _, w, h, dims = ops.shape(m1)
    m1 = ops.reshape(m1, [-1, w*h, dims])
    combined_features = keras.layers.Identity(name = "MergedFeatureMap")(ops.concatenate([m1, m2], axis = -1))

    return Model(input_tensor, combined_features,
                name = f"{f1.name}With{f2.name}_dualencoder")

# Example usage
#dual_model = get_dual_encoder()
#dual_model.summary()

- radimagenet tfrecord key : image, label
- nih cxr tfrecord key : image_raw, label

# RadImageNet decoding

In [None]:
def _parse_tfrecord(res = res):
    def parse_tfrecord(tfrecord):
        features = {'image': tf.io.FixedLenFeature([], tf.string),
                    'label': tf.io.FixedLenFeature([], tf.int64),
                    }
        x = tf.io.parse_single_example(tfrecord, features)
        image_train = tf.image.decode_jpeg(x['image'], channels=1)
        image_train = _transform_images(res = res)(image_train)
        label = tf.cast(x["label"], tf.int32)
        return (image_train, label)
    
    return parse_tfrecord


def _transform_images(res = res):
    def transform_images(x_train):
        x_train = tf.image.resize_with_pad(x_train, res, res, antialias = True)
        x_train = tf.cast(x_train, tf.uint8)
        return x_train
    return transform_images

def load_tfrecord_dataset(tfrecord_name, res = res, batch_size = batch_size, shuffle=True, buffer_size=10240):
    """load dataset from tfrecord"""
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name, compression_type = "GZIP")
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.batch(batch_size, drop_remainder = True)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

train_radimagenet_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RagImageNet_Train_GZIP.tfrecord")
val_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RagImageNet_Test_GZIP.tfrecord")

# NIH CXR decoding

In [None]:
def _parse_tfrecord(res = res):
    def parse_tfrecord(tfrecord):
        features = {'image_raw': tf.io.FixedLenFeature([], tf.string),
                    'label': tf.io.FixedLenFeature([], tf.int64),
                    }
        x = tf.io.parse_single_example(tfrecord, features)
        image_train = tf.image.decode_jpeg(x['image_raw'], channels=1)
        image_train = _transform_images(res = res)(image_train)
        label = tf.cast(x["label"], tf.int32)
        return (image_train, label)
    
    return parse_tfrecord


def _transform_images(res = res):
    def transform_images(x_train):
        x_train = tf.image.resize_with_pad(x_train, res, res, antialias = True)
        x_train = tf.cast(x_train, tf.uint8)
        return x_train
    return transform_images

def load_tfrecord_dataset(tfrecord_name, res = res, batch_size = batch_size, shuffle=True, buffer_size=10240):
    """load dataset from tfrecord"""
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name)
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    if batch_size:
        dataset = dataset.batch(batch_size, drop_remainder = True)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

nih_cxr_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/nih_cxr_images.tfrecords")

# Merging 2 datasets

In [None]:
train_ds = tf.data.Dataset.sample_from_datasets([train_radimagenet_ds.unbatch(), nih_cxr_ds.unbatch()], weights = [0.75, 0.25]).batch(batch_size, drop_remainder = True).repeat().prefetch(tf.data.AUTOTUNE)
val_ds_ = tf.data.Dataset.sample_from_datasets([train_radimagenet_ds.unbatch(), nih_cxr_ds.unbatch()], weights = [0.75, 0.25]).batch(6, drop_remainder = True).prefetch(tf.data.AUTOTUNE)
#train ds output : ([batch_size, res, res, 1], [batch_size,])
# train data curation
for images, labels in val_ds_.take(1):
    sample_img = images
    labels = labels
    
del val_ds_

In [None]:
def get_sobel_fn():
    def sobel_merge(image, label):
        image = image[tf.newaxis, ...]
        rand_num = keras.random.randint(shape = (), minval = 1, maxval = 10)
        if rand_num > 5:
            image = ops.cast(image, 'float32')
            ed = tf.image.sobel_edges(image)[..., 0, :]
            ed_norm = 255.0 * (ed - ops.min(ed)) / (ops.max(ed) - ops.min(ed)) ; del ed
            ed_norm = ops.cast(ed_norm, "uint8")
            image =ops.concatenate([image, ed_norm],
                                  axis = -1)
            image = ops.cast(image, "uint8")
        else:
            try:
                image = tf.image.grayscale_to_rgb(image)
            except:
                pass
        image = image[0]
        return image, label
    return sobel_merge
sobel_merge = get_sobel_fn()

# Convert supervised dataset into SSL dataset

In [None]:
multiview_fn = get_map_fn(res = res, input_type = "supervised", output_type = "ssl",
                         n_view = n_multicrop, grayscale = grayscale)

train_ds_multiview = train_ds.map(multiview_fn, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
val_ds_multiview = val_ds.map(multiview_fn, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

mask_map_fn_ = get_masking_fn(grayscale = grayscale, masking_rate = 0.5, patch_size = patch_size)

train_edge_ds = train_ds.unbatch().map(sobel_merge, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size, drop_remainder = True).prefetch(tf.data.AUTOTUNE)

def masking_function(image, label):
    return mask_map_fn_(image)

train_ds_masked = train_ds.map(masking_function, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE).repeat()
#train_ds_edge_masked = train_edge_ds.unbatch().map(masking_function, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size, drop_remainder = True).prefetch(tf.data.AUTOTUNE).repeat()
#train_ds_edge_multiview = train_edge_ds.map(multiview_fn, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
#train_ds_edge_simple_multiview = train_edge_ds.map(simple_aug_fn, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

----------
# Experiment - helper functions

In [None]:
df_train_rad = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RadImgNet_train.csv")
df_train_nih = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/nih_trainval_split.csv"
                          )
df_val_rad = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RadImgNet_test.csv")
df_val_nih = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/nih_test_split.csv")


train_cases = len(df_train_rad) + len(df_train_nih) + len(df_val_nih)
val_cases = len(df_val_rad)

train_steps = train_cases//batch_size
val_steps = val_cases//batch_size
print(f"Total train cases : {train_cases}, validation cases : {val_cases}")

In [None]:
class GCAdamW(keras.optimizers.AdamW):
    def get_gradients(self, loss, params):
        # We here just provide a modified get_gradients() function since we are
        # trying to just compute the centralized gradients.

        grads = []
        gradients = super().get_gradients()
        for grad in gradients:
            grad_len = len(grad.shape)
            if grad_len > 1:
                axis = list(range(grad_len - 1))
                grad -= ops.mean(grad, axis=axis, keep_dims=True)
            grads.append(grad)

        return grads
    
class GCAdam(keras.optimizers.Adam):
    def get_gradients(self, loss, params):
        # We here just provide a modified get_gradients() function since we are
        # trying to just compute the centralized gradients.

        grads = []
        gradients = super().get_gradients()
        for grad in gradients:
            grad_len = len(grad.shape)
            if grad_len > 1:
                axis = list(range(grad_len - 1))
                grad -= ops.mean(grad, axis=axis, keep_dims=True)
            grads.append(grad)

        return grads

In [None]:
class ModelSaveCallback(keras.callbacks.Callback):
    def __init__(self, exp_name, message = None, **kwargs):
        super().__init__(**kwargs)
        self.exp_name = exp_name
        self.message = message if message is not None else " "
    def on_epoch_end(self, epoch, logs=None):
        feature_ext_name = self.model.feature_extractor.name
        save_dir = "/kaggle/working/" ; target_dir = '/kaggle/working/model_save'
        os.makedirs(target_dir, exist_ok = True)
        if (epoch % 1 == 0):
            try:
                print("\nModel Saving to local notebook...")
                file_name = f"{feature_ext_name}_FE{self.exp_name}_Epoch{epoch}_{self.message}.keras"
                filepath = os.path.join(target_dir, file_name)
                saved_dir = self.model.feature_extractor.save(filepath, overwrite=True)
                print("\nModel Uploading to NAS...")
                upload_file(file_name, filepath)
                print("\nModel Saved to Local NAS")
            except Exception as e: 
                print('Model Saving Error:\n', e)
    def on_train_batch_end(self, batch, logs=None):
        feature_ext_name = self.model.feature_extractor.name
        save_dir = "/kaggle/working/" ; target_dir = '/kaggle/working/model_save'
        os.makedirs(target_dir, exist_ok = True)
        if (batch % 20000 == 0) and (batch != 0): 
            print("\nModel Saving to local notebook...")
            file_name = f"{feature_ext_name}_FE{self.exp_name}_Batch{batch}_{self.message}.keras"
            filepath = os.path.join(target_dir, file_name)
            saved_dir = self.model.feature_extractor.save(filepath, overwrite=True)
            print("\nModel Uploading to NAS...")
            upload_file(file_name, filepath)
            print("\nModel Saved to Local NAS")
                
class TemperatureScheduler(keras.callbacks.Callback):
    def __init__(self, initial_t = 0.5, decay_rate = 0.99):
        super().__init__()
        self.initial_t = initial_t
        self.decay_rate = decay_rate

    def on_train_batch_begin(self, batch, logs=None):
        if not hasattr(self.model, 't'):
            self.model.t = ops.convert_to_tensor(self.initial_t, dtype='float32')
        else:
            if (batch > 0) and (batch % 5000 == 0):
                self.model.t = self.model.t * self.decay_rate

In [None]:
class SimpleModelSaveCallback(keras.callbacks.Callback):
    def __init__(self, exp_name, message = None, **kwargs):
        super().__init__(**kwargs)
        self.exp_name = exp_name
        self.message = message if message is not None else " "
    
    def on_train_batch_end(self, batch, logs=None):
        feature_ext_name = self.model.name
        save_dir = "/kaggle/working/" ; target_dir = '/kaggle/working/model_save'
        os.makedirs(target_dir, exist_ok = True)
        if (batch % 20000 == 0) and (batch != 0): 
            print("\nModel Saving to local notebook...")
            file_name = f"{feature_ext_name}_FE{self.exp_name}_Batch{batch}_{self.message}.keras"
            filepath = os.path.join(target_dir, file_name)
            saved_dir = self.model.save(filepath, overwrite=True)
            print("\nModel Uploading to NAS...")
            upload_file(file_name, filepath)
            print("\nModel Saved to Local NAS")
                

In [None]:
class TrainingViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    def on_epoch_end(self, epoch, logs=None):
        try:
            configs = self.model.get_env_config() ; method = configs["SSL_method"]
            if method in ["CLIP" , "SigLIP", "SPARC"]:
                feature_extractor = self.model
            else:
                try:
                    feature_extractor = self.model.feature_extractor
                except:
                    feature_extractor = self.model.get_full_model(res = res)
            viz_weights, merged_weights = ssl_module.att_visualize(feature_extractor, sample_img, res,
                                                  thresholding = 0)
            viz_weights = np.array(viz_weights) #batch, heads, res, res, 3
            merged_weights = np.array(merged_weights)
            heads = viz_weights.shape[1]
            origin = ["Original Image"]
            col = [f"Head{idx + 1}" for idx in range(heads)]
            col = origin + ["Merged"] + col

            visualize_data = []
            for idx, weights in enumerate(viz_weights):
                origin_img = [wandb.Image(sample_img[idx])]
                merged_tmp = [wandb.Image(merged_weights[idx])]
                tmp = [wandb.Image(weights[idx]) for idx in range(heads)]
                tmp = origin_img + merged_tmp + tmp
                visualize_data.append(tmp)
                del tmp, origin_img, merged_tmp
            tbl = wandb.Table(columns = col, data = visualize_data)
            wandb.log({f"Epoch{epoch+1}_{method}_result": tbl})
            del feature_extractor, tbl
            tf.keras.backend.clear_session()
            
            # feature vector visualization
            embed_v = feature_visualize(self.model, sample_img)
            data = [[x, y] for (x, y) in zip(embed_v[..., 0], embed_v[..., 1])]
            table = wandb.Table(data=data, columns = ["x", "y"])
            wandb.log({f"Epoch{epoch+1}_{method}_FeatureViz" : wandb.plot.scatter(table, "x", "y", title="TSNE Scatter Plot")})
            tf.keras.backend.clear_session()

            
        except Exception as e: 
                print('Model Saving Error:\n', e)
        
    def on_train_batch_end(self, batch, logs=None):
        if (batch % (10000) == 0) : 
            try:
                configs = self.model.get_env_config() ; method = configs["SSL_method"]
                if method in ["CLIP" , "SigLIP", "SPARC"]:
                    feature_extractor = self.model
                else:
                    try:
                        feature_extractor = self.model.feature_extractor
                    except:
                        feature_extractor = self.model.get_full_model(res = res)
                viz_weights, merged_weights = ssl_module.att_visualize(feature_extractor, sample_img, res,
                                                      thresholding = False)
                _, rollout_merged_image = ssl_module.att_visualize_merged(feature_extractor, sample_img, res)
                viz_weights = np.array(viz_weights) #batch, heads, res, res, 3
                merged_weights = np.array(merged_weights)
                heads = viz_weights.shape[1]
                origin = ["Original Image"]
                col = [f"Head{idx + 1}" for idx in range(heads)]
                col = origin + ["MergedMap"] + ['Top-K head merging map'] + col
                
                visualize_data = []
                for idx, weights in enumerate(viz_weights): #heads, res, res, 3
                    origin_img = [wandb.Image(sample_img[idx])]
                    merged_map = [wandb.Image(merged_weights[idx])]
                    merged_map_rollout = [wandb.Image(rollout_merged_image[idx])]
                    
                    tmp = [wandb.Image(weights[idx]) for idx in range(heads)]
                    tmp = origin_img + merged_map + merged_map_rollout + tmp
                    visualize_data.append(tmp)
                    del tmp, origin_img, merged_map, merged_map_rollout
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"ZeroBatch_{method}_result": tbl})
                else:
                    wandb.log({f"MidEpoch_{method}_result": tbl})
                del feature_extractor, tbl
                tf.keras.backend.clear_session()
                       
                embed_v = feature_visualize(self.model, sample_img)
                data = [[x, y] for (x, y) in zip(embed_v[..., 0], embed_v[..., 1])]
                table = wandb.Table(data=data, columns = ["x", "y"])
                wandb.log({f"Batch{batch}_{method}_FeatureViz" : wandb.plot.scatter(table, "x", "y", title="TSNE Scatter Plot")})
                tf.keras.backend.clear_session()
                       
            except Exception as e:
                print("Error code in callback : ", e)
           
        else:
            pass

> Real world evaluation and Segmentation callback

In [None]:
real_world_dir = "/kaggle/input/real-world-medical-image-dataset-for-evaluation/radiopaedia_example" ; filenames_ = os.listdir(real_world_dir)
filenames_.sort()
labels_ = [name.split('.')[0] for name in filenames_]
real_world_files = [os.path.join(real_world_dir, paths) for paths in filenames_]
def get_img_tensor(path, res = res) :
    file = tf.io.read_file(path)
    c =1 if grayscale else 3
    image = tf.io.decode_image(file, channels=c)
    image = tf.image.resize_with_pad(image, res, res, antialias = True)
    image = ops.cast(image, "uint8")
    return image
real_world_images = tf.stack([get_img_tensor(f) for f in real_world_files],
                             axis = 0)


> medpix external evaluation

In [None]:
def _parse_function_medpix(proto):
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'caption': tf.io.FixedLenFeature([], tf.string)
    }
    parsed_features = tf.io.parse_single_example(proto, feature_description)

    # image 데이터를 디코딩 (채널 수 1: grayscale 이미지로 가정)
    image = tf.image.decode_image(parsed_features['image'], channels=1)
    image = ops.cast(image, "float32")
    image = (image - ops.min(image)) / (ops.max(image) - ops.min(image) + 1e-4)
    image = 255.0*image
    image = ops.cast(image, "uint8")

    # caption은 bytes 타입이므로 그대로 반환 (나중에 decode 처리)
    caption = parsed_features['caption']
    return image, caption



# TFRecord 파일을 GZIP 압축 옵션과 함께 읽어들임
tfrecord_file = '/kaggle/input/real-world-medical-image-dataset-for-evaluation/medpix_val.tfrecord'
raw_dataset = tf.data.TFRecordDataset(tfrecord_file, compression_type='GZIP')
medpix_val = raw_dataset.map(_parse_function_medpix)
medpix_images = []
medpix_captions = []

for i, c in medpix_val.take(60):
    i = tf.image.resize_with_pad(i, 512,512, antialias = True)
    c = c.numpy().decode('utf-8')
    medpix_images.append(i)
    medpix_captions.append(c)
medpix_images = tf.stack(medpix_images, axis = 0)

In [None]:
class RealWorldViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    
    def on_train_batch_end(self, batch, logs=None):
        if (batch % (5000) == 0) : 
            try:
                configs = self.model.get_env_config() ; method = configs["SSL_method"]
                if method in ["CLIP" , "SigLIP", "SPARC"]:
                    feature_extractor = self.model
                else:
                    try:
                        feature_extractor = self.model.feature_extractor
                    except:
                        feature_extractor = self.model.get_full_model(res = res)
                viz_weights, merged_weights = ssl_module.att_visualize(feature_extractor, real_world_images, res,
                                                      thresholding = False)
                _, rollout_merged_image = ssl_module.att_visualize_merged(feature_extractor, 
                                                                          real_world_images, res)
                viz_weights = np.array(viz_weights) #batch, heads, res, res, 3
                merged_weights = np.array(merged_weights)
                
                heads = viz_weights.shape[1]
                origin = ["Original Image"]
                col = [f"Head{idx + 1}" for idx in range(heads)]
                col = origin + ["Original Label"] + ["MergedMap"] + ['Top-K head merging map'] + col
                visualize_data = []
                for idx, weights in enumerate(viz_weights):
                    origin_img = [wandb.Image(real_world_images[idx])]
                    lab = [labels_[idx]]
                    
                    merged_map = [wandb.Image(merged_weights[idx])]
                    merged_map_rollout = [wandb.Image(rollout_merged_image[idx])]
                    
                    tmp = [wandb.Image(weights[idx]) for idx in range(heads)]
                    tmp = origin_img +lab +  merged_map + merged_map_rollout + tmp
                    visualize_data.append(tmp)
                    del tmp, origin_img, merged_map, merged_map_rollout
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"RW_ZeroBatch_{method}_result": tbl})
                else:
                    wandb.log({f"RW_Batch{batch}_{method}_result": tbl})
                del tbl
                tf.keras.backend.clear_session()
                
                embed_v = feature_visualize(self.model, real_world_images)
                data = [[x, y] for (x, y) in zip(embed_v[..., 0], embed_v[..., 1])]
                table = wandb.Table(data=data, columns = ["x", "y"])
                wandb.log({f"RW_Batch{batch}_{method}_FeatureViz" : wandb.plot.scatter(table, "x", "y", 
                                                                                       title=f"RW_Batch{batch}_TSNE")})
                #########Hierarchical clustering##########
                feature_map = feature_extractor(real_world_images)[1]
                n_patch = feature_map.shape[1] ; w_ = ops.sqrt(ops.cast(n_patch, "float32")
                                                              )
                w_ = ops.cast(w_, "int32")
                embed_dims = feature_map.shape[-1]
                clustering_output = ssl_module.H_clustering(n_clusters = 200)(feature_map)
                clustering_output = ops.reshape(clustering_output, [-1, w_, w_,1])
                data = []
                for i, sample_image in enumerate(real_world_images):
                    cluster_plot = tf.convert_to_tensor(clustering_output[i])
                    cluster_plot = (cluster_plot - ops.min(cluster_plot)) / (ops.max(cluster_plot) - ops.min(cluster_plot)) 
                    cluster_plot *= 255 ; cluster_plot = tf.image.grayscale_to_rgb(ops.cast(cluster_plot, "uint8"))
                    cluster_plot = np.array(cluster_plot)
                    cluster_plot = PILImage.fromarray(cluster_plot, mode="RGB")
                    cluster_plot = wandb.Image(cluster_plot)
                    rw_image = wandb.Image(sample_image)
                    tmp = [rw_image, cluster_plot]
                    data.append(tmp) ; del tmp, rw_image
                 
                table = wandb.Table(data=data, columns = ["Original_image", "AgglomerativeCluster"])
                wandb.log({f"Cluster_RW_Batch{batch}_{method}_result": table})
                tf.keras.backend.clear_session()
                del feature_extractor
            except Exception as e:
                print("Error code in callback : ", e)
        else:
            pass
        

In [None]:
class AttVizForHybridViT(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    
    def on_train_batch_end(self, batch, logs=None):
        if (batch % (5000) == 0) : 
            print("Callback for visualization...", "\n")
            print("val dataset infer...\n")
            infer_output_val = self.model.infer(sample_img)
            print("MedPix dataset infer...\n")
            infer_output_medpix = self.model.infer(medpix_images)
            
            heads = ops.shape(infer_output_val['merged_original'])[1]
            batch_size_val = ops.shape(infer_output_val['merged_original'])[0]
            batch_size_medpix = ops.shape(infer_output_medpix['merged_original'])[0]
            
            feature_name_ = ['pca', 'tsne',  'agglomerative', 'kmeans']
            
            col_features_ = ["PCA map", "TSNE map", "Agglomerative c. map", "Kmeans map"]
            
            viz_val, viz_medpix = [], []
            # val ds loop
            col_medpix = ['Original Image', 'Captions', 'Headwise merged att map'] + col_features_ + [f"Head{k}" for k in range(heads)]
            print("MedPix dataset uploading...\n")
            for idx in range(batch_size_medpix):
                origin = [wandb.Image(medpix_images[idx])]
                lab = [medpix_captions[idx]]
                
                head_merged = [wandb.Image(infer_output_medpix["head_merged_original"][idx])]
                tmp = [wandb.Image(infer_output_medpix["merged_original"][idx, h]) for h in range(heads)]
                feature_tmp = [wandb.Image(infer_output_medpix[N][idx]) for N in feature_name_]
                
                data_ = origin + lab + head_merged + feature_tmp + tmp
                viz_medpix.append(data_)
            tbl = wandb.Table(columns = col_medpix, data = viz_medpix)
            wandb.log({f"MedPix_data_viz_{batch}batch": tbl})

            col_val = ['Original Image', 'Headwise merged att map'] + col_features_ + [f"Head{k}" for k in range(heads)]
            print("Val dataset uploading...\n")
            for idx in range(batch_size_val):
                origin = [wandb.Image(sample_img[idx])]
                head_merged = [wandb.Image(infer_output_val["head_merged_original"][idx])]
                tmp = [wandb.Image(infer_output_val["merged_original"][idx, h]) for h in range(heads)]
                feature_tmp = [wandb.Image(infer_output_val[N][idx]) for N in feature_name_]
                data_ = origin + head_merged + feature_tmp + tmp
                viz_val.append(data_)
            tbl = wandb.Table(columns = col_val, data = viz_val)
            wandb.log({f"val_data_viz_{batch}batch": tbl})
            
            tf.keras.backend.clear_session()
            print("Done!\n")

-------------
- Special callback for QNCLR

In [None]:
class QRealWorldViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    
    def on_train_batch_end(self, batch, logs=None):
        if (batch % (5000) == 0) : 
            if True:
                feature_extractor = self.model.feature_extractor
                real_world_images = ops.cast(real_world_images, "float32")
                q_attention_weights, q_batch_merged = ssl_module.q_visualize(feature_extractor, real_world_images, res,
                                                      thresholding = False)
                q_attention_weights = np.array(q_attention_weights) #batch, N_Q, res, res, 3
                q_batch_merged = np.array(q_batch_merged) #batch, res, res, 3
                
                n_queries = q_attention_weights.shape[1]
                origin = ["Original Image"]
                col = [f"LearnableQuery{idx + 1}" for idx in range(n_queries)]
                
                col = origin + ["Original Label"] + ["MergedMap"] + col
                print(col)
                
                visualize_data = []
                for idx, weights in enumerate(q_attention_weights):
                    origin_img = [wandb.Image(real_world_images[idx])]
                    lab = [labels_[idx]]
                    
                    merged_map = [wandb.Image(q_batch_merged[idx])]
                    each_query_map = [wandb.Image(weights[idx]) for idx in range(n_queries)]
                    
                    tmp = origin_img +lab + merged_map + each_query_map
                    visualize_data.append(tmp)
                    print(len(tmp))
                    del tmp, origin_img, lab, merged_map, each_query_map
                    
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"RW_ZeroBatch_{method}_result": tbl})
                else:
                    wandb.log({f"RW_Batch{batch}_{method}_result": tbl})
                del feature_extractor, tbl
                tf.keras.backend.clear_session()

            
        else:
            pass

In [None]:
class RealWorldPatchViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    
    def on_train_batch_end(self, batch, logs=None):
        if (batch % (5000) == 0) : 
            if True:
                #real_world_images = ops.cast(real_world_images, "uint8")
                feature_extractor = self.model.feature_extractor
                patch_heatmap, patch_merged_images = ssl_module.pca_patch_viz(feature_extractor, real_world_images)
                patch_heatmap = np.array(patch_heatmap)
                patch_merged_images = np.array(patch_merged_images)
                
                col = ["Original image"] + ["Original Label"] + ["Merged image"] + ["Encoded Patches"]
                print(col)
                
                visualize_data = []
                for idx, m_img in enumerate(patch_merged_images):
                    origin_img = [wandb.Image(real_world_images[idx])]
                    lab = [labels_[idx]]
                    
                    merged_image = [wandb.Image(m_img)]
                    e_patches = [wandb.Image(patch_heatmap[idx])]
                    
                    tmp = origin_img +lab + merged_image + e_patches
                    visualize_data.append(tmp)
                    
                    del tmp, origin_img, lab, merged_image, e_patches
                    
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"RW_ZeroBatch_Patch_result": tbl})
                else:
                    wandb.log({f"RW_Batch{batch}_Patch_result": tbl})
                del feature_extractor, tbl
                tf.keras.backend.clear_session()

            
        else:
            pass

In [None]:
class SegViz(keras.callbacks.Callback):
    def __init__(self, run, images, labels = None):
        super().__init__()
        self.run = run
        self.images = images
        self.labels = labels
    def on_train_batch_end(self, batch, logs=None):
        configs = self.model.get_env_config() ; method = configs["SSL_method"]
        if (batch % (10000) == 0) and  (method in ["UnsupSeg", "MixedUnsupSeg"]): 
            try:
                heatmap, superimposed_images = self.model.get_segments(self.images)
                origin = ["Original Image"]
                col = origin + ["Original Label"] + ["Segmentation Result"]
                visualize_data = []
                for idx, sup_img in enumerate(superimposed_images):
                    origin_img = [wandb.Image(self.images[idx])]
                    if self.labels is None:
                        lab = ["Label not provided."]
                    else:
                        lab = [self.labels[idx]] 
                    tmp = [wandb.Image(sup_img)]
                    tmp = origin_img + lab + tmp
                    visualize_data.append(tmp)
                    del tmp, origin_img
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"Seg_ZeroBatch_{method}_result": tbl})
                else:
                    wandb.log({f"Seg_MidEpoch_{method}_result": tbl})
                
                tf.keras.backend.clear_session()
            except Exception as e:
                print("Error code in Segmentation callback : ", e)
        else:
            pass

In [None]:
def run_exp(model, train_ds = train_ds, val_ds = val_ds, epochs = 10, note= None, exp_name = None):
    try:
        wandb.finish()
    except:
        pass
    
    if True :
        wandb_config()
        configs = model.get_env_config()
        method = configs["SSL_method"]
        try:
            feature_extractor = model.feature_extractor
        except:
            feature_extractor = model.get_full_model(res = res)
        
        if method in ['CLIP', "SigLIP", "SPARC"]:
            _ = model((example_images[:2], example_reports[:2]))
        elif method in ["SimMIM", "MixedMIM","DistilMIM", "MixedUnsupSeg", 
                        "NCLR_nnclr_without_momentum", 'NCLR_snclr_without_momentum']:
            pass
        else:
            pass
        try:
            feature_extractor_flops = get_flops(feature_extractor, [tf.random.normal([1,res,res,c])])
        except:
            feature_extractor_flops = "Uncheck"
        del feature_extractor
        
        env_config = {"batch_size" : batch_size, "Patch size": patch_size,
                      "original resolution" : res, "local view resolution" : small_res,
                     "Training steps" : train_steps,
                     "Val steps" : val_steps,
                     "train cases" : train_cases,
                     "val cases" : val_cases,
                     "embed_dims" : embed_dims,
                     "Image resolution" : res,
                     "(Image) Encoder Flops(G)" : feature_extractor_flops,
                     "dtype" : keras.mixed_precision.dtype_policy(),
                      "Optimizer configs" : model.optimizer.get_config(),
                      "Multicrop N" : n_multicrop, "metaencoder depth" : depth, 'embedding dims' : embed_dims,
                     }
        configs.update(env_config)
        
        wd = "/kaggle/working/"
        file_name = os.path.join(wd, f"{method}_radimgnet_mini.keras")
        print(configs, "\n\n")
        
        run = wandb.init(project="RadImageNet", 
                         entity="gongbungkim", config = configs, notes = note,
                        name = exp_name)
        wandb.run.log_code(".")
        pass_error = keras.callbacks.TerminateOnNaN()
        wb_callback = wandb.keras.WandbMetricsLogger(log_freq = 100)
        if isinstance(model, ssl_module.QNCLR):
            callbacks = [pass_error, wb_callback, ModelSaveCallback(f"RI_SSL_{method}", note), 
                        QRealWorldViz(run),
                        TemperatureScheduler()]
        elif isinstance(model, HybridViT):
            callbacks = [pass_error, wb_callback, AttVizForHybridViT()]
        else:
            callbacks = [pass_error, wb_callback, ModelSaveCallback(f"RI_SSL_{method}", note), 
                         TrainingViz(run),
                        RealWorldViz(run), RealWorldPatchViz(run),
                        SegViz(run, images = sample_img),
                        SegViz(run, images = real_world_images, labels = labels_),
                        TemperatureScheduler()]
        if val_ds is not None:
            hist = model.fit(train_ds, 
                             steps_per_epoch = train_steps, 
                             epochs = epochs, 
                             validation_data = val_ds, 
                             validation_steps = val_steps, 
                             verbose = 1,
                             callbacks = callbacks)
        else:
            hist = model.fit(train_ds, 
                         steps_per_epoch = train_steps, 
                         epochs = epochs, 
                         verbose = 1,
                         callbacks = callbacks)
    return hist

In [None]:
cosine_decay = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate = 1e-6,
    decay_steps = train_steps - 10000,
    alpha=1e-5,
    name='CosineDecay',
    warmup_target=2e-4,
    warmup_steps=10000
)
simple_cos_decay = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate = 3e-5,
    decay_steps = train_steps,
    alpha=1e-5,
    name='SimpleCosineDecay',
    #warmup_target=1e-3,
    #warmup_steps=train_steps - int(0.3*train_steps)
)
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate = 2e-4,
    decay_steps=20000,
    decay_rate=0.75,
    staircase=True)

In [None]:
def ssl_train(module, feature_extractor, learning_rate = lr_schedule,
              embed_dims = embed_dims, multiview = True, gradient_accumulation = None, use_ema = False,
             note = "", name = "",
             apply_barlow = False, apply_simclr = False):
    try:
        ssl_trainer = module(feature_extractor, embed_dims = embed_dims, multiview = multiview,
                            apply_barlow = apply_barlow, apply_simclr = apply_simclr)
    except Exception as e:
        print("Error : ",e)
        ssl_trainer = module(feature_extractor, embed_dims = embed_dims, multiview = multiview)
    ssl_trainer.compile(optimizer = keras.optimizers.Adam(learning_rate = learning_rate,
                                                         clipnorm = 0.5,
                                                         #amsgrad = True,
                                                           gradient_accumulation_steps=gradient_accumulation,
                                                         use_ema = use_ema),
                        jit_compile = False
                      )
    
    run_exp(ssl_trainer, train_ds_multiview, None, epochs = 100,
       note = note, exp_name = name)

# Hybrid ViT training

In [None]:
def simple_train_map(image, lab):
    #image = tf.image.grayscale_to_rgb(image)
    return image
train_ds_vit = train_ds.map(simple_train_map).prefetch(tf.data.AUTOTUNE)

In [None]:
with strategy.scope():
    reg_token = 0
    vit_layers = 8
    cutoff_layer = 4
    embed_dims_vit = 768
    heads = 12
    model = HybridViT(image_size = res, 
                      num_layers = vit_layers, 
                      split_layer = cutoff_layer, 
                      num_heads = heads,
                     embed_dims = embed_dims_vit,
                     patch_size = 32,
                     n_register = reg_token)
    
    _, encoded_patches, w = model(tf.image.rgb_to_grayscale(real_world_images)) ; print(encoded_patches.shape, w.shape)
    dummy_input = tf.zeros((1, embed_dims_vit))
    _ = model.mim_head(dummy_input)
    _ = model.nnclr_proj(dummy_input)
    
    opt = keras.optimizers.SGD(learning_rate = simple_cos_decay, 
                                                   #clipnorm = 1.0, 
                                                   #use_ema = True
                                                        )
    
    model_summary = []
    model.summary(print_fn=lambda x: model_summary.append(x))
    summary_str = "\n".join(model_summary)
    
    model.compile(optimizer = opt)
    
    wandb_config()
    env_config = {"batch_size" : batch_size, "Patch size": patch_size,
                         "Training steps" : train_steps,
                         "Val steps" : val_steps,
                         "train cases" : train_cases,
                         "val cases" : val_cases,
                         "embed_dims" : embed_dims_vit,
                         "Image resolution" : res,
                         "dtype" : keras.mixed_precision.dtype_policy(),
                          "Optimizer configs" : model.optimizer.get_config(),
                  "metaencoder depth" : vit_layers, 
                  'embedding dims' : embed_dims_vit,
                         }
    run = wandb.init(project="RadImageNet", 
                             entity="gongbungkim", config = env_config,
                            name = f'HybridVit_depth{vit_layers}_{heads}heads_cutoff{cutoff_layer}',
                    notes = f"{reg_token}reg, Dynamic masking, Variable Resolution with denseCL, gray input, res{res}")
    run.summary["model_summary"] = summary_str
    pass_error = keras.callbacks.TerminateOnNaN()
    wb_callback = wandb.keras.WandbMetricsLogger(log_freq = 100)
    callbacks = [pass_error, wb_callback, AttVizForHybridViT(run),
                SimpleModelSaveCallback(f"HybridVit_depth{vit_layers}_{heads}heads_cutoff{cutoff_layer}")]
    
    model.fit(train_ds_vit, epochs = 1, steps_per_epoch = train_steps, 
             verbose = 1, callbacks = callbacks)

# Else

In [None]:
ibot = 0
other = 0
mim = 0
qnclr = 0
sobel_other = 0

In [None]:
if mim:
    model_ = 'attention'
    vanilla_model = ssl_module.get_metaformer(model_, res = res, embed_dims = embed_dims, 
                                              att_depth = depth, att_heads = heads,
                                              att_dims = att_dims,
                                              grayscale = grayscale, patch_size = patch_size, 
                                              register_tokens = registers,
                                             pretrained_encoder = pretrained_encoder,
                                             return_patches = True)
    vanilla_model.summary()
    ssl_trainer = ssl_module.MixedMIM(vanilla_model, grayscale = grayscale, patch_size = patch_size)
    ssl_trainer.compile(optimizer = keras.optimizers.AdamW(learning_rate = cosine_decay, 
                                                           clipnorm = 1.0,
                                                           #gradient_accumulation_steps=64,
                                                           use_ema = True
                                                          ),
                        jit_compile = False
                      )
    configs = ssl_trainer.get_env_config()
    method = configs["SSL_method"]
    run_exp(ssl_trainer, train_ds_edge_masked, None, 
           note = pretrained_note+"_"+model_, exp_name = f"SobelMerging_Patch{patch_size}_{method}_{model_}")

In [None]:
if qnclr:
    vanilla_model = ssl_module.get_encdec_model(pretrained_encoder,
                                               res = res,
                                               att_dims = embed_dims,
                                               q_size = 8,
                                               encoder_trainable = True)
    ssl_trainer = ssl_module.QNCLR(vanilla_model, embed_dims = embed_dims, t = 0.05)
    ssl_trainer.use_mim = True
    ssl_trainer.compile(optimizer = keras.optimizers.AdamW(learning_rate = 5e-5, 
                                                               clipnorm = 1.0,
                                                               gradient_accumulation_steps=32,
                                                               #use_ema = True
                                                              ),
                            jit_compile = False,
                            
                          )
    method = "Q_NNCLR"

    run_exp(ssl_trainer, train_ds_masked, None, 
               note = pretrained_note, exp_name = f"{method}_StrongAug")

In [None]:
if sobel_other:
    if True:
        model_ = 'attention'
        vanilla_model = ssl_module.get_metaformer(model_, res = res, embed_dims = embed_dims, 
                                                  att_depth = depth, att_heads = heads,
                                                  att_dims = att_dims,
                                                  grayscale = grayscale, patch_size = patch_size, 
                                                  register_tokens = registers,
                                                 pretrained_encoder = pretrained_encoder,
                                                 # pretrained_encoder = pretrained_regnet,
                                                  #pretrained_encoder = pretrained_vit,pretrained_vit = True,
                                                 return_patches = True)
        #dual_model = ssl_module.get_metaformer(model_, res = res, embed_dims = embed_dims, 
        #                                          att_depth = depth, att_heads = heads,
        #                                          att_dims = att_dims,
        #                                          grayscale = grayscale, patch_size = patch_size, 
        #                                          register_tokens = registers,
        #                                         pretrained_encoder = get_dual_encoder(),
        #                                         # pretrained_encoder = pretrained_regnet,
        #                                          #pretrained_encoder = pretrained_vit,pretrained_vit = True,
        #                                         return_patches = True)
        #dual_model.summary()
        
        feature_map = vanilla_model(real_world_images)[1]
        n_patch = feature_map.shape[1] ; w_ = ops.sqrt(ops.cast(n_patch, "float32")
                                                                      )
        w_ = ops.cast(w_, "int32")
        embed_dims = feature_map.shape[-1]
        #pretrained_note = f"ViT{depth}"
        #ssl_trainer = ssl_module.DINO_MIM(vanilla_model, vanilla_model)
        #ssl_trainer = ssl_module.NCLR(vanilla_model, embed_dims = embed_dims, subtype = "nnclr", use_mim = True, patch_size = patch_size)
        ssl_trainer = ssl_module.SNCLR_SwAV(vanilla_model, swav_weight = 10, snclr_weight = 1,
                                           q_size = 2**15)
        #ssl_trainer = ssl_module.Moco(dual_model, use_dino = True)
        ssl_trainer.compile(optimizer = keras.optimizers.SGD(learning_rate = 1e-4,
                                                            momentum = 0.9,
                                                            weight_decay = 0.0001,
                                                            ),
                            jit_compile = False,
                            
                          )
        configs = ssl_trainer.get_env_config()
        method = configs["SSL_method"]

        run_exp(ssl_trainer, train_ds_masked, None, 
               note = pretrained_note+"_"+model_, exp_name = f"{method}_{model_}_StrongAug_SimplerNCLR")


In [None]:
if ibot:
    ssl_trainer = ssl_module.iBOT(att_depth = depth, att_dims = att_dims, att_heads = heads,
                                  embed_dims = 2048, patch_size = patch_size,

                                  multiview = True, apply_simclr = False,
                                  grayscale = True
                                 )
    ssl_trainer.compile(optimizer = keras.optimizers.AdamW(learning_rate = lr_schedule,
                                                         clipnorm = 1.0, use_ema = True),
                       jit_compile = False)
    run_exp(ssl_trainer, train_ds_multiview, None, epochs = 100,
           note = "+ NEW aug, New Patching", exp_name = "iBOT_VanillaViT")

In [None]:
if other:
    model_ = 'gMLP'
    if pretrained_encoder is None:
        note = "From Scratch"
        assert grayscale is True, "If building from scratch, make sure [grayscale = True]"
    else:
        note = f"With pretrained {pretrained_encoder.name}"
        assert grayscale is False, "If using pretrained network, make sure [grayscale = False]"
    vanilla_model = ssl_module.get_metaformer(model_, res = res, embed_dims = embed_dims, 
                                              att_depth = depth, att_heads = heads,att_dims = att_dims,
                                              grayscale = grayscale, patch_size = patch_size, 
                                              register_tokens = 4,
                                             pretrained_encoder = pretrained_encoder)
    ssl_train(ssl_module.DINO, vanilla_model, 
             note = note + " / 2-view",
             name = f"DINO_{model_}_reg",
             learning_rate = lr_schedule,
             multiview = False,
             gradient_accumulation = 32)