<a href="https://colab.research.google.com/github/jasleenkaursandhu/Reproducing-chest-xray-report-generation-boag/blob/main/3gram.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# N-gram Model for Report Generation
# This notebook implements an n-gram language model for chest X-ray report generation.
# The model selects the most similar training images for each test image and builds a language model from their reports.


In [2]:
# docker image code to made files labelled for chexpert
# For 1-gram
##docker run --platform linux/amd64 -v /Users/jasleensandhu/Desktop/CS598DLH:/data uwizeye2/chexpert-labeler:amd64 python label.py --reports_path /data/1gram_headerless.csv --output_path /data/output/labeled_1gram.csv --verbose

# For 2-gram
##docker run --platform linux/amd64 -v /Users/jasleensandhu/Desktop/CS598DLH:/data uwizeye2/chexpert-labeler:amd64 python label.py --reports_path /data/2gram_headerless.csv --output_path /data/output/labeled_2gram.csv --verbose

# For 3-gram
##docker run --platform linux/amd64 -v /Users/jasleensandhu/Desktop/CS598DLH:/data uwizeye2/chexpert-labeler:amd64 python label.py --reports_path /data/3gram_headerless.csv --output_path /data/output/labeled_3gram.csv --verbose

In [3]:
# Import necessary libraries
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from PIL import Image
import tqdm
from collections import defaultdict, Counter
import pickle
import gzip
import random
import re
import warnings
!pip install pydicom
import pydicom
from time import gmtime, strftime

Collecting pydicom
  Downloading pydicom-3.0.1-py3-none-any.whl.metadata (9.4 kB)
Downloading pydicom-3.0.1-py3-none-any.whl (2.4 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/2.4 MB[0m [31m6.2 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━[0m [32m2.1/2.4 MB[0m [31m30.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pydicom
Successfully installed pydicom-3.0.1


In [4]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
base_path = '/content/drive/MyDrive/mimic-cxr-project'
!mkdir -p {base_path}/data
!mkdir -p {base_path}/output

Mounted at /content/drive


In [5]:
# Import the report parser module
import sys
sys.path.append(f"{base_path}/modules")
from report_parser import parse_report, MIMIC_RE
print("Successfully imported report parser module")

Successfully imported report parser module


In [6]:
# Load train and test data
data_dir = os.path.join(base_path, 'data')
files_path = os.path.join(base_path, 'files')
output_dir = os.path.join(base_path, 'output')
features_dir = os.path.join(base_path, 'features')

train_df = pd.read_csv(os.path.join(base_path, 'local_output/data/train.tsv'), sep='\t')
test_df = pd.read_csv(os.path.join(base_path, 'local_output/data/test.tsv'), sep='\t')

print(f"Train data shape: {train_df.shape}")
print(f"Test data shape: {test_df.shape}")

Train data shape: (4291, 3)
Test data shape: (1757, 3)


In [7]:
# Load the top 100 neighbors for each test image
neighbors_path = os.path.join(output_dir, 'top100_neighbors.pkl')

if os.path.exists(neighbors_path):
    with open(neighbors_path, 'rb') as f:
        neighbors = pickle.load(f)

    print(f"Loaded neighbors for {len(neighbors)} test images")
    print(f"Sample neighbors for first test image: {list(neighbors.items())[0][1][:5]}...")
else:
    print(f"Warning: Neighbors file not found at {neighbors_path}")
    print("Please run the KNN model first to generate the neighbors file.")
    neighbors = {}
    for dicom_id in test_df.dicom_id.values:
        neighbors[dicom_id] = random.sample(train_df.dicom_id.tolist(), min(100, len(train_df)))

Loaded neighbors for 868 test images
Sample neighbors for first test image: ['9368d351-90051a0c-c8dc80c0-96a2254d-2a884177', '6cad5cfd-bf8d1805-f3a3906a-588e0bfe-fe5c153d', 'd6c4c0d6-f12c2415-a5b3d2df-bcc4a202-442d3523', 'd18b8527-4c5a26fd-ac1041dd-8f9fd75f-4327f46f', '6083cad9-c1727fa2-b948d959-763ca27c-6391e63b']...


In [8]:
# Map each dicom to its study_id
report_id_column = 'study_id'
report_lookup = dict(train_df[['dicom_id', report_id_column]].values)
print(f"Created lookup dictionary for {len(report_lookup)} training images")

Created lookup dictionary for 4291 training images


In [9]:
# Define tokens for sequence boundaries
START = '<START>'
END = '<END>'

# Build n-gram language model from neighbors
def fit(dicom_ids, n=3):
    """Build language model from the reports of the given dicom_ids"""
    # Language model maps context (n-1 previous words) to possible next words
    LM = defaultdict(Counter)

    for dicom_id in dicom_ids:
        if dicom_id not in report_lookup:
            continue

        report_id = report_lookup[dicom_id]

        # Get corresponding subject_id
        subject_row = train_df[train_df.dicom_id == dicom_id]
        if len(subject_row) == 0:
            continue

        subject_id = subject_row.iloc[0]['subject_id']

        # Construct path to the report
        subject_prefix = f"p{str(subject_id)[:2]}"
        subject_dir = f"p{subject_id}"
        study_dir = f"s{report_id}"
        report_file = f"{study_dir}.txt"
        report_path = os.path.join(files_path, subject_prefix, subject_dir, report_file)

        # Parse the report
        try:
            if os.path.exists(report_path):
                parsed_report = parse_report(report_path)

                if 'findings' in parsed_report:
                    # Tokenize the findings text
                    toks = parsed_report['findings'].replace('.', ' . ').split()

                    # Add padding tokens at the beginning and END token at the end
                    padded_toks = [START for _ in range(n-1)] + toks + [END]

                    # Build n-gram model by counting follow words for each context
                    for i in range(len(padded_toks) - n + 1):
                        context = tuple(padded_toks[i:i+n-1])
                        target = padded_toks[i+n-1]
                        sim = 1
                        LM[context][target] += sim
        except Exception as e:
            continue

    return LM

In [10]:
# Sample from the n-gram model
def sample(LM, seq_so_far, n):
    """Sample the next word based on the n-gram language model"""
    last = tuple(seq_so_far[-(n-1):])

    if last not in LM or not LM[last]:
        # If context not found in model, return END token
        return END

    words, counts = zip(*LM[last].items())
    total = sum(counts)
    P = np.array(counts) / total

    # Sample next word based on probabilities
    choice = np.random.choice(words, p=P)
    return choice

In [11]:
# Set n-gram size
n = 3

# Generate reports for test images
generated_reports = {}

print(f"Generating reports with {n}-gram model...")
for pred_dicom in tqdm.tqdm(test_df.dicom_id.values):
    # Skip if we don't have neighbors for this test image
    if pred_dicom not in neighbors:
        print(f"Warning: No neighbors for {pred_dicom}")
        continue

    # Build n-gram model from the neighbors' reports
    nn = neighbors[pred_dicom]
    LM = fit(nn, n=n)

    # Skip if model is empty
    if not LM:
        print(f"Warning: Empty language model for {pred_dicom}")
        continue

    # Handle initialization differently for different n values
    if n == 1:
        # For 1-gram, we don't need context
        generated_toks = []
        current = START  # Start token won't actually be used for context
    else:
        # For n > 1, initialize with n-1 START tokens
        generated_toks = [START for _ in range(n-1)]
        current = generated_toks[-1]

    # Generate until END token or max length
    while current != END and len(generated_toks) < 100:
        next_word = sample(LM, generated_toks, n)
        generated_toks.append(next_word)
        current = next_word

    # Remove START tokens (if any) and potentially END token
    if n > 1:
        generated_toks = generated_toks[n-1:]
    if generated_toks and generated_toks[-1] == END:
        generated_toks = generated_toks[:-1]

    # Join tokens into text
    generated_text = ' '.join(generated_toks)
    generated_reports[pred_dicom] = generated_text

print(f"Generated reports for {len(generated_reports)} test images")

Generating reports with 3-gram model...


  0%|          | 0/1757 [00:00<?, ?it/s]



  0%|          | 5/1757 [06:00<23:43:04, 48.74s/it]



  1%|          | 14/1757 [06:22<5:17:35, 10.93s/it]



  1%|          | 18/1757 [06:49<4:37:45,  9.58s/it]



  2%|▏         | 38/1757 [08:17<3:20:48,  7.01s/it]



  3%|▎         | 44/1757 [08:37<2:13:10,  4.66s/it]



  3%|▎         | 49/1757 [08:42<1:20:42,  2.84s/it]



  4%|▍         | 72/1757 [09:11<47:33,  1.69s/it]



  7%|▋         | 125/1757 [09:46<37:52,  1.39s/it]



  8%|▊         | 134/1757 [10:10<49:53,  1.84s/it]



  9%|▊         | 150/1757 [10:14<28:43,  1.07s/it]



 10%|▉         | 167/1757 [10:21<21:30,  1.23it/s]



 10%|▉         | 170/1757 [10:26<23:47,  1.11it/s]



 11%|█         | 192/1757 [10:50<1:09:06,  2.65s/it]



 12%|█▏        | 212/1757 [11:11<33:07,  1.29s/it]



 13%|█▎        | 222/1757 [11:14<21:48,  1.17it/s]



 14%|█▍        | 252/1757 [11:18<07:45,  3.23it/s]



 15%|█▌        | 265/1757 [11:30<12:50,  1.94it/s]



 16%|█▌        | 277/1757 [11:31<09:22,  2.63it/s]



 16%|█▌        | 279/1757 [11:35<11:58,  2.06it/s]



 17%|█▋        | 294/1757 [11:44<15:04,  1.62it/s]



 18%|█▊        | 313/1757 [12:18<38:04,  1.58s/it]



 19%|█▊        | 326/1757 [12:20<11:59,  1.99it/s]



 20%|█▉        | 343/1757 [12:23<08:09,  2.89it/s]



 21%|██        | 367/1757 [12:24<03:33,  6.52it/s]



 21%|██        | 372/1757 [12:25<03:49,  6.04it/s]



 22%|██▏       | 383/1757 [12:30<09:28,  2.42it/s]



 22%|██▏       | 395/1757 [12:34<10:11,  2.23it/s]



 23%|██▎       | 399/1757 [12:34<07:17,  3.11it/s]



 24%|██▎       | 416/1757 [12:36<04:27,  5.01it/s]



 24%|██▍       | 421/1757 [12:42<13:35,  1.64it/s]



 24%|██▍       | 429/1757 [12:44<10:22,  2.13it/s]



 25%|██▍       | 431/1757 [12:45<08:28,  2.61it/s]



 26%|██▌       | 454/1757 [12:46<02:03, 10.56it/s]



 28%|██▊       | 490/1757 [12:46<00:43, 29.44it/s]



 28%|██▊       | 496/1757 [12:47<00:58, 21.66it/s]



 29%|██▊       | 501/1757 [12:50<03:02,  6.88it/s]



 29%|██▊       | 504/1757 [12:53<05:22,  3.89it/s]



 29%|██▉       | 510/1757 [12:54<04:23,  4.72it/s]



 29%|██▉       | 516/1757 [12:54<03:47,  5.45it/s]



 30%|██▉       | 519/1757 [12:55<04:00,  5.14it/s]



 30%|██▉       | 527/1757 [12:55<02:41,  7.61it/s]



 30%|███       | 529/1757 [13:01<11:02,  1.85it/s]



 30%|███       | 535/1757 [13:03<08:48,  2.31it/s]



 32%|███▏      | 561/1757 [13:03<01:37, 12.21it/s]



 32%|███▏      | 566/1757 [13:08<04:30,  4.40it/s]



 34%|███▎      | 589/1757 [13:08<01:50, 10.55it/s]



 34%|███▍      | 602/1757 [13:09<01:53, 10.14it/s]



 37%|███▋      | 642/1757 [13:18<03:59,  4.66it/s]



 37%|███▋      | 646/1757 [13:20<04:42,  3.93it/s]



 37%|███▋      | 652/1757 [13:21<03:30,  5.25it/s]



 38%|███▊      | 660/1757 [13:23<05:24,  3.38it/s]



 39%|███▊      | 679/1757 [13:27<02:13,  8.07it/s]



 39%|███▉      | 691/1757 [13:28<02:30,  7.08it/s]



 41%|████▏     | 726/1757 [13:29<00:38, 26.89it/s]



 42%|████▏     | 740/1757 [13:36<04:08,  4.10it/s]



 43%|████▎     | 760/1757 [13:37<02:18,  7.22it/s]



 45%|████▌     | 798/1757 [13:38<00:47, 20.03it/s]



 46%|████▌     | 806/1757 [13:39<01:15, 12.67it/s]



 46%|████▋     | 817/1757 [13:41<01:45,  8.87it/s]



 48%|████▊     | 852/1757 [13:41<00:38, 23.60it/s]



 49%|████▉     | 860/1757 [13:42<00:49, 18.30it/s]



 50%|█████     | 884/1757 [13:44<00:43, 19.93it/s]



 51%|█████     | 890/1757 [13:46<01:21, 10.61it/s]



 51%|█████     | 894/1757 [13:46<01:40,  8.56it/s]



 51%|█████     | 899/1757 [13:47<01:48,  7.94it/s]



 52%|█████▏    | 905/1757 [13:48<01:53,  7.48it/s]



 54%|█████▎    | 940/1757 [13:50<01:05, 12.43it/s]



 56%|█████▌    | 980/1757 [13:51<00:34, 22.22it/s]



 56%|█████▋    | 990/1757 [13:51<00:35, 21.71it/s]



 58%|█████▊    | 1013/1757 [13:52<00:33, 22.34it/s]



 58%|█████▊    | 1016/1757 [13:53<00:35, 20.95it/s]



 60%|█████▉    | 1049/1757 [13:53<00:15, 44.57it/s]



 60%|██████    | 1062/1757 [13:55<00:45, 15.39it/s]



 62%|██████▏   | 1097/1757 [13:56<00:38, 17.10it/s]



 65%|██████▍   | 1136/1757 [13:57<00:17, 35.81it/s]



 66%|██████▌   | 1160/1757 [13:57<00:15, 39.32it/s]



 66%|██████▋   | 1168/1757 [13:58<00:17, 33.92it/s]



 67%|██████▋   | 1176/1757 [13:58<00:16, 35.98it/s]



 68%|██████▊   | 1188/1757 [13:58<00:16, 34.54it/s]



 68%|██████▊   | 1193/1757 [14:00<00:39, 14.24it/s]



 70%|███████   | 1233/1757 [14:00<00:12, 42.16it/s]



 71%|███████   | 1245/1757 [14:00<00:11, 45.58it/s]



 73%|███████▎  | 1282/1757 [14:02<00:15, 30.10it/s]



 74%|███████▍  | 1304/1757 [14:04<00:20, 21.88it/s]



 75%|███████▌  | 1319/1757 [14:06<00:35, 12.41it/s]



 76%|███████▋  | 1341/1757 [14:07<00:21, 19.43it/s]



 77%|███████▋  | 1357/1757 [14:08<00:20, 19.57it/s]



 78%|███████▊  | 1366/1757 [14:08<00:20, 18.96it/s]



 79%|███████▊  | 1380/1757 [14:09<00:17, 22.07it/s]



 79%|███████▉  | 1391/1757 [14:10<00:18, 19.72it/s]



 79%|███████▉  | 1394/1757 [14:10<00:18, 19.11it/s]



 80%|███████▉  | 1397/1757 [14:10<00:20, 17.24it/s]



 82%|████████▏ | 1444/1757 [14:12<00:12, 25.62it/s]



 85%|████████▌ | 1501/1757 [14:12<00:04, 57.36it/s]



 86%|████████▌ | 1512/1757 [14:14<00:10, 22.45it/s]



 87%|████████▋ | 1522/1757 [14:14<00:09, 25.27it/s]



 88%|████████▊ | 1553/1757 [14:15<00:05, 35.37it/s]



 91%|█████████ | 1597/1757 [14:15<00:02, 64.95it/s]



 93%|█████████▎| 1626/1757 [14:16<00:01, 66.53it/s]



 98%|█████████▊| 1723/1757 [14:16<00:00, 145.27it/s]



100%|██████████| 1757/1757 [14:16<00:00,  2.05it/s] 

Generated reports for 286 test images





In [12]:
# Save the generated reports
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

# Create output directory if it doesn't exist
pred_dir = os.path.join(base_path, 'output')
os.makedirs(pred_dir, exist_ok=True)

# Save the generated reports
pred_file = os.path.join(pred_dir, f'{n}-gram.tsv')
print(f"Saving predictions to {pred_file}")

with open(pred_file, 'w') as f:
    print('dicom_id\tgenerated', file=f)
    for dicom_id, generated in sorted(generated_reports.items()):
        # Clean up the text (remove any tabs)
        cleaned_text = generated.replace('\t', ' ')
        print(f'{dicom_id}\t{cleaned_text}', file=f)

print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

2025-04-21 05:16:33
Saving predictions to /content/drive/MyDrive/mimic-cxr-project/output/3-gram.tsv
2025-04-21 05:16:34


In [13]:
# Display sample of generated reports
sample_count = min(3, len(generated_reports))
sample_dicoms = list(generated_reports.keys())[:sample_count]

for dicom_id in sample_dicoms:
    print(f"\nSample report for {dicom_id}:")
    report_text = generated_reports[dicom_id]

    # Print preview of the report (first 200 characters)
    if len(report_text) > 200:
        print(report_text[:200] + "...")
    else:
        print(report_text)


Sample report for 63100eab-9e8a8d90-392bc822-325de482-69a64e3b:
the patient in the appropriate clinical setting . there is bibasilar atelectasis is noted above the thoracic inlet which may reflect atelectasis but could be developing aspiration pneumonia . minimal ...

Sample report for 17269efa-b016a94d-1361e8df-ac428071-d1133672:
extensive right lower lobe . with the endotracheal tube tip courses below the field of view .

Sample report for 247d5e7b-66c77989-ca5fec41-608aaa71-eab4c699:
cardiomediastinal and hilar silhouettes are stable . no effusion
