# X-ray Classification with TorchXRayVision

This notebook implements zero-shot classification on chest X-ray images using pretrained models from TorchXRayVision.

## Setup and Imports

In [18]:
# Colab-specific setup (uncomment if using Colab)
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/Shareddrives/CS231N/chestxray-classification

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/Shareddrives/CS231N/chestxray-classification


In [19]:
# Install dependencies (uncomment if needed)
!pip install torchxrayvision python-dotenv scikit-learn matplotlib tqdm --quiet

In [20]:
# Standard imports
import os
import sys
import logging
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
import torch
import torchxrayvision as xrv
from PIL import Image



In [None]:
# Add the parent directory to path (if running in notebooks/)
import shutil

src = "/content/drive/Shareddrives/CS231N/chestxray-classification"
dst = "/content/chestxray-classification"

shutil.copytree(
    src,
    dst,
    ignore=shutil.ignore_patterns(
        "assignment4", ".git", "__pycache__", "*.ipynb_checkpoints"
    ),
    dirs_exist_ok=True
)
sys.path.insert(0, "/content/chestxray-classification")


In [None]:
# Import custom modules
from utils.config import Config
from utils.data_utils import load_and_prepare_data, get_test_set
from models.inference import run_inference
from models.xray_models_load import load_models
from evaluation.metrics import evaluate_predictions
from evaluation.visualizations import plot_results, plot_roc_curves, plot_pr_curves

In [None]:
# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    stream=sys.stdout,      # Redirect to notebook output
    force=True
)
logger = logging.getLogger("xray_evaluation")

## Configuration

In [None]:
# Check if paths are set correctly
print(f"Data path: {Config.data_path}")
print(f"Image path: {Config.image_path}")
print(f"CSV path: {Config.csv_path}")
print(f"Output path: {Config.output_path}")
print(f"Using device: {Config.device}")
print(f"Target labels: {Config.target_labels}")

Data path: /content/drive/Shareddrives/CS231N/assignment4/cs231n/datasets/nih-chestxray
Image path: /content/drive/Shareddrives/CS231N/assignment4/cs231n/datasets/nih-chestxray/images
CSV path: /content/drive/Shareddrives/CS231N/assignment4/cs231n/datasets/nih-chestxray/Data_Entry_2017_v2020.csv
Output path: results
Using device: cpu
Target labels: ['Cardiomegaly', 'Atelectasis', 'Effusion', 'Pneumothorax']


In [None]:
# Update paths if needed
# Config.data_path = "/path/to/data"
# Config.image_path = os.path.join(Config.data_path, "images")
# Config.csv_path = os.path.join(Config.data_path, "Data_Entry_2017_v2020.csv")

# Create output directory
os.makedirs(Config.output_path, exist_ok=True)

## Load Models

In [None]:
# List available models
print("Available pretrained models:")
for model_name in xrv.models.model_urls.keys():
    print(f"- {model_name}")

Available pretrained models:
- all
- densenet121-res224-all
- nih
- densenet121-res224-nih
- pc
- densenet121-res224-pc
- chex
- densenet121-res224-chex
- rsna
- densenet121-res224-rsna
- mimic_nb
- densenet121-res224-mimic_nb
- mimic_ch
- densenet121-res224-mimic_ch
- resnet50-res512-all


In [None]:
# Load models
model_mimic, model_chex, pathologies = load_models(Config.device)

2025-05-20 17:44:45 - INFO - models.xray_models_load - Loading DenseNet121 pretrained on MIMIC-CXR...
2025-05-20 17:44:45 - INFO - models.xray_models_load - Loading DenseNet121 pretrained on CheXpert...
2025-05-20 17:44:46 - INFO - models.xray_models_load - Models loaded successfully with 18 disease classes


In [None]:
# Display pathologies
print("\nSupported pathologies:")
for i, pathology in enumerate(pathologies):
    print(f"{i}: {pathology}")

# Highlight target pathologies
print("\nTarget pathologies:")
for label in Config.target_labels:
    idx = pathologies.index(label)
    print(f"{idx}: {label}")


Supported pathologies:
0: Atelectasis
1: Consolidation
2: 
3: Pneumothorax
4: Edema
5: 
6: 
7: Effusion
8: Pneumonia
9: 
10: Cardiomegaly
11: 
12: 
13: 
14: Lung Lesion
15: Fracture
16: Lung Opacity
17: Enlarged Cardiomediastinum

Target pathologies:
10: Cardiomegaly
0: Atelectasis
7: Effusion
3: Pneumothorax


## Load and Prepare Data

In [27]:
# Load and prepare data
metadata_df, label_indices = load_and_prepare_data(Config, pathologies)

2025-05-20 17:50:53 - INFO - utils.data_utils - Loading metadata from /content/drive/Shareddrives/CS231N/assignment4/cs231n/datasets/nih-chestxray/Data_Entry_2017_v2020.csv
2025-05-20 17:50:54 - INFO - utils.data_utils - Raw metadata contains 112120 entries
2025-05-20 17:51:01 - INFO - utils.data_utils - Found 54999 PNG images in directory
2025-05-20 17:51:01 - INFO - utils.data_utils - Filtered to 54999 entries with available images


In [31]:
import os
from dotenv import load_dotenv

load_dotenv(".env")  # loads GH_TOKEN into os.environ

# Now safely retrieve the token
token = os.getenv("GH_TOKEN")
username = "your-username"
repo = "chestxray-classification"
push_url = f"https://{username}:{token}@github.com/{username}/{repo}.git"

In [None]:
# Load and prepare test_data set
positive_df = get_test_set(metadata_df,label_indices, Config)

2025-05-20 17:45:08 - INFO - utils.data_utils - Found 42511 cases negative for all target diseases


ValueError: Unable to coerce to Series, length must be 14: given 1

In [None]:
# Display dataset composition
print("\nDataset Composition:")
for label in Config.target_labels:
    count = test_df[label].sum()
    print(f"{label}: {count} positive samples")

# Display negative cases
negative_count = (test_df[Config.target_labels].sum(axis=1) == 0).sum()
print(f"No target diseases: {negative_count} samples")

In [None]:
# Visualize class distribution
plt.figure(figsize=(10, 5))
test_df[Config.target_labels].sum().plot(kind="bar", color="lightgreen")
plt.title("Positive Case Count per Disease")
plt.ylabel("Count")
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

## Visualize Sample Images

In [None]:
# Function to display sample X-ray images
def display_sample_images(test_df, image_path, num_samples=3):
    """Display sample X-ray images for each disease category"""
    fig, axes = plt.subplots(len(Config.target_labels) + 1, num_samples,
                            figsize=(num_samples * 4, (len(Config.target_labels) + 1) * 4))

    # For each disease, show samples
    for i, disease in enumerate(Config.target_labels):
        disease_samples = test_df[test_df[disease] == 1].sample(num_samples)

        for j, (_, row) in enumerate(disease_samples.iterrows()):
            img_file = os.path.join(image_path, row["Image Index"])
            img = Image.open(img_file).convert('L')
            axes[i, j].imshow(img, cmap='gray')
            axes[i, j].set_title(f"{disease}\n{row['Image Index']}")
            axes[i, j].axis('off')

    # Show negative samples (no target disease)
    negative_samples = test_df[(test_df[Config.target_labels] == 0).all(axis=1)].sample(num_samples)

    for j, (_, row) in enumerate(negative_samples.iterrows()):
        img_file = os.path.join(image_path, row["Image Index"])
        img = Image.open(img_file).convert('L')
        axes[-1, j].imshow(img, cmap='gray')
        axes[-1, j].set_title(f"No Target Disease\n{row['Image Index']}")
        axes[-1, j].axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# Display sample images
display_sample_images(test_df, Config.image_path, num_samples=3)

## Run Inference

In [None]:
# Run inference
y_true, y_pred = run_inference(test_df, model_mimic, model_chex, label_indices, Config)

## Evaluate Results

In [None]:
# Evaluate predictions
results_df = evaluate_predictions(y_true, y_pred, Config.target_labels, Config.output_path)

In [None]:
# Display results dataframe
results_df

## Visualize Results

In [None]:
# Plot summary results
plot_results(results_df, test_df, Config.target_labels, Config.output_path)

In [None]:
# Plot ROC curves
plot_roc_curves(y_true, y_pred, Config.target_labels, Config.output_path)

In [None]:
# Plot Precision-Recall curves
plot_pr_curves(y_true, y_pred, Config.target_labels, Config.output_path)

## Conclusion

Summary of findings and next steps.