In [1]:
# Faster RCNN network to find and classify cells in images with Pytorch


# Imports
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.utils as utils
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import time
import copy
import random
import math
import cv2
import pandas as pd
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from tqdm import tqdm

In [2]:
# read in the csv file as a pandas dataframe
df_train = pd.read_csv("QC/df_train.csv")
df_valid = pd.read_csv("QC/df_valid.csv")
df_test = pd.read_csv("QC/df_test.csv")

In [3]:
# Define the dataset class where the inputs are:
# df: a dataframe that contains the image id, classification and xmax, xmin, ymin, ymax for the bounding box
# image_dir and transforms
class CellDataset(Dataset):
    def __init__(self, df, image_dir, transforms=None):
        super().__init__()
        self.df = df
        self.image_ids = self.df["image_id"].unique()
        self.image_dir = image_dir
        self.transforms = transforms
        
    def __len__(self):
        return self.image_ids.shape[0]
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        records = self.df[self.df["image_id"] == image_id]

        # get the image
        image_name = image_id + ".png"
        image = Image.open(os.path.join(self.image_dir, image_name)).convert("RGB")
        image = transforms.ToTensor()(image)

        # get the bounding box coordinates
        boxes = records[["xmin", "ymin", "xmax", "ymax"]].values
        boxes = torch.as_tensor(boxes, dtype=torch.int64)


        # get the classification
        labels = torch.tensor(records["super_classification"].values, dtype=torch.int64)
        

        # create a target dictionary
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])


        # apply the transformations
        if self.transforms is not None:
            image = self.transforms(image)
            
        return image, target, image_id
    


In [4]:
# Define the Faster RCNN model
model = models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=True, progress=True, pretrained_backbone=True, 
                                trainable_backbone_layers=0, image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225])
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=5)

# Train the model for 5 epochs
num_epochs = 5

# Define the optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# Define the learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

# Define the loss function
loss_func = nn.CrossEntropyLoss()

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Move the model to the device
model.to(device)

# Define the training and validation datasets
train_dataset = CellDataset(df_train, "QC/rgb/", transforms=transforms.Compose([transforms.ToTensor()]))
valid_dataset = CellDataset(df_valid, "QC/rgb/", transforms=transforms.Compose([transforms.ToTensor()]))
test_dataset = CellDataset(df_test, "QC/rgb/", transforms=transforms.Compose([transforms.ToTensor()]))

# Define the training and validation dataloaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, num_workers=4)

# Define the training loop
def train_model(model, loss_func, optimizer, lr_scheduler, num_epochs):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs - 1))
        print("-" * 10)

        # Each epoch has a training and validation phase
        for phase in ["train", "valid"]:
            if phase == "train":
                model.train()
            else:
                model.eval()

            running_loss = 0.0


            print("Phase: {}".format(phase))
            # Iterate over the data
            loader = train_loader if phase == "train" else valid_loader
            for images, targets, image_ids in tqdm(loader):
                print("58")
                images = list(image.to(device) for image in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                print("61")

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == "train"):
                    outputs = model(images, targets)
                    print("69")
                    loss = sum(loss for loss in outputs.values())

                    # backward + optimize only if in training phase
                    if phase == "train":
                        loss.backward()
                        print("75")
                        optimizer.step()

                # statistics
                running_loss += loss.item() * images.size(0)

            if phase == "train":
                lr_scheduler.step()

            epoch_loss = running_loss / len(loader.dataset)

            print("{} Loss: {:.4f}".format(phase, epoch_loss))

            # deep copy the model
            if phase == "valid" and epoch_loss < best_loss:
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
    print("Best val loss: {:4f}".format(best_loss))


    # load best model weights
    model.load_state_dict(best_model_wts)

    return model
                

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "
  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


In [5]:
train_model(model, loss_func, optimizer, lr_scheduler, num_epochs)

  0%|          | 0/436 [00:00<?, ?it/s]

Epoch 0/4
----------
