In [44]:
import pandas as pd
import numpy as np
import random
import pydicom

from sklearn.manifold import TSNE
import re
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.cm as cm
seed = 2024

import warnings
warnings.filterwarnings("ignore")

# ML tools 

import tensorflow as tf
import keras
from keras import ops #keras 3부터는 tf연산이 ops 하위 attribute로 옮겨감
import keras_cv
import keras_nlp

import cv2
keras.utils.set_random_seed(seed)
import tensorflow_io as tfio
from kaggle_datasets import KaggleDatasets
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import tensorflow_decision_forests as tfdf

print(f"Tensorflow version : {tf.__version__}")
try:
    print(f"Keras version : {keras.__version__}")
except:
    pass
from tensorflow.keras import Input, Model
from tensorflow.keras.models import load_model

from tensorflow.keras.layers import Conv2D, DepthwiseConv2D, Dense, Activation, BatchNormalization, LayerNormalization, MultiHeadAttention, Embedding, Subtract, Add, Multiply, GlobalAveragePooling2D, GlobalAveragePooling1D, LayerNormalization
from tensorflow.keras.utils import load_img, img_to_array
from tensorflow.keras.applications import *
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from tqdm.notebook import tqdm

import wandb
def wandb_config():
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    try:
        secret_value_0 = user_secrets.get_secret("__gcloud_sdk_auth__")
        secret_value_1 = user_secrets.get_secret("huggingface_key")
        secret_value_2 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_2
    except:
        secret_value_0 = user_secrets.get_secret("huggingface_key")
        secret_value_1 = user_secrets.get_secret("wandb_key")
        !wandb login $secret_value_1
    

res = int(1.5*256)
batch_size = 16
embed_dims = 768

def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        tpu = False
        strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return tpu, strategy

tpu, strategy = auto_select_accelerator()
batch_size = strategy.num_replicas_in_sync * batch_size
print('batch size', batch_size)

Tensorflow version : 2.15.0
Keras version : 3.4.1
Running on 1 replicas
batch size 16


In [2]:
channels = 3
all_labels = ['Aortic enlargement', 'Atelectasis', 'Calcification', 'Cardiomegaly', 'Consolidation', 'ILD', 'Infiltration', 'Lung Opacity', 'Nodule/Mass', 'Other lesion', 'Pleural effusion', 'Pleural thickening', 'Pneumothorax', 'Pulmonary fibrosis', 'No finding']

configs = {'batch size' : batch_size,
          "image size" : res,
          'image channel' : channels,
          f"labels({len(all_labels)})" : all_labels}

labelencoder = sklearn.preprocessing.MultiLabelBinarizer().fit([all_labels])

# Parsing dataset and curation
- Metadataset link : https://www.kaggle.com/datasets/financekim/vincxr-metadataset

In [3]:
mother_dir = "/kaggle/input/vinbigdata-chest-xray-abnormalities-detection/train"
df_train = pd.read_csv("/kaggle/input/vinbigdata-chest-xray-abnormalities-detection/train.csv")
df_train["image_id"] = [os.path.join(mother_dir, ids+".dicom") for ids in df_train["image_id"]]

In [4]:
def get_labels(image_path):
    labels = np.unique(df_train.loc[df_train["image_id"] == image_path, "class_name"])
    labels = labelencoder.transform([labels])
    return labels[0]
img_paths = np.unique(df_train["image_id"].values)
df_train_new = pd.DataFrame({"image_id" : img_paths,
                            "class_id" : [get_labels(path) for path in tqdm(img_paths)]})

  0%|          | 0/15000 [00:00<?, ?it/s]

In [5]:
df_train, df_val = train_test_split(df_train_new, test_size = 20*batch_size,
                                   random_state = seed)
train_cases, test_cases = len(df_train), len(df_val)
train_steps, test_steps = train_cases//batch_size, test_cases//batch_size

#setting global image decoding functions
def get_image_tensor(filepath, res = res) :
    #filepath = str(filepath)
    image_bytes = tf.io.read_file(filepath)
    if tf.strings.split(filepath, sep = '.')[-1] == "dicom" or tf.strings.split(filepath, sep = '.')[-1] == "dcm":
        image = tfio.image.decode_dicom_image(image_bytes, dtype=tf.uint16)
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = image[0]
    else:
        image = tf.io.decode_png(image_bytes)
        image = tf.cast(image, tf.float32)
    image = tf.image.resize_with_pad(image, res, res)
    image = (image - tf.reduce_min(image)) / (tf.reduce_max(image) - tf.reduce_min(image))
    image = image * 255.0
    if channels == 3:
        try:
            image = tf.image.grayscale_to_rgb(image)
        except:
            pass
    return image

def load_data(image_paths, labels):
    image_paths = tf.convert_to_tensor(image_paths, dtype=tf.string)
    labels = tf.convert_to_tensor(labels, tf.int32)
    
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))
    
    def _parse_function(image_path, label):
        image = get_image_tensor(image_path)
        image = ops.cast(image, 'uint8')
        return image, label
    
    dataset = dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE).ignore_errors()
    return dataset

train_ds = load_data(df_train["image_id"].values, np.stack(df_train["class_id"].values))
val_ds = load_data(df_val["image_id"].values, np.stack(df_val["class_id"].values))
train_ds = train_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE).repeat().cache()
val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE).repeat().cache()

for img, lab in val_ds.take(1):
    sample_images = img
    sample_labels = lab
    print(ops.shape(sample_images))
    print(lab)

W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'


(16, 384, 384, 3)
tf.Tensor(
[[0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [1 0 0 1 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [1 0 0 1 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 1 1 1 0 0 0 1 1 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 1 0 0 0 0 0 0]], shape=(16, 15), dtype=int32)


# Modelling and Report callback

In [9]:
def get_classifier_model(base_model):
    base_name = base_model.name
    inputs = Input([res,res,channels], name = "VinCXR_Images")
    outputs = base_model(inputs)
    if len(base_model.outputs) == 3:
        cls_token, encoded_patches, weights = outputs
        
    elif len(base_model.outputs) == 2:
        cls_token, encoded_patches = outputs
        
    else:
        encoded_patches = outputs
        if len(ops.shape(encoded_patches)) == 4:
            cls_token = keras.layers.GlobalAveragePooling2D(name = "CLS_bottleneck")(encoded_patches)
        elif len(ops.shape(encoded_patches)) == 3:
            cls_token = keras.layers.GlobalAveragePooling1D(name = "CLS_bottleneck")(encoded_patches)
    if ops.shape(encoded_patches) == 3:
        n_patches = ops.shape(encoded_patches)[1] ; dims = ops.shape(encoded_patches)[2]
        w_ = ops.sqrt(ops.cast(n_patches, "float32")
                     )
        encoded_patches = ops.reshape(encoded_patches, [-1, w_, w_, dims])
    encoded_patches = keras.layers.Identity(name = 'Patches')(encoded_patches)
    
    classifier_layer = Dense(units = len(all_labels), activation = "sigmoid", name = "DiseaseClassifier")(cls_token)
    model = Model(inputs, [classifier_layer, encoded_patches],
                 name = f"{base_name}_based_NaiveClassifier")
    metric_set = [[keras.metrics.AUC(curve = 'ROC', name = "AUROC", multi_label = True), 
                  keras.metrics.AUC(curve = 'PR', name = "AUPRC", multi_label = True), 
                  keras.metrics.Precision(name = "Precision"), 
                  keras.metrics.Recall(name = "Recall"), 
                  keras.metrics.F1Score(average = 'weighted', name = "F1score")], None]
    
    model.compile(optimizer = keras.optimizers.AdamW(learning_rate = 1e-4),
                 loss = ["binary_crossentropy", None],
                 metrics = metric_set,
                 jit_compile = False)
    return model

class ClassificationReportCallback(keras.callbacks.Callback):
    def __init__(self, test_dataset):
        super().__init__()
        self.test_dataset = test_dataset

    def on_epoch_end(self, epoch, logs=None):
        y_true = []
        y_pred = []

        for idx, batch in tqdm(enumerate(self.test_dataset), total = test_steps +1):
            X_batch, y_batch = batch
            y_batch_pred = (self.model.predict(X_batch, verbose = 0)[0] > 0.5).astype(int)
            y_true.extend(y_batch.numpy())
            y_pred.extend(y_batch_pred)
            if idx > test_steps:
                break

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        
        report = classification_report(y_true, y_pred, output_dict=True, target_names=[f'class_{name}' for name in all_labels])
        report_df = pd.DataFrame(report).transpose()
        colnames = all_labels + ["MicroAVG", "MacroAVG", "WeightedAVG", "SamplesAVG"]
        report_df["ItemNames"] = colnames
        # wandb에 업로드
        wandb.log({'classification_report': wandb.Table(dataframe=report_df)})
    def on_epoch_begin(self, epoch, logs=None):
        self.on_epoch_end(epoch, logs)

In [72]:
def get_full_cam(model, image):
    model.trainable = True
    model_ = keras.Model(model.inputs,
                        [model.get_layer("Patches").output,
                        model.get_layer("DiseaseClassifier").output]
                        )
    if len(ops.shape(image)) == 3:
        image = image[tf.newaxis, ...]
    cam = []
    for idx in range(len(all_labels)):
        with tf.GradientTape() as tape:
            encoded_patches, cls_proba = model_(image)
            cls_proba = cls_proba[:, idx]

        cls_grads = tape.gradient(cls_proba, 
                                  encoded_patches)
        cls_grads = ops.mean(cls_grads, axis = [1,2])
        cls_grads = cls_grads[:, tf.newaxis, :] #batch, 1, dims

        cam_ = ops.einsum('bwhd, bcd -> bwhc', encoded_patches, cls_grads)
        cam_ = (cam_ - ops.min(cam_, axis = [1, 2], keepdims = True)) / (ops.max(cam_, axis = [1,2], keepdims = True) - ops.min(cam_, axis = [1,2], keepdims = True)
                                                                        )
        cam_ = ops.squeeze(cam_, axis = -1)
        # Rescale heatmap to a range 0-255
        cam_ = np.uint8(255 * cam_)

        # Use jet colormap to colorize heatmap
        jet = mpl.colormaps["jet"]

        # Use RGB values of the colormap
        jet_colors = jet(np.arange(256))[:, :3]
        cam_ = jet_colors[cam_]

        # Create an image with RGB colorized heatmap
        cam_ = tf.image.resize(cam_, (image.shape[2], image.shape[1]),
                              method = 'gaussian')
        
        cam.append(cam_)
        tf.keras.backend.clear_session()
        
    cam = ops.stack(cam, axis = 1)
    cam = ops.clip(cam, 0, 1)
    cam *= 255.0
    cam = ops.cast(cam, "uint8")
    try:
        image = tf.image.grayscale_to_rgb(image)
    except:
        pass
    
    superimposed = 0.6*ops.cast(image[:, tf.newaxis, :, :, :], "float32") + 0.4*ops.cast(cam, "float32")
    superimposed = ops.cast(superimposed, "uint8")
    return cam, superimposed

# Experiment

In [11]:
def run_exp(base_model, notes = None, exp_name = None, epochs = 100,
           ):
    try:
        wandb.finish()
    except:
        pass
    print(configs)
    model = get_classifier_model(base_model)
    if exp_name is None:
        encname = base_model.name
        exp_name = f"{encname}_CXRClassification"
    pass_error = keras.callbacks.TerminateOnNaN()
    
    
    wandb_config()
    run = wandb.init(project="Eval_RadImageNet_VinCXRClassif", 
                         entity="gongbungkim", config = configs, notes = notes,
                        name = exp_name)
    wb_callback = wandb.keras.WandbMetricsLogger(log_freq = 100)
    vizcallback = ClassificationReportCallback(val_ds)
    callbacks = [pass_error, wb_callback, vizcallback]
    hist = model.fit(train_ds, steps_per_epoch = train_steps, epochs = epochs, verbose = 1, callbacks = callbacks)
    return hist, model

In [13]:
run_exp(keras.applications.EfficientNetV2M(input_shape = [res,res,channels], include_top = False)
       )

{'batch size': 16, 'image size': 384, 'image channel': 3, 'labels(15)': ['Aortic enlargement', 'Atelectasis', 'Calcification', 'Cardiomegaly', 'Consolidation', 'ILD', 'Infiltration', 'Lung Opacity', 'Nodule/Mass', 'Other lesion', 'Pleural effusion', 'Pleural thickening', 'Pneumothorax', 'Pulmonary fibrosis', 'No finding']}
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


[34m[1mwandb[0m: Currently logged in as: [33mgongbungkim[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/21 [00:00<?, ?it/s]

W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), > 8 for OB encoded uncompressed 'PixelData'
W: invalid value for 'BitsAllocated' (16), 

Epoch 1/100


KeyboardInterrupt: 