# Region-based convolutional neural networks (R-CNN)

### Overview

Fine-tune a PyTorch R-CNN model to extract text from natural scenes (see figure below).   

Prepare:
  1. validation figures that show machine-predicted bounding boxes vs. the ground-truth boxes
  1. loss-curves that show component-wise R-CNN losses as a function of epoch

![Text regions](attachment:4_images_with_boxes_and_text-2.png)

### Dataset

Train your model using the **Text** MS-COCO dataset.  Text MS-COCO is a subset of the original MS-COCO dataset (2014) with additional text annotations.  Read more: [COCO Text](https://bgshih.github.io/cocotext/)

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Helpers

In [2]:
#@title Imports
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch
from torch.optim import SGD, Adam
import os
import cv2
import numpy as np
from PIL import Image
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
import torch
import shutil
import time
from torch.cuda.amp import autocast, GradScaler

In [3]:
#@title COCO_Text class: Handles the annotation file
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

__author__ = 'andreasveit'
__version__ = '1.1'
# Interface for accessing the COCO-Text dataset.

# COCO-Text is a large dataset designed for text detection and recognition.
# This is a Python API that assists in loading, parsing and visualizing the
# annotations. The format of the COCO-Text annotations is also described on
# the project website http://vision.cornell.edu/se3/coco-text/. In addition to this API, please download both
# the COCO images and annotations.
# This dataset is based on Microsoft COCO. Please visit http://mscoco.org/
# for more information on COCO, including for the image data, object annotatins
# and caption annotations.

# An alternative to using the API is to load the annotations directly
# into Python dictionary:
# with open(annotation_filename) as json_file:
#     coco_text = json.load(json_file)
# Using the API provides additional utility functions.

# The following API functions are defined:
#  COCO_Text  - COCO-Text api class that loads COCO annotations and prepare data structures.
#  getAnnIds  - Get ann ids that satisfy given filter conditions.
#  getImgIds  - Get img ids that satisfy given filter conditions.
#  loadAnns   - Load anns with the specified ids.
#  loadImgs   - Load imgs with the specified ids.
#  showAnns   - Display the specified annotations.
#  loadRes    - Load algorithm results and create API for accessing them.
# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.

# COCO-Text Toolbox.        Version 1.1
# Data and  paper available at:  http://vision.cornell.edu/se3/coco-text/
# Code based on Microsoft COCO Toolbox Version 1.0 by Piotr Dollar and Tsung-Yi Lin
# extended and adapted by Andreas Veit, 2016.
# Licensed under the Simplified BSD License [see bsd.txt]

import json
import datetime
import matplotlib.pyplot as plt
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle, PathPatch
from matplotlib.path import Path
import numpy as np
import copy
import os

class COCO_Text:
    def __init__(self, annotation_file=None):
        """
        Constructor of COCO-Text helper class for reading and visualizing annotations.
        :param annotation_file (str): location of annotation file
        :return:
        """
        # load dataset
        self.dataset = {}
        self.anns = {}
        self.imgToAnns = {}
        self.catToImgs = {}
        self.imgs = {}
        self.cats = {}
        self.val = []
        self.test = []
        self.train = []
        if not annotation_file == None:
            assert os.path.isfile(annotation_file), "file does not exist"
            print('loading annotations into memory...')
            time_t = datetime.datetime.utcnow()
            dataset = json.load(open(annotation_file, 'r'))
            print(datetime.datetime.utcnow() - time_t)
            self.dataset = dataset
            self.createIndex()

    def createIndex(self):
        # create index
        print('creating index...')
        self.imgToAnns = {int(cocoid): self.dataset['imgToAnns'][cocoid] for cocoid in self.dataset['imgToAnns']}
        self.imgs      = {int(cocoid): self.dataset['imgs'][cocoid] for cocoid in self.dataset['imgs']}
        self.anns      = {int(annid): self.dataset['anns'][annid] for annid in self.dataset['anns']}
        self.cats      = self.dataset['cats']
        self.val       = [int(cocoid) for cocoid in self.dataset['imgs'] if self.dataset['imgs'][cocoid]['set'] == 'val']
        self.test      = [int(cocoid) for cocoid in self.dataset['imgs'] if self.dataset['imgs'][cocoid]['set'] == 'test']
        self.train     = [int(cocoid) for cocoid in self.dataset['imgs'] if self.dataset['imgs'][cocoid]['set'] == 'train']
        print('index created!')

    def info(self):
        """
        Print information about the annotation file.
        :return:
        """
        for key, value in self.dataset['info'].items():
            print('%s: %s'%(key, value))

    def filtering(self, filterDict, criteria):
        return [key for key in filterDict if all(criterion(filterDict[key]) for criterion in criteria)]

    def getAnnByCat(self, properties):
        """
        Get ann ids that satisfy given properties
        :param properties (list of tuples of the form [(category type, category)] e.g., [('readability','readable')]
            : get anns for given categories - anns have to satisfy all given property tuples
        :return: ids (int array)       : integer array of ann ids
        """
        return self.filtering(self.anns, [lambda d, x=a, y=b:d[x] == y for (a,b) in properties])

    def getAnnIds(self, imgIds=[], catIds=[], areaRng=[]):
        """
        Get ann ids that satisfy given filter conditions. default skips that filter
        :param imgIds  (int array)     : get anns for given imgs
               catIds  (list of tuples of the form [(category type, category)] e.g., [('readability','readable')]
                : get anns for given cats
               areaRng (float array)   : get anns for given area range (e.g. [0 inf])
        :return: ids (int array)       : integer array of ann ids
        """
        imgIds = imgIds if type(imgIds) == list else [imgIds]
        catIds = catIds if type(catIds) == list else [catIds]

        if len(imgIds) == len(catIds) == len(areaRng) == 0:
            anns = list(self.anns.keys())
        else:
            if not len(imgIds) == 0:
                anns = sum([self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns],[])
            else:
                anns = list(self.anns.keys())
            anns = anns if len(catIds)  == 0 else list(set(anns).intersection(set(self.getAnnByCat(catIds))))
            anns = anns if len(areaRng) == 0 else [ann for ann in anns if self.anns[ann]['area'] > areaRng[0] and self.anns[ann]['area'] < areaRng[1]]
        return anns

    def getImgIds(self, imgIds=[], catIds=[]):
        '''
        Get img ids that satisfy given filter conditions.
        :param imgIds (int array) : get imgs for given ids
        :param catIds (int array) : get imgs with all given cats
        :return: ids (int array)  : integer array of img ids
        '''
        imgIds = imgIds if type(imgIds) == list else [imgIds]
        catIds = catIds if type(catIds) == list else [catIds]

        if len(imgIds) == len(catIds) == 0:
            ids = list(self.imgs.keys())
        else:
            ids = set(imgIds)
            if not len(catIds) == 0:
                ids  = ids.intersection(set([self.anns[annid]['image_id'] for annid in self.getAnnByCat(catIds)]))
        return list(ids)

    def loadAnns(self, ids=[]):
        """
        Load anns with the specified ids.
        :param ids (int array)       : integer ids specifying anns
        :return: anns (object array) : loaded ann objects
        """
        if type(ids) == list:
            return [self.anns[id] for id in ids]
        elif type(ids) == int:
            return [self.anns[ids]]

    def loadImgs(self, ids=[]):
        """
        Load anns with the specified ids.
        :param ids (int array)       : integer ids specifying img
        :return: imgs (object array) : loaded img objects
        """
        if type(ids) == list:
            return [self.imgs[id] for id in ids]
        elif type(ids) == int:
            return [self.imgs[ids]]

    def showAnns(self, anns, show_polygon=False):
        """
        Display the specified annotations.
        :param anns (array of object): annotations to display
        :return: None
        """
        if len(anns) == 0:
            return 0
        ax = plt.gca()
        boxes = []
        color = []
        for ann in anns:
            c = np.random.random((1, 3)).tolist()[0]
            if show_polygon:
                tl_x, tl_y, tr_x, tr_y, br_x, br_y, bl_x, bl_y = ann['polygon']
                verts = [(tl_x, tl_y), (tr_x, tr_y), (br_x, br_y), (bl_x, bl_y), (0, 0)]
                codes = [Path.MOVETO, Path.LINETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY]
                path = Path(verts, codes)
                patch = PathPatch(path, facecolor='none')
                boxes.append(patch)
                left, top = tl_x, tl_y
            else:
                left, top, width, height = ann['bbox']
                boxes.append(Rectangle([left,top],width,height,alpha=0.4))
            color.append(c)
            if 'utf8_string' in list(ann.keys()):
                ax.annotate(ann['utf8_string'],(left,top-4),color=c)
        p = PatchCollection(boxes, facecolors=color, edgecolors=(0,0,0,1), linewidths=3, alpha=0.4)
        ax.add_collection(p)

    def loadRes(self, resFile):
        """
        Load result file and return a result api object.
        :param   resFile (str)     : file name of result file
        :return: res (obj)         : result api object
        """
        res = COCO_Text()
        res.dataset['imgs'] = [img for img in self.dataset['imgs']]

        print('Loading and preparing results...     ')
        time_t = datetime.datetime.utcnow()
        if type(resFile) == str:
            anns = json.load(open(resFile))
        else:
            anns = resFile
        assert type(anns) == list, 'results in not an array of objects'
        annsImgIds = [int(ann['image_id']) for ann in anns]

        if set(annsImgIds) != (set(annsImgIds) & set(self.getImgIds())):
            print('Results do not correspond to current coco set')
            print('skipping ', str(len(set(annsImgIds)) - len(set(annsImgIds) & set(self.getImgIds()))), ' images')
        annsImgIds = list(set(annsImgIds) & set(self.getImgIds()))

        res.imgToAnns = {cocoid : [] for cocoid in annsImgIds}
        res.imgs = {cocoid: self.imgs[cocoid] for cocoid in annsImgIds}

        assert anns[0]['bbox'] != [], 'results have incorrect format'
        for id, ann in enumerate(anns):
            if ann['image_id'] not in annsImgIds:
                continue
            bb = ann['bbox']
            ann['area'] = bb[2]*bb[3]
            ann['id'] = id
            res.anns[id] = ann
            res.imgToAnns[ann['image_id']].append(id)
        print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()))

        return res

### CocoDataset class

In [4]:
#@title Dataset class: use images and annotation file to implement a dataset that can used with Dataloader

class CocoDataset(Dataset):
    def __init__(self, root_dir, annFile, transform=None, cuda=True):
        self.root_dir = root_dir
        self.transform = transform
        self.imgs = os.listdir(root_dir)
        # annotations
        self.ct = COCO_Text(annFile)
        self.imgIds = self.ct.getImgIds(imgIds=self.ct.train,
                    catIds=[('legibility','legible'),('class','machine printed')])

        for imgId in self.imgIds:
            file_name = self.ct.loadImgs(imgId)[0]['file_name']
            if file_name not in self.imgs:
                self.imgIds.remove(imgId)
        # manual exclude
        self.imgIds.remove(275939)
        self.imgIds.remove(443671)

        # remaining images
        print(f"remaining images in ann file: {len(self.imgIds)}, remaining images in folder: {len(self.imgs)}")

        self.imgIds.sort()
        # sort the images in same order as the annotations
        self.imgs = [self.ct.loadImgs(imgId)[0]['file_name'] for imgId in self.imgIds]

        self.img_h = 224
        self.img_w = 224
        self.cuda = cuda


    def __len__(self):
        return len(self.imgs)


    def __getitem__(self, idx):
        # 1. Get the image ID
        img_id = self.imgIds[idx]

        # 2. Load image file
        img_info = self.ct.loadImgs(img_id)[0]
        img_path = os.path.join(self.root_dir, img_info['file_name'])
        image = Image.open(img_path).convert("RGB")  # Ensure image is in RGB format

        # Get the original image size before any resizing
        original_width, original_height = image.size

        # 3. Get annotation IDs for the image
        ann_ids = self.ct.getAnnIds(imgIds=[img_id])
        anns = self.ct.loadAnns(ann_ids)

        # 4. Initialize lists to hold bounding boxes and labels
        boxes = []
        labels = []

        # 5. Process each annotation
        for ann in anns:
            # Get bbox (in COCO format: [x_min, y_min, width, height])
            bbox = ann['bbox']
            xmin, ymin, width, height = bbox
            xmax = xmin + width
            ymax = ymin + height

            # Append the bbox in [xmin, ymin, xmax, ymax] format
            boxes.append([xmin, ymin, xmax, ymax])

            # For text detection, all annotations are labeled as "text" (class 1)
            label = 1  # 1 represents "text" class
            labels.append(label)

        # 6. Convert to torch tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)  # Convert labels to integer tensor

        # 7. Create the target dictionary
        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([img_id])
        }

        # 8. Apply transformations if any (e.g., data augmentation, resizing)
        if self.transform:
            # Apply the transformations to the image
            image = self.transform(image)

            # If resizing, we need to adjust bounding boxes based on new image size
            # After transformation, image is now in tensor format with [C, H, W]
            _, new_height, new_width = image.shape

            # Calculate scaling factors based on new size
            scale_x = new_width / original_width
            scale_y = new_height / original_height

            # Adjust the bounding boxes based on scaling
            boxes = target['boxes']
            boxes[:, [0, 2]] *= scale_x  # Scale xmin and xmax
            boxes[:, [1, 3]] *= scale_y  # Scale ymin and ymax

            # Update the target dictionary with scaled bounding boxes
            target['boxes'] = boxes

        return image, target



# coalate_fn is used to collate the data into batches
def collate_fn(batch):
    images = []
    targets = []
    for item in batch:
        images.append(item[0])
        targets.append(item[1])
    images = torch.stack(images, 0)
    return images, targets

# Evaluate / region visualization

Load four images and predict regions using model.
Create a figure that shows the images as subplots.
Overlay model predicted regions (blue) and target regions (red).
Save the output as an image -- use epoch in filename to preserve order.

In [5]:
# def evaluate(model, dataloader, device, epoch):
#     # TODO: create four-panel subplots showing bounding boxes

#     raise NotImplementedError("evaluate()")

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch

def denormalize(image, mean, std):
    """Denormalize an image tensor that has been normalized using mean and std."""
    image = image.clone().cpu()
    for t, m, s in zip(image, mean, std):
        t.mul_(s).add_(m)  # Reverse the normalization: x' = (x * std) + mean
    return image

def evaluate(model, dataloader, device, epoch):
    """
    Evaluates the model on four images and visualizes the predicted and target bounding boxes.

    Args:
        model: The Faster R-CNN model.
        dataloader: The DataLoader for the dataset.
        device: The device to run the model on (e.g., 'cuda' or 'cpu').
        epoch: The current epoch, used to save output with appropriate filename.
    """

    # Set model to evaluation mode
    model.eval()

    # Select four images from the dataloader
    images, targets = next(iter(dataloader))

    # Move images and model to the specified device
    images = [img.to(device) for img in images]
    model.to(device)

    # Perform inference using the pre-trained model (no gradients required)
    with torch.no_grad():
        predictions = model(images)

    # Mean and std used for normalization (ImageNet values)
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    # Create a figure for the 4 subplots
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))

    for i in range(4):
        # Denormalize the image for visualization
        img = denormalize(images[i], mean, std).permute(1, 2, 0).numpy()  # [C, H, W] -> [H, W, C]
        img = (img * 255).astype('uint8')  # Convert to uint8 for proper display

        # Display image
        axes[i].imshow(img)

        # Plot target (ground truth) boxes in red - these boxes should NOT be normalized
        for box in targets[i]['boxes']:
            xmin, ymin, xmax, ymax = box.cpu().numpy()
            rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                     linewidth=2, edgecolor='r', facecolor='none')
            axes[i].add_patch(rect)

        # Plot predicted boxes in blue - these are also based on the original pixel values
        for box in predictions[i]['boxes']:
            xmin, ymin, xmax, ymax = box.cpu().numpy()
            rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                     linewidth=2, edgecolor='b', facecolor='none')
            axes[i].add_patch(rect)

        # Remove axis labels for cleaner visualization
        axes[i].axis('off')

    # Save the figure as an image
    plt.tight_layout()
    plt.savefig(f'/content/drive/MyDrive/track_image/evaluation_epoch_{epoch}.png')
    plt.show()



# Model factory

[Example](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)

In [6]:
# def get_model():
#     # TODO: create and return a PyTorch model

#     raise NotImplementedError("get_model()")

def get_model(num_classes=2):
    """
    Load a pretrained Faster R-CNN model and modify it for the COCO-Text dataset.

    Args:
        num_classes (int): The number of classes (including background) in your dataset.

    Returns:
        model: A modified Faster R-CNN model ready for training on your dataset.
    """

    # Load the pre-trained Faster R-CNN model on COCO dataset
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    # Get the number of input features for the classifier (this is the final layer of Faster R-CNN)
    in_features = model.roi_heads.box_predictor.cls_score.in_features

    # Replace the pre-trained head with a new one (num_classes includes the background class)
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

# Setup

In [7]:
# transform = transforms.Compose([
#   # TODO: e.g. transforms.Normalize()
# ])

# # TODO: dataset = CocoDataset(..., transform=transform)
# dataset = None
# # TODO: dataloader = DataLoader(..., collate_fn=collate_fn)
# dataloader = None

# model = get_model()
# # revise as needed
# device = torch.device('cpu')
# model.to(device)

root_dir = "/content/drive/MyDrive/EE641_Dataset/mscoco_text_cleaned_v01/data/train"
annFile = "/content/drive/MyDrive/EE641_Dataset/COCO_Text.json"

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize all images to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = CocoDataset(root_dir, annFile, transform=transform, cuda=True)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8, collate_fn=collate_fn)
dataloader_evaluate = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

model = get_model(num_classes=2)  # For text detection (2 classes: background + text)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model.to(device)

loading annotations into memory...
0:00:03.538181
creating index...
index created!
remaining images in ann file: 8738, remaining images in folder: 8738


Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:02<00:00, 78.9MB/s]


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

# Training

In [11]:
# # TODO: complete optimizer (optional, learning rate scheduler)
# # optimizer =

# num_epochs = 100

# for epoch in range(num_epochs):
#     for idx, data in enumerate(dataloader):
#         model.train()

#         # TODO: implement training loop: inference, backpropagate, update

#         # https://pytorch.org/vision/stable/models/faster_rcnn.html
#         #
#         # See: https://github.com/pytorch/vision/blob/main/torchvision/models/detection/faster_rcnn.py
#         #   class FasterRCNN

#         # TODO: track componentwise (4) RCNN loss vs. epoch.


#     # TODO: checkpoint model weights, optimizer state, and other as needed to resume
#     # TODO: plot each of the 4 RCNN loss components vs epoch-number

#     # VALIDATE
#     model.eval()
#     evaluate(model, dataloader, device, epoch)

# Define optimizer and optional learning rate scheduler
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

# Optionally, define a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

# Define number of epochs
num_epochs = 100

scaler = torch.amp.GradScaler('cuda')

# Initialize lists to store loss components for plotting later
loss_classifier_list = []
loss_box_reg_list = []
loss_objectness_list = []
loss_rpn_box_reg_list = []

# # Stepwise unfreezing of RPN and Backbone
# def freeze_rpn_backbone(model):
#     """ Freeze RPN and Backbone parameters. """
#     for param in model.backbone.parameters():
#         param.requires_grad = False
#     for param in model.rpn.parameters():
#         param.requires_grad = False

# def unfreeze_rpn(model):
#     """ Unfreeze only RPN, keep Backbone frozen. """
#     for param in model.rpn.parameters():
#         param.requires_grad = True

# def unfreeze_backbone(model):
#     """ Unfreeze Backbone. """
#     for param in model.backbone.parameters():
#         param.requires_grad = True

# # Start with RPN and Backbone frozen
# freeze_rpn_backbone(model)

# Training loop
for epoch in range(1, num_epochs + 1):
    # # Unfreeze RPN after 5 epochs
    # if epoch == 2:
    #     print(f"Unfreezing RPN at epoch {epoch}")
    #     unfreeze_rpn(model)

    # # Unfreeze Backbone after 10 epochs
    # if epoch == 3:
    #     print(f"Unfreezing Backbone at epoch {epoch}")
    #     unfreeze_backbone(model)

    model.train()  # Set model to training mode
    epoch_loss = {'loss_classifier': 0, 'loss_box_reg': 0, 'loss_objectness': 0, 'loss_rpn_box_reg': 0}

    # Training loop over batches
    for idx, (images, targets) in enumerate(dataloader):
        # Move images and targets to device (CPU or GPU)
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Zero the gradients before the backward pass
        optimizer.zero_grad()

        with autocast():
        # Forward pass (inference)
          loss_dict = model(images, targets)

          # Compute total loss by summing the individual loss components
          total_loss = sum(loss for loss in loss_dict.values())

        # Backward pass (backpropagation) with mixed precision scaling
        scaler.scale(total_loss).backward()

        # Update model weights
        scaler.step(optimizer)
        scaler.update()

        # Track losses for each component
        for key in epoch_loss:
            epoch_loss[key] += loss_dict[key].item()

        # Every 500 batches, print losses
        if idx % 50 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Batch [{idx}/{len(dataloader)}]")
            print(f"Total Loss: {total_loss.item():.4f}, "
                  f"Classifier Loss: {loss_dict['loss_classifier'].item():.4f}, "
                  f"Box Reg Loss: {loss_dict['loss_box_reg'].item():.4f}, "
                  f"Objectness Loss: {loss_dict['loss_objectness'].item():.4f}, "
                  f"RPN Box Reg Loss: {loss_dict['loss_rpn_box_reg'].item():.4f}")

    # Average the loss components for the epoch
    num_batches = len(dataloader)
    for key in epoch_loss:
        epoch_loss[key] /= num_batches

    # Store loss values for plotting
    loss_classifier_list.append(epoch_loss['loss_classifier'])
    loss_box_reg_list.append(epoch_loss['loss_box_reg'])
    loss_objectness_list.append(epoch_loss['loss_objectness'])
    loss_rpn_box_reg_list.append(epoch_loss['loss_rpn_box_reg'])

    # Step the learning rate scheduler if defined
    if lr_scheduler:
        lr_scheduler.step()

    if epoch % 10 == 0:
        # Checkpoint model weights, optimizer state, and epoch
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'scaler_state_dict': scaler.state_dict()
        }
        torch.save(checkpoint, f"/content/drive/MyDrive/check_point/checkpoint_epoch_{epoch}.pth")

    # VALIDATE (run evaluation)
    model.eval()
    evaluate(model, dataloader_evaluate, device, epoch)

# Plot the 4 RCNN loss components over epochs
epochs_range = range(1, num_epochs + 1)
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, loss_classifier_list, label='Classifier Loss')
plt.plot(epochs_range, loss_box_reg_list, label='Box Regression Loss')
plt.plot(epochs_range, loss_objectness_list, label='Objectness Loss')
plt.plot(epochs_range, loss_rpn_box_reg_list, label='RPN Box Regression Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("RCNN Loss Components Over Epochs")
plt.show()

Output hidden; open in https://colab.research.google.com to view.