In [None]:
import pandas as pd
import numpy as np
import random
import pydicom

from sklearn.manifold import TSNE
import re
import matplotlib.pyplot as plt
import matplotlib.cm as cm
seed = 2024

import warnings
warnings.filterwarnings("ignore")

# ML tools 

import tensorflow as tf
import keras #; keras.config.set_dtype_policy("mixed_float16")
import keras_cv
import keras_nlp

import cv2
from skimage.io import imread
keras.utils.set_random_seed(seed)
import tensorflow_io as tfio
from kaggle_datasets import KaggleDatasets
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import tensorflow_decision_forests as tfdf

print(f"Tensorflow version : {tf.__version__}")
try:
    print(f"Keras version : {keras.__version__}")
except:
    pass

from keras import Input, Model, ops
from keras.models import load_model

from keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, BatchNormalization, LayerNormalization, MultiHeadAttention, Embedding, Subtract, Add, Multiply, GlobalAveragePooling2D, GlobalAveragePooling1D, LayerNormalization
from keras.utils import load_img, img_to_array
from keras.applications import *
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.model_selection import train_test_split
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from tqdm.notebook import tqdm
import wandb
from wandb.keras import WandbCallback, WandbModelCheckpoint, WandbMetricsLogger
def wandb_config():
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    try:
        secret_value_0 = user_secrets.get_secret("__gcloud_sdk_auth__")
        secret_value_1 = user_secrets.get_secret("huggingface_key")
        secret_value_2 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_2
    except:
        secret_value_0 = user_secrets.get_secret("huggingface_key")
        secret_value_1 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_1
    

res = int(1.0*256)
small_res = 64
batch_size = 32
embed_dims = 768
n_multicrop = 6

def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        tpu = False
        strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return tpu, strategy

tpu, strategy = auto_select_accelerator()
batch_size = strategy.num_replicas_in_sync * batch_size
print('batch size', batch_size)

In [None]:
import ssl_module
from ssl_module import get_map_fn, get_gcvit_configs, get_flops, att_visualize, get_full_model, AttentionPooling, BarlowModel, VICRegModel, Moco, SimSiam, CLIP, SigLIP
import nas_ftp_module
from nas_ftp_module import upload_file, download_file

# Data import (with Generator)
- 목적 : CXR의 prior knowledge를 SwAV으로 feature map generator에 주입시키기
- Bounding box의 information을 사용하지 않음 + External data를 사용하자

In [None]:
metainfo_dir = "/kaggle/input/chexdet-image-and-annotations/ChestXDet_Metainformations/ChestX-Det-Dataset-main"
train_det_dir = "/kaggle/input/chexdet-image-and-annotations/train_data/train"
val_det_dir = "/kaggle/input/chexdet-image-and-annotations/test_data/test"

df_det_train = pd.read_json("/kaggle/input/chexdet-image-and-annotations/ChestXDet_Metainformations/ChestX-Det-Dataset-main/ChestX_Det_train.json")
df_det_train["file_name"] = [os.path.join(train_det_dir, fname) for fname in df_det_train.file_name.values]
df_det_train = df_det_train.loc[:, ["file_name"]]

df_val = pd.read_json("/kaggle/input/chexdet-image-and-annotations/ChestXDet_Metainformations/ChestX-Det-Dataset-main/ChestX_Det_test.json")
df_val["file_name"] = [os.path.join(val_det_dir, fname) for fname in df_val.file_name.values]
df_val_cxr = df_val.loc[:, ["file_name"]]

#
ext_dir = "/kaggle/input/vinbigdata-chest-xray-original-png/train"
dict_ext = {"file_name" : [os.path.join(ext_dir, fname) for fname in os.listdir(ext_dir)] }
df_ext = pd.DataFrame(dict_ext)
df_train_cxr = pd.concat([df_det_train, df_ext], axis = 0)
print(f"Total training cases for CXR : {len(df_train_cxr)} cases, Validation case : {len(df_val_cxr)} case")

> Deeplesion metainformation dataframe 생성

In [None]:
ct_fname = []
base_img_dir = '/kaggle/input/nih-deeplesion-subset/minideeplesion'
for dirname, _, filenames in tqdm(os.walk(base_img_dir)):
    for filename in filenames:
        ct_fname.append(os.path.join(dirname, filename))
        
df_ct_whole = pd.DataFrame({"file_name" : ct_fname})

df_ct_train, df_ct_val = train_test_split(df_ct_whole, 
                                         test_size = 134,
                                         random_state = seed)
print(f"Total training cases of Chest/Abdomen CT : {len(df_ct_train)} cases, Validation case : {len(df_ct_val)} case")

> import RSNA ICH dataset metainformation dataframe

In [None]:
dicom_dir = "/kaggle/input/rsna-intracranial-hemorrhage-detection/rsna-intracranial-hemorrhage-detection/stage_2_train"
df_train_brainct = pd.read_csv("/kaggle/input/rsna-ich-detection-metadata/df_train_split.csv")
df_val_brainct = pd.read_csv("/kaggle/input/rsna-ich-detection-metadata/df_val_splt.csv").head(300)

for df in [df_train_brainct, df_val_brainct]:
    df["file_name"] = [os.path.join(dicom_dir, fname + ".dcm") for fname in df['SOPInstanceUID']]
    
print(f"Total training cases of Brain, NonCE CT : {len(df_train_brainct)} cases, Validation case : {len(df_val_brainct)} case")

In [None]:
df_train = pd.concat([df_ct_train, df_train_cxr, df_train_brainct], axis = 0, join='inner')
df_val = pd.concat([df_val_cxr, df_ct_val, df_val_brainct], axis = 0, join='inner')

df_train.to_csv("df_train_ER_SSL.csv", index = False)
df_val.to_csv("df_val_ER_SSL.csv", index = False)

df_train.sample(10)

# Building Dataloader in keras-3 style
- Merging 2 kinds of dataset : original files with pd dataframe and tfrecord
    - using this [tf dataset method](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#sample_from_datasets)
    - [reference code](https://www.kaggle.com/code/calebeverett/combining-dataset-examples#Sample)

# Original Files with dataframe
- using keras.utils.Sequence

In [None]:
class ImageDataLoader(keras.utils.Sequence):
    def __init__(self, dataframe, x_col, res, batch_size, y_col = None, shuffle = True):
        self.df = dataframe
        self.x_col = x_col ; self.y_col = y_col
        self.res = res
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()
    def dicom_to_tensor(self, dicom_path):
        dataset = pydicom.dcmread(dicom_path)
        tensor = np.array(dataset.pixel_array)
        slope = dataset.RescaleSlope   # dicom header (Rescale slope)
        intercept = dataset.RescaleIntercept   # dicom header (Rescale intercept)
        center = dataset.WindowCenter   # dicom header (Window center)
        width = dataset.WindowWidth   # dicom header (Window width)

        if(type(dataset.WindowCenter) == pydicom.multival.MultiValue):
                center = float(dataset.WindowCenter[0])
                width = float(dataset.WindowWidth[0])       
        else:    
                center = float(dataset.WindowCenter)
                width = float(dataset.WindowWidth)

        tensor = slope*tensor + intercept
        lbound, ubound = center - 0.5*width, center + 0.5*width
        tensor[np.where(tensor < lbound)] = lbound
        tensor[np.where(tensor > ubound)] = ubound
        tensor = tf.image.resize(tensor[:,:,tf.newaxis], [self.res,self.res],
                                antialias = True) #HU unit
        if tf.shape(tensor)[-1] == 1 :#gray
            tensor = tf.image.grayscale_to_rgb(tensor)
            
        tensor = (tensor - tf.reduce_min(tensor)) / (tf.reduce_max(tensor) - tf.reduce_min(tensor) + 1e-4) #HU unit to Uint8
        tensor = tensor*255.0
        try:
            del dataset
        except:
            pass
        #print(f"Dicom tensor shape : {ops.shape(tensor)}")
        return tensor
    
    def image_to_tensor(self, path):
        if path.split(".")[-1] == "dcm":
            return self.dicom_to_tensor(path)
        
        if "minideeplesion" in str(path).split("/"):
            image = imread(path).astype(np.float32)-32768
            image = image[..., tf.newaxis]
            image = tf.image.resize(image, [self.res, self.res],
                                   antialias = True)
            #print(f"deepLesion tensor shape : {ops.shape(image)}")
            image = tf.clip_by_value(image, -750.0, 700.0)
            image = (image - tf.reduce_min(image))/(tf.reduce_max(image) - tf.reduce_min(image) + 1e-3)
            image = image * 255.0
            
        else:           
            image = load_img(path, target_size = [self.res, self.res])
            image = img_to_array(image)
            #print(f"other tensor shape : {ops.shape(image)}")
        if tf.shape(image)[-1] == 1 :#gray
            image = tf.image.grayscale_to_rgb(image)
            #image = np.array(image)
            

        return image
        
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.df))
        if self.shuffle:
            np.random.shuffle(self.indexes)
            
    def __len__(self):
        return int(np.floor(len(self.df) / self.batch_size))
    
    def __data_generation(self, img_name):
        ## path를 받아 img화 및 token화 하여 실제로 Feeding할 데이터를 반환
        X = []
        for i, fname in enumerate(img_name):
            img = self.image_to_tensor(fname)
            img = tf.convert_to_tensor(img)
            img = tf.cast(img, tf.uint8)
            X.append(img)
        
        return X
        
                
    def __getitem__(self, index):
        indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
        
        img_name = [self.df.iloc[k].loc[self.x_col] for k in indexes]
        
        X = self.__data_generation(img_name)
        #X = np.array(X).reshape([-1, self.res, self.res, 3])
        return X
    
def get_train_gen():
    return ImageDataLoader(df_train, x_col = "file_name",
                         res = res, batch_size = batch_size)

def get_val_gen():
    return ImageDataLoader(df_val, x_col = "file_name",
                         res = res, batch_size = batch_size)


train_ds = tf.data.Dataset.from_generator(get_train_gen, (tf.uint8), output_shapes = (batch_size, res, res,3) ).ignore_errors().prefetch(tf.data.AUTOTUNE).repeat()
val_ds = tf.data.Dataset.from_generator(get_val_gen, (tf.uint8), output_shapes = (batch_size, res, res,3) ).ignore_errors().prefetch(tf.data.AUTOTUNE).repeat()

# Spinal X-ray dataset in TFrecords

In [None]:
#spinal xray dataset

label_map = {0: 'Disc space narrowing', 1: 'Foraminal stenosis', 2: 'No finding', 3: 'Osteophytes', 4: 'Other lesions', 5: 'Spondylolysthesis', 6: 'Surgical implant', 7: 'Vertebral collapse'}

labels = list(label_map.values())
labels.sort()
n_labels = len(label_map)


def deserialize_example(serialized_string, train = True):
    image_feature_description = {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.string)
        }
    parsed_record = tf.io.parse_single_example(serialized_string, image_feature_description)
    image = tf.io.parse_tensor(parsed_record["image"], tf.float32)
    image = (image - tf.reduce_min(image))/(tf.reduce_max(image)-tf.reduce_min(image)+1e-4)
    image = image * 255.0
    image = tf.cast(image, tf.uint8)
    image = ops.reshape(image, [res, res, 3])
    label = tf.io.decode_raw(parsed_record['label'], tf.int32)
    label = ops.reshape(label, [n_labels,])
    return image, label
    
def load_dataset(filenames):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False # disable order, increase speed
    dataset = tf.data.TFRecordDataset(filenames, compression_type = "GZIP", 
                                      num_parallel_reads=tf.data.AUTOTUNE) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(deserialize_example, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

original_train_ds = load_dataset("/kaggle/input/tfrecords-vindr-spinexr-tfrecords/train_gzip_384.tfrecord")
original_val_ds = load_dataset("/kaggle/input/tfrecords-vindr-spinexr-tfrecords/val_gzip_384.tfrecord")

spine_train_ds = original_train_ds.batch(batch_size, drop_remainder = True).map(lambda x,y:x).ignore_errors().repeat().prefetch(tf.data.AUTOTUNE)
spine_val_ds = original_val_ds.batch(batch_size, drop_remainder = True).map(lambda x,y:x).ignore_errors().repeat().prefetch(tf.data.AUTOTUNE)

# Merging 2 dataset

In [None]:
merged_train_ds = tf.data.Dataset.sample_from_datasets([train_ds.unbatch(), spine_train_ds.unbatch()], weights = [0.5, 0.5]).batch(batch_size).ignore_errors().repeat().prefetch(tf.data.AUTOTUNE)
merged_val_ds = tf.data.Dataset.sample_from_datasets([val_ds.unbatch(), spine_val_ds.unbatch()], weights = [0.5, 0.5]).batch(batch_size).ignore_errors().repeat().prefetch(tf.data.AUTOTUNE)

> calculating train and validation steps per epoch

- Spinal dataset : [여기 참고](https://www.kaggle.com/code/khsmdjjys/self-supervised-learning-with-tfrecord)

In [None]:
train_1 = len(df_train) ; val_1 = len(df_val)
train_2 = 8389 ; val_2 = 2077

train_steps = (train_1 + train_2)//batch_size
val_steps = (val_1 + val_2)//batch_size

print(f"Total Train cases, Val cases : {train_1 + train_2, val_1 + val_2}")

# Applying SSL functions
- A. Basic function : return 2 global views (g=2)
- B. SwAV-like strategy : return 2 global views + additional local views (l = 4)
- Use get_map_fn in SSL module:
> parameters of get_map_fn:
    - res = image resolution, 
    - input_type = "without_label" or "supervised"
    - output_type = "ssl" or "ssl_with_label"
    - n_view = HOW MANY VIEWS? -> n_view >= 3일 때, 첫번째 이미지와 두 번째 이미지는 비교적 global information을 담고, 나머지 이미지는 local image (가로/세로 1/2)임.

In [None]:
multiview_fn = get_map_fn(res = res, input_type = "without_label", output_type = "ssl",
                         n_view = n_multicrop)
two_view_fn = get_map_fn(res = res, input_type = "without_label", output_type = "ssl",
                         n_view = 2)

train_ds = merged_train_ds.unbatch().map(two_view_fn, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).ignore_errors().prefetch(tf.data.AUTOTUNE)
val_ds = merged_val_ds.unbatch().map(two_view_fn, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).ignore_errors().prefetch(tf.data.AUTOTUNE)
train_ds_multiview = merged_train_ds.unbatch().map(multiview_fn, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).ignore_errors().prefetch(tf.data.AUTOTUNE)
val_ds_multiview = merged_val_ds.unbatch().map(multiview_fn, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).ignore_errors().prefetch(tf.data.AUTOTUNE)

In [None]:
for imgs in val_ds_multiview.take(1):
    images = imgs
sample_img = images[0]
test_set = tuple([comp[:2] for comp in images])

> Curate the dataset

In [None]:
view_curation = False

In [None]:
if view_curation:
    print("Training dataset Curation, with Basic SSL Fn (2 global views)")
    for originals, augs in train_ds.take(1):
        for origin, aug in zip(originals, augs):
            fig, axes = plt.subplots(1,2, figsize = (16, 8))
            axes = axes.flatten()
            axes[0].imshow(ops.cast(origin, "uint8"))
            axes[1].imshow(ops.cast(aug, "uint8"))
            axes[0].set_title("ORIGINAL")
            axes[1].set_title("GLOBAL VIEW AUGMENTATION")
            plt.show()

    print("Validation dataset Curation, with Basic SSL Fn (2 global views)")
    for originals, augs in val_ds.take(1):
        for origin, aug in zip(originals, augs):
            fig, axes = plt.subplots(1,2, figsize = (16, 8))
            axes = axes.flatten()
            axes[0].imshow(ops.cast(origin, "uint8"))
            axes[1].imshow(ops.cast(aug, "uint8"))
            axes[0].set_title("ORIGINAL")
            axes[1].set_title("GLOBAL VIEW AUGMENTATION")
            plt.show()

In [None]:
if view_curation:
    for multiset in train_ds_multiview.take(1):
        global_views = multiset[:2]
        local_views = multiset[2:]
    for idx in tqdm(range(batch_size)):
        print(f"=================\nBatch No.{idx}\n===================")
        print("Global Views")
        fig, axes = plt.subplots(1,2, figsize = (20,10))
        axes = axes.flatten()
        g1, g2 = global_views[0][idx], global_views[1][idx] 
        axes[0].imshow(ops.cast(ops.squeeze(g1), "uint8"))
        axes[1].imshow(ops.cast(ops.squeeze(g2), "uint8"))
        plt.show()
        print("=================\nLocal Views\n===================")
        fig, axes = plt.subplots(2,2, figsize = (16,16))
        axes = axes.flatten()
        local_set = [local_views[0][idx], local_views[1][idx], local_views[2][idx], local_views[3][idx]] 

        for k in range(4):
            axes[k].imshow(ops.cast(ops.squeeze(local_set[k]), "uint8"))
        plt.show()

# SSL experiment : Information-Maximization

> Model Save and Attention map visualize callbacks

In [None]:
class ModelSaveCallback(keras.callbacks.Callback):
    def __init__(self, exp_name, **kwargs):
        super().__init__(**kwargs)
        self.exp_name = exp_name
    def on_epoch_end(self, epoch, logs=None):
        save_dir = "/kaggle/working/" ; target_dir = '/kaggle/working/model_save'
        os.makedirs(target_dir, exist_ok = True)
        if (epoch % 1 == 0):
            try:
                print("\nModel Saving to local notebook...")
                file_name = f"{self.exp_name}_{self.model.name}_keras_v3_Epoch{epoch}.keras"
                filepath = os.path.join(target_dir, file_name)
                saved_dir = self.model.save(filepath, overwrite=True)
                if (epoch+1) % 5 == 0:
                    print("\nModel Uploading to NAS...")
                    upload_file(file_name, filepath)
                    print("\nModel Saved to Local NAS")
            except Exception as e: 
                print('Model Saving Error:\n', e)
    def on_train_batch_end(self, batch, logs=None):
        save_dir = "/kaggle/working/" ; target_dir = '/kaggle/working/model_save'
        os.makedirs(target_dir, exist_ok = True)
        if (batch % 5000 == 0) and (batch != 0): 
            try:
                print("\nModel Saving to local notebook...")
                file_name = f"{self.exp_name}_{self.model.name}_keras_v3_Batch{batch}.keras"
                filepath = os.path.join(target_dir, file_name)
                saved_dir = self.model.save(filepath, overwrite=True)
                if (batch % 10000 == 0):
                    print("\nModel Uploading to NAS...")
                    upload_file(file_name, filepath)
                    print("\nModel Saved to Local NAS")
            except Exception as e: 
                print('Model Saving Error:\n', e)
                
                
class TrainingViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    def on_epoch_end(self, epoch, logs=None):
        configs = self.model.get_config() ; method = configs["SSL_method"]
        if method in ["CLIP" , "SigLIP", "SPARC"]:
            feature_extractor = self.model
        else:
            feature_extractor = self.model.feature_extractor
        viz_weights = ssl_module.att_visualize(feature_extractor, sample_img, res,
                                              thresholding = True)
        viz_weights = np.array(viz_weights) #batch, heads, res, res, 3
        heads = viz_weights.shape[1]
        origin = ["Original Image"]
        col = [f"Head{idx + 1}" for idx in range(heads)]
        col = origin + col
        
        visualize_data = []
        for idx, weights in enumerate(viz_weights):
            origin_img = [wandb.Image(sample_img[idx])]
            tmp = [wandb.Image(weights[idx]) for idx in range(heads)]
            tmp = origin_img + tmp
            visualize_data.append(tmp)
            del tmp, origin_img
        tbl = wandb.Table(columns = col, data = visualize_data)
        wandb.log({f"Epoch{epoch+1}_{method}_result": tbl})
        del feature_extractor, tbl
        tf.keras.backend.clear_session()
        
        
    def on_train_batch_end(self, batch, logs=None):
        if (batch == (train_steps//2)) or (batch == 0): 
            try:
                configs = self.model.get_config() ; method = configs["SSL_method"]
                if method in ["CLIP" , "SigLIP", "SPARC"]:
                    feature_extractor = self.model
                else:
                    feature_extractor = self.model.feature_extractor
                viz_weights = ssl_module.att_visualize(feature_extractor, sample_img, res,
                                                      thresholding = True)
                viz_weights = np.array(viz_weights) #batch, heads, res, res, 3
                heads = viz_weights.shape[1]
                origin = ["Original Image"]
                col = [f"Head{idx + 1}" for idx in range(heads)]
                col = origin + col
                visualize_data = []
                for idx, weights in enumerate(viz_weights):
                    origin_img = [wandb.Image(sample_img[idx])]
                    tmp = [wandb.Image(weights[idx]) for idx in range(heads)]
                    tmp = origin_img + tmp
                    visualize_data.append(tmp)
                    del tmp, origin_img
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"ZeroBatch_{method}_result": tbl})
                else:
                    wandb.log({f"MidEpoch_{method}_result": tbl})
                del feature_extractor, tbl
                tf.keras.backend.clear_session()
            except Exception as e:
                print("Error code in callback : ", e)
           
        else:
            pass

> Training/logging helper function

In [None]:
def run_exp(model, train_ds = train_ds, val_ds = val_ds, epochs = 10, note= None):
    try:
        wandb.finish()
    except:
        pass
    
    if True :
        wandb_config()
        configs = model.get_config()
        method = configs["SSL_method"]
        if method in ['CLIP', "SigLIP", "SPARC"]:
            _ = model((example_images[:2], example_reports[:2]))
            feature_extractor_flops = get_flops(model.image_encoder, [example_images[:1]])
        else:
            _ = model(test_set)
            feature_extractor_flops = get_flops(model.feature_extractor, [sample_img[:1]])
        env_config = {"batch_size" : batch_size, "original resolution" : res, "local view resolution" : small_res,
                     "Training steps" : train_steps,
                     "Val steps" : val_steps,
                     "train cases" : (train_1 + train_2),
                     "val cases" : (val_1 + val_2),
                     "embed_dims" : embed_dims,
                     "Image resolution" : res,
                     "(Image) Encoder Flops(G)" : feature_extractor_flops,
                     "dtype" : keras.mixed_precision.dtype_policy(),
                      "Optimizer configs" : model.optimizer.get_config(),
                      "Multicrop N" : n_multicrop,
                     }
        configs.update(env_config)
        
        wd = "/kaggle/working/"
        file_name = os.path.join(wd, f"{method}_GrandCXR_mini.keras")
        print(configs, "\n\n")
        model.summary()
        run = wandb.init(project="FusionFocus", 
                         entity="gongbungkim", config = configs, notes = note)

        pass_error = keras.callbacks.TerminateOnNaN()
        wb_callback = WandbMetricsLogger(log_freq = 100)
        
        callbacks = [pass_error, wb_callback, ModelSaveCallback(f"FF_SSL_{method}"), 
                     TrainingViz(run)]
        if val_ds is not None:
            hist = model.fit(train_ds, 
                             steps_per_epoch = train_steps, 
                             epochs = epochs, 
                             validation_data = val_ds, 
                             validation_steps = val_steps, 
                             verbose = 1,
                             callbacks = callbacks)
        else:
            hist = model.fit(train_ds, 
                         steps_per_epoch = train_steps, 
                         epochs = epochs, 
                         verbose = 1,
                         callbacks = callbacks)
    return hist

> Feature extractor setting
- General-Context Vision Transformer,
- Convolution-based models:
    - EfficientNetV2B0, Small
    - ConvNeXtTiny, Small

In [None]:
gc_xxtiny_configs = get_gcvit_configs(res, 64, "GC_ViT_xxtiny")
gc_xxtiny_configs["level_depth"] = [1,1,1,2]

gc_tiny_configs = get_gcvit_configs(res, 64, "GC_ViT_tiny")
gc_tiny_configs["level_depth"] = [1,1,2,4]

gc_small_configs = get_gcvit_configs(res, 64, "GC_ViT_small")
gc_small_configs["level_depth"] = [1,2,4,6]

- Setting Final Feature Extractors

In [None]:
gcvit_xxtiny = get_full_model(gc_xxtiny_configs, res = res, pe_type = None, att_depth = 2, embed_dims = embed_dims)
gcvit_tiny = get_full_model(gc_tiny_configs, res = res, pe_type = None, att_depth = 2, embed_dims = embed_dims)
gcvit_small = get_full_model(gc_small_configs, res = res, pe_type = None, att_depth = 2, embed_dims = embed_dims)
###############
eff_tiny = get_full_model("effnet", res = res, pe_type = 'learnable', att_depth = 2, embed_dims = embed_dims)
eff_small = get_full_model("effnet_small", res = res, pe_type = 'learnable', att_depth = 2, embed_dims = embed_dims)

conv_tiny = get_full_model("convnext", res = res, pe_type = 'learnable', att_depth = 2, embed_dims = embed_dims)
conv_small = get_full_model("convnext_small", res = res, pe_type = 'learnable', att_depth = 2, embed_dims = embed_dims)

mlpmixer = get_full_model("mlpmixer_16_4_512",res = res, att_depth = 2, embed_dims = embed_dims)
convmixer = get_full_model("convmixer_16_4_512",res = res, att_depth = 2, embed_dims = embed_dims)

> learning rate setting

In [None]:
cosine_decay = keras.optimizers.schedules.CosineDecayRestarts(2e-4, train_steps, 
                                                              t_mul=1.0, m_mul=0.5)
cosine_decay_high_lr = keras.optimizers.schedules.CosineDecayRestarts(2e-3, train_steps, 
                                                                      t_mul=1.0, m_mul=0.5)

In [None]:
def ssl_train(module, feature_extractor, embed_dims = embed_dims, multiview = True, gradient_accumulation = None,
             note = ""):
    ssl_trainer = module(feature_extractor, embed_dims = embed_dims, multiview = multiview)
    ssl_trainer.compile(optimizer = keras.optimizers.AdamW(learning_rate = cosine_decay,
                                                         clipvalue = 1.0,
                                                         #amsgrad = True,
                                                           gradient_accumulation_steps=gradient_accumulation,
                                                         )
                      )
    run_exp(ssl_trainer, train_ds_multiview, None, epochs = 100,
       note = "Without validation d/t lack of resources" + note)

# Barlow Twins

In [None]:
#barlow_trainer = BarlowModel(convmixer, 
#                             embed_dims = embed_dims, multiview = True)
#barlow_trainer.compile(optimizer = keras.optimizers.AdamW(learning_rate = cosine_decay,
#                                                         clipvalue = 1.0,
#                                                         #amsgrad = True
#                                                         )
#                      )
#run_exp(barlow_trainer, train_ds_multiview, None, epochs = 100,
#       note = "Without validation d/t lack of resources")

# VICReg

In [None]:
#vic_trainer = VICRegModel(eff_tiny, 
#                             embed_dims = embed_dims, multiview = True)
#vic_trainer.compile(optimizer = keras.optimizers.AdamW(learning_rate = cosine_decay,
#                                                         clipvalue = 1.0,
#                                                         #amsgrad = True
#                                                         )
#                      )
#run_exp(vic_trainer, train_ds_multiview, None, epochs = 100,
#       note = "Without validation d/t lack of resources")

# SimSiam
- instant collapse....

In [None]:
#ssl_train(ssl_module.SimSiam, eff_tiny)

# SimCLR

In [None]:
ssl_train(ssl_module.SimCLR, eff_small,
         gradient_accumulation = 8, note = "+ Attentional pooling with register")