In [25]:
import torch
import torchvision
from torch.utils.data import DataLoader, Subset
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.datasets import CocoDetection
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from torch.utils.data import Dataset
from pycocotools.coco import COCO
import os
import torchvision.transforms as transforms
from pycocotools.cocoeval import COCOeval
import numpy as np
import json
from tqdm import tqdm
import json
from evaluate_mAP import *
import time
from torch.utils.data.dataloader import default_collate

### Create custom dataset class

In [26]:
class Vegeta(Dataset):
    def __init__(self, root, annotation, transforms= None, subset_size=None):
    
        self.root = root #for storing the root directory path, will require later
        self.annotation = annotation  #variable with location of the annotation file
        self.transforms = transforms #assigning transformations applied to the images
        self.coco = COCO(annotation)  #COCO is an object from pycocotools whcih provides access to different functions to access the dataset annotations in COCO format
        self.ids = list(sorted(self.coco.imgs.keys()))  #self.coco.imgs gives keys are image IDs and the values contain metadata (such as file names, height, width, etc.) about each image in the dataset.

        if subset_size is not None: #only take a subset of data
            self.ids = self.ids[:subset_size] 

    def __getitem__(self, index): #way to access the class as a collection using the index
        coco = self.coco
        img_id = self.ids[index]  #index will take you to the image id, we have sorted the ids above
        ann_ids = coco.getAnnIds(imgIds=img_id) #gets the annotation ids for that image_id, an image_id can have multiple annotations id
        coco_annotation = coco.loadAnns(ann_ids) #returns a list of annotations for that particular annotation_ids for an image_id
       
        path = coco.loadImgs(img_id)[0]['file_name']

        #but here the filename is in format '/qa/2070-full-qa/004ab015-2a61-4b42-8b75-3a5969a190ab_226059276-51.jpg'
        image_filename = path.split('/')[-1]  # Extract filename from path
        image_id = None
        for img_info in coco.dataset['images']:
            if image_filename in img_info['file_name'] : # stop when you find the image
                image_id = img_info['id']
                break
        
        if image_id is None:
            print(f"Warning: Image filename '{image_filename}' not found in annotations. Skipping this image.")
            return None

        try:
            img = Image.open(os.path.join(self.root, image_filename)).convert('RGB')
        except FileNotFoundError:
            print(f"Warning: Image file '{image_filename}' not found. Skipping this image.")
            return None  # Skip this image and continue
  
        """
        coco.loadImgs method loads metadata about the image with the given image_id, [0]
        because the method returns a list of image 
        metadata dictionaries and we consider only first as we can get file path from any of them 
        """
        num_objs = len(coco_annotation) # returns the number of objects that are present in the image, could be 1 or 3 etc
        boxes = []
        labels = []

        for i in range(num_objs): #since we can have multiple objs
            ## annotation Format: [x, y, width, height]
            xmin = coco_annotation[i]['bbox'][0]
            ymin = coco_annotation[i]['bbox'][1]
            xmax = xmin + coco_annotation[i]['bbox'][2]
            ymax = ymin + coco_annotation[i]['bbox'][3]
            boxes.append([xmin, ymin, xmax, ymax])  #add bb coordinates to list
            labels.append(coco_annotation[i]['category_id'])  #add labels to list 

            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            image_id = torch.tensor([img_id])
            
            #value of 0 indicates that the object is not part of a crowd, and if 1, it would indicate that the object is part of a crowd 

            target = {}
            target["boxes"] = boxes
            target["labels"] = labels
            target["image_id"] = image_id

            if self.transforms is not None:
                img = self.transforms(img)
            return img, target

    def __len__(self):  #special function of class to find the length of the dataset
        return len(self.ids)

#### Load the training and validation data (can load a subset as well)

In [None]:
def get_transform(train):
    transform_list = [
        transforms.Resize((600, 600)),
        transforms.ToTensor()
    ]
    # if train:
    #     transform_list.append(transforms.RandomHorizontalFlip(0.5)) # image transformation that randomly flips the image horizontally with a probability of 0.5 (50%)
    return transforms.Compose(transform_list) 


def collate_fn(batch):
    """
     to define how to combine a list of samples (batches) into a single batch. Here we are removing empty items
     from a batch. These items could come from the fact some of the annotations do not have the original image in
     image folder
    """
    # Filter out None values from the batch
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None  # Return None if the entire batch is empty

    # Separate images and targets
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]

    # Stack images into a single tensor
    images = default_collate(images)

    # Return images and targets as a tuple
    return images, targets

In [35]:
train_image_folder_path = './guardrail-damage/images/train'
train_image_annotations_path = './guardrail-damage/annotation_files/train_filtered.json'

val_image_folder_path = './guardrail-damage/images/validation'
val_image_annotations_path = './guardrail-damage/annotation_files/validation.json'

#load the training data
# trainset = Vegeta(root=train_image_folder_path, annotation=train_image_annotations_path, transforms=get_transform(train=True), subset_size=1)
trainset = Vegeta(root=train_image_folder_path, annotation=train_image_annotations_path, transforms=get_transform(train=True))
train_loader = DataLoader(trainset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
# train_loader = DataLoader(trainset, batch_size=4, shuffle=True, collate_fn=collate_fn)



#load the validation data
# validationset = Vegeta(root=val_image_folder_path, annotation=val_image_annotations_path, transforms=get_transform(train=False), subset_size=10)
validationset = Vegeta(root=val_image_folder_path, annotation=val_image_annotations_path, transforms=get_transform(train=False))
# val_loader = DataLoader(validationset, batch_size=4, shuffle=False, collate_fn=collate_fn)
val_loader = DataLoader(validationset, batch_size=4, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

loading annotations into memory...
Done (t=0.20s)
creating index...
index created!
loading annotations into memory...
Done (t=0.39s)
creating index...
index created!


#### Load the pre-trained model

In [28]:
#shifting model to device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print("device", device)

device cuda


In [None]:
def get_faster_rcnn_model_with_class(num_classes, model_checkpoint_path):
    """
    Function to get a model last layer equalling your class
    """
    # Load pre-trained Faster R-CNN
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features 
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    # Move the newly created layer to the same device as the model
    model.roi_heads.box_predictor.to(device)
    # Load trained model weights and map them to the correct device
    model.load_state_dict(torch.load(model_checkpoint_path, map_location=device))

    return model

In [29]:
num_classes=2 #1 class + background
model_checkpoint_path = "./guardrail/models/guardrail_models_guardrails.pth"
model = get_faster_rcnn_model_with_class(num_classes, model_checkpoint_path)  # 1 classes + background
model = model.to(device)

  model.load_state_dict(torch.load(model_checkpoint_path, map_location=device))


#### Training the model and saving every epoch

In [30]:
# Prepare predictions and ground truth for COCO evaluation
def prepare_for_coco_eval(pred_boxes, pred_labels, pred_scores, gt_boxes, gt_labels, image_id):
    # Convert predictions to COCO format
    coco_preds = []
    for box, label, score in zip(pred_boxes, pred_labels, pred_scores):
        xmin, ymin, xmax, ymax = box
        width = xmax - xmin
        height = ymax - ymin
        coco_preds.append({
            "image_id": image_id,
            "category_id": label,
            "bbox": [xmin, ymin, width, height],
            "score": score,
        })

    # Convert ground truth to COCO format
    coco_gts = []
    for box, label in zip(gt_boxes, gt_labels):
        xmin, ymin, xmax, ymax = box
        width = xmax - xmin
        height = ymax - ymin
        coco_gts.append({
            "image_id": image_id,
            "category_id": label,
            "bbox": [xmin, ymin, width, height],
            "iscrowd": 0,  # Assuming no crowd annotations
        })

    return coco_preds, coco_gts


In [31]:
from tqdm import tqdm
import numpy as np


def compute_mAP(model, val_loader, device, coco_gt, confidence_threshold = 0.5, iou_threshold = 0.5, iou_function = 'GIoU'):
    model.eval()
    coco_preds = []

    # Initialize tqdm progress bar
    progress_bar = tqdm(val_loader, desc="Processing Batches", unit="batch", leave=True, colour="green")

    with torch.no_grad():
        # Iterate over all batches in the validation loader
        for batch_idx, (images, targets) in enumerate(progress_bar):
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Perform inference
            predictions = model(images)

            # Process predictions for each image in the batch
            for i, prediction in enumerate(predictions):
                boxes = prediction['boxes'].cpu().numpy()
                scores = prediction['scores'].cpu().numpy()
                labels = prediction['labels'].cpu().numpy()

                # Filter predictions by confidence threshold
                filtered_boxes = boxes[scores > confidence_threshold]
                filtered_labels = labels[scores > confidence_threshold]
                filtered_scores = scores[scores > confidence_threshold]

                # print("filtered_boxes", filtered_boxes)

                # If no predictions are left after filtering, skip this image
                if len(filtered_boxes) == 0:
                    image_id = targets[i]['image_id'].item()
                    print(f"No predictions for image {image_id} after filtering.")
                    continue

                # Find the image ID (assuming it's stored under 'image_id')
                image_id = targets[i]['image_id'].item()


                # Calculate IoU for each predicted box against all ground truth boxes
                for pred_box, label, score in zip(filtered_boxes, filtered_labels, filtered_scores):
                    pred_box = pred_box.tolist()  # Convert numpy array to list
                    coco_preds.append({
                        'image_id': image_id,
                        'category_id': label,  # Category ID is from the prediction
                        'bbox': [pred_box[0], pred_box[1], pred_box[2] - pred_box[0], pred_box[3] - pred_box[1]],  # Convert to COCO format (X, y, w, h)
                        'score': score  # Prediction score
                    })

            # print("======== 1 image processed =====")

            # Update progress bar description with the current batch number
            progress_bar.set_postfix(batch=batch_idx + 1)

    print(f"Total predictions collected: {len(coco_preds)}")

    
    # Only proceed with COCO evaluation if there are predictions
    if coco_preds:
        # Running Custom mAP Function
        mAP = evaluate_custom_mAP(coco_preds, ground_truths, iou_threshold, iou_function)
        print("mAP:", mAP)
        return mAP
    else:
        print("No predictions to evaluate. mAP cannot be computed.")
        return 0.0

In [32]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005) #model hyperparameters and loss updation
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

num_epochs = 5
coco_val_gt = COCO(val_image_annotations_path)
best_mAP = 0.0
confidence_threshold = 0.5
iou_threshold = 0.5
iou_function = 'GIoU'

with open(val_image_annotations_path, "r") as f:
    ground_truths = json.load(f)
    ground_truths = ground_truths['annotations']

loading annotations into memory...
Done (t=0.37s)
creating index...
index created!


In [33]:
# for epoch in range(0, num_epochs):
#     # Start timing for training
#     epoch_training_start_time = time.time()
#     model.train() #setting model to training mode
#     train_loss = 0  # Initialize epoch-level loss
#     batch_losses = []  # List to store batch losses for logging
    
#     for batch_idx, (images, targets) in enumerate(train_loader):  # Iterate over data
#         # Skip if the batch is empty (due to missing images)
#         if images is None or targets is None:
#             print(f"Skipping empty batch {batch_idx + 1}")
#             continue
#         # print("before", device)
#     # for images, targets in train_loader: #iterate over data
#         images = list(image.to(device) for image in images) #moving all the images to device
#         targets = [{key: value.to(device) for key, value in t.items()} for t in targets] #target will have labels and BB, used dict comprehension
#         # print("after", device)
        
#         #loss dictionary
#         optimizer.zero_grad() #makes gradient zero for next iteration
#         loss_dict = model(images, targets) #gets the loss 
#         losses = sum(loss for loss in loss_dict.values()) #adds the loss
        
#         losses.backward() #backprop
#         optimizer.step() #

#         train_loss += losses.item()

#         batch_loss = losses.item()
#         batch_losses.append(batch_loss)

#         # Print batch loss
#         print(f"Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Batch Loss: {batch_loss:.4f}")

#     # Calculate average epoch loss
#     avg_train_loss = train_loss / len(train_loader)

#     # End timing for training
#     epoch_training_end_time = time.time()
#     epoch_training_time = epoch_training_end_time - epoch_training_start_time
#     total_training_time += epoch_training_time

#     print(f"Epoch {epoch + 1}/{num_epochs} - Average Train Loss: {avg_train_loss:.4f}")
#     print(f"Epoch {epoch + 1}/{num_epochs} - Training Time: {epoch_training_time:.2f} seconds")

#     #update learning rate 
#     lr_scheduler.step()
    
#     # validation file 

#     # Start timing for validation
#     epoch_validation_start_time = time.time()
    
#     mAP = compute_mAP(model, val_loader, device, ground_truths, confidence_threshold, iou_threshold , iou_function)
#     # print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss/len(train_loader):.4f}, mAP: {mAP:.4f}, mAP_50: {mAP_50:.4f}, mAP_75: {mAP_75:.4f}")
#     # print(f"Epoch {epoch+1}/{num_epochs} -Train Loss: {train_loss/len(train_loader):.4f},  mAP: {mAP:.4f}")

#     # End timing for validation
#     epoch_validation_end_time = time.time()
#     epoch_validation_time = epoch_validation_end_time - epoch_validation_start_time
#     total_validation_time += epoch_validation_time

#     print(f"Epoch {epoch + 1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, mAP: {mAP:.4f}")
#     print(f"Epoch {epoch + 1}/{num_epochs} - Validation Time: {epoch_validation_time:.2f} seconds")
#     # Save the best model based on mAP
#     if mAP > best_mAP:
#         best_mAP = mAP
#         # torch.save(model.state_dict(), 'best_model.pth')
#         # Save the model's state dictionary after every epoch
#     model_path = f"fasterrcnn_resnet50_epoch_{epoch + 1}.pth"
#     torch.save(model.state_dict(), model_path)
#     # print(f"Epoch {epoch+1}/{num_epochs} mAP: {losses.item()}")

#### writing the print statement to the file

In [None]:
import time
import torch

# Initialize total training and validation times
total_training_time = 0
total_validation_time = 0

# Open a file to write logs
with open('training_log_2.txt', 'a') as log_file:
    for epoch in range(0, num_epochs):
        try:
            # Start timing for training
            epoch_training_start_time = time.time()
            model.train()  # Setting model to training mode
            train_loss = 0  # Initialize epoch-level loss
            batch_losses = []  # List to store batch losses for logging

            for batch_idx, (images, targets) in enumerate(train_loader):  # Iterate over data
                # Skip if the batch is empty (due to missing images)
                if images is None or targets is None:
                    log_file.write(f"Skipping empty batch {batch_idx + 1}\n")
                    continue

                images = list(image.to(device) for image in images)  # Moving all the images to device
                targets = [{key: value.to(device) for key, value in t.items()} for t in targets]  # Target will have labels and BB, used dict comprehension

                # Loss dictionary
                optimizer.zero_grad()  # Makes gradient zero for next iteration
                loss_dict = model(images, targets)  # Gets the loss
                losses = sum(loss for loss in loss_dict.values())  # Adds the loss

                losses.backward()  # Backprop
                optimizer.step()  # Step

                train_loss += losses.item()

                batch_loss = losses.item()
                batch_losses.append(batch_loss)

                # Write batch loss to file
                log_file.write(f"Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Batch Loss: {batch_loss:.4f}\n")
                print(f"Epoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Batch Loss: {batch_loss:.4f}")

            # Calculate average epoch loss
            avg_train_loss = train_loss / len(train_loader)

            # End timing for training
            epoch_training_end_time = time.time()
            epoch_training_time = epoch_training_end_time - epoch_training_start_time
            total_training_time += epoch_training_time

            log_file.write(f"Epoch {epoch + 1}/{num_epochs} - Average Train Loss: {avg_train_loss:.4f}\n")
            log_file.write(f"Epoch {epoch + 1}/{num_epochs} - Training Time: {epoch_training_time:.2f} seconds\n")

            # Update learning rate
            lr_scheduler.step()

            # Start timing for validation
            epoch_validation_start_time = time.time()

            mAP = compute_mAP(model, val_loader, device, ground_truths, confidence_threshold, iou_threshold, iou_function)

            # End timing for validation
            epoch_validation_end_time = time.time()
            epoch_validation_time = epoch_validation_end_time - epoch_validation_start_time
            total_validation_time += epoch_validation_time

            log_file.write(f"Epoch {epoch + 1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, mAP: {mAP:.4f}\n")
            log_file.write(f"Epoch {epoch + 1}/{num_epochs} - Validation Time: {epoch_validation_time:.2f} seconds\n")

            # Save the best model based on mAP
            if mAP > best_mAP:
                best_mAP = mAP
                # torch.save(model.state_dict(), 'best_model.pth')
                # Save the model's state dictionary after every epoch
            model_path = f"fasterrcnn_resnet50_epoch_{epoch + 1}.pth"
            torch.save(model.state_dict(), model_path)

            # Optional log for model save
            log_file.write(f"Epoch {epoch + 1}/{num_epochs} - Model saved to {model_path}\n")

        except Exception as e:
            # If any error occurs, log it
            log_file.write(f"Error during epoch {epoch + 1}: {str(e)}\n")
            print(f"Error during epoch {epoch + 1}: {str(e)}")

# Log the total training and validation times
with open('training_log_2.txt', 'a') as log_file:
    log_file.write(f"Total Training Time: {total_training_time:.2f} seconds\n")
    log_file.write(f"Total Validation Time: {total_validation_time:.2f} seconds\n")

Epoch 1/5, Batch 1/7596, Batch Loss: 34.8204
Epoch 1/5, Batch 2/7596, Batch Loss: 42.2539
Epoch 1/5, Batch 3/7596, Batch Loss: 30.5249
Epoch 1/5, Batch 4/7596, Batch Loss: 26.7810
Epoch 1/5, Batch 5/7596, Batch Loss: 16.0477
Epoch 1/5, Batch 6/7596, Batch Loss: 29.0471
Epoch 1/5, Batch 7/7596, Batch Loss: 13.0665
Epoch 1/5, Batch 8/7596, Batch Loss: 13.8179
Epoch 1/5, Batch 9/7596, Batch Loss: 12.5253
Epoch 1/5, Batch 10/7596, Batch Loss: 14.2276
Epoch 1/5, Batch 11/7596, Batch Loss: 24.4384
Epoch 1/5, Batch 12/7596, Batch Loss: 13.1253
Epoch 1/5, Batch 13/7596, Batch Loss: 21.2757
Epoch 1/5, Batch 14/7596, Batch Loss: 18.7908
Epoch 1/5, Batch 15/7596, Batch Loss: 5.2461
Epoch 1/5, Batch 16/7596, Batch Loss: 10.0148
Epoch 1/5, Batch 17/7596, Batch Loss: 12.6599
Epoch 1/5, Batch 18/7596, Batch Loss: 6.8914
Epoch 1/5, Batch 19/7596, Batch Loss: 8.1478
Epoch 1/5, Batch 20/7596, Batch Loss: 13.8085
Epoch 1/5, Batch 21/7596, Batch Loss: 9.7520
Epoch 1/5, Batch 22/7596, Batch Loss: 11.3766
E

## Training finished now only inferencing

#### Loading the saved model and setting it for Inference

In [None]:
model = get_faster_rcnn_model_with_class(num_classes)  # 1 classes + background
model.load_state_dict(torch.load('best_model.pth'))
model.to(device)
model.eval() 

#### visualize the results

In [None]:
def visualize_predictions(image, boxes, labels, scores, threshold=0.5, size = (12, 8)):
    fig, ax = plt.subplots(1, figsize= size) #initialize the axis
    ax.imshow(image)

    for box, label, score in zip(boxes, labels, scores):
        if score > threshold: #score check
            x, y, w, h = box
            rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
            ax.text(x, y, f'{label}: {score:.2f}', color='white', backgroundcolor='red', fontsize=8)

    plt.axis('off')
    plt.title('Predicted bounding boxes', fontsize = 10)
    plt.show()
    plt.close()

#### If you want to run the loop on all the validation images

In [None]:
# Visualize predictions on a few validation images

fig_size = (10,10) 
threshold_value = 0.1

for images, targets in val_loader:
    images = list(image.to(device) for image in images) #send to device
    with torch.no_grad():
        predictions = model(images)

    for i, (image, prediction) in enumerate(zip(images, predictions)):
#         image = np.transpose(image.cpu().numpy(), (1,2,0))
        image = image.cpu().permute(1, 2, 0).numpy()  # Convert to X,Y,Z format as image is stored as channel, X, Y
        boxes = prediction['boxes'].cpu().numpy()
        labels = prediction['labels'].cpu().numpy()
        scores = prediction['scores'].cpu().numpy()
        visualize_predictions(image, boxes, labels, scores, threshold=threshold_value, size = fig_size)

## Validation on 1 image at a time


In [None]:
# Plot the image with bounding boxes
def plot_image_with_pred_boxes(image, boxes, labels):
    fig, ax = plt.subplots(1, figsize= (10,10)) #no of rows, figure size
    ax.imshow(image)
    
    for box, label in zip(boxes, labels):
        xmin, ymin, xmax, ymax = box
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=1, edgecolor='w', facecolor='none')
        ax.add_patch(rect)
        ax.text(xmin, ymin, f'Class: {category_id[label]}', color='white', fontsize=12)
    
    plt.title('Image with predicted bounding boxes', fontsize= 10)
    plt.axis('off')
    plt.show()
    plt.close()

In [None]:
# Define transforms for the image
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load the image
image_path = 'data/test/q5khrnrwimu71.jpg'
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension and send to device
category_id = ('Background', 'Vegeta')
confidence_threshold = 0.002

# Perform inference
with torch.no_grad():
    prediction = model(image_tensor)

# Extract the predictions
prediction = prediction[0]  # Remove batch dimension
boxes = prediction['boxes'].cpu().numpy()
scores = prediction['scores'].cpu().numpy()
labels = prediction['labels'].cpu().numpy()

# Filter predictions by confidence threshold

filtered_boxes = boxes[scores > confidence_threshold]
filtered_labels = labels[scores > confidence_threshold]

# Visualize the results
plot_image_with_pred_boxes(image, filtered_boxes, filtered_labels)

#### if you want to get both predicted and original bounding box

In [None]:
# Plot the image with predicted and ground truth bounding boxes
def plot_image_with_gt_and_pred_boxes(image, pred_boxes, pred_labels, gt_boxes, gt_labels):
    fig, ax = plt.subplots(1)
    ax.imshow(image)
    
    # Draw predicted boxes (in red)
    for box, label in zip(pred_boxes, pred_labels):
        xmin, ymin, xmax, ymax = box
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.text(xmin, ymin, f'Pred: {label}', color='r', fontsize=10, backgroundcolor='white')
    
    # Draw ground truth boxes (in green)
    for box, label in zip(gt_boxes, gt_labels):
        xmin, ymin, xmax, ymax = box
        rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, linewidth=1, edgecolor='g', facecolor='none')
        ax.add_patch(rect)
        ax.text(xmin, ymin, f'GT: {label}', color='g', fontsize=10, backgroundcolor='white')
    
    plt.title('Image with predicted and ground truth bounding boxes', fontsize= 10)
    plt.axis('off')
    plt.show()
    plt.close()

In [None]:

from pycocotools.coco import COCO
import matplotlib.pyplot as plt

# Define transforms for the image
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load the image
image_path = 'data/val/majin-vegeta-and-other-dragon-ball-characters-tajyteiulwizv8zz.jpg'
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimension and send to device

annotation_path = 'data/val_annot/val_an_coco.json'
coco = COCO(annotation_path)

# Find the image ID by matching the filename
image_filename = image_path.split('/')[-1]  # Extract filename from path
image_id = None
for img_info in coco.dataset['images']:
    if img_info['file_name'] == image_filename: # stop when you find the image
        image_id = img_info['id']
        break

if image_id is None:
    raise ValueError(f"Image filename '{image_filename}' not found in annotations.")

# Extract ground truth boxes and labels
ann_ids = coco.getAnnIds(imgIds=image_id)
gt_annotations = coco.loadAnns(ann_ids) #load the annotations

# Extract ground truth boxes and labels
gt_boxes = []
gt_labels = []
for ann in gt_annotations:
    xmin, ymin, width, height = ann['bbox']
    xmax = xmin + width
    ymax = ymin + height
    gt_boxes.append([xmin, ymin, xmax, ymax])
    gt_labels.append(ann['category_id'])
    

# Visualize the results
plot_image_with_gt_and_pred_boxes(image, filtered_boxes, filtered_labels, gt_boxes, gt_labels)