In [1]:
import pandas as pd
import numpy as np
import os, sys
import random
import pydicom

from sklearn.manifold import TSNE
import re
import matplotlib.pyplot as plt
import matplotlib.cm as cm
seed = 2024

import warnings
warnings.filterwarnings("ignore")

# ML tools 
sys.path.append("/kaggle/input/kimm-keras-image-model-repository"
               )

import tensorflow as tf
import keras# ; keras.config.set_dtype_policy("mixed_float16")
import kimm
import keras_cv
import keras_nlp

import cv2
from skimage.io import imread
keras.utils.set_random_seed(seed)
import tensorflow_io as tfio
from kaggle_datasets import KaggleDatasets
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import tensorflow_decision_forests as tfdf

print(f"Tensorflow version : {tf.__version__}")
try:
    print(f"Keras version : {keras.__version__}")
except:
    pass

from keras import Input, Model, ops
from keras.models import load_model

from keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, BatchNormalization, LayerNormalization, MultiHeadAttention, Embedding, Subtract, Add, Multiply, GlobalAveragePooling2D, GlobalAveragePooling1D, LayerNormalization
from keras.utils import load_img, img_to_array
from keras.applications import *
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.model_selection import train_test_split
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from tqdm.notebook import tqdm
import wandb
#from wandb.keras import WandbCallback, WandbModelCheckpoint, WandbMetricsLogger
def wandb_config():
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    try:
        secret_value_0 = user_secrets.get_secret("__gcloud_sdk_auth__")
        secret_value_1 = user_secrets.get_secret("huggingface_key")
        secret_value_2 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_2
    except:
        secret_value_0 = user_secrets.get_secret("huggingface_key")
        secret_value_1 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_1
    

res = int(1.5*256)
small_res = 64
batch_size = 8
embed_dims = 1024
n_multicrop = 4

def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        tpu = False
        strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return tpu, strategy

tpu, strategy = auto_select_accelerator()
batch_size = strategy.num_replicas_in_sync * batch_size
print('batch size', batch_size)


2024-06-19 04:33:24.181312: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-19 04:33:24.181435: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-19 04:33:24.307312: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Tensorflow version : 2.15.0
Keras version : 3.3.3
Running on 1 replicas
batch size 8


In [2]:
import ssl_module
from ssl_module import get_map_fn, get_gcvit_configs, get_flops, att_visualize, get_full_model, AttentionPooling, BarlowModel, VICRegModel, Moco, SimSiam, CLIP, SigLIP
import nas_ftp_module
from nas_ftp_module import upload_file, download_file
ssl_module.available_models()

Requirements loaded, keras : v3.3.3, Tensorflow : v2.15.0
RandAug Component in this SSL module :  ['random_contrast', 'random_brightness', 'random_shear', 'random_shear_1', 'random_translation', 'random_translation_1']


{'models_from_kimm': ['ConvMixer1024D20',
  'ConvMixer1536D20',
  'ConvMixer736D32',
  'ConvNeXtAtto',
  'ConvNeXtBase',
  'ConvNeXtFemto',
  'ConvNeXtLarge',
  'ConvNeXtNano',
  'ConvNeXtPico',
  'ConvNeXtSmall',
  'ConvNeXtTiny',
  'ConvNeXtXLarge',
  'DenseNet121',
  'DenseNet161',
  'DenseNet169',
  'DenseNet201',
  'EfficientNetB0',
  'EfficientNetB1',
  'EfficientNetB2',
  'EfficientNetB3',
  'EfficientNetB4',
  'EfficientNetB5',
  'EfficientNetB6',
  'EfficientNetB7',
  'EfficientNetLiteB0',
  'EfficientNetLiteB1',
  'EfficientNetLiteB2',
  'EfficientNetLiteB3',
  'EfficientNetLiteB4',
  'EfficientNetV2B0',
  'EfficientNetV2B1',
  'EfficientNetV2B2',
  'EfficientNetV2B3',
  'EfficientNetV2L',
  'EfficientNetV2M',
  'EfficientNetV2S',
  'EfficientNetV2XL',
  'GhostNet050',
  'GhostNet100',
  'GhostNet100V2',
  'GhostNet130',
  'GhostNet130V2',
  'GhostNet160V2',
  'HGNetBase',
  'HGNetSmall',
  'HGNetTiny',
  'HGNetV2B0',
  'HGNetV2B1',
  'HGNetV2B2',
  'HGNetV2B3',
  'HGNetV2B4'

- radimagenet tfrecord key : image, label
- nih cxr tfrecord key : image_raw, label

# RadImageNet decoding

In [3]:
def _parse_tfrecord(res = res):
    def parse_tfrecord(tfrecord):
        features = {'image': tf.io.FixedLenFeature([], tf.string),
                    'label': tf.io.FixedLenFeature([], tf.int64),
                    }
        x = tf.io.parse_single_example(tfrecord, features)
        image_train = tf.image.decode_jpeg(x['image'], channels=1)
        image_train = _transform_images(res = res)(image_train)
        label = tf.cast(x["label"], tf.int32)
        return (image_train, label)
    
    return parse_tfrecord


def _transform_images(res = res):
    def transform_images(x_train):
        x_train = tf.image.resize_with_pad(x_train, res, res, antialias = True)
        x_train = tf.cast(x_train, tf.uint8)
        return x_train
    return transform_images

def load_tfrecord_dataset(tfrecord_name, res = res, batch_size = batch_size, shuffle=True, buffer_size=10240):
    """load dataset from tfrecord"""
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name, compression_type = "GZIP")
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

train_radimagenet_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RagImageNet_Train_GZIP.tfrecord")
val_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RagImageNet_Test_GZIP.tfrecord")
for img, label in val_ds.take(1):
    print(ops.shape(img))
    print(label)

(8, 384, 384, 1)
tf.Tensor([114  30 114 157  83 146 107 151], shape=(8,), dtype=int32)


# NIH CXR decoding

In [4]:
def _parse_tfrecord(res = res):
    def parse_tfrecord(tfrecord):
        features = {'image_raw': tf.io.FixedLenFeature([], tf.string),
                    'label': tf.io.FixedLenFeature([], tf.int64),
                    }
        x = tf.io.parse_single_example(tfrecord, features)
        image_train = tf.image.decode_jpeg(x['image_raw'], channels=1)
        image_train = _transform_images(res = res)(image_train)
        label = tf.cast(x["label"], tf.int32)
        return (image_train, label)
    
    return parse_tfrecord


def _transform_images(res = res):
    def transform_images(x_train):
        x_train = tf.image.resize_with_pad(x_train, res, res, antialias = True)
        x_train = tf.cast(x_train, tf.uint8)
        return x_train
    return transform_images

def load_tfrecord_dataset(tfrecord_name, res = res, batch_size = batch_size, shuffle=True, buffer_size=10240):
    """load dataset from tfrecord"""
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name)
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    if batch_size:
        dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

nih_cxr_ds = load_tfrecord_dataset("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/nih_cxr_images.tfrecords")

# Merging 2 datasets

In [5]:
train_ds = tf.data.Dataset.sample_from_datasets([train_radimagenet_ds.unbatch(), nih_cxr_ds.unbatch()], weights = [0.75, 0.25]).batch(batch_size).repeat().prefetch(tf.data.AUTOTUNE)
val_ds_ = tf.data.Dataset.sample_from_datasets([train_radimagenet_ds.unbatch(), nih_cxr_ds.unbatch()], weights = [0.75, 0.25]).batch(32).prefetch(tf.data.AUTOTUNE)
# train data curation
for images, labels in val_ds_.take(1):
    sample_img = images
    labels = labels
del val_ds_

# Convert supervised dataset into SSL dataset

In [6]:
multiview_fn = get_map_fn(res = res, input_type = "supervised", output_type = "ssl",
                         n_view = n_multicrop)
two_view_fn = get_map_fn(res = res, input_type = "supervised", output_type = "ssl",
                         n_view = 2)
train_ds_multiview = train_ds.unbatch().map(multiview_fn, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_ds_multiview = val_ds.unbatch().map(multiview_fn, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)

In [7]:
for test_set in val_ds_multiview.take(1):
    test_set = test_set

----------
# Experiment - helper functions

In [8]:
df_train_rad = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RadImgNet_train.csv")
df_train_nih = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/nih_trainval_split.csv"
                          )
df_val_rad = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/RadImgNet_test.csv")
df_val_nih = pd.read_csv("/kaggle/input/radimagenet-and-nih-cxr-dataset-tfrecord/nih_test_split.csv")


train_cases = len(df_train_rad) + len(df_train_nih) + len(df_val_nih)
val_cases = len(df_val_rad)

train_steps = train_cases//batch_size
val_steps = val_cases//batch_size
print(f"Total train cases : {train_cases}, validation cases : {val_cases}")

Total train cases : 1303237, validation cases : 163796


In [9]:
class ModelSaveCallback(keras.callbacks.Callback):
    def __init__(self, exp_name, **kwargs):
        super().__init__(**kwargs)
        self.exp_name = exp_name
    def on_epoch_end(self, epoch, logs=None):
        save_dir = "/kaggle/working/" ; target_dir = '/kaggle/working/model_save'
        os.makedirs(target_dir, exist_ok = True)
        if (epoch % 1 == 0):
            try:
                print("\nModel Saving to local notebook...")
                file_name = f"{self.exp_name}_{self.model.name}_keras_v3_Epoch{epoch}.keras"
                filepath = os.path.join(target_dir, file_name)
                saved_dir = self.model.save(filepath, overwrite=True)
                if (epoch+1) % 5 == 0:
                    print("\nModel Uploading to NAS...")
                    upload_file(file_name, filepath)
                    print("\nModel Saved to Local NAS")
            except Exception as e: 
                print('Model Saving Error:\n', e)
    def on_train_batch_end(self, batch, logs=None):
        save_dir = "/kaggle/working/" ; target_dir = '/kaggle/working/model_save'
        os.makedirs(target_dir, exist_ok = True)
        if (batch % 50000 == 0) and (batch != 0): 
            try:
                print("\nModel Saving to local notebook...")
                file_name = f"{self.exp_name}_{self.model.name}_keras_v3_Batch{batch}.keras"
                filepath = os.path.join(target_dir, file_name)
                saved_dir = self.model.save(filepath, overwrite=True)
                if (batch % 10000 == 0):
                    print("\nModel Uploading to NAS...")
                    upload_file(file_name, filepath)
                    print("\nModel Saved to Local NAS")
            except Exception as e: 
                print('Model Saving Error:\n', e)
                
                
class TrainingViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    def on_epoch_end(self, epoch, logs=None):
        try:
            configs = self.model.get_config() ; method = configs["SSL_method"]
            if method in ["CLIP" , "SigLIP", "SPARC"]:
                feature_extractor = self.model
            else:
                try:
                    feature_extractor = self.model.feature_extractor
                except:
                    feature_extractor = self.model.get_full_model(res = res)
            viz_weights = ssl_module.att_visualize(feature_extractor, sample_img, res,
                                                  thresholding = True)
            viz_weights = np.array(viz_weights) #batch, heads, res, res, 3
            heads = viz_weights.shape[1]
            origin = ["Original Image"]
            col = [f"Head{idx + 1}" for idx in range(heads)]
            col = origin + col

            visualize_data = []
            for idx, weights in enumerate(viz_weights):
                origin_img = [wandb.Image(sample_img[idx])]
                tmp = [wandb.Image(weights[idx]) for idx in range(heads)]
                tmp = origin_img + tmp
                visualize_data.append(tmp)
                del tmp, origin_img
            tbl = wandb.Table(columns = col, data = visualize_data)
            wandb.log({f"Epoch{epoch+1}_{method}_result": tbl})
            del feature_extractor, tbl
            tf.keras.backend.clear_session()
        except Exception as e: 
                print('Model Saving Error:\n', e)
        
    def on_train_batch_end(self, batch, logs=None):
        if (batch % (10000) == 0) : 
            try:
                configs = self.model.get_config() ; method = configs["SSL_method"]
                if method in ["CLIP" , "SigLIP", "SPARC"]:
                    feature_extractor = self.model
                else:
                    try:
                        feature_extractor = self.model.feature_extractor
                    except:
                        feature_extractor = self.model.get_full_model(res = res)
                viz_weights = ssl_module.att_visualize(feature_extractor, sample_img, res,
                                                      thresholding = True)
                viz_weights = np.array(viz_weights) #batch, heads, res, res, 3
                heads = viz_weights.shape[1]
                origin = ["Original Image"]
                col = [f"Head{idx + 1}" for idx in range(heads)]
                col = origin + col
                visualize_data = []
                for idx, weights in enumerate(viz_weights):
                    origin_img = [wandb.Image(sample_img[idx])]
                    tmp = [wandb.Image(weights[idx]) for idx in range(heads)]
                    tmp = origin_img + tmp
                    visualize_data.append(tmp)
                    del tmp, origin_img
                tbl = wandb.Table(columns = col, data = visualize_data)
                if batch == 0:
                    wandb.log({f"ZeroBatch_{method}_result": tbl})
                else:
                    wandb.log({f"MidEpoch_{method}_result": tbl})
                del feature_extractor, tbl
                tf.keras.backend.clear_session()
            except Exception as e:
                print("Error code in callback : ", e)
           
        else:
            pass

In [10]:
def run_exp(model, train_ds = train_ds, val_ds = val_ds, epochs = 10, note= None, exp_name = None):
    try:
        wandb.finish()
    except:
        pass
    
    if True :
        wandb_config()
        configs = model.get_config()
        method = configs["SSL_method"]
        try:
            feature_extractor = model.feature_extractor
        except:
            feature_extractor = model.get_full_model(res = res)
        
        if method in ['CLIP', "SigLIP", "SPARC"]:
            _ = model((example_images[:2], example_reports[:2]))
        else:
            _ = model(test_set)
        feature_extractor_flops = get_flops(feature_extractor, [tf.random.normal([1,res,res,1])])
        del feature_extractor
        
        env_config = {"batch_size" : batch_size, "original resolution" : res, "local view resolution" : small_res,
                     "Training steps" : train_steps,
                     "Val steps" : val_steps,
                     "train cases" : train_cases,
                     "val cases" : val_cases,
                     "embed_dims" : embed_dims,
                     "Image resolution" : res,
                     "(Image) Encoder Flops(G)" : feature_extractor_flops,
                     "dtype" : keras.mixed_precision.dtype_policy(),
                      "Optimizer configs" : model.optimizer.get_config(),
                      "Multicrop N" : n_multicrop,
                     }
        configs.update(env_config)
        
        wd = "/kaggle/working/"
        file_name = os.path.join(wd, f"{method}_radimgnet_mini.keras")
        print(configs, "\n\n")
        model.summary()
        run = wandb.init(project="RadImageNet", 
                         entity="gongbungkim", config = configs, notes = note,
                        name = exp_name)

        pass_error = keras.callbacks.TerminateOnNaN()
        wb_callback = wandb.keras.WandbMetricsLogger(log_freq = 100)
        
        callbacks = [pass_error, wb_callback, ModelSaveCallback(f"RI_SSL_{method}"), 
                     TrainingViz(run)]
        if val_ds is not None:
            hist = model.fit(train_ds, 
                             steps_per_epoch = train_steps, 
                             epochs = epochs, 
                             validation_data = val_ds, 
                             validation_steps = val_steps, 
                             verbose = 1,
                             callbacks = callbacks)
        else:
            hist = model.fit(train_ds, 
                         steps_per_epoch = train_steps, 
                         epochs = epochs, 
                         verbose = 1,
                         callbacks = callbacks)
    return hist

In [11]:
cosine_decay = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate = 1e-6,
    decay_steps = int(0.5*train_steps),
    alpha=1e-5,
    name='CosineDecay',
    warmup_target=2e-4,
    warmup_steps=train_steps - int(0.3*train_steps)
)

lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate = 2e-4,
    decay_steps=20000,
    decay_rate=0.75,
    staircase=True)

In [12]:
def ssl_train(module, feature_extractor, learning_rate = lr_schedule,
              embed_dims = embed_dims, multiview = True, gradient_accumulation = None, use_ema = False,
             note = "", name = "",
             apply_barlow = False, apply_simclr = False):
    try:
        ssl_trainer = module(feature_extractor, embed_dims = 2048, multiview = multiview,
                            apply_barlow = apply_barlow, apply_simclr = apply_simclr)
    except Exception as e:
        print("Error : ",e)
        ssl_trainer = module(feature_extractor, embed_dims = 2048, multiview = multiview)
    ssl_trainer.compile(optimizer = keras.optimizers.Adam(learning_rate = learning_rate,
                                                         clipnorm = 0.5,
                                                         #amsgrad = True,
                                                           gradient_accumulation_steps=gradient_accumulation,
                                                         use_ema = use_ema),
                        jit_compile = False
                      )
    
    run_exp(ssl_trainer, train_ds_multiview, None, epochs = 100,
       note = "Without validation d/t lack of resources" + note, exp_name = name)

In [13]:
ibot = 0
other = 1

depth = 8
heads = 8
att_dims = heads * 64
patch_size = 24 #16, 24, 32

In [14]:
if ibot:
    

    ssl_trainer = ssl_module.iBOT(att_depth = depth, att_dims = att_dims, att_heads = heads,
                                  embed_dims = 2048, patch_size = patch_size,

                                  multiview = True, apply_simclr = False,
                                  grayscale = True
                                 )
    ssl_trainer.compile(optimizer = keras.optimizers.AdamW(learning_rate = lr_schedule,
                                                         clipnorm = 1.0, use_ema = True),
                       jit_compile = False)
    run_exp(ssl_trainer, train_ds_multiview, None, epochs = 100,
           note = "+ NEW aug, New Patching", exp_name = "iBOT_VanillaViT")

In [15]:
if other:
    model_ = 'attention'
    vanilla_model = ssl_module.get_metaformer(model_, res = res, embed_dims = 512, 
                                              att_depth = depth, att_heads = heads,att_dims = att_dims,
                                              grayscale = True, patch_size = patch_size, register_tokens = 2)
    ssl_train(ssl_module.VICRegModel, vanilla_model, 
             note = "+ register early, NEW aug, New Patching",
             name = f"VICReg_{model_}_reg",
             #apply_barlow = 0, apply_simclr = 0,
             learning_rate = lr_schedule, use_ema = True,
             gradient_accumulation = 16)

Error :  Unrecognized keyword arguments passed to VICRegModel: {'apply_barlow': False, 'apply_simclr': False}
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
{'feature_extractor_name': 'Metaformer_res384_type_attention', 'embed_dims': 1024, 'Multiview(>2)': True, 'Variance_coefficient': 20, 'Invariance_coefficient': 20, 'Covariance_coefficient': 1, 'Variance_gamma': 5.0, 'SSL_method': 'VICReg', 'Linear Probe': False, 'N_Categories': 0, 'Probe Activation': 'NA', 'batch_size': 8, 'original resolution': 384, 'local view resolution': 64, 'Training steps': 162904, 'Val steps': 20474, 'train cases': 1303237, 'val cases': 163796, 'Image resolution': 384, '(Image) Encoder Flops(G)': 45.607580264, 'dtype': <FloatDTypePolicy "float32">, 'Optimizer configs': {'name': 'adam', 'learning_rate': {'module': 'keras.optimizers.schedules', 'class_name': 'ExponentialDecay', 'config': {'initial_learning_rate': 0.0002, 'decay_steps': 20000, 'decay_rate': 0.75, 'staircase

[34m[1mwandb[0m: Currently logged in as: [33mgongbungkim[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: wandb version 0.17.2 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
[34m[1mwandb[0m: Tracking run with wandb version 0.17.0
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240619_043511-6p45tj39[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mVICReg_attention_reg[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/gongbungkim/RadImageNet[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/gongbungkim/RadImageNet/runs/6p45tj39[0m


Epoch 1/100


[34m[1mwandb[0m: [32m[41mERROR[0m Unable to log learning rate.


Using Raw model with Att pooling 
 Possible error: 
 'Functional' object has no attribute 'get_full_model'


I0000 00:00:1718771795.322026      69 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
W0000 00:00:1718771795.344470      69 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m  1552/162904[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m46:36:10[0m 1s/step - covariance_loss: 5.6187e-04 - invariance_loss: 0.0180 - loss: 199.2760 - variance_loss: 199.2573Batch 1552: Invalid loss, terminating training
[1m  1553/162904[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m46:36:14[0m 1s/step - covariance_loss: nan - invariance_loss: nan - loss: nan - variance_loss: nan                    
Model Saving to local notebook...
Using Raw model with Att pooling 
 Possible error: 
 'Functional' object has no attribute 'get_full_model'
[1m162904/162904[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1728s[0m 10ms/step - covariance_loss: nan - invariance_loss: nan - loss: nan - variance_loss: nan 
