<a href="https://colab.research.google.com/github/ayyucedemirbas/RadiologyCLIP/blob/main/RadiologyCLIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import random
import json
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.image import load_img, img_to_array

In [2]:
!wget https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz

--2025-02-01 17:53:08--  https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz
Resolving openi.nlm.nih.gov (openi.nlm.nih.gov)... 130.14.65.157, 2607:f220:41e:7065::157
Connecting to openi.nlm.nih.gov (openi.nlm.nih.gov)|130.14.65.157|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1360814128 (1.3G) [application/x-gzip]
Saving to: ‘NLMCXR_png.tgz’


2025-02-01 18:04:08 (1.97 MB/s) - ‘NLMCXR_png.tgz’ saved [1360814128/1360814128]



In [3]:
!wget https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz

--2025-02-01 18:04:27--  https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz
Resolving openi.nlm.nih.gov (openi.nlm.nih.gov)... 130.14.65.157, 2607:f220:41e:7065::157
Connecting to openi.nlm.nih.gov (openi.nlm.nih.gov)|130.14.65.157|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1112632 (1.1M) [application/x-gzip]
Saving to: ‘NLMCXR_reports.tgz’


2025-02-01 18:04:28 (1.69 MB/s) - ‘NLMCXR_reports.tgz’ saved [1112632/1112632]



In [4]:
!mkdir workdir

In [5]:
%cd workdir

/content/workdir


In [None]:
!tar -xvzf /content/NLMCXR_png.tgz -C .
!tar -xvzf /content/NLMCXR_reports.tgz -C .

In [7]:
%cd ..

/content


In [8]:
!mv workdir/ecgen-radiology .

In [9]:
!ls -al

total 1330472
drwxr-xr-x 1 root root       4096 Feb  1 18:05 .
drwxr-xr-x 1 root root       4096 Feb  1 17:48 ..
drwxr-xr-x 4 root root       4096 Jan 30 14:18 .config
drwxr-xr-x 2  929  212      94208 Feb  4  2016 ecgen-radiology
-rw-r--r-- 1 root root 1360814128 Sep  6  2016 NLMCXR_png.tgz
-rw-r--r-- 1 root root    1112632 Jun  6  2017 NLMCXR_reports.tgz
drwxr-xr-x 1 root root       4096 Jan 30 14:19 sample_data
drwxr-xr-x 2   48   48     352256 Feb  1 18:05 workdir


In [10]:
import xml.etree.ElementTree as ET

In [11]:
def find_valid_pairs(image_dir='workdir', report_dir='ecgen-radiology'):
    valid_pairs = []
    report_files = {}

    for root, _, files in os.walk(report_dir):
        for f in files:
            if f.lower().endswith('.xml'):
                base_name = os.path.splitext(f)[0].lower()
                report_files[base_name] = os.path.join(root, f)

    for root, _, files in os.walk(image_dir):
        for f in files:
            if f.lower().endswith('.png'):
                base_name = os.path.splitext(f)[0].lower()
                if base_name in report_files:
                    valid_pairs.append((
                        os.path.join(root, f),
                        report_files[base_name]
                    ))
    return valid_pairs


def parse_xml_report(xml_path):
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        ns = {'pmc': 'http://www.ncbi.nlm.nih.gov/pmc/articles/PMC'}

        findings_elem = root.find(".//pmc:AbstractText[@Label='FINDINGS']", ns)
        impression_elem = root.find(".//pmc:AbstractText[@Label='IMPRESSION']", ns)

        findings = findings_elem.text.strip() if (findings_elem is not None and findings_elem.text) else ""
        impression = impression_elem.text.strip() if (impression_elem is not None and impression_elem.text) else ""

        if not findings and not impression:
            return None

        return f"FINDINGS: {findings}. IMPRESSION: {impression}"

    except Exception as e:
        print(f"Error parsing {xml_path}: {str(e)}")
        return None


In [12]:
import os
import xml.etree.ElementTree as ET
import pandas as pd

def parse_xml_report(xml_path):
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        ns = {'pmc': 'http://www.ncbi.nlm.nih.gov/pmc/articles/PMC'}

        findings = root.find(".//AbstractText[@Label='FINDINGS']", ns)
        impression = root.find(".//AbstractText[@Label='IMPRESSION']", ns)

        findings_text = findings.text.strip() if (findings is not None and findings.text) else ""
        impression_text = impression.text.strip() if (impression is not None and impression.text) else ""

        if not findings_text and not impression_text:
            return None, []

        image_ids = [img.get('id') for img in root.findall('.//parentImage')]

        image_paths = [f"workdir/{img_id}.png" for img_id in image_ids]

        return (
            f"FINDINGS: {findings_text}. IMPRESSION: {impression_text}",
            image_paths
        )

    except Exception as e:
        print(f"Error parsing {xml_path}: {str(e)}")
        return None, []

def create_dataset():
    data = []

    for root_dir, _, files in os.walk("ecgen-radiology"):
        for file in files:
            if file.lower().endswith('.xml'):
                xml_path = os.path.join(root_dir, file)
                report_text, image_paths = parse_xml_report(xml_path)

                if report_text and image_paths:
                    for img_path in image_paths:
                        if os.path.exists(img_path):
                            data.append({
                                'image_path': img_path,
                                'report': report_text
                            })
                        else:
                            print(f"Missing image: {img_path}")

    return pd.DataFrame(data)

radiology_df = create_dataset()
print(f"Found {len(radiology_df)} valid image-report pairs")

if not radiology_df.empty:
    print("\nSample entry:")
    print(f"Image path: {radiology_df.iloc[0]['image_path']}")
    print(f"Report: {radiology_df.iloc[0]['report']}")
else:
    print("\nNo valid pairs found. Check:")
    print("- XML files in ecgen-radiology directory")
    print("- Image files in workdir directory")
    print("- File naming consistency between XML and PNG files")

Found 7430 valid image-report pairs

Sample entry:
Image path: workdir/CXR1329_IM-0211-1001.png
Report: FINDINGS: Two nodules are noted in the right XXXX XXXX measuring 13 mm and one measuring 16 mm in diameter. The smaller one appears to be within the right upper lobe and the large XXXX appears to be within the left lower lobe. No focal consolidation and no other pulmonary nodules are identified. However, if a full evaluation for lung nodules is desired consider XXXX for further evaluation. No pleural effusions or pneumothoraces. Heart and mediastinum of normal size and contour.. IMPRESSION: At XXXX 2 right lung pulmonary nodules concerning for<BR>metastatic disease


In [13]:
tokenizer = Tokenizer(
    num_words=5000,
    oov_token="",
    filters='!"#$%&()*+.,-/:;=?@[\]^_`{|}~ '
)


tokenizer.fit_on_texts(radiology_df['report'])
vocab_size = len(tokenizer.word_index) + 1
max_length = 100


print([word for word in tokenizer.word_index if "pneumonia" in word or "effusion" in word][:10])

['effusion', 'effusions', 'pneumonia']


In [14]:
def preprocess_sample(sample):
    img = load_img(sample['image_path'], target_size=(224, 224))
    img = img_to_array(img) / 255.0

    seq = tokenizer.texts_to_sequences([sample['report']])[0]
    seq = tf.keras.preprocessing.sequence.pad_sequences([seq], maxlen=max_length, padding='post')[0]
    return img, seq

def dataset_generator(df):
    for idx, row in df.iterrows():
        yield {'image_path': row['image_path'], 'report': row['report']}

In [15]:
output_signature = {
    'image_path': tf.TensorSpec(shape=(), dtype=tf.string),
    'report': tf.TensorSpec(shape=(), dtype=tf.string)
}

ds = tf.data.Dataset.from_generator(
    lambda: dataset_generator(radiology_df),
    output_signature=output_signature
)

In [16]:
def map_func(image_path, report):
    image_path = image_path.numpy().decode('utf-8')
    report = report.numpy().decode('utf-8')
    img, seq = preprocess_sample({'image_path': image_path, 'report': report})
    return img, seq

def tf_map_func(image_path, report):
    img, seq = tf.py_function(
        func=map_func,
        inp=[image_path, report],
        Tout=[tf.float32, tf.int32]
    )
    img.set_shape((224, 224, 3))
    seq.set_shape((max_length,))
    return img, seq

In [17]:
batch_size = 16
AUTOTUNE = tf.data.AUTOTUNE

In [18]:
ds = ds.shuffle(buffer_size=len(radiology_df))
ds = ds.map(lambda sample: tf_map_func(sample['image_path'], sample['report']),
            num_parallel_calls=AUTOTUNE)
ds = ds.batch(batch_size)
ds = ds.prefetch(AUTOTUNE)

In [19]:
embedding_dim = 256
num_transformer_blocks = 4
num_heads = 8
ff_dim = 512
image_size = 224

In [20]:
def create_vit_encoder(image_size):
    inputs = layers.Input(shape=(image_size, image_size, 3))
    patch_size = 16
    num_patches = (image_size // patch_size) ** 2
    projection_dim = 768

    patches = layers.Conv2D(
        filters=projection_dim,
        kernel_size=patch_size,
        strides=patch_size,
        padding="valid"
    )(inputs)
    patches = layers.Reshape((num_patches, projection_dim))(patches)

    positional_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
    positions = tf.range(start=0, limit=num_patches, delta=1)
    encoded_patches = patches + positional_embedding(positions)

    for _ in range(num_transformer_blocks):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim)(x1, x1)
        x2 = layers.Add()([x1, attention_output])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        ffn_output = layers.Dense(ff_dim, activation="relu")(x3)
        ffn_output = layers.Dense(projection_dim)(ffn_output)
        encoded_patches = layers.Add()([x2, ffn_output])

    model = models.Model(inputs, encoded_patches, name="vit_encoder")
    return model

def create_image_encoder(image_size, embedding_dim):
    inputs = layers.Input(shape=(image_size, image_size, 3))
    vit = create_vit_encoder(image_size)
    x = vit(inputs)
    x = layers.GlobalAveragePooling1D()(x)
    # Project to shared embedding dimension
    outputs = layers.Dense(embedding_dim)(x)
    model = models.Model(inputs, outputs, name="image_encoder")
    return model

def create_text_encoder(vocab_size, embedding_dim, max_length):
    inputs = layers.Input(shape=(max_length,))
    x = layers.Embedding(vocab_size, embedding_dim, mask_zero=True)(inputs)

    positional_embeddings = layers.Embedding(max_length, embedding_dim)(tf.range(start=0, limit=max_length, delta=1))
    x = x + positional_embeddings

    for _ in range(num_transformer_blocks):
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embedding_dim)(x1, x1)
        x2 = layers.Add()([x1, attention_output])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        ffn_output = layers.Dense(ff_dim, activation="relu")(x3)
        ffn_output = layers.Dense(embedding_dim)(ffn_output)
        x = layers.Add()([x2, ffn_output])
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(embedding_dim)(x)
    model = models.Model(inputs, outputs, name="text_encoder")
    return model


In [21]:
image_encoder = create_image_encoder(image_size, embedding_dim)
text_encoder = create_text_encoder(vocab_size, embedding_dim, max_length)

In [22]:
class CLIPModel(tf.keras.Model):
    def __init__(self, image_encoder, text_encoder, temperature=0.05):
        super(CLIPModel, self).__init__()
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.temperature = tf.Variable(temperature, trainable=True, dtype=tf.float32)

    def compile(self, optimizer):
        super(CLIPModel, self).compile()
        self.optimizer = optimizer
        self.loss_tracker = tf.keras.metrics.Mean(name="loss")

    @property
    def metrics(self):
        return [self.loss_tracker]

    def train_step(self, data):
        images, texts = data  # images: (batch, H, W, 3); texts: (batch, max_length)

        with tf.GradientTape() as tape:
            img_embeddings = self.image_encoder(images, training=True)
            txt_embeddings = self.text_encoder(texts, training=True)

            img_embeddings = tf.math.l2_normalize(img_embeddings, axis=1)
            txt_embeddings = tf.math.l2_normalize(txt_embeddings, axis=1)

            # Compute similarity logits: (batch, batch)
            logits = tf.matmul(img_embeddings, txt_embeddings, transpose_b=True)
            logits = logits / self.temperature

            # Ground truth: diagonal elements are the matching pairs.
            batch_size = tf.shape(images)[0]
            labels = tf.range(batch_size)

            # Compute cross-entropy loss for image->text and text->image
            loss_i2t = tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
            loss_t2i = tf.keras.losses.sparse_categorical_crossentropy(labels, tf.transpose(logits), from_logits=True)
            loss = (loss_i2t + loss_t2i) / 2.0
            # implemented indirectly via cross-entropy on the similarity logits

        # Compute gradients and update weights
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def call(self, inputs):
        images, texts = inputs
        img_embeddings = tf.math.l2_normalize(self.image_encoder(images, training=False), axis=1)
        txt_embeddings = tf.math.l2_normalize(self.text_encoder(texts, training=False), axis=1)
        return img_embeddings, txt_embeddings

In [23]:
clip_model = CLIPModel(image_encoder, text_encoder, temperature=0.05)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
clip_model.compile(optimizer=optimizer)

In [29]:
epochs = 10 #you can increase it
clip_model.fit(ds, epochs=epochs)

Epoch 1/10
[1m465/465[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m435s[0m 929ms/step - loss: 2.7765
Epoch 2/10
[1m465/465[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m440s[0m 925ms/step - loss: 2.7726
Epoch 3/10
[1m465/465[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m442s[0m 925ms/step - loss: 2.7726
Epoch 4/10
[1m465/465[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m442s[0m 926ms/step - loss: 2.7726
Epoch 5/10
[1m465/465[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m442s[0m 925ms/step - loss: 2.7726
Epoch 6/10
[1m465/465[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m441s[0m 924ms/step - loss: 2.7726
Epoch 7/10
[1m465/465[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m442s[0m 925ms/step - loss: 2.7726
Epoch 8/10
[1m465/465[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m442s[0m 925ms/step - loss: 2.7726
Epoch 9/10
[1m465/465[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m443s[0m 928ms/step - loss: 2.7726
Epoch 10/10
[1m465/465[0m [32m━━━━━━━━━━━━━━━━━━━━[

<keras.src.callbacks.history.History at 0x7eee8446d310>

In [30]:
def get_text_embedding(text):
    seq = tokenizer.texts_to_sequences([text])[0]
    seq = tf.keras.preprocessing.sequence.pad_sequences([seq], maxlen=max_length, padding='post')
    emb = text_encoder(seq, training=False)
    emb = tf.math.l2_normalize(emb, axis=1)
    return emb

def get_image_embedding(image_path):
    img = load_img(image_path, target_size=(image_size, image_size))
    img = img_to_array(img) / 255.0
    img = np.expand_dims(img, axis=0)
    emb = image_encoder(img, training=False)
    emb = tf.math.l2_normalize(emb, axis=1)
    return emb

test_image = 'xr.jpeg'
img_emb = get_image_embedding(test_image)
print("Image embedding:", img_emb.numpy())

all_reports = radiology_df['report'].tolist()
all_text_embeddings = []
for report in all_reports:
    emb = get_text_embedding(report)
    all_text_embeddings.append(emb.numpy()[0])
all_text_embeddings = np.array(all_text_embeddings)

cosine_sim = np.dot(all_text_embeddings, img_emb.numpy()[0])
best_idx = np.argmax(cosine_sim)
print("\nBest matching report:")
print(all_reports[best_idx])

Image embedding: [[ 0.05733508  0.00523869  0.06383529  0.05051493  0.04355091  0.02500839
  -0.07777718  0.1457428  -0.05198224  0.07125071 -0.0904338   0.06187361
  -0.09715363  0.12147623 -0.01999591  0.03804174 -0.03612426  0.1149757
  -0.01391032 -0.00195657  0.1456474  -0.07698674 -0.04502937 -0.03379296
   0.07548678  0.00504048  0.04394463  0.02888563 -0.03367941  0.08016434
  -0.11808249  0.11248635  0.13601714  0.00332339  0.01640047  0.04144336
   0.09689321  0.02218678  0.09288475  0.08762556  0.04284214 -0.04380634
   0.00071143 -0.04711069 -0.02884681  0.02897729  0.01118722 -0.03120163
   0.05351967 -0.04563073 -0.01525307  0.02657428  0.06119745 -0.07053658
  -0.07705502 -0.01606986  0.08667487 -0.05638002  0.02364861 -0.10039265
   0.03728671 -0.04434197 -0.0638153  -0.02112338 -0.0598181   0.09828798
   0.07277232  0.01066731  0.04956194 -0.01133192 -0.00738006  0.01844435
  -0.05654645  0.00830322 -0.07037811  0.03265889  0.0075481  -0.02276928
   0.00311648 -0.04058