In [1]:
import pandas as pd
import numpy as np
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
seed = 2024

import warnings
warnings.filterwarnings("ignore")

# ML tools 

import tensorflow as tf
import keras #; keras.config.set_dtype_policy("mixed_float16")
import keras_nlp
import keras_cv
from keras import ops

keras.utils.set_random_seed(seed)

import cv2
import tensorflow_io as tfio
from kaggle_datasets import KaggleDatasets
import tensorflow_datasets as tfds

from keras import Input, Model, layers
from keras.models import load_model
from keras.layers import Layer
from keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, BatchNormalization, LayerNormalization, MultiHeadAttention, Embedding, Subtract, Add, Multiply, GlobalAveragePooling2D, GlobalAveragePooling1D, LayerNormalization
from keras.preprocessing.image import load_img, img_to_array
from keras.applications import *
import os, sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from sklearn.model_selection import train_test_split
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from tqdm.notebook import tqdm
print(f"Requirements loaded, keras : v{keras.__version__}, Tensorflow : v{tf.__version__}")

2024-07-14 16:34:47.189676: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-14 16:34:47.189801: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-14 16:34:47.356311: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Requirements loaded, keras : v3.4.1, Tensorflow : v2.15.0


# Setting hyperparameters

In [2]:
# hyperparameters about images
res = 384
batch_size = 16
# hyperparameters about text
seq_len = 128
# hyperparameters about cross attentive decoder
att_depth = 4
att_heads = 8
att_dims = att_heads * 32
use_bias = False

train_cases = 70108
val_cases = 9972
train_steps = train_cases//batch_size
val_steps = val_cases//batch_size

train_dataset_dir = '/kaggle/input/roco-v2-tfrecord-dataset/Tot70108cases_RoCoV2_radiology_train_GZIP.tfrecord'
val_dataset_dir = '/kaggle/input/roco-v2-tfrecord-dataset/Tot9972cases_RoCoV2_radiology_test_GZIP.tfrecord'

# Parsing Dataset

In [3]:
def _parse_tfrecord(res = res):
    def parse_tfrecord(tfrecord):
        features = {'image': tf.io.FixedLenFeature([], tf.string),
                    'report': tf.io.FixedLenFeature([], tf.string),
                    }
        x = tf.io.parse_single_example(tfrecord, features)
        image_train = tf.image.decode_jpeg(x['image'], channels=3)
        image_train = _transform_images(res = res)(image_train)
        report = tf.cast(x["report"], tf.string)
        return image_train, report
    
    return parse_tfrecord


def _transform_images(res = res):
    def transform_images(x_train):
        x_train = tf.image.resize_with_pad(x_train, res, res, antialias = True)
        x_train = tf.cast(x_train, tf.uint8)
        return x_train
    return transform_images

def load_tfrecord_dataset(tfrecord_name, res = res, batch_size = batch_size, shuffle=True, buffer_size=10240):
    """load dataset from tfrecord"""
    raw_dataset = tf.data.TFRecordDataset(tfrecord_name, compression_type = "GZIP")
    raw_dataset = raw_dataset.repeat()
    if shuffle:
        raw_dataset = raw_dataset.shuffle(buffer_size=buffer_size)
    dataset = raw_dataset.map(
        _parse_tfrecord(),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

train_ds_ = load_tfrecord_dataset(train_dataset_dir)
val_ds_ = load_tfrecord_dataset(val_dataset_dir) #image, report 2 outputs

# Preparing Natural language decoder (pretrained on plain texts)

In [4]:
gpt_freeset = "gpt2_base_en"
text_only_decoder = keras_nlp.models.Backbone.from_preset(
    gpt_freeset,
    trainable=False,
)
preprocessor = keras_nlp.models.GPT2Preprocessor.from_preset(
    gpt_freeset,
    sequence_length=seq_len,
)
preprocessor.trainable = False
text_only_decoder.enable_lora(8)
text_only_decoder.summary()

Attaching 'model.safetensors' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.safetensors.index.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'config.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.weights.h5' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.safetensors' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'model.safetensors.index.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Kaggle notebook...
Attaching 'metadata.json' from model 'keras/gpt2/keras/gpt2_base_en/2' to your Ka

In [5]:
#according to summarized GPT2 model structure:
text_embed_dims = 768
n_vocab = preprocessor.tokenizer.vocabulary_size()

# Modelling
- building a single model that functions as a testbed;
    - image encoder, text encoder(GPT2), cross-attentive attention layer로 구성
    - 모델은 image와 raw text를 input으로 받고
        - image -> encoder -> [CLS_token, encoded_patches, attention_weight] 3개의 output return
        - raw text -> GPT2 -> encoded_text & get mask
    - cross_attended_text = TransformerDecoder(query = raw_text, key = encoded_patches, value = encoded_patches, mask = mask) -> Dense -> Perplexity, accuracy 측정
    - 위 과정과 cls_token 및 pooled encoded_text의 batchwise cross-correlation 구해서 contrastive accuracy 구하기

In [6]:
class MedicalCaptioner(keras.Model): #original CLIP + CLIP surgery
    def __init__(self, image_encoder, 
                 preprocessor, text_encoder,
                 att_heads = att_heads, att_dims = att_dims, att_depth = att_depth,
                 **kwargs):
        super().__init__(**kwargs)
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.embed_dims = embed_dims
        self.mlp_text = keras.layers.Dense(units = n_vocab)
        self.att_heads = att_heads
        self.att_dims = att_dims
        self.att_depth = att_depth
        self.cross_attention_layers = [keras_nlp.layers.TransformerDecoder(att_dims, att_heads, name = f"CrossMHADecoder{i+1}",
                                                                          dropout = 0.2, activation = "gelu") for i in att_depth]
        
    def get_config(self):
        return {"Image encoder name": self.image_encoder.name,
               "Text encoder name" : self.text_encoder.name,
               "cross attention heads" : att_heads, "intermediate dims" : self.att_dims, "cross attention depth" : self.att_depth
               }
        
    def train_step(self, dataset): ###<- 여기서부터!!
        image, text = dataset
        with tf.GradientTape() as tape: 
            image_feature, text_feature = self.image_encoder(image, training = True), self.text_encoder(text, training = True)
            batch_size, w, h, dims = ops.shape(image_feature) ; batch_size = ops.shape(image_feature)[0]
            
            image_feature = ops.reshape(image_feature, [batch_size, w*h, dims])
            image_feature = self.pe_fn(image_feature)
            text_vector = keras.layers.GlobalAveragePooling1D()(text_feature)
            if len(ops.shape(image_feature)) == 3:
                image_vector = keras.layers.GlobalAveragePooling1D()(image_feature)
            elif len(ops.shape(image_feature)) == 4:
                image_vector = keras.layers.GlobalAveragePooling2D()(image_feature)

            image_vector, text_vector = self.mlp_image(image_vector, training = True), self.mlp_text(text_vector, training = True)
            image_vector, text_vector = self.image_pooler([image_vector, image_feature], training = True)[0], self.text_pooler([text_vector, text_feature], training = True)[0]
            loss = self.get_clip_loss(image_vector, text_vector)
        
        encoder_weights = self.image_encoder.trainable_weights + self.text_encoder.trainable_weights
        mlp_weights = self.mlp_image.trainable_weights + self.mlp_text.trainable_weights
        pool_weights = self.image_pooler.trainable_weights + self.text_pooler.trainable_weights
        trainable_weights = encoder_weights + mlp_weights + pool_weights + self.pe_fn.trainable_weights
        
        grads = tape.gradient(loss, trainable_weights)
        self.optimizer.apply_gradients(zip(grads, trainable_weights))
        self.loss_tracker.update_state(loss)
        return {self.loss_tracker.name : self.loss_tracker.result()}
    
    def test_step(self, dataset):
        image, text = dataset
        image_feature, text_feature = self.image_encoder(image, training = False), self.text_encoder(text, training = False)
        batch_size, w, h, dims = ops.shape(image_feature) ; batch_size = ops.shape(image_feature)[0]
            
        image_feature = ops.reshape(image_feature, [batch_size, w*h, dims])
        image_feature = self.pe_fn(image_feature)
        
        text_vector = keras.layers.GlobalAveragePooling1D()(text_feature)
        if len(ops.shape(image_feature)) == 3:
            image_vector = keras.layers.GlobalAveragePooling1D()(image_feature)
        elif len(ops.shape(image_feature)) == 4:
            image_vector = keras.layers.GlobalAveragePooling2D()(image_feature)

        image_vector, text_vector = self.mlp_image(image_vector, training = False), self.mlp_text(text_vector, training = False)
        image_vector, text_vector = self.image_pooler([image_vector, image_feature], training = False)[0], self.text_pooler([text_vector, text_feature], training = False)[0]
        loss = self.get_clip_loss(image_vector, text_vector)
        self.loss_tracker.update_state(loss)
        return {self.loss_tracker.name : self.loss_tracker.result()}
    def call(self, dataset):
        return self.test_step(dataset)
    def get_full_model(self):
        inputs = self.image_encoder.inputs
        feature = self.image_encoder.output
        if len(ops.shape(feature)) == 4:
            batch_size, w, h, dims = ops.shape(feature)
            batch_size = ops.shape(feature)[0]
            feature = ops.reshape(feature, [-1, w*h, dims])
        feature = self.pe_fn(feature)
        image_vector = keras.layers.GlobalAveragePooling1D()(feature)
        z_image = self.mlp_image(image_vector)
        outputs = self.image_pooler([z_image, feature])
        return keras.Model(inputs, outputs,
                          name = f"FullModel_{self.image_encoder.name}")