<a href="https://colab.research.google.com/github/jasleenkaursandhu/Reproducing-chest-xray-report-generation-boag/blob/main/ngram.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# N-gram Models for Report Generation
# This notebook implements 1-gram, 2-gram, and 3-gram language models for chest X-ray report generation.

In [2]:
# Import necessary libraries
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from PIL import Image
import tqdm
from collections import defaultdict, Counter
import pickle
import gzip
import random
import re
import warnings
!pip install pydicom
import pydicom
from time import gmtime, strftime

Collecting pydicom
  Downloading pydicom-3.0.1-py3-none-any.whl.metadata (9.4 kB)
Downloading pydicom-3.0.1-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pydicom
Successfully installed pydicom-3.0.1


In [3]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
base_path = '/content/drive/MyDrive/mimic-cxr-project'
!mkdir -p {base_path}/data
!mkdir -p {base_path}/output

Mounted at /content/drive


In [4]:
# Import the report parser module
import sys
sys.path.append(f"{base_path}/modules")
from report_parser import parse_report, MIMIC_RE
print("Successfully imported report parser module")

Successfully imported report parser module


In [5]:
# Load train and test data
data_dir = os.path.join(base_path, 'data')
files_path = os.path.join(base_path, 'files')
output_dir = os.path.join(base_path, 'output')

train_df = pd.read_csv(os.path.join(data_dir, 'train.tsv'), sep='\t')
test_df = pd.read_csv(os.path.join(data_dir, 'test.tsv'), sep='\t')

print(f"Train data shape: {train_df.shape}")
print(f"Test data shape: {test_df.shape}")

Train data shape: (2243, 3)
Test data shape: (871, 3)


In [6]:
# Load the top 100 neighbors for each test image
neighbors_path = os.path.join(output_dir, 'top100_neighbors.pkl')

if os.path.exists(neighbors_path):
    with open(neighbors_path, 'rb') as f:
        neighbors = pickle.load(f)

    print(f"Loaded neighbors for {len(neighbors)} test images")
    print(f"Sample neighbors for first test image: {list(neighbors.items())[0][1][:5]}...")
else:
    print(f"Warning: Neighbors file not found at {neighbors_path}")
    print("Please run the KNN model first to generate the neighbors file.")
    neighbors = {}
    for dicom_id in test_df.dicom_id.values:
        neighbors[dicom_id] = random.sample(train_df.dicom_id.tolist(), min(100, len(train_df)))

Loaded neighbors for 380 test images
Sample neighbors for first test image: ['15a0e62a-ac8edf75-3444949e-35cf275c-b22ee616', '1d00cfa1-29d99da7-c62126a2-18449dbb-6dd404f0', 'fd228853-2df84977-18361e36-f22ccc25-7d9a4046', '350acbc7-85b7cb9f-030eeac1-1e4ff930-a29191a1', '76bdc346-c4561bc4-c75ab157-4fdcde0e-843596ec']...


In [7]:
# Map each dicom to its study_id
report_id_column = 'study_id'
report_lookup = dict(train_df[['dicom_id', report_id_column]].values)
print(f"Created lookup dictionary for {len(report_lookup)} training images")

Created lookup dictionary for 2243 training images


In [8]:
# Define tokens for sequence boundaries
START = ''
END = ''

# Function to parse reports for a list of dicom_ids
def get_report_tokens(dicom_ids):
    """Get tokens from reports for a list of dicom_ids"""
    all_reports = []

    for dicom_id in dicom_ids:
        if dicom_id not in report_lookup:
            continue

        report_id = report_lookup[dicom_id]

        # Get corresponding subject_id
        subject_row = train_df[train_df.dicom_id == dicom_id]
        if len(subject_row) == 0:
            continue

        subject_id = subject_row.iloc[0]['subject_id']

        # Construct path to the report
        subject_prefix = f"p{str(subject_id)[:2]}"
        subject_dir = f"p{subject_id}"
        study_dir = f"s{report_id}"
        report_file = f"{study_dir}.txt"
        report_path = os.path.join(files_path, subject_prefix, subject_dir, report_file)

        # Parse the report
        try:
            if os.path.exists(report_path):
                parsed_report = parse_report(report_path)
                if 'findings' in parsed_report:
                    # Tokenize the findings text
                    toks = parsed_report['findings'].replace('.', ' . ').split()
                    all_reports.append(toks)
        except Exception as e:
            continue

    return all_reports

In [9]:
# 1-gram model implementation
def build_1gram_model(dicom_ids):
    """Build a 1-gram language model from the reports of the given dicom_ids"""
    word_counts = Counter()

    reports = get_report_tokens(dicom_ids)
    for tokens in reports:
        word_counts.update(tokens)
        # Add END token for each report
        word_counts[END] += 1

    return word_counts

def generate_with_1gram(dicom_id):
    """Generate a report using a 1-gram model"""
    if dicom_id not in neighbors:
        return ""

    # Build model from neighbors
    nn = neighbors[dicom_id]
    word_counts = build_1gram_model(nn)

    if not word_counts:
        return ""

    # Generate report
    generated_toks = []
    current = ""

    while current != END and len(generated_toks) < 100:
        words, counts = zip(*word_counts.items())
        total = sum(counts)
        P = np.array(counts) / total

        next_word = np.random.choice(words, p=P)
        if next_word != END:
            generated_toks.append(next_word)
        current = next_word

    return ' '.join(generated_toks)

In [10]:
# 2-gram model implementation
def build_2gram_model(dicom_ids):
    """Build a 2-gram language model from the reports of the given dicom_ids"""
    LM = defaultdict(Counter)

    reports = get_report_tokens(dicom_ids)
    for tokens in reports:
        # Add START and END tokens
        padded_toks = [START] + tokens + [END]

        # Build 2-gram model
        for i in range(len(padded_toks) - 1):
            context = padded_toks[i]
            target = padded_toks[i+1]
            LM[context][target] += 1

    return LM

def generate_with_2gram(dicom_id):
    """Generate a report using a 2-gram model"""
    if dicom_id not in neighbors:
        return ""

    # Build model from neighbors
    nn = neighbors[dicom_id]
    LM = build_2gram_model(nn)

    if not LM:
        return ""

    # Generate report
    generated_toks = [START]
    current = generated_toks[-1]

    while current != END and len(generated_toks) < 100:
        if current not in LM or not LM[current]:
            break

        words, counts = zip(*LM[current].items())
        total = sum(counts)
        P = np.array(counts) / total

        next_word = np.random.choice(words, p=P)
        generated_toks.append(next_word)
        current = next_word

    # Remove START and END tokens
    generated_toks = [t for t in generated_toks if t != START]
    if generated_toks and generated_toks[-1] == END:
        generated_toks = generated_toks[:-1]

    return ' '.join(generated_toks)

In [11]:
# 3-gram model implementation
def build_3gram_model(dicom_ids):
    """Build a 3-gram language model from the reports of the given dicom_ids"""
    LM = defaultdict(Counter)

    reports = get_report_tokens(dicom_ids)
    for tokens in reports:
        # Add START and END tokens
        padded_toks = [START, START] + tokens + [END]

        # Build 3-gram model
        for i in range(len(padded_toks) - 2):
            context = (padded_toks[i], padded_toks[i+1])
            target = padded_toks[i+2]
            LM[context][target] += 1

    return LM

def generate_with_3gram(dicom_id):
    """Generate a report using a 3-gram model"""
    if dicom_id not in neighbors:
        return ""

    # Build model from neighbors
    nn = neighbors[dicom_id]
    LM = build_3gram_model(nn)

    if not LM:
        return ""

    # Generate report
    generated_toks = [START, START]
    current = generated_toks[-1]

    while current != END and len(generated_toks) < 100:
        context = (generated_toks[-2], generated_toks[-1])
        if context not in LM or not LM[context]:
            break

        words, counts = zip(*LM[context].items())
        total = sum(counts)
        P = np.array(counts) / total

        next_word = np.random.choice(words, p=P)
        generated_toks.append(next_word)
        current = next_word

    # Remove START and END tokens
    generated_toks = [t for t in generated_toks if t != START]
    if generated_toks and generated_toks[-1] == END:
        generated_toks = generated_toks[:-1]

    return ' '.join(generated_toks)

In [12]:
# Generate and save reports for a specific n-gram model
def run_ngram_generation(n):
    """Run n-gram generation and save results"""
    if n not in [1, 2, 3]:
        print(f"Error: n must be 1, 2, or 3, got {n}")
        return

    print(f"Generating reports with {n}-gram model...")

    if n == 1:
        generator_func = generate_with_1gram
    elif n == 2:
        generator_func = generate_with_2gram
    else:  # n == 3
        generator_func = generate_with_3gram

    # Generate reports for all test images
    generated_reports = {}
    for pred_dicom in tqdm.tqdm(test_df.dicom_id.values):
        generated_text = generator_func(pred_dicom)
        if generated_text:
            generated_reports[pred_dicom] = generated_text

    print(f"Generated reports for {len(generated_reports)} test images")

    # Save the generated reports
    pred_file = os.path.join(output_dir, f'{n}-gram.tsv')
    print(f"Saving predictions to {pred_file}")

    with open(pred_file, 'w') as f:
        print('dicom_id\tgenerated', file=f)
        for dicom_id, generated in sorted(generated_reports.items()):
            cleaned_text = generated.replace('\t', ' ')
            print(f'{dicom_id}\t{cleaned_text}', file=f)

    # Display sample reports
    sample_count = min(3, len(generated_reports))
    sample_dicoms = list(generated_reports.keys())[:sample_count]

    for dicom_id in sample_dicoms:
        print(f"\nSample report for {dicom_id}:")
        report_text = generated_reports[dicom_id]

        if len(report_text) > 200:
            print(report_text[:200] + "...")
        else:
            print(report_text)

    return generated_reports

In [13]:
# 1-gram model
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))
print("Starting 1-gram generation")
run_ngram_generation(1)
print("Finished 1-gram generation")
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

# 2-gram model
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))
print("Starting 2-gram generation")
run_ngram_generation(2)
print("Finished 2-gram generation")
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

# 3-gram model
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))
print("Starting 3-gram generation")
run_ngram_generation(3)
print("Finished 3-gram generation")
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

2025-04-17 00:06:27
Starting 1-gram generation
Generating reports with 1-gram model...


100%|██████████| 871/871 [07:29<00:00,  1.94it/s]


Generated reports for 0 test images
Saving predictions to /content/drive/MyDrive/mimic-cxr-project/output/1-gram.tsv
Finished 1-gram generation
2025-04-17 00:13:56
2025-04-17 00:13:56
Starting 2-gram generation
Generating reports with 2-gram model...


100%|██████████| 871/871 [00:24<00:00, 35.74it/s]


Generated reports for 0 test images
Saving predictions to /content/drive/MyDrive/mimic-cxr-project/output/2-gram.tsv
Finished 2-gram generation
2025-04-17 00:14:20
2025-04-17 00:14:20
Starting 3-gram generation
Generating reports with 3-gram model...


100%|██████████| 871/871 [00:27<00:00, 32.13it/s]


Generated reports for 0 test images
Saving predictions to /content/drive/MyDrive/mimic-cxr-project/output/3-gram.tsv
Finished 3-gram generation
2025-04-17 00:14:48
