In [1]:
!pip install torchxrayvision matplotlib numpy pandas ipywidgets scikit-learn scikit-image seaborn



[TorchXRayVision: A library of chest X-ray datasets and models](https://arxiv.org/pdf/2111.00595)

https://github.com/naitik2314/Chest-X-Ray-Medical-Diagnosis-with-Deep-Learning

## Imports

In [None]:
import torchxrayvision as xrv
from sklearn.model_selection import train_test_split
import skimage, torch, torchvision
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import skimage.transform
import urllib.request
import os
import tarfile
from concurrent.futures import ThreadPoolExecutor

  from tqdm.autonotebook import tqdm


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

use_amp = True
scaler = torch.amp.GradScaler(enabled=use_amp)

## Paths

In [None]:
path_dataset = "/mnt/b/Xray/images/"
path_train_val_list = "/mnt/b/Xray/train_val_list.txt"
path_test_list = "/mnt/b/Xray/test_list.txt"

os.makedirs(path_dataset, exist_ok=True)

path_models = "/mnt/b/Xray/models/"
checkpoint_dir = "/mnt/b/Xray/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

## Classes

## Functions

In [5]:
def download_file(link, folder, idx):
    """Downloads a file from a link to the specified folder."""
    file_name = f'images_{idx+1:03d}.tar.gz'
    file_path = os.path.join(folder, file_name)
    if os.path.exists(file_path):
        print(f"{file_name} already exists, skipping download.")
        return file_path
    try:
        print(f"Downloading {file_name}...")
        urllib.request.urlretrieve(link, file_path)
        print(f"{file_name} downloaded successfully.")
        return file_path
    except Exception as e:
        print(f"Failed to download {file_name}: {e}")
        return None

In [6]:
def extract_file(file_path, folder):
    """Extracts a .tar.gz file to the specified folder."""
    extracted_flag = file_path.replace('.tar.gz', '_extracted.flag')
    if os.path.exists(extracted_flag):
        print(f"{os.path.basename(file_path)} already extracted, skipping.")
        return
    try:
        print(f"Extracting {os.path.basename(file_path)}...")
        with tarfile.open(file_path, 'r:gz') as tar:
            tar.extractall(path=folder)
        with open(extracted_flag, 'w') as f:
            f.write('extracted')
        print(f"{os.path.basename(file_path)} extracted successfully.")
    except Exception as e:
        print(f"Failed to extract {os.path.basename(file_path)}: {e}")

In [7]:
def process_link(idx, link):
    """Handles downloading and extracting a single link."""
    file_path = download_file(link, path_dataset, idx)
    if file_path:
        extract_file(file_path, path_dataset)

In [None]:
def load_checkpoint(checkpoint_path, model, optimizer, scaler):
    """
    Load the model, optimizer, and scaler states from a checkpoint file.

    Args:
        checkpoint_path (str): Path to the checkpoint file.
        model (torch.nn.Module): The model to load the state into.
        optimizer (torch.optim.Optimizer): The optimizer to load the state into.
        scaler (torch.cuda.amp.GradScaler): The gradient scaler to load the state into.

    Returns:
        int: The starting epoch to resume training.
        float: The loss at the checkpoint.
    """
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    start_epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Resuming from epoch {start_epoch}, loss: {loss:.4f}")
    return start_epoch, loss


## Download data

In [None]:
links = [
    'https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz',
    'https://nihcc.box.com/shared/static/i28rlmbvmfjbl8p2n3ril0pptcmcu9d1.gz',
    'https://nihcc.box.com/shared/static/f1t00wrtdk94satdfb9olcolqx20z2jp.gz',
	'https://nihcc.box.com/shared/static/0aowwzs5lhjrceb3qp67ahp0rd1l1etg.gz',
    'https://nihcc.box.com/shared/static/v5e3goj22zr6h8tzualxfsqlqaygfbsn.gz',
	'https://nihcc.box.com/shared/static/asi7ikud9jwnkrnkj99jnpfkjdes7l6l.gz',
	'https://nihcc.box.com/shared/static/jn1b4mw4n6lnh74ovmcjb8y48h8xj07n.gz',
    'https://nihcc.box.com/shared/static/tvpxmn7qyrgl0w8wfh9kqfjskv6nmm1j.gz',
	'https://nihcc.box.com/shared/static/upyy3ml7qdumlgk2rfcvlb9k6gvqq2pj.gz',
	'https://nihcc.box.com/shared/static/l6nilvfa9cg3s28tqv1qc1olm3gnz54p.gz',
	'https://nihcc.box.com/shared/static/hhq8fkdgvcari67vfhs7ppg2w6ni4jze.gz',
	'https://nihcc.box.com/shared/static/ioqwiy20ihqwyr8pf4c24eazhh281pbu.gz'
]

with ThreadPoolExecutor(max_workers=5) as executor:
    executor.map(lambda args: process_link(*args), enumerate(links))

print("Download and extraction complete. Please check the extracted files.")

images_001.tar.gz already exists, skipping download.
images_003.tar.gz already exists, skipping download.
images_002.tar.gz already exists, skipping download.
images_002.tar.gz already extracted, skipping.
images_003.tar.gz already extracted, skipping.
images_004.tar.gz already exists, skipping download.
images_005.tar.gz already exists, skipping download.
images_001.tar.gz already extracted, skipping.
Extracting images_005.tar.gz...
images_006.tar.gz already exists, skipping download.
images_007.tar.gz already exists, skipping download.
images_008.tar.gz already exists, skipping download.
images_004.tar.gz already extracted, skipping.
Extracting images_006.tar.gz...
Extracting images_007.tar.gz...
Extracting images_008.tar.gz...
images_009.tar.gz already exists, skipping download.
Extracting images_009.tar.gz...
images_005.tar.gz extracted successfully.
images_010.tar.gz already exists, skipping download.
Extracting images_010.tar.gz...
images_006.tar.gz extracted successfully.
images

## Load data

In [None]:
# Use XRV transforms to crop and resize the images
transforms = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(224)])

dataset = xrv.datasets.NIH_Dataset(imgpath=path_dataset, transform=transforms)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=4, pin_memory=True, prefetch_factor=4)

## Load model

In [None]:
# Load pre-trained model and erase classifier
model = xrv.models.DenseNet(weights="densenet121-res224-all").to(device)
model.op_threshs = None # prevent pre-trained model calibration
model.classifier = torch.nn.Linear(1024,1).to(device) # reinitialize classifier

## Optimizers

In [None]:
optimizer = torch.optim.Adam(model.classifier.parameters()) # only train classifier

## Loss

In [None]:
criterion = torch.nn.BCEWithLogitsLoss().to(device)

## Training

In [None]:
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint_epoch_5.pth")  # Adjust path as needed
if os.path.exists(checkpoint_path):
    start_epoch, _ = load_checkpoint(checkpoint_path, model, optimizer, scaler)
else:
    start_epoch = 0  # Start from scratch if no checkpoint is found

In [None]:
# Training loop
num_epochs = 200

model_name = "food11_stamp_kernel3_stride2_lossweigthed_maxnormdyn1_512_64-128-256-512-1024"





losses = []  # Store losses for visualization
for epoch in range(start_epoch, 10):  
    epoch_losses = []

    # ========================
    # Train Model
    # ========================
    for i, batch in enumerate(dataloader):
        # Move data to device
        batch = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()

        # Mixed precision training with autocast
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
            outputs = model(batch["img"])
            targets = batch["lab"][:, dataset.pathologies.index("Lung Opacity"), None]
            loss = criterion(outputs, targets)
        
        # Backpropagation
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        scaler.step(optimizer)
        scaler.update()

        # Log loss
        epoch_losses.append(loss.item())

        # ========================
        # Visualize
        # ========================
        if i % 10 == 0:
            print(f"Batch {i}, Loss: {loss.item():.4f}")
            # Plot sample images
            fig, axes = plt.subplots(1, 4, figsize=(16, 4))
            for j in range(4):  # Display 4 images
                if j < batch["img"].size(0):  # Ensure the batch has enough images
                    img = batch["img"][j].cpu().numpy().transpose(1, 2, 0)  # Convert to HWC
                    img = (img - img.min()) / (img.max() - img.min())  # Normalize
                    axes[j].imshow(img, cmap='gray')
                    axes[j].axis('off')
                    axes[j].set_title(f"Target: {targets[j].item():.4f}")
            plt.show()

    # Store average loss for the epoch
    avg_loss = np.mean(epoch_losses)
    losses.append(avg_loss)
    print(f"Epoch {epoch + 1} Loss: {avg_loss:.4f}")

    # ========================
    # Logging
    # ========================

    checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'loss': avg_loss,
    }, checkpoint_path)
    print(f"Checkpoint saved: {checkpoint_path}")


# ========================
# Loss Curve
# ========================
plt.figure(figsize=(10, 6))
plt.plot(range(1, 11), losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.grid()
plt.show()
