# Train a model for classifying tissue samples into benign vs malign (Pytorch)

## Import needed libraries

In [1]:
import urllib.request
import os
import tarfile
import pandas as pd
import shutil
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets, models
import matplotlib.pyplot as plt
import numpy as np
import random
from PIL import Image
import torch.nn.functional as F

Pillow version: 9.4.0


## Get the data ready

### Download the images

In [3]:
# Download the 56 zip files in Images_png in batches
# URLs for the zip files
links = [
    'https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz',
    'https://nihcc.box.com/shared/static/i28rlmbvmfjbl8p2n3ril0pptcmcu9d1.gz',
    'https://nihcc.box.com/shared/static/f1t00wrtdk94satdfb9olcolqx20z2jp.gz',
	'https://nihcc.box.com/shared/static/0aowwzs5lhjrceb3qp67ahp0rd1l1etg.gz',
    'https://nihcc.box.com/shared/static/v5e3goj22zr6h8tzualxfsqlqaygfbsn.gz',
	'https://nihcc.box.com/shared/static/asi7ikud9jwnkrnkj99jnpfkjdes7l6l.gz',
	'https://nihcc.box.com/shared/static/jn1b4mw4n6lnh74ovmcjb8y48h8xj07n.gz',
    'https://nihcc.box.com/shared/static/tvpxmn7qyrgl0w8wfh9kqfjskv6nmm1j.gz',
	'https://nihcc.box.com/shared/static/upyy3ml7qdumlgk2rfcvlb9k6gvqq2pj.gz',
	'https://nihcc.box.com/shared/static/l6nilvfa9cg3s28tqv1qc1olm3gnz54p.gz',
	'https://nihcc.box.com/shared/static/hhq8fkdgvcari67vfhs7ppg2w6ni4jze.gz',
	'https://nihcc.box.com/shared/static/ioqwiy20ihqwyr8pf4c24eazhh281pbu.gz'
]


for idx, link in enumerate(links):
    fn = '/home/crisrbarreram/Documents/ccir_demo/imgs/images_%02d.tar.gz' % (idx+1)
    print('downloading'+fn+'...')
    urllib.request.urlretrieve(link, fn)  # download the zip file


print("Download complete. Please check the checksums")

downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_01.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_02.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_03.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_04.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_05.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_06.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_07.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_08.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_09.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_10.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_11.tar.gz...
downloading/home/crisrbarreram/Documents/ccir_demo/imgs/images_12.tar.gz...
Download complete. Please check the checksums


### Define the paths

In [2]:
# Paths
data_dir = '/home/cbarr23/Documents/ccir_demo/imgs/'
csv_path = '/home/cbarr23/Documents/ccir_demo/Data_Entry_2017_v2020.csv'
output_dir = '/home/cbarr23/Documents/ccir_demo/processed/'
image_dir = '/home/cbarr23/Documents/ccir_demo/imgs/images/'

### Unpack the files

In [10]:
# Unpack all .tar.gz files into individual subdirectories
tar_files = [f for f in os.listdir(data_dir) if f.endswith('.tar.gz')]
for tar_file in tar_files:
    tar_path = os.path.join(data_dir, tar_file)
    subdir = os.path.join(data_dir, tar_file[:-7])  # Create subdirectory based on tar file name
    os.makedirs(subdir, exist_ok=True)
    with tarfile.open(tar_path, 'r:gz') as tar:
        tar.extractall(path=subdir)
    print(f'Unpacked {tar_file} into {subdir}')

Unpacked images_01.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_01
Unpacked images_02.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_02
Unpacked images_03.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_03
Unpacked images_05.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_05
Unpacked images_10.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_10
Unpacked images_09.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_09
Unpacked images_12.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_12
Unpacked images_06.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_06
Unpacked images_07.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_07
Unpacked images_04.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_04
Unpacked images_08.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_08
Unpacked images_11.tar.gz into /home/cbarr23/Documents/ccir_demo/imgs/images_11


In [11]:
# Create directories for processed data
os.makedirs(os.path.join(output_dir, 'train/normal'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'train/pneumonia'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'test/normal'), exist_ok=True)
os.makedirs(os.path.join(output_dir, 'test/pneumonia'), exist_ok=True)

### Load the csv file for identifying the labels for normal vs pneumonia

In [12]:
# Load CSV
df = pd.read_csv(csv_path)

### filter out for Pneumonia vs Normal

In [47]:
# Load CSV and filter images with labels 'No Finding' and 'Pneumonia'
df = pd.read_csv(csv_path)
filtered_df = df[df['Finding Labels'].isin(['No Finding', 'Pneumonia'])].copy()

# Map 'No Finding' to 'normal' and 'Pneumonia' to 'pneumonia'
filtered_df.loc[:, 'label'] = filtered_df['Finding Labels'].map({'No Finding': 'normal', 'Pneumonia': 'pneumonia'})


### Split into train and test

In [48]:
# Split into train and test sets
train_df, test_df = train_test_split(filtered_df, test_size=0.2, stratify=filtered_df['label'], random_state=42)

### copy images to corresponding folders

In [49]:
# Function to copy images to their respective directories
def copy_images(df, split):
    for _, row in df.iterrows():
        label = row['label']
        image_path = os.path.join(image_dir, row['Image Index'])
        if os.path.exists(image_path):
            shutil.copy(image_path, os.path.join(output_dir, split, label, row['Image Index']))


# Copy images to train and test directories
copy_images(train_df, 'train')
copy_images(test_df, 'test')


print("Images have been successfully filtered and organized.")

Images have been successfully filtered and organized.


## Get the model ready

### Parameters

In [28]:
# Parameters for training
img_height, img_width = 224, 224
batch_size = 32
epochs = 10

### Generate the data for the DL model training

In [29]:
# Define image transformations
train_transforms = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize((img_height, img_width)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Data generators
train_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'train'), transform=train_transforms)
test_dataset = datasets.ImageFolder(root=os.path.join(output_dir, 'test'), transform=test_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


Found 48546 images belonging to 2 classes.
Found 12137 images belonging to 2 classes.


### Model building

In [30]:
# Define the model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(128 * (img_height // 8) * (img_width // 8), 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * (img_height // 8) * (img_width // 8))
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.sigmoid(self.fc2(x))
        return x

### Combine the model

In [31]:
model = SimpleCNN()

# Loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

### Train the model

In [None]:
# Training loop
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs).squeeze()
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Epoch {epoch + 1}/{epochs}, Loss: {running_loss/len(train_loader):.4f}')


Epoch 1/10


2024-06-24 07:14:22.447530: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
	 [[{{node Placeholder/_0}}]]




2024-06-24 07:25:43.732492: I tensorflow/core/common_runtime/executor.cc:1197] [/device:CPU:0] (DEBUG INFO) Executor start aborting (this does not indicate an error and you can ignore this message): INVALID_ARGUMENT: You must feed a value for placeholder tensor 'Placeholder/_0' with dtype int32
	 [[{{node Placeholder/_0}}]]


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10

### Save the model

In [None]:
# Save the model
torch.save(model.state_dict(), '/home/cbarr23/Documents/ccir_demo/model/cxr_classification_epoch_10_model.pth')

In [None]:
# Load the trained model
model.load_state_dict(torch.load('/home/cbarr23/Documents/ccir_demo/model/cxr_classification_epoch_10_model.pth'))
model.eval()

# Grad-CAM function
def get_gradcam_heatmap(model, img_tensor, target_layer):
    model.zero_grad()
    features = []
    def hook(module, input, output):
        features.append(output)
    
    hook_handle = target_layer.register_forward_hook(hook)
    
    output = model(img_tensor)
    pred_class = output.argmax(dim=1).item()
    pred_class_tensor = output[:, pred_class]
    
    pred_class_tensor.backward()
    
    grads = target_layer.weight.grad[0]
    pooled_grads = torch.mean(grads, dim=[0, 2, 3])
    
    target = features[0]
    for i in range(pooled_grads.size(0)):
        target[:, i, :, :] *= pooled_grads[i]
    
    heatmap = torch.mean(target, dim=1).squeeze()
    heatmap = F.relu(heatmap)
    heatmap /= torch.max(heatmap)
    hook_handle.remove()
    
    return heatmap.detach().cpu().numpy()

# Function to display Grad-CAM
def display_gradcam(img, heatmap, alpha=0.4):
    img = img.permute(1, 2, 0).numpy()
    img = np.uint8(255 * img)
    
    heatmap = np.uint8(255 * heatmap)
    heatmap = np.uint8(plt.get_cmap("jet")(heatmap)[:, :, :3] * 255)
    
    superimposed_img = heatmap * alpha + img
    superimposed_img = np.uint8(superimposed_img)
    
    plt.figure(figsize=(10, 10))
    plt.imshow(superimposed_img)
    plt.axis('off')
    plt.show()

# Display random samples from the test set with Grad-CAM heatmaps
num_images = 4
correct_indices = [i for i, (inputs, labels) in enumerate(test_loader) if torch.round(model(inputs).squeeze()) == labels]
incorrect_indices = [i for i, (inputs, labels) in enumerate(test_loader) if torch.round(model(inputs).squeeze()) != labels]

# Display correct predictions with Grad-CAM
if len(correct_indices) > 0:
    print("Correct Predictions with Grad-CAM:")
    for idx in random.sample(correct_indices, min(num_images, len(correct_indices))):
        img_tensor, _ = test_dataset[idx]
        img_tensor = img_tensor.unsqueeze(0)
        heatmap = get_gradcam_heatmap(model, img_tensor, model.conv3)
        display_gradcam(img_tensor.squeeze(), heatmap)
else:
    print("No correct predictions to display.")

# Display incorrect predictions with Grad-CAM
if len(incorrect_indices) > 0:
    print("Incorrect Predictions with Grad-CAM:")
    for idx in random.sample(incorrect_indices, min(num_images, len(incorrect_indices))):
        img_tensor, _ = test_dataset[idx]
        img_tensor = img_tensor.unsqueeze(0)
        heatmap = get_gradcam_heatmap(model, img_tensor, model.conv3)
        display_gradcam(img_tensor.squeeze(), heatmap)
else:
    print("No incorrect predictions to display.")