In [1]:
import pandas as pd
import numpy as np
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
seed = 2024

import warnings
warnings.filterwarnings("ignore")

# ML tools 

import tensorflow as tf
import keras #; keras.config.set_dtype_policy("mixed_float16")
import keras_nlp
import keras_cv
from keras import ops

keras.utils.set_random_seed(seed)

import cv2
import tensorflow_io as tfio
from kaggle_datasets import KaggleDatasets
import tensorflow_datasets as tfds

from keras import Input, Model, layers
from keras.models import load_model
from keras.layers import Layer
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, BatchNormalization, LayerNormalization, MultiHeadAttention, Embedding, Subtract, Add, Multiply, GlobalAveragePooling2D, GlobalAveragePooling1D, LayerNormalization
from keras.preprocessing.image import load_img, img_to_array
from keras.applications import *
import os, sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.model_selection import train_test_split
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from tqdm.notebook import tqdm
import wandb
print(f"Requirements loaded, keras : v{keras.__version__}, Tensorflow : v{tf.__version__}")

2024-07-17 01:31:41.221056: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-17 01:31:41.221189: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-17 01:31:41.338335: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Requirements loaded, keras : v3.4.1, Tensorflow : v2.15.0


# Setting hyperparameters

In [2]:
# hyperparameters about images
res = 384
batch_size = 16
# hyperparameters about text
seq_len = 64
# hyperparameters about cross attentive decoder
att_depth = 4
att_heads = 8
att_dims = att_heads * 24
use_bias = False

train_cases = 70108
val_cases = 9972
train_steps = train_cases//batch_size
val_steps = val_cases//batch_size

train_dataset_dir = '/kaggle/input/roco-v2-tfrecord-dataset/Tot70108cases_RoCoV2_radiology_train_GZIP.tfrecord'
val_dataset_dir = '/kaggle/input/roco-v2-tfrecord-dataset/Tot9972cases_RoCoV2_radiology_test_GZIP.tfrecord'

# Parsing Dataset

In [3]:
def _parse_tfrecord(c, res = res):
    def parse_tfrecord(tfrecord):
        features = {'image': tf.io.FixedLenFeature([], tf.string),
                    'report': tf.io.FixedLenFeature([], tf.string),
                    }
        x = tf.io.parse_single_example(tfrecord, features)
        image_train = tf.image.decode_jpeg(x['image'], channels=c)
        image_train = _transform_images(res = res)(image_train)
        report = tf.cast(x["report"], tf.string)
        return image_train, report
    
    return parse_tfrecord


def _transform_images(res = res):
    def transform_images(x_train):
        x_train = tf.image.resize_with_pad(x_train, res, res, antialias = True)
        x_train = tf.cast(x_train, tf.uint8)
        return x_train
    return transform_images

def load_tfrecord_dataset(tfrecord_name, res = res, batch_size = batch_size, shuffle=True, buffer_size=10240, grayscale = False):
    """load dataset from tfrecord"""
    c = 1 if grayscale else 3
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name, compression_type = "GZIP")
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(c = c, res = res),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

train_ds = load_tfrecord_dataset(train_dataset_dir)
val_ds = load_tfrecord_dataset(val_dataset_dir) #image, report 2 outputs

for a, b in val_ds.take(1):
    val_images = a
    val_texts = b

# Preparing Natural language decoder (pretrained on plain texts)

In [4]:
gpt_freeset = "gpt2_base_en"
text_only_decoder = keras_nlp.models.Backbone.from_preset(
    gpt_freeset,
    trainable=False,
)
preprocessor = keras_nlp.models.GPT2Preprocessor.from_preset(
    gpt_freeset,
    sequence_length=seq_len,
    add_start_token = False
)
preprocessor.trainable = False
text_only_decoder.enable_lora(8)
text_only_decoder.summary()

Attaching 'model.safetensors' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.safetensors.index.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.safetensors' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.safetensors.index.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Ka

In [5]:
#according to summarized GPT2 model structure:
text_embed_dims = 768
n_vocab = preprocessor.tokenizer.vocabulary_size()
pad_token = preprocessor.tokenizer.pad_token_id
start_packer = keras_nlp.layers.StartEndPacker(
    sequence_length=seq_len,
    start_value=None,
)
prompt_words_ = preprocessor.tokenizer(["This image shows"]) ; start_word_idx = len(prompt_words_[0]) 
prompt = start_packer(prompt_words_)

# Modelling
- building a single model that functions as a testbed;
    - image encoder, text encoder(GPT2), cross-attentive attention layer로 구성
    - 모델은 image와 raw text를 input으로 받고
        - image -> encoder -> [CLS_token, encoded_patches, attention_weight] 3개의 output return
        - raw text -> GPT2 -> encoded_text & get mask
    - cross_attended_text = TransformerDecoder(query = raw_text, key = encoded_patches, value = encoded_patches, mask = mask) -> Dense -> Perplexity, accuracy 측정
    - 위 과정과 cls_token 및 pooled encoded_text의 batchwise cross-correlation 구해서 contrastive accuracy 구하기

In [6]:
class MedicalCaptioner(keras.Model): 
    def __init__(self, image_encoder, 
                 preprocessor, text_encoder,
                 att_heads = att_heads, att_dims = att_dims, att_depth = att_depth,
                 **kwargs):
        super().__init__(**kwargs)
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder ; 
        self.text_preprocessor = preprocessor ; self.text_preprocessor.trainable = False
        
        self.mlp_text = keras.layers.Dense(units = n_vocab, activation = "softmax", name = "VocabClassifier")
        self.att_heads = att_heads
        self.att_dims = att_dims
        self.att_depth = att_depth
        self.cross_attention_layers = [keras_nlp.layers.TransformerDecoder(att_dims, att_heads, name = f"CrossMHADecoder{i+1}",
                                                                          dropout = 0.2, activation = "gelu") for i in range(att_depth)]
        self.compute_perplexity = keras_nlp.metrics.Perplexity(mask_token_id=pad_token)
        self.compute_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.compute_softmax_loss = keras.losses.SparseCategoricalCrossentropy(ignore_class = pad_token, reduction = None)
    def get_config(self):
        return {"Image encoder name": self.image_encoder.name,
               "Text encoder name" : self.text_encoder.name,
               "cross attention heads" : att_heads, "intermediate dims" : self.att_dims, "cross attention depth" : self.att_depth
               }
    def call(self, image, text):
        if len(self.image_encoder.outputs) == 2 :
            image_token, encoded_patches = self.image_encoder(image)
        elif len(self.image_encoder.outputs) == 3 :
            image_token, encoded_patches, image_attention_weights = self.image_encoder(image)
        elif len(self.image_encoder.outputs) == 1:
            encoded_patches = self.image_encoder(image)
        
        if len(ops.shape(encoded_patches)) == 4:
            _, w, h, dims = ops.shape(encoded_patches)
            encoded_patches = ops.reshape(encoded_patches, [-1, w*h, dims])

        preprocessed_text = self.text_preprocessor(text) ; text_mask = preprocessed_text["padding_mask"]
        original_token = preprocessed_text["token_ids"]
        encoded_text = self.text_encoder(preprocessed_text)
        
        # Cross-Attention
        for idx, decoder in enumerate(self.cross_attention_layers):
            encoded_text = decoder(decoder_sequence = encoded_text,
                                  encoder_sequence = encoded_patches,
                                  decoder_padding_mask = text_mask)
        vocab_p = self.mlp_text(encoded_text)
        
        return vocab_p, original_token, text_mask
    
    def get_next_proba_fn(self, image, prompt):
        
        def next(prompt, cache, index):
            prompt = self.text_preprocessor.tokenizer.detokenize(prompt)
            vocab_proba = self(image, prompt)[0]
            logits = vocab_proba[:, index - 1, :]
            # Ignore hidden states for now; only needed for contrastive search.
            hidden_states = None
            return logits, hidden_states, cache
        return next
    def infer(self, image):
        if len(ops.shape(image)) == 3:
            image = image[tf.newaxis, ...]
        
        next_fn = self.get_next_proba_fn(image = image, prompt = prompt)
        
        greedy_sampler = keras_nlp.samplers.GreedySampler()
        nuc_sampler = keras_nlp.samplers.TopPSampler(p=0.5, k = 10)
        
        greedy_tokens = greedy_sampler(next=next_fn,
                                prompt=prompt,
                                index=start_word_idx)
        
        nuc_tokens = nuc_sampler(next=next_fn,
                                prompt=prompt,
                                index=start_word_idx)
        
        greedy_words, nuc_words = self.text_preprocessor.tokenizer.detokenize(greedy_tokens), self.text_preprocessor.tokenizer.detokenize(nuc_tokens)
        greedy_words, nuc_words = greedy_words.numpy(), nuc_words.numpy()
        try:
            greedy_words = [w.decode() for w in greedy_words]
        except:
            pass
        try:
            nuc_words = [w.decode() for w in nuc_words]
        except:
            nuc_words = str(nuc_words[0])
        return {"greedy_sampling" : greedy_words,
               "nucleus_sampling" : nuc_words}
    
    def train_step(self, dataset): 
        image, text = dataset
        with tf.GradientTape() as tape: 
            vocab_p, original_tokens, text_mask = self(image, text)
            
            loss = self.compute_softmax_loss(y_true = original_tokens, y_pred = vocab_p)
            perplexity = self.compute_perplexity(y_true = original_tokens, y_pred = vocab_p)
            accuracy = self.compute_accuracy(y_true = original_tokens, y_pred = vocab_p, sample_weight = text_mask)
            loss = ops.mean(loss)
            
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        
        return {"VocabSoftmaxLoss" : loss,
               "Perplexity" : perplexity,
               "TokenAccuracy" : accuracy}

In [7]:
def wandb_config():
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    try:
        secret_value_0 = user_secrets.get_secret("__gcloud_sdk_auth__")
        secret_value_1 = user_secrets.get_secret("huggingface_key")
        secret_value_2 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_2
    except:
        secret_value_0 = user_secrets.get_secret("huggingface_key")
        secret_value_1 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_1

In [8]:
class TrainingViz(keras.callbacks.Callback):
    def __init__(self, run):
        super().__init__()
        self.run = run
    def on_epoch_end(self, epoch, logs=None):
        try:
            origin = ["Original Image"]
            col = ["Original image", "Ground truth caption", "Greedy caption", "Nucleus caption"]
            visualize_data = []
            for idx, original_image in tqdm(enumerate(val_images)):
                origin_img = [wandb.Image(original_image)]
                gt_caption = [val_texts[idx].numpy()]
                outputs = self.model.infer(val_images[idx])
                greedy = [outputs["greedy_sampling"]]
                nuc = [outputs["nucleus_sampling"]]
                tmp = origin_img + gt_caption + greedy + nuc
                visualize_data.append(tmp)
                del tmp, origin_img
            tbl = wandb.Table(columns = col, data = visualize_data)
            wandb.log({f"Epoch{epoch+1}_result": tbl})
            del tbl
            tf.keras.backend.clear_session()
        except Exception as e: 
            print('visualization error', e)
        
    def on_train_batch_end(self, batch, logs=None):
        if (batch % 1500 == 0) and (batch != 0):
            if True:
                origin = ["Original Image"]
                col = ["Original image", "Ground truth caption", "Greedy caption", "Nucleus caption"]
                visualize_data = []
                len_ = 3 if batch == 0 else batch_size
                for idx, original_image in tqdm(enumerate(val_images[:len_])):
                    origin_img = [wandb.Image(original_image)]
                    gt_caption = [val_texts[idx].numpy()]
                    outputs = self.model.infer(val_images[idx])
                    greedy = [outputs["greedy_sampling"]]
                    nuc = [outputs["nucleus_sampling"]]
                    tmp = origin_img + gt_caption + greedy + nuc
                    visualize_data.append(tmp)
                    del tmp, origin_img
                tbl = wandb.Table(columns = col, data = visualize_data)
                wandb.log({f"batch{batch+1}_result": tbl})
                del tbl
                tf.keras.backend.clear_session()
            #except Exception as e: 
            #    print('visualization error', e)
        else:
            pass

In [9]:
def run_exp(image_encoder, notes = None, exp_name = None, epochs = 1,
           image_encoder_trainable = True):
    image_encoder.trainable = image_encoder_trainable
    model = MedicalCaptioner(image_encoder, preprocessor, text_only_decoder)
    model.compile(optimizer = keras.optimizers.Adam(learning_rate = 1e-4),
             jit_compile = False)
    configs = model.get_config()
    configs["finetune"] = "full fine tuning" if image_encoder_trainable else 'frozen, zeroshot'
    configs["batch size"] = batch_size
    configs["caption sequence length"] = seq_len
    try:
        wandb.finish()
    except:
        pass
    print(configs)
    
    if exp_name is None:
        encname = configs["Image encoder name"]
        tune_ = configs["finetune"]
        exp_name = f"{encname}_captioning_{tune_}"
    pass_error = keras.callbacks.TerminateOnNaN()
    
    
    wandb_config()
    run = wandb.init(project="Eval_RadImageNet_caption", 
                         entity="gongbungkim", config = configs, notes = notes,
                        name = exp_name)
    wb_callback = wandb.keras.WandbMetricsLogger(log_freq = 100)
    vizcallback = TrainingViz(run)
    callbacks = [pass_error, wb_callback, vizcallback]
    hist = model.fit(train_ds, steps_per_epoch = train_steps, epochs = epochs, verbose = 1, callbacks = callbacks)
    return hist, model

In [10]:
run_exp(keras.applications.EfficientNetV2B1(include_top = False, input_shape = [res,res,3]),
       epochs = 1)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/efficientnet_v2/efficientnetv2-b1_notop.h5
[1m28456008/28456008[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
{'Image encoder name': 'efficientnetv2-b1', 'Text encoder name': 'gpt2_backbone', 'cross attention heads': 8, 'intermediate dims': 192, 'cross attention depth': 4, 'finetune': 'full fine tuning', 'batch size': 16, 'caption sequence length': 64}
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Currently logged in as: [33mgongbungkim[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.17.4
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240717_013237-v5i6vvdq[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mefficientnetv2-b1_captioning_full fine tuning[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/gongbungkim/Eval_RadImageNet_caption[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/gongbungkim/Eval_RadImageNet_caption/runs/v5i6vvdq[0m
[34m[1mwandb[0m: [32m[41mERROR[0m Unable to log learning rate.


[1m1500/4381[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m29:11[0m 608ms/step - Perplexity: 1079.0686 - TokenAccuracy: 0.4591 - VocabSoftmaxLoss: 1.1109

0it [00:00, ?it/s]

I0000 00:00:1721181003.991465      24 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m3000/4381[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m26:31[0m 1s/step - Perplexity: 542.1666 - TokenAccuracy: 0.6289 - VocabSoftmaxLoss: 0.6232

0it [00:00, ?it/s]

wandb: Network error (ReadTimeout), entering retry loop.


[1m4381/4381[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - Perplexity: 372.1570 - TokenAccuracy: 0.7059 - VocabSoftmaxLoss: 0.4468

0it [00:00, ?it/s]

[1m4381/4381[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7623s[0m 2s/step - Perplexity: 372.0726 - TokenAccuracy: 0.7059 - VocabSoftmaxLoss: 0.4467


(<keras.src.callbacks.history.History at 0x7f1f75b96cb0>,
 <MedicalCaptioner name=medical_captioner, built=True>)