<a href="https://colab.research.google.com/github/jasleenkaursandhu/Reproducing-chest-xray-report-generation-boag/blob/main/knn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# KNN Model for Feature Extraction and Nearest Neighbor Analysis
# This notebook implements a KNN-based approach for chest X-ray report generation.
# It extracts DenseNet features from images and finds the most similar training images for each test image.

In [2]:
# Import necessary libraries
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import tqdm
from collections import defaultdict
import pickle
!pip install pydicom
import pydicom  # Note: Original used 'dicom' but we'll use 'pydicom' which is the current version
from time import gmtime, strftime

Collecting pydicom
  Downloading pydicom-3.0.1-py3-none-any.whl.metadata (9.4 kB)
Downloading pydicom-3.0.1-py3-none-any.whl (2.4 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.4 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━[0m [32m1.5/2.4 MB[0m [31m44.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m37.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pydicom
Successfully installed pydicom-3.0.1


In [3]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
base_path = '/content/drive/MyDrive/mimic-cxr-project'
!mkdir -p {base_path}/data
!mkdir -p {base_path}/output

Mounted at /content/drive


In [4]:
# Import the report parser module
import sys
sys.path.append(f"{base_path}/modules")
from report_parser import parse_report, MIMIC_RE
print("Successfully imported report parser module")

Successfully imported report parser module


In [5]:
# Load train and test data
data_dir = os.path.join(base_path, 'data')
files_path = os.path.join(base_path, 'files')
output_dir = os.path.join(base_path, 'output')
features_dir = os.path.join(base_path, 'features')

train_df = pd.read_csv(os.path.join(data_dir, 'train.tsv'), sep='\t')
test_df = pd.read_csv(os.path.join(data_dir, 'test.tsv'), sep='\t')

print(f"Train data shape: {train_df.shape}")
print(f"Test data shape: {test_df.shape}")

Train data shape: (2243, 3)
Test data shape: (871, 3)


In [6]:
# Map each dicom to its report identifier (study_id in our case)
# Note: Original code used 'rad_id', we'll use 'study_id' based on our data structure
report_id_column = 'study_id'
if report_id_column in train_df.columns:
    rad_lookup = dict(train_df[['dicom_id', report_id_column]].values)
    print(f"Created lookup using {report_id_column}")
else:
    print(f"Warning: {report_id_column} not found in columns: {train_df.columns.tolist()}")
    rad_lookup = {}

print("Sample of lookup dictionary:")
print(dict(list(rad_lookup.items())[:5]))

Created lookup using study_id
Sample of lookup dictionary:
{'f214e3d8-00d5e538-62592a72-37cf2660-0ec426a1': 58033516, 'c6f2ef94-1c46417f-1cdb50d7-bbae3a47-5dd45401': 57744330, 'f7a4e18f-004ac053-81024d1b-568dbb86-b20f308b': 50301279, '500bbbdf-966e91a0-474e045e-81e494ac-7c6124f7': 57457041, 'c0eb8f9c-b404b698-4b47abf9-cea216fd-27bea26f': 59435834}


In [8]:
# Path to neighbors file
neighbors_path = os.path.join(base_path, 'local_output/output', '1nn_neighbors.pkl')

# Check if neighbors file exists
if os.path.exists(neighbors_path):
    print(f"Loading neighbors from {neighbors_path}")
    with open(neighbors_path, 'rb') as f:
        neighbors = pickle.load(f)
    print(f"Loaded neighbors for {len(neighbors)} test images")
else:
    print(f"Neighbors file not found at {neighbors_path}")
    print("Please run the feature extraction and nearest neighbor search first.")

Loading neighbors from /content/drive/MyDrive/mimic-cxr-project/local_output/output/1nn_neighbors.pkl
Loaded neighbors for 1757 test images


In [9]:
# Define the path to the reports directory
files_path = os.path.join(base_path, 'files')

# Generate reports for each test image
generated_reports = {}

for pred_dicom in tqdm.tqdm(test_df.dicom_id):
    # Skip if we don't have neighbors for this test image
    if pred_dicom not in neighbors:
        print(f"Warning: No neighbors for {pred_dicom}")
        continue

    nn = neighbors[pred_dicom]

    found = False
    i = 0  # Start with the closest neighbor

    while not found and i < len(nn):
        nearest_dicom = nn[i]

        # Skip if we don't have a report ID for this training image
        if nearest_dicom not in rad_lookup:
            i += 1
            continue

        nearest_train_report_id = rad_lookup[nearest_dicom]

        # Get corresponding subject_id for the training image
        subject_row = train_df[train_df.dicom_id == nearest_dicom]
        if len(subject_row) == 0:
            i += 1
            continue

        subject_id = subject_row.iloc[0]['subject_id']

        # Construct path to the report
        subject_prefix = f"p{str(subject_id)[:2]}"
        subject_dir = f"p{subject_id}"
        study_dir = f"s{nearest_train_report_id}"
        report_file = f"{study_dir}.txt"
        report_path = os.path.join(files_path, subject_prefix, subject_dir, report_file)

        # Parse the report
        try:
            if os.path.exists(report_path):
                report = parse_report(report_path)

                # If the report has a findings section, use it
                if 'findings' in report:
                    found = True
                    generated_reports[pred_dicom] = report['findings']
        except Exception as e:
            pass

        i += 1

    if not found:
        print(f"Warning: Could not find a valid report for {pred_dicom}")

print(f"Generated reports for {len(generated_reports)}/{len(test_df)} test images")

  0%|          | 0/871 [00:00<?, ?it/s]



  1%|▏         | 12/871 [00:54<1:16:30,  5.34s/it]



  2%|▏         | 19/871 [01:21<1:04:37,  4.55s/it]



  3%|▎         | 29/871 [01:46<54:54,  3.91s/it]



  4%|▍         | 37/871 [02:10<47:26,  3.41s/it]



  5%|▌         | 45/871 [02:58<1:07:52,  4.93s/it]



  6%|▌         | 52/871 [03:23<57:26,  4.21s/it]  



  8%|▊         | 68/871 [03:47<33:23,  2.50s/it]



  8%|▊         | 70/871 [03:48<30:32,  2.29s/it]



  9%|▉         | 80/871 [03:49<18:16,  1.39s/it]



  9%|▉         | 82/871 [03:50<16:28,  1.25s/it]



 10%|▉         | 86/871 [03:51<12:20,  1.06it/s]



 11%|█         | 96/871 [04:12<19:39,  1.52s/it]



 13%|█▎        | 112/871 [04:13<09:13,  1.37it/s]



 14%|█▍        | 124/871 [04:13<06:05,  2.04it/s]



 16%|█▌        | 138/871 [04:15<04:03,  3.01it/s]



 16%|█▌        | 141/871 [04:15<04:01,  3.03it/s]



 18%|█▊        | 158/871 [04:17<02:29,  4.77it/s]



 19%|█▉        | 165/871 [04:18<02:18,  5.12it/s]



 20%|█▉        | 172/871 [04:18<02:05,  5.57it/s]



 21%|██        | 181/871 [04:19<01:46,  6.46it/s]



 21%|██▏       | 186/871 [04:20<01:48,  6.31it/s]



 22%|██▏       | 194/871 [04:21<01:36,  6.99it/s]



 23%|██▎       | 200/871 [04:22<01:34,  7.11it/s]



 24%|██▍       | 207/871 [04:23<01:32,  7.18it/s]



 25%|██▍       | 217/871 [04:24<01:19,  8.26it/s]



 25%|██▌       | 220/871 [04:24<01:22,  7.87it/s]



 26%|██▌       | 225/871 [04:25<01:28,  7.29it/s]



 26%|██▌       | 227/871 [04:26<01:54,  5.61it/s]



 26%|██▋       | 229/871 [04:27<02:21,  4.53it/s]



 27%|██▋       | 234/871 [04:28<02:40,  3.96it/s]



 29%|██▉       | 252/871 [04:29<01:04,  9.61it/s]



 32%|███▏      | 283/871 [04:30<00:33, 17.70it/s]



 33%|███▎      | 291/871 [04:31<00:37, 15.38it/s]



 35%|███▌      | 305/871 [04:33<00:54, 10.36it/s]



 37%|███▋      | 325/871 [04:35<00:52, 10.41it/s]



 38%|███▊      | 328/871 [04:35<01:06,  8.13it/s]



 38%|███▊      | 335/871 [04:37<01:37,  5.47it/s]



 40%|████      | 351/871 [04:39<01:18,  6.62it/s]



 41%|████      | 354/871 [04:40<01:28,  5.85it/s]



 42%|████▏     | 365/871 [04:41<01:01,  8.21it/s]



 42%|████▏     | 367/871 [04:42<01:19,  6.33it/s]



 44%|████▍     | 384/871 [04:43<00:48, 10.05it/s]



 45%|████▌     | 394/871 [04:43<00:46, 10.36it/s]



 48%|████▊     | 414/871 [04:44<00:32, 13.95it/s]



 49%|████▉     | 427/871 [04:45<00:34, 13.05it/s]



 49%|████▉     | 429/871 [04:46<00:43, 10.06it/s]



 49%|████▉     | 431/871 [04:47<00:58,  7.49it/s]



 50%|████▉     | 433/871 [04:48<01:11,  6.15it/s]



 52%|█████▏    | 456/871 [04:49<00:30, 13.52it/s]



 55%|█████▍    | 479/871 [04:50<00:22, 17.76it/s]



 56%|█████▋    | 490/871 [04:50<00:22, 16.64it/s]



 58%|█████▊    | 504/871 [04:52<00:32, 11.35it/s]



 58%|█████▊    | 506/871 [04:53<00:37,  9.82it/s]



 58%|█████▊    | 508/871 [04:53<00:49,  7.35it/s]



 59%|█████▉    | 518/871 [04:54<00:35, 10.03it/s]



 63%|██████▎   | 546/871 [04:55<00:18, 17.77it/s]



 63%|██████▎   | 548/871 [04:56<00:25, 12.63it/s]



 63%|██████▎   | 550/871 [04:57<00:34,  9.26it/s]



 64%|██████▍   | 561/871 [04:57<00:26, 11.76it/s]



 69%|██████▊   | 598/871 [04:58<00:10, 25.87it/s]



 69%|██████▉   | 605/871 [05:00<00:25, 10.32it/s]



 70%|███████   | 610/871 [05:01<00:30,  8.63it/s]



 71%|███████   | 616/871 [05:03<00:44,  5.78it/s]



 73%|███████▎  | 635/871 [05:04<00:22, 10.51it/s]



 73%|███████▎  | 639/871 [05:05<00:25,  8.95it/s]



 75%|███████▌  | 654/871 [05:06<00:19, 11.38it/s]



 77%|███████▋  | 673/871 [05:07<00:13, 14.98it/s]



 78%|███████▊  | 677/871 [05:07<00:15, 12.24it/s]



 80%|███████▉  | 694/871 [05:08<00:10, 17.15it/s]



 84%|████████▎ | 729/871 [05:09<00:05, 24.86it/s]



 84%|████████▍ | 732/871 [05:10<00:10, 13.38it/s]



 86%|████████▌ | 750/871 [05:11<00:07, 15.54it/s]



 88%|████████▊ | 767/871 [05:12<00:05, 18.54it/s]



 88%|████████▊ | 770/871 [05:12<00:06, 16.43it/s]



 89%|████████▊ | 772/871 [05:13<00:08, 11.12it/s]



 89%|████████▉ | 778/871 [05:15<00:12,  7.31it/s]



 90%|█████████ | 788/871 [05:15<00:08, 10.00it/s]



 92%|█████████▏| 799/871 [05:16<00:06, 11.48it/s]



 95%|█████████▌| 829/871 [05:17<00:02, 18.48it/s]



 96%|█████████▌| 835/871 [05:18<00:02, 14.43it/s]



 97%|█████████▋| 842/871 [05:19<00:02, 12.64it/s]



 98%|█████████▊| 854/871 [05:20<00:01, 13.18it/s]



100%|██████████| 871/871 [05:21<00:00,  2.71it/s]

Generated reports for 107/871 test images





In [10]:
# Save the generated reports to a TSV file
print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

pred_dir = os.path.join(base_path, 'local_output/output')
os.makedirs(pred_dir, exist_ok=True)

pred_file = os.path.join(pred_dir, 'knn.tsv')
print(f"Saving predictions to {pred_file}")

with open(pred_file, 'w') as f:
    print('dicom_id\tgenerated', file=f)
    for dicom_id, generated in sorted(generated_reports.items()):
        # Clean up the text (remove any tabs)
        cleaned_text = generated.replace('\t', ' ')
        print(f'{dicom_id}\t{cleaned_text}', file=f)

print(strftime("%Y-%m-%d %H:%M:%S", gmtime()))

2025-04-21 01:11:57
Saving predictions to /content/drive/MyDrive/mimic-cxr-project/local_output/output/knn.tsv
2025-04-21 01:11:57


In [11]:
# Display sample of generated reports
sample_count = min(3, len(generated_reports))
sample_dicoms = list(generated_reports.keys())[:sample_count]

for dicom_id in sample_dicoms:
    print(f"\nSample report for {dicom_id}:")
    report_text = generated_reports[dicom_id]

    # Print preview of the report (first 200 characters)
    if len(report_text) > 200:
        print(report_text[:200] + "...")
    else:
        print(report_text)


Sample report for 6497aa38-6297f53a-fa4befd4-ba1a4958-64aedec6:
tracheostomy tube again noted. ng tube again noted, extending beneath the diaphragm to overlie the stomach. right subclavian picc line tip lies near the svc/ ra junction, similar to prior. cardiomedia...

Sample report for 814416f0-c76e5d61-9a01f44f-4d5ebd41-87097577:
the et tube is no longer visualized. spinal fixation device and ng tube are present. there is complete opacification of the right lower lobe compatible with a combination of infiltrate and volume loss...

Sample report for 001bb54b-a4e0bb99-48a28f4c-9df85f1b-e1606587:
left picc tip terminates in the lower svc, unchanged. right-sided dual-lumen central venous catheter terminates in the right atrium, unchanged. tracheostomy tube tip remains in stable position. the pa...
