In [1]:
import pandas as pd
import numpy as np
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
seed = 2024

import warnings
warnings.filterwarnings("ignore")

# ML tools 

import tensorflow as tf
import keras #; keras.config.set_dtype_policy("mixed_float16")
import keras_nlp
import keras_cv
from keras import ops

keras.utils.set_random_seed(seed)

import cv2
import tensorflow_io as tfio
from kaggle_datasets import KaggleDatasets
import tensorflow_datasets as tfds

from keras import Input, Model, layers
from keras.models import load_model
from keras.layers import Layer
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, BatchNormalization, LayerNormalization, MultiHeadAttention, Embedding, Subtract, Add, Multiply, GlobalAveragePooling2D, GlobalAveragePooling1D, LayerNormalization
from keras.preprocessing.image import load_img, img_to_array
from keras.applications import *
import os, sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.model_selection import train_test_split
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from tqdm.notebook import tqdm
print(f"Requirements loaded, keras : v{keras.__version__}, Tensorflow : v{tf.__version__}")

2024-07-15 07:56:39.179151: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-15 07:56:39.179277: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-15 07:56:39.326230: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Requirements loaded, keras : v3.4.1, Tensorflow : v2.15.0


# Setting hyperparameters

In [2]:
# hyperparameters about images
res = 384
batch_size = 16
# hyperparameters about text
seq_len = 128
# hyperparameters about cross attentive decoder
att_depth = 4
att_heads = 8
att_dims = att_heads * 32
use_bias = False

train_cases = 70108
val_cases = 9972
train_steps = train_cases//batch_size
val_steps = val_cases//batch_size

train_dataset_dir = '/kaggle/input/roco-v2-tfrecord-dataset/Tot70108cases_RoCoV2_radiology_train_GZIP.tfrecord'
val_dataset_dir = '/kaggle/input/roco-v2-tfrecord-dataset/Tot9972cases_RoCoV2_radiology_test_GZIP.tfrecord'

# Parsing Dataset

In [3]:
def _parse_tfrecord(c, res = res):
    def parse_tfrecord(tfrecord):
        features = {'image': tf.io.FixedLenFeature([], tf.string),
                    'report': tf.io.FixedLenFeature([], tf.string),
                    }
        x = tf.io.parse_single_example(tfrecord, features)
        image_train = tf.image.decode_jpeg(x['image'], channels=c)
        image_train = _transform_images(res = res)(image_train)
        report = tf.cast(x["report"], tf.string)
        return image_train, report
    
    return parse_tfrecord


def _transform_images(res = res):
    def transform_images(x_train):
        x_train = tf.image.resize_with_pad(x_train, res, res, antialias = True)
        x_train = tf.cast(x_train, tf.uint8)
        return x_train
    return transform_images

def load_tfrecord_dataset(tfrecord_name, res = res, batch_size = batch_size, shuffle=True, buffer_size=10240, grayscale = False):
    """load dataset from tfrecord"""
    c = 1 if grayscale else 3
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name, compression_type = "GZIP")
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(c = c, res = res),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

train_ds = load_tfrecord_dataset(train_dataset_dir)
val_ds = load_tfrecord_dataset(val_dataset_dir) #image, report 2 outputs

for a, b in val_ds.take(1):
    val_images = a
    val_texts = b

# Preparing Natural language decoder (pretrained on plain texts)

In [4]:
gpt_freeset = "gpt2_base_en"
text_only_decoder = keras_nlp.models.Backbone.from_preset(
    gpt_freeset,
    trainable=False,
)
preprocessor = keras_nlp.models.GPT2Preprocessor.from_preset(
    gpt_freeset,
    sequence_length=seq_len,
    add_start_token = False
)
preprocessor.trainable = False
text_only_decoder.enable_lora(8)
text_only_decoder.summary()

Attaching 'model.safetensors' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.safetensors.index.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.safetensors' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.safetensors.index.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Ka

In [5]:
#according to summarized GPT2 model structure:
text_embed_dims = 768
n_vocab = preprocessor.tokenizer.vocabulary_size()
pad_token = preprocessor.tokenizer.pad_token_id
start_packer = keras_nlp.layers.StartEndPacker(
    sequence_length=seq_len,
    start_value=None,
)
prompt_words_ = preprocessor.tokenizer(["This image shows"]) ; start_word_idx = len(prompt_words_[0]) 
prompt = start_packer(prompt_words_)

# Modelling
- building a single model that functions as a testbed;
    - image encoder, text encoder(GPT2), cross-attentive attention layer로 구성
    - 모델은 image와 raw text를 input으로 받고
        - image -> encoder -> [CLS_token, encoded_patches, attention_weight] 3개의 output return
        - raw text -> GPT2 -> encoded_text & get mask
    - cross_attended_text = TransformerDecoder(query = raw_text, key = encoded_patches, value = encoded_patches, mask = mask) -> Dense -> Perplexity, accuracy 측정
    - 위 과정과 cls_token 및 pooled encoded_text의 batchwise cross-correlation 구해서 contrastive accuracy 구하기

In [21]:
class MedicalCaptioner(keras.Model): 
    def __init__(self, image_encoder, 
                 preprocessor, text_encoder,
                 att_heads = att_heads, att_dims = att_dims, att_depth = att_depth,
                 **kwargs):
        super().__init__(**kwargs)
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder ; 
        self.text_preprocessor = preprocessor ; self.text_preprocessor.trainable = False
        
        self.mlp_text = keras.layers.Dense(units = n_vocab, activation = "softmax", name = "VocabClassifier")
        self.att_heads = att_heads
        self.att_dims = att_dims
        self.att_depth = att_depth
        self.cross_attention_layers = [keras_nlp.layers.TransformerDecoder(att_dims, att_heads, name = f"CrossMHADecoder{i+1}",
                                                                          dropout = 0.2, activation = "gelu") for i in range(att_depth)]
        self.compute_perplexity = keras_nlp.metrics.Perplexity(mask_token_id=pad_token)
        self.compute_accuracy = keras.metrics.SparseCategoricalAccuracy()
        self.compute_softmax_loss = keras.losses.SparseCategoricalCrossentropy(ignore_class = pad_token, reduction = None)
    def get_config(self):
        return {"Image encoder name": self.image_encoder.name,
               "Text encoder name" : self.text_encoder.name,
               "cross attention heads" : att_heads, "intermediate dims" : self.att_dims, "cross attention depth" : self.att_depth
               }
    def call(self, image, text):
        if len(self.image_encoder.outputs) == 2 :
            image_token, encoded_patches = self.image_encoder(image)
        elif len(self.image_encoder.outputs) == 3 :
            image_token, encoded_patches, image_attention_weights = self.image_encoder(image)
        elif len(self.image_encoder.outputs) == 1:
            encoded_patches = self.image_encoder(image)
        
        if len(ops.shape(encoded_patches)) == 4:
            _, w, h, dims = ops.shape(encoded_patches)
            encoded_patches = ops.reshape(encoded_patches, [-1, w*h, dims])

        preprocessed_text = self.text_preprocessor(text) ; text_mask = preprocessed_text["padding_mask"]
        original_token = preprocessed_text["token_ids"]
        encoded_text = self.text_encoder(preprocessed_text)
        
        # Cross-Attention
        for idx, decoder in enumerate(self.cross_attention_layers):
            encoded_text = decoder(decoder_sequence = encoded_text,
                                  encoder_sequence = encoded_patches,
                                  decoder_padding_mask = text_mask)
        vocab_p = self.mlp_text(encoded_text)
        
        return vocab_p, original_token, text_mask
    
    def get_next_proba_fn(self, image, prompt):
        
        def next(prompt, cache, index):
            prompt = self.text_preprocessor.tokenizer.detokenize(prompt)
            vocab_proba = self(image, prompt)[0]
            logits = vocab_proba[:, index - 1, :]
            # Ignore hidden states for now; only needed for contrastive search.
            hidden_states = None
            return logits, hidden_states, cache
        return next
    def infer(self, image):
        if len(ops.shape(image)) == 3:
            image = image[tf.newaxis, ...]
        next_fn = self.get_next_proba_fn(image = image, prompt = prompt)
        
        greedy_sampler = keras_nlp.samplers.GreedySampler()
        nuc_sampler = keras_nlp.samplers.TopPSampler(p=0.5, k = 10)
        
        greedy_tokens = greedy_sampler(next=next_fn,
                                prompt=prompt,
                                index=start_word_idx)
        
        nuc_tokens = nuc_sampler(next=next_fn,
                                prompt=prompt,
                                index=start_word_idx)
        
        greedy_words, nuc_words = self.text_preprocessor.tokenizer.detokenize(greedy_tokens), self.text_preprocessor.tokenizer.detokenize(nuc_tokens)
        greedy_words, nuc_words = greedy_words.numpy(), nuc_words.numpy()
        greedy_words = [w.decode() for w in greedy_words]
        nuc_words = [w.decode() for w in nuc_words]
        return {"greedy_sampling" : greedy_words,
               "nucleus_sampling" : nuc_words}
    
    def train_step(self, dataset): 
        image, text = dataset
        with tf.GradientTape() as tape: 
            vocab_p, original_tokens, text_mask = self(image, text)
            
            loss = self.compute_softmax_loss(y_true = original_tokens, y_pred = vocab_p)
            perplexity = self.compute_perplexity(y_true = original_tokens, y_pred = vocab_p)
            accuracy = self.compute_accuracy(y_true = original_tokens, y_pred = vocab_p, sample_weight = text_mask)
            loss = ops.mean(loss)
            
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        
        return {"VocabSoftmaxLoss" : loss,
               "Perplexity" : perplexity,
               "TokenAccuracy" : accuracy}

In [22]:
#effnet = keras.applications.EfficientNetV2B0(include_top = False,
#                                            input_shape = [res,res,3])
#model = MedicalCaptioner(effnet, preprocessor, text_only_decoder)
#model.compile(optimizer = keras.optimizers.AdamW(learning_rate = 1e-4),
#             jit_compile = False)
#model.fit(train_ds, epochs = 1, steps_per_epoch = 100)
#model.infer(val_images[0]), val_texts[0]

[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m168s[0m 728ms/step - Perplexity: 13411.0898 - TokenAccuracy: 0.0585 - VocabSoftmaxLoss: 2.1031


({'greedy_sampling': ['This image shows of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of of..........<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|> (<|endoftext|><|endoftext|> ( ( ( ( ( ( ( ( (.. (.. ( (... (.. ( (.. ( ( (.. ( (.. ( ( ( ( ( ( ( ( ( ('],
  'nucleus_sampling': ['This image shows in the) a,, (<|endoftext|>,arrow., a of (<|endoftext|> of.<|endoftext|>)arrow., (arrowarrow witharrow,,,- with. in<|endoftext|>arrow<|endoftext|> of of a<|endoftext|> of with<|endoftext|><|endoftext|>. in witharrow of--arrow (i. of in ( of (,arrow<|endoftext|>,arrow of<|endoftext|> with with) with in)<|endoftext|> of<|endoftext|>-)<|endoftext|> the) and<|endoftext|>- (.<|endoftext|> the of the., the- ofarrow,-),,.- aarrow,. of-<|endoftext|><|endoftext|>-,. a of in. in,,<|endoftext|> a']},
 <tf.Tensor: shape=(), dtype=string, numpy=b'Initial CT an