## Introduction
We are ready to train the Cardiac Detection Model now!

## Imports:

* torch and torchvision for model and dataloader creation
* pytorch lightning for efficient and easy training implementation
* ModelCheckpoint and TensorboardLogger for checkpoint saving and logging
* numpy data loading
* cv2 for drawing rectangles on images
* imgaug for augmentation pipeline
* Our CardiacDataset



In [1]:
import torch
import torchvision
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import cv2
# import imgaug.augmenters as iaa

from torchvision.tv_tensors import BoundingBoxes
from torchvision.transforms import v2
from dataset import CardiacDataset


We create the dataset objects and the augmentation parameters to specify the augmentation parameters

In [2]:
mu = 0.49
std_mg = 0.082

In [3]:
train_root_path = "../Data/rsna-pneumonia-detection-challenge/Processed-Heart-Detection/train/"
train_subjects = "../Data/rsna-pneumonia-detection-challenge/Processed-Heart-Detection/train_subjects_det.npy"
val_root_path = "../Data/rsna-pneumonia-detection-challenge/Processed-Heart-Detection/val/"
val_subjects = "../Data/rsna-pneumonia-detection-challenge/Processed-Heart-Detection/val_subjects_det.npy"




In [170]:
# train_root_path = "Processed-Heart-Detection/train/"
# train_subjects = "train_subjects.npy"
# val_root_path = "Processed-Heart-Detection/val/"
# val_subjects = "val_subjects.npy"

# train_transforms = iaa.Sequential([
#                                 iaa.GammaContrast(),
#                                 iaa.Affine(
#                                     scale=(0.8, 1.2),
#                                     rotate=(-10, 10),
#                                     translate_px=(-10, 10)
#                                 )
#                             ])

In [5]:
train_transforms = v2.Compose([
                                    # v2.ToImage(),  # Convert numpy array to tensor
                                    v2.Normalize(mean=[mu], std=[std_mg]), # (0.49, 0.248),  # Use mean and std from preprocessing notebook
                                    v2.RandomAutocontrast(),
                                    v2.RandomAffine(degrees=(-10, 10), translate=(0, 0.05), scale=(0.8, 1.2)), # Data Augmentation
                                    v2.RandomResizedCrop((224, 224), scale=(0.35, 1))

])

val_transforms = v2.Compose([ 
                                    # v2.ToImage(),  # Convert numpy array to tensor
                                    v2.Normalize(mean=[mu], std=[std_mg]),  # Use mean and std from preprocessing notebook
])



In [6]:
train_dataset = CardiacDataset("./rsna_heart_detection.csv", train_subjects, train_root_path, train_transforms)
val_dataset = CardiacDataset("./rsna_heart_detection.csv", val_subjects, val_root_path, None)

In [7]:
img, bbox = train_dataset[0]
print(img.shape, bbox.shape)

torch.Size([1, 224, 224]) torch.Size([4])


  img = torch.tensor(img)


In [8]:
batch_size = 4
num_workers = 0

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [9]:
train_dataset = CardiacDataset(
    "./rsna_heart_detection.csv",
     train_subjects,
     train_root_path,
     augs = train_transforms)

val_dataset = CardiacDataset(
    "./rsna_heart_detection.csv",
     val_subjects,
     val_root_path,
     augs=None)

print(f"There are {len(train_dataset)} train images and {len(val_dataset)} val images")

There are 400 train images and 96 val images


Adapt batch size and num_workers according to your computing hardware.

In [10]:
train_dataset[2]

  img = torch.tensor(img)


(tensor([[[ 2.8254,  1.1521,  0.1458,  ..., -1.0988, -0.2859,  0.1934],
          [-1.1464, -3.0578, -4.0627,  ..., -5.2581, -5.1628, -4.6842],
          [-4.1580, -5.0191, -5.4494,  ..., -5.4974, -5.4494, -5.5450],
          ...,
          [-5.3061, -5.3061, -5.3061,  ..., -5.4017, -5.4017, -5.4974],
          [-5.3061, -5.3061, -5.3061,  ..., -5.4017, -5.4017, -5.4494],
          [-5.3061, -5.3061, -5.3061,  ..., -5.4017, -5.4017, -5.4494]]]),
 tensor([ 61,  59, 169, 176]))

In [None]:
batch_size = 8#TODO
num_workers = 0# TODO


train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                           num_workers=num_workers, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)


## Model Creation

We use the same architecture as we used in the classifcation task with some small adaptations:

1. 4 outputs: Instead of predicting a binary label we need to estimate the location of the heart (xmin, ymin, xmax, ymax).
2. Loss function: Instead of using a cross entropy loss, we are going to use the L2 loss (Mean Squared Error), as we are dealing with continuous values.

In [None]:
# class CardiacDetectionModel(pl.LightningModule):
#     def __init__(self):
#         super().__init__()
        
#         self.model = torchvision.models.resnet18(pretrained=True)
#         self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#         self.model.fc = torch.nn.Linear(in_features=512 ,out_features=4)
        
#         self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
#         self.loss_fn = torch.nn.MSELoss()
        
#     def forward(self, data):
#         return self.model(data)
    
#     def training_step(self, batch, batch_idx):
#         x_ray, label = batch
#         label = label.float()
#         pred = self(x_ray)
#         loss = self.loss_fn(pred, label)
        
#         self.log("Train Loss", loss)
        
#         if batch_idx % 50 == 0:
#             self.log_images(x_ray.cpu(), pred.cpu(), label.cpu(), "Train")
#         return loss
    
#     def validation_step(self, batch, batch_idx):
#         x_ray, label = batch
#         label = label.float()
#         pred = self(x_ray)
#         loss = self.loss_fn(pred, label)
        
#         self.log("Val Loss", loss)
        
#         if batch_idx % 50 == 0:
#             self.log_images(x_ray.cpu(), pred.cpu(), label.cpu(), "Val")
#         return loss
    
#     def log_images(self, x_ray, pred, label, name):
#         results = []
        
#         for i in range(4):
#             coords_labels = label[i]
#             coords_pred = pred[i]
            
#             img = ((x_ray[i] * 0.252)+0.494).numpy()[0]
            
#             x0, y0 = coords_labels[0].int().item(), coords_labels[1].int().item()
#             x1, y1 = coords_labels[2].int().item(), coords_labels[3].int().item()
#             img = cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 0), 2)
            
#             x0, y0 = coords_pred[0].int().item(), coords_pred[1].int().item()
#             x1, y1 = coords_pred[2].int().item(), coords_pred[3].int().item()
#             img = cv2.rectangle(img, (x0, y0), (x1, y1), (1, 1, 1), 2)
            
#             results.append(torch.tensor(img).unsqueeze(0))
        
#         grid = torchvision.utils.make_grid(results, 2)
#         self.logger.experiment.add_image(name, grid, self.global_step)
        
#     def configure_optimizers(self):
#         #Caution! You always need to return a list here (just pack your optimizer into one :))
#         return [self.optimizer]


In [11]:
class CardiacDetectionModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = torchvision.models.resnet18(pretrained=True)
        
        # Change conv1 from 3 to 1 input channels
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
        # Change out_feature of the last fully connected layer (called fc in resnet18) from 1000 to 4
        self.model.fc = torch.nn.Linear(in_features=512, out_features=4)
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.loss_fn = torch.nn.MSELoss()
    
    def forward(self, data):
        pred = self.model(data)
        return pred
    
    def training_step(self, batch, batch_idx):
        x_ray, label = batch
        label = label.float()  # Convert label to float (just needed for loss computation)
        pred = self(x_ray)
        loss = self.loss_fn(pred, label)  # Compute the loss
        
        # Log loss
        self.log("Train Loss", loss)
        if batch_idx % 50 == 0:
            self.log_images(x_ray.cpu(), pred.cpu(), label.cpu(), "Train")

        return loss
    
        
    def validation_step(self, batch, batch_idx):
        # Same steps as in the training_step
        x_ray, label = batch
        label = label

        label = label.float()  # Convert label to float (just needed for loss computation)
        pred = self(x_ray)
        
        loss = self.loss_fn(pred, label)
        self.log("Val Loss", loss)
        if batch_idx % 50 == 0:
            self.log_images(x_ray.cpu(), pred.cpu(), label.cpu(), "Val")
        return loss
    
    def log_images(self, x_ray, pred, label, name):
        results = []
        
        # Here we create a grid consisting of 4 predictions
        for i in range(4):
            coords_labels = label[i]
            coords_pred = pred[i]
            img = ((x_ray[i] * std_mg) + mu).numpy()[0]
            
            # Extract the coordinates from the label
            x0, y0 = coords_labels[0].int().item(), coords_labels[1].int().item()
            x1, y1 = coords_labels[2].int().item(), coords_labels[3].int().item()
            img = cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 0), 2)
            
            # Extract the coordinates from the prediction           
            x0, y0 = coords_pred[0].int().item(), coords_pred[1].int().item()
            x1, y1 = coords_pred[2].int().item(), coords_pred[3].int().item()
            img = cv2.rectangle(img, (x0, y0), (x1, y1), (1, 1, 1), 2)
            
            
            results.append(torch.tensor(img).unsqueeze(0))
        grid = torchvision.utils.make_grid(results, 2)
        self.logger.experiment.add_image(f"{name} Prediction vs Label", grid, self.global_step)

            
    
    def configure_optimizers(self):
        #Caution! You always need to return a list here (just pack your optimizer into one :))
        return [self.optimizer]



In [12]:
# Create the model object
model = CardiacDetectionModel()  # Instanciate the model



In [None]:
# # Testing

# random_input = torch.randn(1, 1, 224, 224)
# print(random_input.shape)
# output = model(random_input)
# output.shape
# # assert output.shape == torch.Size([1, 1, 224, 224])


In [13]:
# Create the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='Val Loss',
    dirpath='./weights',
    save_top_k=10,
    mode='min')

Train for at least 50 epochs to get a decent result.
100 epochs lead to great results.

You can train this on a CPU!

In [14]:
# Create the trainer
# Change the gpus parameter to the number of available gpus in your computer. Use 0 for CPU training

# gpus = 1 #TODO
trainer = pl.Trainer( accelerator='auto', logger=TensorBoardLogger("./logs"), log_every_n_steps=1,
                     callbacks=checkpoint_callback, max_epochs=5)



GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [17]:
# Train the detection model
trainer.fit(model, train_loader, val_loader)

/opt/anaconda3/envs/pytorchenvAImed/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/mgabr001/Documents/UDEMY/AI-IN-MEDICAL-MATERIALS/05-Detection/weights exists and is not empty.

  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | model   | ResNet  | 11.2 M | train
1 | loss_fn | MSELoss | 0      | train
--------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.689    Total estimated model params size (MB)
69        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/opt/anaconda3/envs/pytorchenvAImed/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
  img = torch.tensor(img)
/opt/anaconda3/envs/pytorchenvAImed/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


## Evaluation

In [18]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.load_from_checkpoint("weight.ckpt")
model.eval();
model.to(device)

Compute prediction for all validation samples

In [None]:
preds = []
labels = []

with torch.no_grad():
    for data, label in val_dataset:
        data = data.to(device).float().unsqueeze(0)
        pred = model(data)[0].cpu()
        preds.append(pred)
        labels.append(label)
        
preds=torch.stack(preds)
labels=torch.stack(labels)

Compute mean deviation between prediction and labels for each coordinate

In [None]:
abs(preds-labels).mean(0)

Example prediction:

In [None]:
IDX = 60  # Feel free to inspect all validation samples by changing the index
img, label = val_dataset[IDX]
current_pred = preds[IDX]

fig, axis = plt.subplots(1, 1)
axis.imshow(img[0], cmap="bone")
heart = patches.Rectangle((current_pred[0], current_pred[1]), current_pred[2]-current_pred[0],
                          current_pred[3]-current_pred[1], linewidth=1, edgecolor='r', facecolor='none')
axis.add_patch(heart)

print(label)

Awesome, looks like we got a working heart detection!