<a href="https://colab.research.google.com/github/jasleenkaursandhu/Reproducing-chest-xray-report-generation-boag/blob/main/knn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
# KNN Model for Feature Extraction and Nearest Neighbor Analysis
# This notebook implements a KNN-based approach for chest X-ray report generation.
# It extracts DenseNet features from images and finds the most similar training images for each test image.

In [13]:
# Import necessary libraries
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import tqdm
from collections import defaultdict
import pickle
!pip install pydicom
import pydicom  # Note: Original used 'dicom' but we'll use 'pydicom' which is the current version
from time import gmtime, strftime



In [14]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
base_path = '/content/drive/MyDrive/mimic-cxr-project'
!mkdir -p {base_path}/data
!mkdir -p {base_path}/output

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [15]:
# Import the report parser module
import sys
sys.path.append(f"{base_path}/modules")
from report_parser import parse_report, MIMIC_RE
print("Successfully imported report parser module")

Successfully imported report parser module


In [18]:
# Load train and test data
data_dir = os.path.join(base_path, 'data')
files_path = os.path.join(base_path, 'files')
output_dir = os.path.join(base_path, 'output')
features_dir = os.path.join(base_path, 'features')

train_df = pd.read_csv(os.path.join(base_path, 'local_output/data/train.tsv'), sep='\t')
test_df = pd.read_csv(os.path.join(base_path, 'local_output/data/test.tsv'), sep='\t')

print(f"Train data shape: {train_df.shape}")
print(f"Test data shape: {test_df.shape}")

Train data shape: (4291, 3)
Test data shape: (1757, 3)


In [19]:
# Map each dicom to its report identifier (study_id in our case)
# Note: Original code used 'rad_id', we'll use 'study_id' based on our data structure
report_id_column = 'study_id'
if report_id_column in train_df.columns:
    rad_lookup = dict(train_df[['dicom_id', report_id_column]].values)
    print(f"Created lookup using {report_id_column}")
else:
    print(f"Warning: {report_id_column} not found in columns: {train_df.columns.tolist()}")
    rad_lookup = {}

print("Sample of lookup dictionary:")
print(dict(list(rad_lookup.items())[:5]))

Created lookup using study_id
Sample of lookup dictionary:
{'7e95cd84-a6e61229-709150ad-10e6ad91-b535ad52': 52201331, '42401e7d-fae7b2ef-87642157-68beaada-014bfcc9': 54545361, '1ef3083e-7ed9110c-e9df3d65-480e18a8-9181ebde': 56347818, '5000f8fd-684ea279-a1e1308e-cfce9b0c-e1eeae50': 58559853, 'f349a7ef-ee518ad2-d5173f92-cbfa71b2-df530a25': 52888009}


In [20]:
# Path to neighbors file
neighbors_path = os.path.join(base_path, 'local_output/output', '1nn_neighbors.pkl')

# Check if neighbors file exists
if os.path.exists(neighbors_path):
    print(f"Loading neighbors from {neighbors_path}")
    with open(neighbors_path, 'rb') as f:
        neighbors = pickle.load(f)
    print(f"Loaded neighbors for {len(neighbors)} test images")
else:
    print(f"Neighbors file not found at {neighbors_path}")
    print("Please run the feature extraction and nearest neighbor search first.")

Loading neighbors from /content/drive/MyDrive/mimic-cxr-project/local_output/output/1nn_neighbors.pkl
Loaded neighbors for 1757 test images


In [21]:
# Define the path to the reports directory
files_path = os.path.join(base_path, 'files')

# Generate reports for each test image
generated_reports = {}

for pred_dicom in tqdm.tqdm(test_df.dicom_id):
    # Skip if we don't have neighbors for this test image
    if pred_dicom not in neighbors:
        print(f"Warning: No neighbors for {pred_dicom}")
        continue

    nn = neighbors[pred_dicom]

    found = False
    i = 0  # Start with the closest neighbor

    while not found and i < len(nn):
        nearest_dicom = nn[i]

        # Skip if we don't have a report ID for this training image
        if nearest_dicom not in rad_lookup:
            i += 1
            continue

        nearest_train_report_id = rad_lookup[nearest_dicom]

        # Get corresponding subject_id for the training image
        subject_row = train_df[train_df.dicom_id == nearest_dicom]
        if len(subject_row) == 0:
            i += 1
            continue

        subject_id = subject_row.iloc[0]['subject_id']

        # Construct path to the report
        subject_prefix = f"p{str(subject_id)[:2]}"
        subject_dir = f"p{subject_id}"
        study_dir = f"s{nearest_train_report_id}"
        report_file = f"{study_dir}.txt"
        report_path = os.path.join(files_path, subject_prefix, subject_dir, report_file)

        # Parse the report
        try:
            if os.path.exists(report_path):
                report = parse_report(report_path)

                # If the report has a findings section, use it
                if 'findings' in report:
                    found = True
                    generated_reports[pred_dicom] = report['findings']
        except Exception as e:
            pass

        i += 1

    if not found:
        print(f"Warning: Could not find a valid report for {pred_dicom}")

print(f"Generated reports for {len(generated_reports)}/{len(test_df)} test images")

100%|██████████| 1757/1757 [09:11<00:00,  3.19it/s]

Generated reports for 1757/1757 test images





In [22]:
# Save the generated reports to a TSV file
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

pred_dir = os.path.join(base_path, 'local_output/output')
os.makedirs(pred_dir, exist_ok=True)

pred_file = os.path.join(pred_dir, 'knn.tsv')
print(f"Saving predictions to {pred_file}")

with open(pred_file, 'w') as f:
    print('dicom_id\tgenerated', file=f)
    for dicom_id, generated in sorted(generated_reports.items()):
        # Clean up the text (remove any tabs)
        cleaned_text = generated.replace('\t', ' ')
        print(f'{dicom_id}\t{cleaned_text}', file=f)

print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

2025-04-21 03:23:48
Saving predictions to /content/drive/MyDrive/mimic-cxr-project/local_output/output/knn.tsv
2025-04-21 03:23:48


In [23]:
# Display sample of generated reports
sample_count = min(3, len(generated_reports))
sample_dicoms = list(generated_reports.keys())[:sample_count]

for dicom_id in sample_dicoms:
    print(f"\nSample report for {dicom_id}:")
    report_text = generated_reports[dicom_id]

    # Print preview of the report (first 200 characters)
    if len(report_text) > 200:
        print(report_text[:200] + "...")
    else:
        print(report_text)


Sample report for 20386a2d-1f7a8868-f12e22ac-0d625d27-4c38c8e2:
cardiac size is normal. peribronchial opacities in the left perihilar region have minimally increased. there is no pneumothorax or pleural effusion.

Sample report for 63100eab-9e8a8d90-392bc822-325de482-69a64e3b:
the patient is intubated. the endotracheal tube terminates about 6 cm above the carina. an orogastric tube terminates in the stomach although the sidehole indicator projects over the distal esophagus....

Sample report for 17269efa-b016a94d-1361e8df-ac428071-d1133672:
cardiac size is normal. the lungs are clear. there is no pneumothorax or pleural effusion.
