In [1]:
!pip install torchxrayvision matplotlib numpy pandas ipywidgets scikit-learn scikit-image seaborn



[TorchXRayVision: A library of chest X-ray datasets and models](https://arxiv.org/pdf/2111.00595)

https://github.com/naitik2314/Chest-X-Ray-Medical-Diagnosis-with-Deep-Learning

## Imports

In [2]:
# General utilities
import os
import numpy as np
import pandas as pd
import tqdm
import urllib.request
import tarfile
from concurrent.futures import ThreadPoolExecutor

# PyTorch and data handling
import torch, torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Image processing
import skimage
import skimage.transform

# Machine learning metrics
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
    balanced_accuracy_score,
    precision_recall_curve,
    auc,
)

# TorchXRayVision library
import torchxrayvision as xrv

  from tqdm.autonotebook import tqdm


In [3]:
num_workers = max(1, os.cpu_count() - 2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

use_amp = True
scaler = torch.amp.GradScaler(enabled=use_amp)

In [4]:
batch_size = 64
num_samples = 1000
num_epochs = 10

## Paths

In [5]:
images_dir = "/mnt/b/Xray/images/"

path_dataset = "/mnt/b/Xray/images/images/" 
path_train_val_list = "/mnt/b/Xray/train_val_list.txt"
path_test_list = "/mnt/b/Xray/test_list.txt"

os.makedirs(images_dir, exist_ok=True)

path_models = "/mnt/b/Xray/models/"
checkpoint_dir = "/mnt/b/Xray/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

## Classes

## Functions

In [6]:
def download_file(link, folder, idx):
    """Downloads a file from a link to the specified folder."""
    file_name = f'images_{idx+1:03d}.tar.gz'
    file_path = os.path.join(folder, file_name)
    if os.path.exists(file_path):
        print(f"{file_name} already exists, skipping download.")
        return file_path
    try:
        print(f"Downloading {file_name}...")
        urllib.request.urlretrieve(link, file_path)
        print(f"{file_name} downloaded successfully.")
        return file_path
    except Exception as e:
        print(f"Failed to download {file_name}: {e}")
        return None

In [7]:
def extract_file(file_path, folder):
    """Extracts a .tar.gz file to the specified folder."""
    extracted_flag = file_path.replace('.tar.gz', '_extracted.flag')
    if os.path.exists(extracted_flag):
        print(f"{os.path.basename(file_path)} already extracted, skipping.")
        return
    try:
        print(f"Extracting {os.path.basename(file_path)}...")
        with tarfile.open(file_path, 'r:gz') as tar:
            tar.extractall(path=folder)
        with open(extracted_flag, 'w') as f:
            f.write('extracted')
        print(f"{os.path.basename(file_path)} extracted successfully.")
    except Exception as e:
        print(f"Failed to extract {os.path.basename(file_path)}: {e}")

In [8]:
def process_link(idx, link):
    """Handles downloading and extracting a single link."""
    file_path = download_file(link, images_dir, idx)
    if file_path:
        extract_file(file_path, images_dir)

In [9]:
def load_checkpoint(checkpoint_path, model, optimizer, scaler):
    """
    Load the model, optimizer, and scaler states from a checkpoint file.

    Args:
        checkpoint_path (str): Path to the checkpoint file.
        model (torch.nn.Module): The model to load the state into.
        optimizer (torch.optim.Optimizer): The optimizer to load the state into.
        scaler (torch.cuda.amp.GradScaler): The gradient scaler to load the state into.

    Returns:
        int: The starting epoch to resume training.
        float: The loss at the checkpoint.
    """
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    start_epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Resuming from epoch {start_epoch}, loss: {loss:.4f}")
    return start_epoch, loss


## Download data

In [10]:
links = [
    'https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz',
    'https://nihcc.box.com/shared/static/i28rlmbvmfjbl8p2n3ril0pptcmcu9d1.gz',
    'https://nihcc.box.com/shared/static/f1t00wrtdk94satdfb9olcolqx20z2jp.gz',
	'https://nihcc.box.com/shared/static/0aowwzs5lhjrceb3qp67ahp0rd1l1etg.gz',
    'https://nihcc.box.com/shared/static/v5e3goj22zr6h8tzualxfsqlqaygfbsn.gz',
	'https://nihcc.box.com/shared/static/asi7ikud9jwnkrnkj99jnpfkjdes7l6l.gz',
	'https://nihcc.box.com/shared/static/jn1b4mw4n6lnh74ovmcjb8y48h8xj07n.gz',
    'https://nihcc.box.com/shared/static/tvpxmn7qyrgl0w8wfh9kqfjskv6nmm1j.gz',
	'https://nihcc.box.com/shared/static/upyy3ml7qdumlgk2rfcvlb9k6gvqq2pj.gz',
	'https://nihcc.box.com/shared/static/l6nilvfa9cg3s28tqv1qc1olm3gnz54p.gz',
	'https://nihcc.box.com/shared/static/hhq8fkdgvcari67vfhs7ppg2w6ni4jze.gz',
	'https://nihcc.box.com/shared/static/ioqwiy20ihqwyr8pf4c24eazhh281pbu.gz'
]

with ThreadPoolExecutor(max_workers=5) as executor:
    executor.map(lambda args: process_link(*args), enumerate(links))

print("Download and extraction complete. Please check the extracted files.")

images_001.tar.gz already exists, skipping download.
images_001.tar.gz already extracted, skipping.
images_002.tar.gz already exists, skipping download.
images_002.tar.gz already extracted, skipping.
images_003.tar.gz already exists, skipping download.
images_004.tar.gz already exists, skipping download.
images_005.tar.gz already exists, skipping download.
images_003.tar.gz already extracted, skipping.
images_005.tar.gz already extracted, skipping.
images_004.tar.gz already extracted, skipping.
images_006.tar.gz already exists, skipping download.
images_010.tar.gz already exists, skipping download.
images_007.tar.gz already exists, skipping download.
images_010.tar.gz already extracted, skipping.
images_008.tar.gz already exists, skipping download.
images_011.tar.gz already exists, skipping download.
images_008.tar.gz already extracted, skipping.
images_009.tar.gz already exists, skipping download.
images_011.tar.gz already extracted, skipping.
images_007.tar.gz already extracted, skip

## Load data

In [11]:
transforms = torchvision.transforms.Compose([
    xrv.datasets.XRayCenterCrop(),
    xrv.datasets.XRayResizer(224)
])

In [12]:
dataset = xrv.datasets.NIH_Dataset(
    imgpath=path_dataset,
    transform=transforms,
    views=["PA", "AP"],
    unique_patients=True, # One image per patient
)
xrv.datasets.relabel_dataset(xrv.datasets.default_pathologies, dataset)

Lung Lesion doesn't exist. Adding nans instead.
Fracture doesn't exist. Adding nans instead.
Lung Opacity doesn't exist. Adding nans instead.
Enlarged Cardiomediastinum doesn't exist. Adding nans instead.


In [13]:
# Subsample dataset to the desired number of samples
if num_samples < len(dataset):
    subset_indices = np.random.choice(len(dataset), num_samples, replace=False)
else:
    subset_indices = np.arange(len(dataset))  # Use the entire dataset if num_samples exceeds its size

# Create a SubsetDataset
dataset = xrv.datasets.SubsetDataset(dataset, subset_indices)

In [14]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=num_workers*2
)

In [15]:
# Example: Iterate through a few batches to verify
for batch in dataloader:
    print("Image batch shape:", batch["img"].shape)
    print("Labels batch shape:", batch["lab"].shape)
    break

Image batch shape: torch.Size([64, 1, 224, 224])
Labels batch shape: torch.Size([64, 18])


In [16]:
# sample = dataset[40]
# plt.imshow(sample["img"][0], cmap="Greys_r")
# dict(zip(dataset.pathologies,sample["lab"]))

## Load model

In [17]:
model = xrv.models.DenseNet(weights="nih").to(device)
model.op_threshs = None # Disable calibrated thresholds for the model

dict(zip(model.pathologies,xrv.datasets.default_pathologies))

{'Atelectasis': 'Atelectasis',
 'Consolidation': 'Consolidation',
 'Infiltration': 'Infiltration',
 'Pneumothorax': 'Pneumothorax',
 'Edema': 'Edema',
 'Emphysema': 'Emphysema',
 'Fibrosis': 'Fibrosis',
 'Effusion': 'Effusion',
 'Pneumonia': 'Pneumonia',
 'Pleural_Thickening': 'Pleural_Thickening',
 'Cardiomegaly': 'Cardiomegaly',
 'Nodule': 'Nodule',
 'Mass': 'Mass',
 'Hernia': 'Hernia',
 '': 'Enlarged Cardiomediastinum'}

In [18]:
# Align dataset pathologies to model pathologies
common_pathologies = list(set(dataset.pathologies) & set(model.pathologies))
num_common_pathologies = len(common_pathologies)
print(f"Common Pathologies: {common_pathologies}")

# Map dataset indices to model indices
dataset_to_model_indices = {dataset.pathologies.index(p): model.pathologies.index(p) for p in common_pathologies}
print(f"Dataset to Model Index Mapping: {dataset_to_model_indices}")

Common Pathologies: ['Pleural_Thickening', 'Edema', 'Cardiomegaly', 'Infiltration', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Fibrosis', 'Emphysema', 'Nodule', 'Mass', 'Atelectasis', 'Effusion', 'Hernia']
Dataset to Model Index Mapping: {9: 9, 4: 4, 10: 10, 2: 2, 8: 8, 3: 3, 1: 1, 6: 6, 5: 5, 11: 11, 12: 12, 0: 0, 7: 7, 13: 13}


In [19]:
# Determine the number of output features dynamically
dummy_input = torch.zeros(1, 1, 224, 224)  # Batch size 1, single channel, 224x224 input
if torch.cuda.is_available():
    dummy_input = dummy_input.cuda()

# Get the output shape of the feature extractor
with torch.no_grad():
    num_features = model.features(dummy_input).shape[1]  # The second dimension is the feature size

# Update the classifier to match the number of pathologies
model.classifier = torch.nn.Linear(num_features, num_common_pathologies).to(device)
print(f"Updated classifier to output {num_common_pathologies} pathologies.")


Updated classifier to output 14 pathologies.


In [20]:
# Update classifier to match the number of common pathologies
model.classifier = torch.nn.Linear(num_features, num_common_pathologies).to(device)
print(f"Updated classifier to output {num_common_pathologies} pathologies.")

Updated classifier to output 14 pathologies.


## Optimizers

In [21]:
optimizer = torch.optim.Adam(model.classifier.parameters()) # only train classifier

## Loss

In [22]:
criterion = torch.nn.BCEWithLogitsLoss().to(device)

## Training

In [23]:
num_epochs = 10
model_name = "first_model"
checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_{model_name}.pth")

# Load checkpoint if available
if os.path.exists(checkpoint_path):
    start_epoch, _ = load_checkpoint(checkpoint_path, model, optimizer, scaler)
else:
    start_epoch = 0  # Start from scratch if no checkpoint is found

In [None]:
# Training loop
losses = []  # Store losses for visualization
all_results = []  # Store evaluation metrics

for epoch in range(start_epoch, num_epochs):
    epoch_losses = []
    model.train()

    # ========================
    # Train Model
    # ========================
    for i, batch in enumerate(dataloader):
        
        # Move data to device
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
        optimizer.zero_grad()

        # Extract targets for common pathologies
        dataset_indices = list(dataset_to_model_indices.keys())
        model_indices = list(dataset_to_model_indices.values())

        # Ensure targets have the correct data type
        targets = batch["lab"][:, dataset_indices].float()  # Convert to torch.float
        targets_aligned = torch.zeros((targets.size(0), len(model.pathologies)), device=device, dtype=torch.float)
        targets_aligned[:, model_indices] = targets

        # Mixed precision training with autocast
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
            outputs = model(batch["img"])
            loss = criterion(outputs[:, model_indices], targets_aligned[:, model_indices])
            
        # Backpropagation
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()

        # Log loss
        epoch_losses.append(loss.item())
    
    # ========================
    # Evaluation Metrics
    # ========================
    model.eval()
    outs, labs = [], []
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(batch["img"])[:, model_indices].cpu().numpy()
            outs.extend(outputs)
            labs.extend(batch["lab"][:, dataset_indices].cpu().numpy())

    outs = np.array(outs)
    labs = np.array(labs)

    results = []
    for idx, pathology in enumerate(common_pathologies):
        result = {"Pathology": pathology}
        if len(np.unique(labs[:, idx])) > 1:  # Only calculate metrics if labels are not constant
            labels = labs[:, idx].astype(bool)
            preds = outs[:, idx]
            result["AUC"] = roc_auc_score(labels, preds)
            result["Acc"] = accuracy_score(labels, preds > 0.5)
            result["F1"] = f1_score(labels, preds > 0.5)
            result["Precision"] = precision_score(labels, preds > 0.5)
            result["Recall"] = recall_score(labels, preds > 0.5)
            tn, fp, fn, tp = confusion_matrix(labels, preds > 0.5).ravel()
            result["Specificity"] = tn / (tn + fp)
            result["Balanced Accuracy"] = balanced_accuracy_score(labels, preds > 0.5)
            precision, recall, _ = precision_recall_curve(labels, preds)
            result["PR AUC"] = auc(recall, precision)
        results.append(result)

    # Log results for this epoch
    df_results = pd.DataFrame(results)
    print(f"Epoch {epoch + 1} Evaluation Results:")
    display(df_results)

    all_results.append(df_results)

    # Store average loss for the epoch
    avg_loss = np.mean(epoch_losses)
    losses.append(avg_loss)
    print(f"Epoch {epoch + 1} Loss: {avg_loss:.4f}")

    # ========================
    # Visualize
    # ========================
    
    print(f"Batch {epoch}, Loss: {loss.item():.4f}")
    # Plot sample images
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    for j in range(4):  # Display 4 images
        if j < batch["img"].size(0):  # Ensure the batch has enough images
            img = batch["img"][j].cpu().numpy().transpose(1, 2, 0)  # Convert to HWC
            img = (img - img.min()) / (img.max() - img.min())  # Normalize
            axes[j].imshow(img, cmap='gray')
            axes[j].axis('off')
            axes[j].set_title(f"Targets: {targets[j].nonzero(as_tuple=True)[0].tolist()}")
    plt.show()

    # Store average loss for the epoch
    avg_loss = np.mean(epoch_losses)
    losses.append(avg_loss)
    print(f"Epoch {epoch + 1} Loss: {avg_loss:.4f}")

    # ========================
    # Logging
    # ========================
    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'loss': avg_loss,
    }, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")

# ========================
# Loss Curve
# ========================
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.grid()
plt.show()

# ========================
# Final Results
# ========================
final_results = pd.concat(all_results, keys=range(1, num_epochs + 1))
final_results.to_csv(os.path.join(checkpoint_dir, "evaluation_results.csv"), index=False)