In [1]:
import pandas as pd
import numpy as np
import os, sys
import random
import pydicom

from sklearn.manifold import TSNE
import re
import matplotlib.pyplot as plt
import matplotlib.cm as cm
seed = 42

import warnings
warnings.filterwarnings("ignore")

# ML tools 
sys.path.append("/kaggle/input/kimm-keras-image-model-repository"
               )

import tensorflow as tf
import keras# ; keras.config.set_dtype_policy("mixed_float16")
import kimm
import keras_cv
import keras_nlp

import cv2
from skimage.io import imread
keras.utils.set_random_seed(seed)
import tensorflow_io as tfio
from kaggle_datasets import KaggleDatasets
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import tensorflow_decision_forests as tfdf

print(f"Tensorflow version : {tf.__version__}")
try:
    print(f"Keras version : {keras.__version__}")
except:
    pass

from keras import Input, Model, ops
from keras.models import load_model

from keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, BatchNormalization, LayerNormalization, MultiHeadAttention, Embedding, Subtract, Add, Multiply, GlobalAveragePooling2D, GlobalAveragePooling1D, LayerNormalization
from keras.utils import load_img, img_to_array
from keras.applications import *
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.model_selection import train_test_split
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from tqdm.notebook import tqdm
import wandb
#from wandb.keras import WandbCallback, WandbModelCheckpoint, WandbMetricsLogger
def wandb_config():
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    try:
        secret_value_0 = user_secrets.get_secret("__gcloud_sdk_auth__")
        secret_value_1 = user_secrets.get_secret("huggingface_key")
        secret_value_2 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_2
    except:
        secret_value_0 = user_secrets.get_secret("huggingface_key")
        secret_value_1 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_1
    
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        tpu = False
        strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return tpu, strategy

tpu, strategy = auto_select_accelerator()
import ssl_module
from ssl_module import feature_visualize, get_masking_fn, get_map_fn, get_gcvit_configs, get_flops, att_visualize, get_full_model, AttentionPooling, BarlowModel, VICRegModel, Moco, SimSiam, CLIP, SigLIP
import nas_ftp_module
from nas_ftp_module import upload_file, download_file
import PIL
from PIL import Image as PILImage

Tensorflow version : 2.17.0
Keras version : 3.4.1
Running on 1 replicas
Requirements loaded, keras : v3.4.1, Tensorflow : v2.17.0
RandAug Component in this SSL module :  ['random_contrast_1', 'random_brightness_1', 'random_shear', 'random_shear_1', 'random_translation', 'random_translation_1']


# 실험 계획
- token mixer : gMLP vs gaMLP vs Attention
- ConvNeXt vs pure-metaformer
     - if convnext, FE 후 token mixer의 갯수에 따른 변화
     - if convnext, ImageNet weight vs randomly initialized

# Setting hyperparameters

In [2]:
batch_size = 8
batch_size = strategy.num_replicas_in_sync * batch_size
print('batch size', batch_size)

res = int(2*256)
small_res = 64

n_multicrop = 2
randaug =keras_cv.layers.RandAugment(
    value_range=(0, 255), magnitude=0.1, magnitude_stddev=0.1, geometric = False
)

grayscale = False # False if using pretrained model, True if from scratch
patch_size = 12
heads = 8
att_dims = 64
embed_dims = 512

c = 1 if grayscale else 3
if grayscale:
    pretrained_encoder = None
    depth = 2
    registers = 2
    pretrained_note = "gray_metaformer"
else:
    depth = 0
    registers = 0
    pretrained_encoder = kimm.models.ConvNeXtTiny(input_shape = [res,res,3], 
                                                         include_top = False,
                                                        weights = None,
                                                        ); patch_size = 32
    
    pretrained_vit = kimm.models.VisionTransformerTiny32(input_shape = [res,res,3], include_top = False)
    #pretrained_regnet = kimm.models.RegNetY040(input_shape = [res,res,3], include_top = False); patch_size = 32
    #pretrained_regnet = keras.Model(inputs = pretrained_regnet.input, outputs = pretrained_regnet.get_layer("s4_b1_conv1").output,
    #                    name = f"{pretrained_regnet.name}_upsample")
    #pretrained_vit = kimm.models.VisionTransformerBase16(input_shape = [res,res,3], include_top = False) ; patch_size = 16
    #pretrained_vit = kimm.models.VisionTransformerLarge16(input_shape = [res,res,3], include_top = False) ; patch_size = 16
    
    for layer in pretrained_encoder.layers:
        layer.dtype_policy = keras.mixed_precision.Policy('mixed_float16')
    for layer in pretrained_vit.layers:
        layer.dtype_policy = keras.mixed_precision.Policy('mixed_float16')
    #for layer in pretrained_regnet.layers:
    #    layer.dtype_policy = keras.mixed_precision.Policy('mixed_float16')
    pretrained_note = f"ConvWithMetaEncoder_ImageNet_TM{depth}_RandomInit"
    #depth = 6
    #registers = 0
    #pretrained_encoder = None
    #pretrained_note = f"RGB_metaformer_TM{depth}"
    #patch_size = 24

batch size 8


In [3]:
def get_dual_encoder():
    input_tensor = Input([res,res,3], name = "DualEncoderInputImg")
    # Step 1: Load pretrained lightweight networks
    # Extract feature maps from the pretrained models
    f1 = pretrained_encoder
    f2 = pretrained_vit

    m1 = f1(input_tensor)  # Feature map from MobileNetV2
    m2 = f2(input_tensor)[:, 1:, :]  # Feature map from ViT
    _, w, h, dims = ops.shape(m1)
    m1 = ops.reshape(m1, [-1, w*h, dims])
    combined_features = keras.layers.Identity(name = "MergedFeatureMap")(ops.concatenate([m1, m2], axis = -1))

    return Model(input_tensor, combined_features,
                name = f"{f1.name}With{f2.name}_dualencoder")

# Example usage
#dual_model = get_dual_encoder()
#dual_model.summary()

- radimagenet tfrecord key : image, label
- nih cxr tfrecord key : image_raw, label

# RadImageNet decoding

In [4]:
def _parse_tfrecord(res = res):
    def parse_tfrecord(tfrecord):
        features = {'image': tf.io.FixedLenFeature([], tf.string),
                    'label': tf.io.FixedLenFeature([], tf.int64),
                    }
        x = tf.io.parse_single_example(tfrecord, features)
        image_train = tf.image.decode_jpeg(x['image'], channels=1)
        image_train = _transform_images(res = res)(image_train)
        label = tf.cast(x["label"], tf.int32)
        return (image_train, label)
    
    return parse_tfrecord


def _transform_images(res = res):
    def transform_images(x_train):
        x_train = tf.image.resize_with_pad(x_train, res, res, antialias = True)
        x_train = tf.cast(x_train, tf.uint8)
        return x_train
    return transform_images

def load_tfrecord_dataset(tfrecord_name, res = res, batch_size = batch_size, shuffle=True, buffer_size=10240):
    """load dataset from tfrecord"""
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name, compression_type = "GZIP")
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.batch(batch_size, drop_remainder = True)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

train_radimagenet_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RagImageNet_Train_GZIP.tfrecord")
val_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RagImageNet_Test_GZIP.tfrecord")

# NIH CXR decoding

In [5]:
def _parse_tfrecord(res = res):
    def parse_tfrecord(tfrecord):
        features = {'image_raw': tf.io.FixedLenFeature([], tf.string),
                    'label': tf.io.FixedLenFeature([], tf.int64),
                    }
        x = tf.io.parse_single_example(tfrecord, features)
        image_train = tf.image.decode_jpeg(x['image_raw'], channels=1)
        image_train = _transform_images(res = res)(image_train)
        label = tf.cast(x["label"], tf.int32)
        return (image_train, label)
    
    return parse_tfrecord


def _transform_images(res = res):
    def transform_images(x_train):
        x_train = tf.image.resize_with_pad(x_train, res, res, antialias = True)
        x_train = tf.cast(x_train, tf.uint8)
        return x_train
    return transform_images

def load_tfrecord_dataset(tfrecord_name, res = res, batch_size = batch_size, shuffle=True, buffer_size=10240):
    """load dataset from tfrecord"""
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name)
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    if batch_size:
        dataset = dataset.batch(batch_size, drop_remainder = True)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

nih_cxr_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/nih_cxr_images.tfrecords")

# Merging 2 datasets

In [6]:
train_ds = tf.data.Dataset.sample_from_datasets([train_radimagenet_ds.unbatch(), nih_cxr_ds.unbatch()], weights = [0.75, 0.25]).batch(batch_size, drop_remainder = True).repeat().prefetch(tf.data.AUTOTUNE)
val_ds_ = tf.data.Dataset.sample_from_datasets([train_radimagenet_ds.unbatch(), nih_cxr_ds.unbatch()], weights = [0.75, 0.25]).batch(16, drop_remainder = True).prefetch(tf.data.AUTOTUNE)
#train ds output : ([batch_size, res, res, 1], [batch_size,])
# train data curation
for images, labels in val_ds_.take(1):
    sample_img = images
    labels = labels
    if not grayscale:
        sample_img = tf.image.grayscale_to_rgb(sample_img)
del val_ds_

In [7]:
def get_sobel_fn():
    def sobel_merge(image, label):
        image = image[tf.newaxis, ...]
        rand_num = keras.random.randint(shape = (), minval = 1, maxval = 10)
        if rand_num > 5:
            image = ops.cast(image, 'float32')
            ed = tf.image.sobel_edges(image)[..., 0, :]
            ed_norm = 255.0 * (ed - ops.min(ed)) / (ops.max(ed) - ops.min(ed)) ; del ed
            ed_norm = ops.cast(ed_norm, "uint8")
            image =ops.concatenate([image, ed_norm],
                                  axis = -1)
            image = ops.cast(image, "uint8")
        else:
            try:
                image = tf.image.grayscale_to_rgb(image)
            except:
                pass
        image = image[0]
        return image, label
    return sobel_merge
sobel_merge = get_sobel_fn()

# Convert supervised dataset into SSL dataset

In [8]:
multiview_fn = get_map_fn(res = res, input_type = "supervised", output_type = "ssl",
                         n_view = n_multicrop, grayscale = grayscale)

train_ds_multiview = train_ds.map(multiview_fn, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
val_ds_multiview = val_ds.map(multiview_fn, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

mask_map_fn_ = get_masking_fn(grayscale = grayscale, masking_rate = 0.5, patch_size = patch_size)

train_edge_ds = train_ds.unbatch().map(sobel_merge, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size, drop_remainder = True).prefetch(tf.data.AUTOTUNE)

def masking_function(image, label):
    return mask_map_fn_(image)

train_ds_masked = train_ds.map(masking_function, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE).repeat()
#train_ds_edge_masked = train_edge_ds.unbatch().map(masking_function, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size, drop_remainder = True).prefetch(tf.data.AUTOTUNE).repeat()
#train_ds_edge_multiview = train_edge_ds.map(multiview_fn, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
#train_ds_edge_simple_multiview = train_edge_ds.map(simple_aug_fn, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

----------
# Experiment - helper functions

In [9]:
df_train_rad = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RadImgNet_train.csv")
df_train_nih = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/nih_trainval_split.csv"
                          )
df_val_rad = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RadImgNet_test.csv")
df_val_nih = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/nih_test_split.csv")


train_cases = len(df_train_rad) + len(df_train_nih) + len(df_val_nih)
val_cases = len(df_val_rad)

train_steps = train_cases//batch_size
val_steps = val_cases//batch_size
print(f"Total train cases : {train_cases}, validation cases : {val_cases}")

Total train cases : 1303237, validation cases : 163796


In [10]:
class GCAdamW(keras.optimizers.AdamW):
    def get_gradients(self, loss, params):
        # We here just provide a modified get_gradients() function since we are
        # trying to just compute the centralized gradients.

        grads = []
        gradients = super().get_gradients()
        for grad in gradients:
            grad_len = len(grad.shape)
            if grad_len > 1:
                axis = list(range(grad_len - 1))
                grad -= ops.mean(grad, axis=axis, keep_dims=True)
            grads.append(grad)

        return grads
    
class GCAdam(keras.optimizers.Adam):
    def get_gradients(self, loss, params):
        # We here just provide a modified get_gradients() function since we are
        # trying to just compute the centralized gradients.

        grads = []
        gradients = super().get_gradients()
        for grad in gradients:
            grad_len = len(grad.shape)
            if grad_len > 1:
                axis = list(range(grad_len - 1))
                grad -= ops.mean(grad, axis=axis, keep_dims=True)
            grads.append(grad)

        return grads

In [11]:
class ModelSaveCallback(keras.callbacks.Callback):
    def __init__(self, exp_name, message = None, **kwargs):
        super().__init__(**kwargs)
        self.exp_name = exp_name
        self.message = message if message is not None else " "
    def on_epoch_end(self, epoch, logs=None):
        feature_ext_name = self.model.feature_extractor.name
        save_dir = "/kaggle/working/" ; target_dir = '/kaggle/working/model_save'
        os.makedirs(target_dir, exist_ok = True)
        if (epoch % 1 == 0):
            try:
                print("\nModel Saving to local notebook...")
                file_name = f"{feature_ext_name}_FE{self.exp_name}_Epoch{epoch}_{self.message}.keras"
                filepath = os.path.join(target_dir, file_name)
                saved_dir = self.model.feature_extractor.save(filepath, overwrite=True)
                print("\nModel Uploading to NAS...")
                upload_file(file_name, filepath)
                print("\nModel Saved to Local NAS")
            except Exception as e: 
                print('Model Saving Error:\n', e)
    def on_train_batch_end(self, batch, logs=None):
        feature_ext_name = self.model.feature_extractor.name
        save_dir = "/kaggle/working/" ; target_dir = '/kaggle/working/model_save'
        os.makedirs(target_dir, exist_ok = True)
        if (batch % 25000 == 0) and (batch != 0): 
            try:
                print("\nModel Saving to local notebook...")
                file_name = f"{feature_ext_name}_FE{self.exp_name}_Batch{batch}_{self.message}.keras"
                filepath = os.path.join(target_dir, file_name)
                saved_dir = self.model.feature_extractor.save(filepath, overwrite=True)
                print("\nModel Uploading to NAS...")
                upload_file(file_name, filepath)
                print("\nModel Saved to Local NAS")
            except Exception as e: 
                print('Model Saving Error:\n', e)
                
class TemperatureScheduler(keras.callbacks.Callback):
    def __init__(self, initial_t = 0.5, decay_rate = 0.99):
        super().__init__()
        self.initial_t = initial_t
        self.decay_rate = decay_rate

    def on_train_batch_begin(self, batch, logs=None):
        if not hasattr(self.model, 't'):
            self.model.t = ops.convert_to_tensor(self.initial_t, dtype='float32')
        else:
            if (batch > 0) and (batch % 5000 == 0):
                self.model.t = self.model.t * self.decay_rate

In [12]:
class TrainingViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    def on_epoch_end(self, epoch, logs=None):
        try:
            configs = self.model.get_env_config() ; method = configs["SSL_method"]
            if method in ["CLIP" , "SigLIP", "SPARC"]:
                feature_extractor = self.model
            else:
                try:
                    feature_extractor = self.model.feature_extractor
                except:
                    feature_extractor = self.model.get_full_model(res = res)
            viz_weights, merged_weights = ssl_module.att_visualize(feature_extractor, sample_img, res,
                                                  thresholding = 0)
            viz_weights = np.array(viz_weights) #batch, heads, res, res, 3
            merged_weights = np.array(merged_weights)
            heads = viz_weights.shape[1]
            origin = ["Original Image"]
            col = [f"Head{idx + 1}" for idx in range(heads)]
            col = origin + ["Merged"] + col

            visualize_data = []
            for idx, weights in enumerate(viz_weights):
                origin_img = [wandb.Image(sample_img[idx])]
                merged_tmp = [wandb.Image(merged_weights[idx])]
                tmp = [wandb.Image(weights[idx]) for idx in range(heads)]
                tmp = origin_img + merged_tmp + tmp
                visualize_data.append(tmp)
                del tmp, origin_img, merged_tmp
            tbl = wandb.Table(columns = col, data = visualize_data)
            wandb.log({f"Epoch{epoch+1}_{method}_result": tbl})
            del feature_extractor, tbl
            tf.keras.backend.clear_session()
            
            # feature vector visualization
            embed_v = feature_visualize(self.model, sample_img)
            data = [[x, y] for (x, y) in zip(embed_v[..., 0], embed_v[..., 1])]
            table = wandb.Table(data=data, columns = ["x", "y"])
            wandb.log({f"Epoch{epoch+1}_{method}_FeatureViz" : wandb.plot.scatter(table, "x", "y", title="TSNE Scatter Plot")})
            tf.keras.backend.clear_session()

            
        except Exception as e: 
                print('Model Saving Error:\n', e)
        
    def on_train_batch_end(self, batch, logs=None):
        if (batch % (10000) == 0) : 
            try:
                configs = self.model.get_env_config() ; method = configs["SSL_method"]
                if method in ["CLIP" , "SigLIP", "SPARC"]:
                    feature_extractor = self.model
                else:
                    try:
                        feature_extractor = self.model.feature_extractor
                    except:
                        feature_extractor = self.model.get_full_model(res = res)
                viz_weights, merged_weights = ssl_module.att_visualize(feature_extractor, sample_img, res,
                                                      thresholding = False)
                _, rollout_merged_image = ssl_module.att_visualize_merged(feature_extractor, sample_img, res)
                viz_weights = np.array(viz_weights) #batch, heads, res, res, 3
                merged_weights = np.array(merged_weights)
                heads = viz_weights.shape[1]
                origin = ["Original Image"]
                col = [f"Head{idx + 1}" for idx in range(heads)]
                col = origin + ["MergedMap"] + ['Top-K head merging map'] + col
                
                visualize_data = []
                for idx, weights in enumerate(viz_weights): #heads, res, res, 3
                    origin_img = [wandb.Image(sample_img[idx])]
                    merged_map = [wandb.Image(merged_weights[idx])]
                    merged_map_rollout = [wandb.Image(rollout_merged_image[idx])]
                    
                    tmp = [wandb.Image(weights[idx]) for idx in range(heads)]
                    tmp = origin_img + merged_map + merged_map_rollout + tmp
                    visualize_data.append(tmp)
                    del tmp, origin_img, merged_map, merged_map_rollout
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"ZeroBatch_{method}_result": tbl})
                else:
                    wandb.log({f"MidEpoch_{method}_result": tbl})
                del feature_extractor, tbl
                tf.keras.backend.clear_session()
                       
                embed_v = feature_visualize(self.model, sample_img)
                data = [[x, y] for (x, y) in zip(embed_v[..., 0], embed_v[..., 1])]
                table = wandb.Table(data=data, columns = ["x", "y"])
                wandb.log({f"Batch{batch}_{method}_FeatureViz" : wandb.plot.scatter(table, "x", "y", title="TSNE Scatter Plot")})
                tf.keras.backend.clear_session()
                       
            except Exception as e:
                print("Error code in callback : ", e)
           
        else:
            pass

> Real world evaluation and Segmentation callback

In [13]:
real_world_dir = "/kaggle/input/real-world-medical-image-dataset-for-evaluation" ; filenames_ = os.listdir(real_world_dir)
filenames_.sort()
labels_ = [name.split('.')[0] for name in filenames_]
real_world_files = [os.path.join(real_world_dir, paths) for paths in filenames_]
def get_img_tensor(path, res = res) :
    file = tf.io.read_file(path)
    c =1 if grayscale else 3
    image = tf.io.decode_image(file, channels=c)
    image = tf.image.resize_with_pad(image, res, res, antialias = True)
    image = ops.cast(image, "uint8")
    return image
real_world_images = tf.stack([get_img_tensor(f) for f in real_world_files],
                             axis = 0)


In [14]:
class RealWorldViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    
    def on_train_batch_end(self, batch, logs=None):
        if (batch % (5000) == 0) : 
            try:
                configs = self.model.get_env_config() ; method = configs["SSL_method"]
                if method in ["CLIP" , "SigLIP", "SPARC"]:
                    feature_extractor = self.model
                else:
                    try:
                        feature_extractor = self.model.feature_extractor
                    except:
                        feature_extractor = self.model.get_full_model(res = res)
                viz_weights, merged_weights = ssl_module.att_visualize(feature_extractor, real_world_images, res,
                                                      thresholding = False)
                _, rollout_merged_image = ssl_module.att_visualize_merged(feature_extractor, real_world_images, res)
                viz_weights = np.array(viz_weights) #batch, heads, res, res, 3
                merged_weights = np.array(merged_weights)
                
                heads = viz_weights.shape[1]
                origin = ["Original Image"]
                col = [f"Head{idx + 1}" for idx in range(heads)]
                col = origin + ["Original Label"] + ["MergedMap"] + ['Top-K head merging map'] + col
                visualize_data = []
                for idx, weights in enumerate(viz_weights):
                    origin_img = [wandb.Image(real_world_images[idx])]
                    lab = [labels_[idx]]
                    
                    merged_map = [wandb.Image(merged_weights[idx])]
                    merged_map_rollout = [wandb.Image(rollout_merged_image[idx])]
                    
                    tmp = [wandb.Image(weights[idx]) for idx in range(heads)]
                    tmp = origin_img +lab +  merged_map + merged_map_rollout + tmp
                    visualize_data.append(tmp)
                    del tmp, origin_img, merged_map, merged_map_rollout
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"RW_ZeroBatch_{method}_result": tbl})
                else:
                    wandb.log({f"RW_Batch{batch}_{method}_result": tbl})
                del tbl
                tf.keras.backend.clear_session()
                
                embed_v = feature_visualize(self.model, real_world_images)
                data = [[x, y] for (x, y) in zip(embed_v[..., 0], embed_v[..., 1])]
                table = wandb.Table(data=data, columns = ["x", "y"])
                wandb.log({f"RW_Batch{batch}_{method}_FeatureViz" : wandb.plot.scatter(table, "x", "y", 
                                                                                       title=f"RW_Batch{batch}_TSNE")})
                #########Hierarchical clustering##########
                feature_map = feature_extractor(real_world_images)[1]
                n_patch = feature_map.shape[1] ; w_ = ops.sqrt(ops.cast(n_patch, "float32")
                                                              )
                w_ = ops.cast(w_, "int32")
                embed_dims = feature_map.shape[-1]
                clustering_output = ssl_module.H_clustering(n_clusters = 200)(feature_map)
                clustering_output = ops.reshape(clustering_output, [-1, w_, w_,1])
                data = []
                for i, sample_image in enumerate(real_world_images):
                    cluster_plot = tf.convert_to_tensor(clustering_output[i])
                    cluster_plot = (cluster_plot - ops.min(cluster_plot)) / (ops.max(cluster_plot) - ops.min(cluster_plot)) 
                    cluster_plot *= 255 ; cluster_plot = tf.image.grayscale_to_rgb(ops.cast(cluster_plot, "uint8"))
                    cluster_plot = np.array(cluster_plot)
                    cluster_plot = PILImage.fromarray(cluster_plot, mode="RGB")
                    cluster_plot = wandb.Image(cluster_plot)
                    rw_image = wandb.Image(sample_image)
                    tmp = [rw_image, cluster_plot]
                    data.append(tmp) ; del tmp, rw_image
                 
                table = wandb.Table(data=data, columns = ["Original_image", "AgglomerativeCluster"])
                wandb.log({f"Cluster_RW_Batch{batch}_{method}_result": table})
                tf.keras.backend.clear_session()
                del feature_extractor
            except Exception as e:
                print("Error code in callback : ", e)
        else:
            pass
        

-------------
- Special callback for QNCLR

In [15]:
class QRealWorldViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    
    def on_train_batch_end(self, batch, logs=None):
        if (batch % (5000) == 0) : 
            if True:
                feature_extractor = self.model.feature_extractor
                real_world_images = ops.cast(real_world_images, "float32")
                q_attention_weights, q_batch_merged = ssl_module.q_visualize(feature_extractor, real_world_images, res,
                                                      thresholding = False)
                q_attention_weights = np.array(q_attention_weights) #batch, N_Q, res, res, 3
                q_batch_merged = np.array(q_batch_merged) #batch, res, res, 3
                
                n_queries = q_attention_weights.shape[1]
                origin = ["Original Image"]
                col = [f"LearnableQuery{idx + 1}" for idx in range(n_queries)]
                
                col = origin + ["Original Label"] + ["MergedMap"] + col
                print(col)
                
                visualize_data = []
                for idx, weights in enumerate(q_attention_weights):
                    origin_img = [wandb.Image(real_world_images[idx])]
                    lab = [labels_[idx]]
                    
                    merged_map = [wandb.Image(q_batch_merged[idx])]
                    each_query_map = [wandb.Image(weights[idx]) for idx in range(n_queries)]
                    
                    tmp = origin_img +lab + merged_map + each_query_map
                    visualize_data.append(tmp)
                    print(len(tmp))
                    del tmp, origin_img, lab, merged_map, each_query_map
                    
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"RW_ZeroBatch_{method}_result": tbl})
                else:
                    wandb.log({f"RW_Batch{batch}_{method}_result": tbl})
                del feature_extractor, tbl
                tf.keras.backend.clear_session()

            
        else:
            pass

In [16]:
class RealWorldPatchViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    
    def on_train_batch_end(self, batch, logs=None):
        if (batch % (5000) == 0) : 
            if True:
                #real_world_images = ops.cast(real_world_images, "uint8")
                feature_extractor = self.model.feature_extractor
                patch_heatmap, patch_merged_images = ssl_module.pca_patch_viz(feature_extractor, real_world_images)
                patch_heatmap = np.array(patch_heatmap)
                patch_merged_images = np.array(patch_merged_images)
                
                col = ["Original image"] + ["Original Label"] + ["Merged image"] + ["Encoded Patches"]
                print(col)
                
                visualize_data = []
                for idx, m_img in enumerate(patch_merged_images):
                    origin_img = [wandb.Image(real_world_images[idx])]
                    lab = [labels_[idx]]
                    
                    merged_image = [wandb.Image(m_img)]
                    e_patches = [wandb.Image(patch_heatmap[idx])]
                    
                    tmp = origin_img +lab + merged_image + e_patches
                    visualize_data.append(tmp)
                    
                    del tmp, origin_img, lab, merged_image, e_patches
                    
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"RW_ZeroBatch_Patch_result": tbl})
                else:
                    wandb.log({f"RW_Batch{batch}_Patch_result": tbl})
                del feature_extractor, tbl
                tf.keras.backend.clear_session()

            
        else:
            pass

In [17]:
class SegViz(keras.callbacks.Callback):
    def __init__(self, run, images, labels = None):
        super().__init__()
        self.run = run
        self.images = images
        self.labels = labels
    def on_train_batch_end(self, batch, logs=None):
        configs = self.model.get_env_config() ; method = configs["SSL_method"]
        if (batch % (10000) == 0) and  (method in ["UnsupSeg", "MixedUnsupSeg"]): 
            try:
                heatmap, superimposed_images = self.model.get_segments(self.images)
                origin = ["Original Image"]
                col = origin + ["Original Label"] + ["Segmentation Result"]
                visualize_data = []
                for idx, sup_img in enumerate(superimposed_images):
                    origin_img = [wandb.Image(self.images[idx])]
                    if self.labels is None:
                        lab = ["Label not provided."]
                    else:
                        lab = [self.labels[idx]] 
                    tmp = [wandb.Image(sup_img)]
                    tmp = origin_img + lab + tmp
                    visualize_data.append(tmp)
                    del tmp, origin_img
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"Seg_ZeroBatch_{method}_result": tbl})
                else:
                    wandb.log({f"Seg_MidEpoch_{method}_result": tbl})
                
                tf.keras.backend.clear_session()
            except Exception as e:
                print("Error code in Segmentation callback : ", e)
        else:
            pass

In [18]:
def run_exp(model, train_ds = train_ds, val_ds = val_ds, epochs = 10, note= None, exp_name = None):
    try:
        wandb.finish()
    except:
        pass
    
    if True :
        wandb_config()
        configs = model.get_env_config()
        method = configs["SSL_method"]
        try:
            feature_extractor = model.feature_extractor
        except:
            feature_extractor = model.get_full_model(res = res)
        
        if method in ['CLIP', "SigLIP", "SPARC"]:
            _ = model((example_images[:2], example_reports[:2]))
        elif method in ["SimMIM", "MixedMIM","DistilMIM", "MixedUnsupSeg", 
                        "NCLR_nnclr_without_momentum", 'NCLR_snclr_without_momentum']:
            pass
        else:
            pass
        try:
            feature_extractor_flops = get_flops(feature_extractor, [tf.random.normal([1,res,res,c])])
        except:
            feature_extractor_flops = "Uncheck"
        del feature_extractor
        
        env_config = {"batch_size" : batch_size, "Patch size": patch_size,
                      "original resolution" : res, "local view resolution" : small_res,
                     "Training steps" : train_steps,
                     "Val steps" : val_steps,
                     "train cases" : train_cases,
                     "val cases" : val_cases,
                     "embed_dims" : embed_dims,
                     "Image resolution" : res,
                     "(Image) Encoder Flops(G)" : feature_extractor_flops,
                     "dtype" : keras.mixed_precision.dtype_policy(),
                      "Optimizer configs" : model.optimizer.get_config(),
                      "Multicrop N" : n_multicrop, "metaencoder depth" : depth, 'embedding dims' : embed_dims,
                     }
        configs.update(env_config)
        
        wd = "/kaggle/working/"
        file_name = os.path.join(wd, f"{method}_radimgnet_mini.keras")
        print(configs, "\n\n")
        
        run = wandb.init(project="RadImageNet", 
                         entity="gongbungkim", config = configs, notes = note,
                        name = exp_name)
        wandb.run.log_code(".")
        pass_error = keras.callbacks.TerminateOnNaN()
        wb_callback = wandb.keras.WandbMetricsLogger(log_freq = 100)
        if isinstance(model, ssl_module.QNCLR):
            callbacks = [pass_error, wb_callback, ModelSaveCallback(f"RI_SSL_{method}", note), 
                        QRealWorldViz(run),
                        TemperatureScheduler()]
        else:
            callbacks = [pass_error, wb_callback, ModelSaveCallback(f"RI_SSL_{method}", note), 
                         TrainingViz(run),
                        RealWorldViz(run), RealWorldPatchViz(run),
                        SegViz(run, images = sample_img),
                        SegViz(run, images = real_world_images, labels = labels_),
                        TemperatureScheduler()]
        if val_ds is not None:
            hist = model.fit(train_ds, 
                             steps_per_epoch = train_steps, 
                             epochs = epochs, 
                             validation_data = val_ds, 
                             validation_steps = val_steps, 
                             verbose = 1,
                             callbacks = callbacks)
        else:
            hist = model.fit(train_ds, 
                         steps_per_epoch = train_steps, 
                         epochs = epochs, 
                         verbose = 1,
                         callbacks = callbacks)
    return hist

In [19]:
cosine_decay = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate = 1e-6,
    decay_steps = int(0.5*train_steps),
    alpha=1e-5,
    name='CosineDecay',
    warmup_target=2e-4,
    warmup_steps=train_steps - int(0.3*train_steps)
)

lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate = 2e-4,
    decay_steps=20000,
    decay_rate=0.75,
    staircase=True)

In [20]:
def ssl_train(module, feature_extractor, learning_rate = lr_schedule,
              embed_dims = embed_dims, multiview = True, gradient_accumulation = None, use_ema = False,
             note = "", name = "",
             apply_barlow = False, apply_simclr = False):
    try:
        ssl_trainer = module(feature_extractor, embed_dims = embed_dims, multiview = multiview,
                            apply_barlow = apply_barlow, apply_simclr = apply_simclr)
    except Exception as e:
        print("Error : ",e)
        ssl_trainer = module(feature_extractor, embed_dims = embed_dims, multiview = multiview)
    ssl_trainer.compile(optimizer = keras.optimizers.Adam(learning_rate = learning_rate,
                                                         clipnorm = 0.5,
                                                         #amsgrad = True,
                                                           gradient_accumulation_steps=gradient_accumulation,
                                                         use_ema = use_ema),
                        jit_compile = False
                      )
    
    run_exp(ssl_trainer, train_ds_multiview, None, epochs = 100,
       note = note, exp_name = name)

In [21]:
ibot = 0
other = 0
mim = 0
qnclr = 0
sobel_other = 1

In [22]:
if mim:
    model_ = 'attention'
    vanilla_model = ssl_module.get_metaformer(model_, res = res, embed_dims = embed_dims, 
                                              att_depth = depth, att_heads = heads,
                                              att_dims = att_dims,
                                              grayscale = grayscale, patch_size = patch_size, 
                                              register_tokens = registers,
                                             pretrained_encoder = pretrained_encoder,
                                             return_patches = True)
    vanilla_model.summary()
    ssl_trainer = ssl_module.MixedMIM(vanilla_model, grayscale = grayscale, patch_size = patch_size)
    ssl_trainer.compile(optimizer = keras.optimizers.AdamW(learning_rate = cosine_decay, 
                                                           clipnorm = 1.0,
                                                           #gradient_accumulation_steps=64,
                                                           use_ema = True
                                                          ),
                        jit_compile = False
                      )
    configs = ssl_trainer.get_env_config()
    method = configs["SSL_method"]
    run_exp(ssl_trainer, train_ds_edge_masked, None, 
           note = pretrained_note+"_"+model_, exp_name = f"SobelMerging_Patch{patch_size}_{method}_{model_}")

In [23]:
if qnclr:
    vanilla_model = ssl_module.get_encdec_model(pretrained_encoder,
                                               res = res,
                                               att_dims = embed_dims,
                                               q_size = 8,
                                               encoder_trainable = True)
    ssl_trainer = ssl_module.QNCLR(vanilla_model, embed_dims = embed_dims, t = 0.05)
    ssl_trainer.use_mim = True
    ssl_trainer.compile(optimizer = keras.optimizers.AdamW(learning_rate = 5e-5, 
                                                               clipnorm = 1.0,
                                                               gradient_accumulation_steps=32,
                                                               #use_ema = True
                                                              ),
                            jit_compile = False,
                            
                          )
    method = "Q_NNCLR"

    run_exp(ssl_trainer, train_ds_masked, None, 
               note = pretrained_note, exp_name = f"{method}_StrongAug")

In [None]:
if sobel_other:
    if True:
        model_ = 'attention'
        vanilla_model = ssl_module.get_metaformer(model_, res = res, embed_dims = embed_dims, 
                                                  att_depth = depth, att_heads = heads,
                                                  att_dims = att_dims,
                                                  grayscale = grayscale, patch_size = patch_size, 
                                                  register_tokens = registers,
                                                 pretrained_encoder = pretrained_encoder,
                                                 # pretrained_encoder = pretrained_regnet,
                                                  #pretrained_encoder = pretrained_vit,pretrained_vit = True,
                                                 return_patches = True)
        dual_model = ssl_module.get_metaformer(model_, res = res, embed_dims = embed_dims, 
                                                  att_depth = depth, att_heads = heads,
                                                  att_dims = att_dims,
                                                  grayscale = grayscale, patch_size = patch_size, 
                                                  register_tokens = registers,
                                                 pretrained_encoder = get_dual_encoder(),
                                                 # pretrained_encoder = pretrained_regnet,
                                                  #pretrained_encoder = pretrained_vit,pretrained_vit = True,
                                                 return_patches = True)
        dual_model.summary()
        
        feature_map = vanilla_model(real_world_images)[1]
        n_patch = feature_map.shape[1] ; w_ = ops.sqrt(ops.cast(n_patch, "float32")
                                                                      )
        w_ = ops.cast(w_, "int32")
        embed_dims = feature_map.shape[-1]
        clustering_output = ssl_module.H_clustering(n_clusters = 100)(feature_map)
        clustering_output = ops.reshape(clustering_output, [-1, w_, w_,1])

        
        #pretrained_note = f"ViT{depth}"
        #ssl_trainer = ssl_module.DINO_MIM(vanilla_model, vanilla_model)
        #ssl_trainer = ssl_module.NCLR(vanilla_model, embed_dims = embed_dims, subtype = "nnclr", use_mim = True, patch_size = patch_size)
        ssl_trainer = ssl_module.Moco(vanilla_model, use_dino = False, q_size = 4096)
        #ssl_trainer = ssl_module.Moco(dual_model, use_dino = True)
        ssl_trainer.compile(optimizer = keras.optimizers.SGD(learning_rate = 5e-5,
                                                            momentum = 0.9,
                                                            weight_decay = 0.0001,
                                                            ),
                            jit_compile = False,
                            
                          )
        configs = ssl_trainer.get_env_config()
        method = configs["SSL_method"]

        run_exp(ssl_trainer, train_ds_masked, None, 
               note = pretrained_note+"_"+model_, exp_name = f"{method}_{model_}_StrongAug_SimplerNCLR")


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
{'feature_extractor_name': 'ConvNeXtTiny_Metaformer_res512_type_attention', 'embed_dims': 512, 'SSL_method': 'Moco', 'Queue size': 32768, 'temperature': 0.1, 'batch_size': 8, 'Patch size': 32, 'original resolution': 512, 'local view resolution': 64, 'Training steps': 162904, 'Val steps': 20474, 'train cases': 1303237, 'val cases': 163796, 'Image resolution': 512, '(Image) Encoder Flops(G)': 47.843845173, 'dtype': <DTypePolicy "float32">, 'Optimizer configs': {'name': 'SGD', 'learning_rate': 4.999999873689376e-05, 'weight_decay': None, 'clipnorm': 1.0, 'global_clipnorm': None, 'clipvalue': None, 'use_ema': False, 'ema_momentum': 0.99, 'ema_overwrite_frequency': None, 'loss_scale_factor': None, 'gradient_accumulation_steps': None, 'momentum': 0.0, 'nesterov': False}, 'Multicrop N': 2, 'metaencoder depth': 0, 'embedding dims': 512} 




[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mgongbungkim[0m. Use [1m`wandb login --relogin`[0m to force relogin




Epoch 1/10
Using Raw model with Att pooling 
 Possible error: 
 'Functional' object has no attribute 'get_full_model'
(16, 8, 256) tf.Tensor(16, shape=(), dtype=int32) 8
Using Raw model with Att pooling 
 Possible error: 
 'Functional' object has no attribute 'get_full_model'
(22, 8, 256) tf.Tensor(16, shape=(), dtype=int32) 8
['Original image', 'Original Label', 'Merged image', 'Encoded Patches']
[1m  5000/162904[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m32:14:02[0m 735ms/step - Attention_weight_entropy: 0.0217 - Moco_loss: 9.6725 - Q_DINO_loss: 8.4864 - Total_loss: 9.6725Using Raw model with Att pooling 
 Possible error: 
 'Functional' object has no attribute 'get_full_model'
(22, 8, 256) tf.Tensor(16, shape=(), dtype=int32) 8
['Original image', 'Original Label', 'Merged image', 'Encoded Patches']
[1m  6298/162904[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m32:14:52[0m 741ms/step - Attention_weight_entropy: 0.0217 - Moco_loss: 10.6622 - Q_DINO_loss: 8.4688 - Total_loss: 10.6622

In [None]:
if ibot:
    ssl_trainer = ssl_module.iBOT(att_depth = depth, att_dims = att_dims, att_heads = heads,
                                  embed_dims = 2048, patch_size = patch_size,

                                  multiview = True, apply_simclr = False,
                                  grayscale = True
                                 )
    ssl_trainer.compile(optimizer = keras.optimizers.AdamW(learning_rate = lr_schedule,
                                                         clipnorm = 1.0, use_ema = True),
                       jit_compile = False)
    run_exp(ssl_trainer, train_ds_multiview, None, epochs = 100,
           note = "+ NEW aug, New Patching", exp_name = "iBOT_VanillaViT")

In [None]:
if other:
    model_ = 'gMLP'
    if pretrained_encoder is None:
        note = "From Scratch"
        assert grayscale is True, "If building from scratch, make sure [grayscale = True]"
    else:
        note = f"With pretrained {pretrained_encoder.name}"
        assert grayscale is False, "If using pretrained network, make sure [grayscale = False]"
    vanilla_model = ssl_module.get_metaformer(model_, res = res, embed_dims = embed_dims, 
                                              att_depth = depth, att_heads = heads,att_dims = att_dims,
                                              grayscale = grayscale, patch_size = patch_size, 
                                              register_tokens = 4,
                                             pretrained_encoder = pretrained_encoder)
    ssl_train(ssl_module.DINO, vanilla_model, 
             note = note + " / 2-view",
             name = f"DINO_{model_}_reg",
             learning_rate = lr_schedule,
             multiview = False,
             gradient_accumulation = 32)