<a href="https://colab.research.google.com/github/bhanup6663/chest_x_ray_reporting/blob/main/fasterCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# !unzip /content/drive/MyDrive/resized_images.zip -d /content/

In [None]:
# !pip install bbox_visualizer
# !pip install torchvision
# !pip install pydicom

In [None]:
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from sklearn import model_selection
from sklearn.model_selection import StratifiedGroupKFold
from torch.optim.lr_scheduler import ReduceLROnPlateau

import cv2
from skimage import io, exposure
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import bbox_visualizer as bbv

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
from glob import glob
from skimage import exposure
from collections import defaultdict

import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SequentialSampler
import torchvision

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from albumentations.augmentations.dropout import CoarseDropout
from torchvision.ops import box_iou

import shutil

import warnings

warnings.filterwarnings('ignore')

In [None]:
dataset = pd.read_csv(os.path.join("train1.csv"))
dataset.head()

Unnamed: 0,image_id,class_name,class_id,rad_id,x_min,y_min,x_max,y_max,width,height
0,50a418190bc3fb1ef1633bf9678929b3,No finding,14,R11,,,,,2332.0,2580.0
1,21a10246a5ec7af151081d0cd6d65dc9,No finding,14,R7,,,,,2954.0,3159.0
2,9a5094b2563a1ef3ff50dc5c7ff71345,Cardiomegaly,3,R10,0.332212,0.588613,0.794712,0.783818,2080.0,2336.0
3,051132a778e61a86eb147c7c6f564dfe,Aortic enlargement,0,R10,0.548611,0.257986,0.699219,0.353819,2304.0,2880.0
4,063319de25ce7edb9b1c6b8881290140,No finding,14,R10,,,,,2540.0,3072.0


In [None]:
dataset_new = dataset[dataset.class_name!='No finding'].reset_index(drop=True)

In [None]:
class_brands = {
    0: 'Aortic enlargement',
    1: 'Atelectasis',
    2: 'Calcification',
    3: 'Cardiomegaly',
    4: 'Consolidation',
    5: 'ILD',
    6: 'Infiltration',
    7: 'Lung Opacity',
    8: 'Nodule/Mass',
    9: 'Other lesion',
    10: 'Pleural effusion',
    11: 'Pleural thickening',
    12: 'Pneumothorax',
    13: 'Pulmonary fibrosis'
}

In [None]:
num_classes = 15

In [None]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

in_features = model.roi_heads.box_predictor.cls_score.in_features

In [None]:
class FastRCNNPredictorWithDropout(nn.Module):
    def __init__(self, in_features, num_classes, dropout_p=0.3):
        super(FastRCNNPredictorWithDropout, self).__init__()
        self.fc = nn.Linear(in_features, in_features)
        self.dropout = nn.Dropout(p=dropout_p)
        self.cls_score = nn.Linear(in_features, num_classes)
        self.bbox_pred = nn.Linear(in_features, num_classes * 4)

    def forward(self, x):
        x = self.fc(x)
        x = torch.relu(x)
        x = self.dropout(x)  # Add dropout here
        scores = self.cls_score(x)
        bbox_deltas = self.bbox_pred(x)
        return scores, bbox_deltas

In [None]:
model.roi_heads.box_predictor = FastRCNNPredictorWithDropout(in_features, num_classes)

In [None]:
def set_device():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return device

In [None]:
device=set_device()

In [None]:
model.to(device)
params = [p for p in model.parameters() if p.requires_grad]

In [None]:
def get_train_transform():
    return A.Compose([ToTensorV2(),], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

def get_valid_transform():
    return A.Compose([ToTensorV2(),], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

In [None]:
class LungsAnnotationDataset(Dataset):
    def __init__(self, dataframe, image_dir, transforms=None):
        super().__init__()
        self.image_ids = dataframe['image_id'].unique()
        self.df = dataframe
        self.image_dir = image_dir
        self.transforms = transforms

    def __getitem__(self, index: int):
        image_id = self.image_ids[index]
        records = self.df[self.df['image_id'] == image_id]

        image = io.imread(f'{self.image_dir}/{image_id}.png')

        # Normalize the image
        image = image / 255.0  # pixel values are in the range [0, 255]
        image = exposure.equalize_hist(image)
        image = image.astype('float32')

        # If the image has 3 channels already (like RGB), no need to stack, else ensure 3 channels
        if image.ndim == 2:  # If the image is grayscale, convert to 3 channels
            image = np.stack([image, image, image], axis=-1)

        # Ensure the image is in the correct (C, H, W) format
        if image.shape[2] == 3:  # Check if image is in (H, W, C)
            image = image.transpose(2, 0, 1)  # Convert from (H, W, C) to (C, H, W)

        # Get bounding boxes and other details
        boxes = records[['x_min', 'y_min', 'x_max', 'y_max']].values

        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        area = torch.as_tensor(area, dtype=torch.float32)

        labels = records.class_id.values + 1
        iscrowd = torch.zeros((records.shape[0],), dtype=torch.int64)

        target = {
            'boxes': torch.tensor(boxes, dtype=torch.float32),
            'labels': torch.tensor(labels, dtype=torch.int64),
            'area': area,
            'iscrowd': iscrowd
        }

        # Apply transformations if available (pass normalized boxes)
        if self.transforms:
            sample = {
                'image': image,
                'bboxes': target['boxes'],
                'labels': labels
            }
            sample = self.transforms(**sample)
            image = sample['image']
            target['boxes'] = torch.tensor(sample['bboxes'], dtype=torch.float32)

        # Denormalize boxes AFTER transformations (if you need pixel coordinates)
        target['boxes'][:, [0, 2]] = target['boxes'][:, [0, 2]] * 512
        target['boxes'][:, [1, 3]] = target['boxes'][:, [1, 3]] * 512

        return image, target


    def __len__(self):
        return len(self.image_ids)

In [None]:
def collate_fn(batch):
    images, targets = zip(*batch)
    images = [image.permute(1, 2, 0) if image.shape[0] != 3 else image for image in images]
    return torch.stack(images), targets


In [None]:
class Averager:
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0

    def send(self, value):
        self.current_total += value
        self.iterations += 1

    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations

    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0

In [None]:
class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.delta = delta

    def __call__(self, val_loss, model, fold, epoch):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0
            # Save model checkpoint as the best so far
            checkpoint_path = f"model_fasterRCNN_fold{fold}_best.pth"
            print(f"Validation improved, saving model at epoch {epoch}...")
            torch.save(model.state_dict(), checkpoint_path)
            shutil.copy(checkpoint_path, f'/content/drive/MyDrive/x_ray_models/{checkpoint_path}')


In [None]:
def calculate_iou(pred_boxes, gt_boxes):
    """
    Calculate IoU (Intersection over Union) between predicted and ground truth boxes.

    Args:
        pred_boxes (Tensor): Predicted bounding boxes, shape [num_pred_boxes, 4].
        gt_boxes (Tensor): Ground truth bounding boxes, shape [num_gt_boxes, 4].

    Returns:
        Tensor: IoU scores, shape [num_pred_boxes, num_gt_boxes].
    """
    return box_iou(pred_boxes, gt_boxes)

In [None]:
def calculate_classwise_iou(outputs, targets, num_classes):
    """
    Calculate IoU for each class based on predicted and ground truth boxes.
    """
    classwise_iou = defaultdict(list)

    for output, target in zip(outputs, targets):
        pred_boxes = output['boxes']
        pred_labels = output['labels']
        gt_boxes = target['boxes']
        gt_labels = target['labels']

        for cls_id in range(1, num_classes + 1):  # Assuming class IDs are 1-based
            # Get boxes for the current class (both predicted and ground truth)
            pred_class_boxes = pred_boxes[pred_labels == cls_id]
            gt_class_boxes = gt_boxes[gt_labels == cls_id]

            if len(pred_class_boxes) > 0 and len(gt_class_boxes) > 0:
                # Calculate IoU between predicted and ground truth boxes for this class
                iou = box_iou(pred_class_boxes, gt_class_boxes)
                classwise_iou[cls_id].append(iou.mean().item())  # Store the mean IoU for this class

    return classwise_iou

In [None]:
lr_scheduler = None

num_epochs = 50

In [None]:
def train_model(train_dataset, val_dataset, fold, start_epoch=0, resume=False, checkpoint_path=None):
    # Initialize data loaders
    train_data_loader = DataLoader(
        train_dataset,
        batch_size=12,
        shuffle=True,
        num_workers=4,
        collate_fn=collate_fn,
        prefetch_factor=4
    )

    val_data_loader = DataLoader(
        val_dataset,
        batch_size=3,
        shuffle=False,
        num_workers=4,
        collate_fn=collate_fn,
        prefetch_factor=4
    )

    # Initialize Averager instances for loss tracking
    loss_hist = Averager()
    val_loss_hist = Averager()

    # Initialize optimizer and scheduler
    optimizer = torch.optim.AdamW(params, lr=0.0005, weight_decay=0.0005,betas=(0.9, 0.999), eps=1e-08)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=10, factor=0.5, verbose=True)

    # Early stopping
    early_stopping = EarlyStopping(patience=10)

    # Resume from checkpoint
    if resume and checkpoint_path:
        print(f"Resuming training from checkpoint {checkpoint_path}...")
        checkpoint = torch.load(checkpoint_path)

        model.load_state_dict(checkpoint, strict=False)

    for epoch in range(start_epoch, num_epochs):
        loss_hist.reset()
        model.train()

        for itr, (images, targets) in enumerate(train_data_loader, 1):
            optimizer.zero_grad()
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            loss_dict = model(images, targets)

            # Calculate the loss
            if isinstance(loss_dict, dict):
                loss_classifier = loss_dict['loss_classifier']
                loss_box_reg = loss_dict['loss_box_reg']
                loss_objectness = loss_dict['loss_objectness']
                loss_rpn_box_reg = loss_dict['loss_rpn_box_reg']

                losses = (loss_objectness +
                          10 * loss_classifier +
                          10 * loss_rpn_box_reg +
                          0.5 * loss_box_reg ** 2)

                loss_value = losses.item()
                loss_hist.send(loss_value)
                losses.backward()
                optimizer.step()

            if itr % 100 == 0:
                print(f"Fold #{fold} Epoch #{epoch+1} Iteration #{itr}/{len(train_data_loader)} loss: {loss_hist.value:.4f}")

        # Validation phase
        model.eval()
        with torch.no_grad():
            iou_hist = defaultdict(Averager)  # Track IoU per class using a dictionary
            class_iterations = defaultdict(int)  # Track the number of instances for each class

            for images, targets in val_data_loader:
                images = [image.to(device) for image in images]
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                outputs = model(images)

                # Compute IoU for each class
                classwise_ious = calculate_classwise_iou(outputs, targets, num_classes)

                # Update IoU history for each class and count the number of instances
                for cls_id, ious in classwise_ious.items():
                    for iou in ious:
                        iou_hist[cls_id].send(iou)
                    class_iterations[cls_id] += len(ious)  # Count how many IoU values are calculated per class

            # After validation, print IoU for each class
            print(f"Validation Stats for Fold #{fold}:")
            class_names = ['Aortic enlargement', 'Atelectasis', 'Calcification', 'Cardiomegaly', 'Consolidation',
                          'ILD', 'Infiltration', 'Lung Opacity', 'No Finding', 'Nodule/Mass', 'Other lesion',
                          'Pleural effusion', 'Pleural thickening', 'Pneumothorax', 'Pulmonary fibrosis']

            for i, class_name in enumerate(class_names, start=1):
                avg_iou = iou_hist[i].value  # Get average IoU for this class
                print(f"{class_name:30} | {avg_iou:8f} | {class_iterations[i]}")

        scheduler.step(avg_iou)

        # Early stopping check
        early_stopping(avg_iou, model, fold, epoch)
        if early_stopping.early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

        # Save model checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            checkpoint_path = f"model_fasterRCNN_fold{fold}_epoch{epoch+1}.pth"
            torch.save(model.state_dict(), checkpoint_path)
            shutil.copy(checkpoint_path, f'/content/drive/MyDrive/x_ray_models/{checkpoint_path}')

    # Save final model state
    final_model_path = f"model_fasterRCNN_fold{fold}_final.pth"
    torch.save(model.state_dict(), final_model_path)
    shutil.copy(final_model_path, f'/content/drive/MyDrive/x_ray_models/{final_model_path}')

In [None]:
DIR_TRAIN = os.path.join( "resized_images")

In [25]:
k = 1
df = dataset_new.sample(frac=1).reset_index(drop=True)
kfold = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)

# Get the target classes (class_id) and the groups (image_id)
X = dataset_new['image_id'].values
y = dataset_new['class_id'].values
groups = dataset_new['image_id'].values

for train_index, val_index in kfold.split(X, y, groups):
    train_df = dataset_new.loc[train_index].reset_index(drop=True)
    valid_df = dataset_new.loc[val_index].reset_index(drop=True)

    train_dataset = LungsAnnotationDataset(train_df, DIR_TRAIN, get_train_transform())
    val_dataset = LungsAnnotationDataset(valid_df, DIR_TRAIN, get_valid_transform())

    # Train and validate the model
    train_model(train_dataset, val_dataset, k, resume=True, start_epoch=31, checkpoint_path="model_fasterRCNN_fold1_final.pth")
    break


Resuming training from checkpoint model_fasterRCNN_fold1_final.pth...
Fold #1 Epoch #32 Iteration #100/294 loss: 1.1210
Fold #1 Epoch #32 Iteration #200/294 loss: 1.2334
Validation Stats for Fold #1:
Aortic enlargement             | 0.600629 | 584
Atelectasis                    | 0.303527 | 42
Calcification                  | 0.302834 | 76
Cardiomegaly                   | 0.568327 | 449
Consolidation                  | 0.353246 | 65
ILD                            | 0.265027 | 73
Infiltration                   | 0.320484 | 118
Lung Opacity                   | 0.232725 | 271
No Finding                     | 0.273553 | 151
Nodule/Mass                    | 0.133426 | 215
Other lesion                   | 0.290234 | 203
Pleural effusion               | 0.166092 | 392
Pleural thickening             | 0.337350 | 17
Pneumothorax                   | 0.185023 | 331
Pulmonary fibrosis             | 0.000000 | 0
Fold #1 Epoch #33 Iteration #100/294 loss: 1.4334
Fold #1 Epoch #33 Iteration #200/294 