<a href="https://colab.research.google.com/github/jasleenkaursandhu/Reproducing-chest-xray-report-generation-boag/blob/main/knn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [68]:
# 1-NN Model for Chest X-ray Report Generation
# This notebook uses the single nearest neighbor (1-NN) approach to generate reports for chest X-rays

import numpy as np
import pandas as pd
import os
import tqdm
import pickle
from time import gmtime, strftime
from sklearn.metrics.pairwise import cosine_similarity

In [69]:
# Set up paths
base_path = '/Users/simeon/Documents/DLH/content/mimic-cxr-project'
data_dir = os.path.join(base_path, 'data')
files_path = os.path.join(base_path, 'new_files')
output_dir = os.path.join(base_path, 'output')
features_dir = os.path.join(base_path, 'features')
reports_dir = os.path.join(base_path, 'reports')

os.makedirs(output_dir, exist_ok=True)

# Import the report parser module
import sys
# sys.path.append(f"{base_path}/modules")
from report_parser import parse_report
print("Successfully imported report parser module")

Successfully imported report parser module


In [70]:
# Load train and test data
train_df = pd.read_csv(os.path.join(data_dir, 'train.tsv'), sep='\t')
test_df = pd.read_csv(os.path.join(data_dir, 'test.tsv'), sep='\t')

print(f"Train data shape: {train_df.shape}")
print(f"Test data shape: {test_df.shape}")

# Map each dicom to its report identifier (study_id in our case)
report_id_column = 'study_id'
if report_id_column in train_df.columns:
    rad_lookup = dict(train_df[['dicom_id', report_id_column]].values)
    print(f"Created lookup using {report_id_column}")
else:
    print(f"Warning: {report_id_column} not found in columns: {train_df.columns.tolist()}")
    rad_lookup = {}

print("Sample of lookup dictionary:")
print(dict(list(rad_lookup.items())[:5]))

Train data shape: (4291, 3)
Test data shape: (1757, 3)
Created lookup using study_id
Sample of lookup dictionary:
{'7e95cd84-a6e61229-709150ad-10e6ad91-b535ad52': 52201331, '42401e7d-fae7b2ef-87642157-68beaada-014bfcc9': 54545361, '1ef3083e-7ed9110c-e9df3d65-480e18a8-9181ebde': 56347818, '5000f8fd-684ea279-a1e1308e-cfce9b0c-e1eeae50': 58559853, 'f349a7ef-ee518ad2-d5173f92-cbfa71b2-df530a25': 52888009}


In [71]:
# Path to 1-NN neighbors file
neighbors_path = os.path.join(output_dir, '1nn_neighbors.pkl')

# Check if 1-NN neighbors file exists
if os.path.exists(neighbors_path):
    print(f"Loading 1-NN neighbors from {neighbors_path}")
    with open(neighbors_path, 'rb') as f:
        neighbors = pickle.load(f)
    print(f"Loaded 1-NN neighbors for {len(neighbors)} test images")
else:
    print(f"1-NN neighbors file not found at {neighbors_path}")
    print("We need to compute the 1-NN neighbors from the extracted features")

    # Load the feature vectors
    train_features_path = os.path.join(features_dir, 'densenet121_train.pkl')
    test_features_path = os.path.join(features_dir, 'densenet121_test.pkl')

    if os.path.exists(train_features_path) and os.path.exists(test_features_path):
        print("Loading DenseNet features")
        with open(train_features_path, 'rb') as f:
            train_features_dict = pickle.load(f)
        with open(test_features_path, 'rb') as f:
            test_features_dict = pickle.load(f)

        # Compute 1-NN
        print("Computing 1-NN for each test image")
        neighbors = {}

        # Convert to arrays for faster processing
        train_dicom_ids = list(train_features_dict.keys())
        train_features_array = np.array([train_features_dict[dicom_id] for dicom_id in train_dicom_ids])

        for test_dicom, test_feature in tqdm.tqdm(test_features_dict.items()):
            # Calculate cosine similarity with all training images
            similarities = cosine_similarity([test_feature], train_features_array)[0]

            # Get the most similar image (1-NN)
            most_similar_idx = np.argmax(similarities)
            most_similar_dicom = train_dicom_ids[most_similar_idx]

            # Store as a list with a single element to maintain compatibility
            neighbors[test_dicom] = [most_similar_dicom]

        # Save the 1-NN neighbors
        with open(neighbors_path, 'wb') as f:
            pickle.dump(neighbors, f)
        print(f"Computed and saved 1-NN neighbors for {len(neighbors)} test images")
    else:
        print("Feature files not found. Please run the feature extraction notebook first.")
        neighbors = {}

# Generate reports for each test image using 1-NN approach
generated_reports = {}

Loading 1-NN neighbors from /Users/simeon/Documents/DLH/content/mimic-cxr-project/output/1nn_neighbors.pkl
Loaded 1-NN neighbors for 1757 test images


In [72]:
for pred_dicom in tqdm.tqdm(test_df.dicom_id, desc="Generating reports using 1-NN"):
    # Skip if we don't have neighbors for this test image
    if pred_dicom not in neighbors:
        print(f"Warning: No neighbors for {pred_dicom}")
        continue

    # Get the single nearest neighbor (1-NN approach)
    # As per the paper: "For this baseline, we 'generate' our text by returning the caption of
    # the training image with the largest cosine similarity to the test query image."
    nearest_dicom = neighbors[pred_dicom][0]

    # Skip if we don't have a report ID for this training image
    if nearest_dicom not in rad_lookup:
        print(f"Warning: No report ID for nearest neighbor {nearest_dicom}")
        continue

    nearest_train_report_id = rad_lookup[nearest_dicom]

    # Get corresponding subject_id for the training image
    subject_row = train_df[train_df.dicom_id == nearest_dicom]
    if len(subject_row) == 0:
        print(f"Warning: Cannot find subject_id for {nearest_dicom}")
        continue

    subject_id = subject_row.iloc[0]['subject_id']

    # Construct path to the report using the correct structure
    subject_prefix = f"p{str(subject_id)[:2]}"
    subject_dir = f"p{subject_id}"
    study_dir = f"s{nearest_train_report_id}"

    # Use the correct path to the report in the reports directory
    report_path = os.path.join(reports_dir, 'files', subject_prefix, subject_dir, f"{study_dir}.txt")

    # Parse the report
    try:
        if os.path.exists(report_path):
            report = parse_report(report_path)

            # If the report has a findings section, use it
            if 'findings' in report:
                generated_reports[pred_dicom] = report['findings']
            else:
                print(f"Warning: No findings section in report for {nearest_dicom}")
        else:
            print(f"Warning: Report file not found at {report_path}")
    except Exception as e:
        print(f"Error parsing report for {nearest_dicom}: {e}")

print(f"Generated reports for {len(generated_reports)}/{len(test_df)} test images")

# Save the generated reports to a TSV file
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

pred_file = os.path.join(output_dir, '1nn_reports.tsv')
print(f"Saving 1-NN generated reports to {pred_file}")

with open(pred_file, 'w') as f:
    print('dicom_id\tgenerated', file=f)
    for dicom_id, generated in sorted(generated_reports.items()):
        # Clean up the text (remove any tabs)
        cleaned_text = generated.replace('\t', ' ')
        print(f'{dicom_id}\t{cleaned_text}', file=f)

print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

# Display sample of generated reports
sample_count = min(3, len(generated_reports))
sample_dicoms = list(generated_reports.keys())[:sample_count]

for dicom_id in sample_dicoms:
    print(f"\nSample report for {dicom_id}:")
    report_text = generated_reports[dicom_id]

    # Print preview of the report (first 200 characters)
    if len(report_text) > 200:
        print(report_text[:200] + "...")
    else:
        print(report_text)

Generating reports using 1-NN: 100%|██████████| 1757/1757 [00:00<00:00, 2762.53it/s]

Generated reports for 1757/1757 test images
2025-04-25 17:13:40
Saving 1-NN generated reports to /Users/simeon/Documents/DLH/content/mimic-cxr-project/output/1nn_reports.tsv
2025-04-25 17:13:40

Sample report for 20386a2d-1f7a8868-f12e22ac-0d625d27-4c38c8e2:
cardiac size is normal. peribronchial opacities in the left perihilar region have minimally increased. there is no pneumothorax or pleural effusion.

Sample report for 63100eab-9e8a8d90-392bc822-325de482-69a64e3b:
the patient is intubated. the endotracheal tube terminates about 6 cm above the carina. an orogastric tube terminates in the stomach although the sidehole indicator projects over the distal esophagus....

Sample report for 17269efa-b016a94d-1361e8df-ac428071-d1133672:
cardiac size is normal. the lungs are clear. there is no pneumothorax or pleural effusion.



