In [1]:
import os
import random
import json
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.image import load_img, img_to_array

In [2]:
!wget https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz

--2025-01-24 20:11:57--  https://openi.nlm.nih.gov/imgs/collections/NLMCXR_png.tgz
Resolving openi.nlm.nih.gov (openi.nlm.nih.gov)... 130.14.65.157, 2607:f220:41e:7065::157
Connecting to openi.nlm.nih.gov (openi.nlm.nih.gov)|130.14.65.157|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1360814128 (1.3G) [application/x-gzip]
Saving to: ‘NLMCXR_png.tgz’


2025-01-24 20:22:51 (1.99 MB/s) - ‘NLMCXR_png.tgz’ saved [1360814128/1360814128]



In [3]:
!wget https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz

--2025-01-24 20:22:51--  https://openi.nlm.nih.gov/imgs/collections/NLMCXR_reports.tgz
Resolving openi.nlm.nih.gov (openi.nlm.nih.gov)... 130.14.65.157, 2607:f220:41e:7065::157
Connecting to openi.nlm.nih.gov (openi.nlm.nih.gov)|130.14.65.157|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1112632 (1.1M) [application/x-gzip]
Saving to: ‘NLMCXR_reports.tgz’


2025-01-24 20:22:52 (1.26 MB/s) - ‘NLMCXR_reports.tgz’ saved [1112632/1112632]



In [4]:
!mkdir workdir

In [5]:
%cd workdir

/content/workdir


In [None]:
!tar -xvzf /content/NLMCXR_png.tgz -C .
!tar -xvzf /content/NLMCXR_reports.tgz -C .

In [7]:
%cd ..

/content


In [8]:
import xml.etree.ElementTree as ET

In [9]:
!mv workdir/ecgen-radiology .

In [10]:
def find_valid_pairs(image_dir='workdir', report_dir='ecgen-radiology'):
    valid_pairs = []
    report_files = {}

    for root, _, files in os.walk(report_dir):
        for f in files:
            if f.lower().endswith('.xml'):
                base_name = os.path.splitext(f)[0].lower()
                report_files[base_name] = os.path.join(root, f)

    for root, _, files in os.walk(image_dir):
        for f in files:
            if f.lower().endswith('.png'):
                base_name = os.path.splitext(f)[0].lower()
                if base_name in report_files:
                    valid_pairs.append((
                        os.path.join(root, f),
                        report_files[base_name]
                    ))
    return valid_pairs

In [11]:
def parse_xml_report(xml_path):
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        ns = {'pmc': 'http://www.ncbi.nlm.nih.gov/pmc/articles/PMC'}

        findings_elem = root.find(".//pmc:AbstractText[@Label='FINDINGS']", ns)
        impression_elem = root.find(".//pmc:AbstractText[@Label='IMPRESSION']", ns)

        findings = findings_elem.text.strip() if (findings_elem is not None and findings_elem.text) else ""
        impression = impression_elem.text.strip() if (impression_elem is not None and impression_elem.text) else ""

        if not findings and not impression:
            return None

        return f"FINDINGS: {findings}. IMPRESSION: {impression}"

    except Exception as e:
        print(f"Error parsing {xml_path}: {str(e)}")
        return None

In [12]:
import os
import xml.etree.ElementTree as ET
import pandas as pd

def parse_xml_report(xml_path):
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        ns = {'pmc': 'http://www.ncbi.nlm.nih.gov/pmc/articles/PMC'}

        findings = root.find(".//AbstractText[@Label='FINDINGS']", ns)
        impression = root.find(".//AbstractText[@Label='IMPRESSION']", ns)

        findings_text = findings.text.strip() if (findings is not None and findings.text) else ""
        impression_text = impression.text.strip() if (impression is not None and impression.text) else ""

        if not findings_text and not impression_text:
            return None, []

        image_ids = [img.get('id') for img in root.findall('.//parentImage')]

        image_paths = [f"workdir/{img_id}.png" for img_id in image_ids]

        return (
            f"FINDINGS: {findings_text}. IMPRESSION: {impression_text}",
            image_paths
        )

    except Exception as e:
        print(f"Error parsing {xml_path}: {str(e)}")
        return None, []

def create_dataset():
    data = []

    for root_dir, _, files in os.walk("ecgen-radiology"):
        for file in files:
            if file.lower().endswith('.xml'):
                xml_path = os.path.join(root_dir, file)
                report_text, image_paths = parse_xml_report(xml_path)

                if report_text and image_paths:
                    for img_path in image_paths:
                        if os.path.exists(img_path):
                            data.append({
                                'image_path': img_path,
                                'report': report_text
                            })
                        else:
                            print(f"Missing image: {img_path}")

    return pd.DataFrame(data)

radiology_df = create_dataset()
print(f"Found {len(radiology_df)} valid image-report pairs")

if not radiology_df.empty:
    print("\nSample entry:")
    print(f"Image path: {radiology_df.iloc[0]['image_path']}")
    print(f"Report: {radiology_df.iloc[0]['report']}")
else:
    print("\nNo valid pairs found. Check:")
    print("- XML files in ecgen-radiology directory")
    print("- Image files in workdir directory")
    print("- File naming consistency between XML and PNG files")

Found 7430 valid image-report pairs

Sample entry:
Image path: workdir/CXR839_IM-2363-1001.png
Report: FINDINGS: Heart size normal. No pleural effusions or pneumothorax. Lungs are clear. Soft tissues and XXXX are unremarkable.. IMPRESSION: Normal chest.


In [13]:
tokenizer = Tokenizer(
    num_words=5000,
    oov_token="<unk>",
    filters='!"#$%&()*+.,-/:;=?@[\]^_`{|}~ '
)

In [14]:
tokenizer.fit_on_texts(radiology_df['report'])
vocab_size = len(tokenizer.word_index) + 1
max_length = 100  # Increase it for longer reports

In [None]:
print([word for word in tokenizer.word_index if "pneumonia" in word or "effusion" in word][:10])

In [16]:
from tensorflow.keras.utils import Sequence
import numpy as np

class RadiologyDataGenerator(Sequence):
    def __init__(self,
                 dataframe,
                 tokenizer,
                 batch_size=8,
                 image_size=224,
                 max_length=100,
                 shuffle=True):

        self.df = dataframe
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.image_size = image_size
        self.max_length = max_length
        self.shuffle = shuffle
        self.indexes = np.arange(len(self.df))

        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))

    def __getitem__(self, index):
        batch_indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        batch_df = self.df.iloc[batch_indexes]

        batch_images = []
        batch_sequences = []

        for _, row in batch_df.iterrows():
            img = load_img(row['image_path'], target_size=(self.image_size, self.image_size))
            img_array = img_to_array(img) / 255.0
            batch_images.append(img_array)

            sequence = self.tokenizer.texts_to_sequences([row['report']])[0]
            padded_sequence = tf.keras.preprocessing.sequence.pad_sequences(
                [sequence], maxlen=self.max_length, padding='post'
            )[0]
            batch_sequences.append(padded_sequence)

        batch_images = np.array(batch_images)
        batch_sequences = np.array(batch_sequences)

        batch_targets = np.zeros_like(batch_sequences)
        batch_targets[:, :-1] = batch_sequences[:, 1:]

        return {'input_layer_2': batch_images, 'input_layer_3': batch_sequences}, batch_targets

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def get_sample(self, index=0):
        if index >= self.__len__():
            index = 0
        return self.__getitem__(index)

In [17]:
image_size = 224
vocab_size = len(tokenizer.word_index) + 1
max_length = 100
batch_size = 16  # Reduced for Colab memory
#vocab_size = 5001
#max_length = 50
embedding_dim = 256
num_heads = 8
ff_dim = 512
num_transformer_blocks = 4

In [18]:
def create_vit_encoder(image_size):
    inputs = layers.Input(shape=(image_size, image_size, 3))

    patch_size = 16
    num_patches = (image_size // patch_size) ** 2
    projection_dim = 768

    patches = layers.Conv2D(
        filters=projection_dim,
        kernel_size=patch_size,
        strides=patch_size,
        padding="valid"
    )(inputs)

    patches = layers.Reshape((num_patches, projection_dim))(patches)

    positional_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)
    positions = tf.range(start=0, limit=num_patches, delta=1)
    encoded_patches = patches + positional_embedding(positions)

    for _ in range(num_transformer_blocks):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim)(x1, x1)
        x2 = layers.Add()([x1, attention_output])

        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        ffn_output = layers.Dense(ff_dim, activation="relu")(x3)
        ffn_output = layers.Dense(projection_dim)(ffn_output)
        encoded_patches = layers.Add()([x2, ffn_output])

    model = models.Model(inputs, encoded_patches)
    return model

def create_text_decoder(vocab_size, embedding_dim, max_length):
    inputs = layers.Input(shape=(max_length,))

    word_embeddings = layers.Embedding(vocab_size, embedding_dim)(inputs)
    positional_embeddings = layers.Embedding(max_length, embedding_dim)(tf.range(start=0, limit=max_length, delta=1))
    embeddings = word_embeddings + positional_embeddings

    x = embeddings
    for _ in range(num_transformer_blocks):
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embedding_dim)(x1, x1)
        x2 = layers.Add()([x1, attention_output])

        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        ffn_output = layers.Dense(ff_dim, activation="relu")(x3)
        ffn_output = layers.Dense(embedding_dim)(ffn_output)
        x = layers.Add()([x2, ffn_output])

    outputs = layers.Dense(vocab_size, activation="softmax")(x)
    model = models.Model(inputs, outputs)
    return model

def create_image_captioning_model(image_size, vocab_size, embedding_dim, max_length):
    vit_encoder = create_vit_encoder(image_size)
    text_decoder = create_text_decoder(vocab_size, embedding_dim, max_length)

    image_inputs = layers.Input(shape=(image_size, image_size, 3), name='input_layer_2')
    text_inputs = layers.Input(shape=(max_length,), name='input_layer_3')

    encoded_image = vit_encoder(image_inputs)

    encoded_image = layers.GlobalAveragePooling1D()(encoded_image)
    encoded_image = layers.Dense(embedding_dim, activation="relu")(encoded_image)
    encoded_image = layers.RepeatVector(max_length)(encoded_image)


    embeddings = layers.Concatenate(axis=2)([encoded_image, text_decoder(text_inputs)])

    outputs = layers.Dense(vocab_size, activation="softmax")(embeddings)

    model = models.Model(inputs=[image_inputs, text_inputs], outputs=outputs)
    return model


In [19]:
train_generator = RadiologyDataGenerator(
    dataframe=radiology_df,
    tokenizer=tokenizer,
    batch_size=8,
    image_size=224,
    max_length=100,
    shuffle=True
)

sample_batch = train_generator.get_sample(0)
print("Batch shapes:")
print(f"Images: {sample_batch[0]['input_layer_2'].shape}")
print(f"Sequences: {sample_batch[0]['input_layer_3'].shape}")
print(f"Targets: {sample_batch[1].shape}")

Batch shapes:
Images: (8, 224, 224, 3)
Sequences: (8, 100)
Targets: (8, 100)


In [20]:
model = create_image_captioning_model(
    image_size=image_size,
    vocab_size=vocab_size,
    embedding_dim=256,
    max_length=max_length
)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
history = model.fit(
    train_generator,
    steps_per_epoch=len(train_generator),
    epochs=20
)

Epoch 1/20


  self._warn_if_super_not_called()


[1m795/929[0m [32m━━━━━━━━━━━━━━━━━[0m[37m━━━[0m [1m42s[0m 318ms/step - accuracy: 0.6028 - loss: 2.8754

In [21]:
def generate_radiology_report(model, image_path):
    img = load_img(image_path, target_size=(image_size, image_size))
    img = img_to_array(img) / 255.0
    input_seq = np.zeros((1, max_length))

    for i in range(max_length):
        pred = model.predict({'input_layer_2': np.expand_dims(img, 0), 'input_layer_3': input_seq})
        predicted_id = np.argmax(pred[0, i])
        if predicted_id == tokenizer.word_index.get('<end>', -1) or i == max_length-1:
            break
        input_seq[0, i] = predicted_id

    return ' '.join([tokenizer.index_word.get(int(id), '') for id in input_seq[0] if id != 0])

test_image = 'abdomen.jpg'
print("\nGenerated Report:")
print(generate_radiology_report(model, test_image))


Generated Report:
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m