In [None]:
import os
import pandas as pd
import torchvision.transforms as T

import torch
import random
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms.functional import to_tensor
from torchvision.io import read_image
from tqdm import tqdm
from torch.utils.data import Dataset

from torch import nn as nn
from torch.nn import functional as Fn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

import json
from PIL import ImageDraw, ImageFont

from tqdm import tqdm
from datetime import datetime

import wandb

In [None]:
img_width, img_height = 576, 576

S = 9       # Divide each image into a SxS grid
B = 2       # Number of bounding boxes to predict

In [None]:
def normalize(data):
    data = TF.normalize(data, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    return data

def augment(data):
    if isinstance(data, torch.Tensor):
        # Convert to PIL image for augmentations
        data = TF.to_pil_image(data)

    # Get image size for translation bounds (20% of image size)
    width, height = data.size  # PIL image
    max_translation = 0.2  # 20% of the image size
    x_shift = max_translation * width * (2 * random.random() - 1)  # Random shift between -20% to 20%
    y_shift = max_translation * height * (2 * random.random() - 1)  # Random shift between -20% to 20%

    # Random scaling between 80% to 120%
    scale = 1.0 + 0.2 * (2 * random.random() - 1)  # Random scale between 0.8 to 1.2

    # Perform affine transformation with random translation and scaling
    data = TF.affine(data, angle=0.0, scale=scale, translate=(x_shift, y_shift), shear=0.0)
    
    # Random brightness adjustment (exposure)
    data = TF.adjust_brightness(data, 1.0 + 0.5 * (2 * random.random() - 1))  # Factor between 0.5 and 1.5

    # Random saturation adjustment in HSV space
    data = TF.adjust_saturation(data, 1.0 + 0.5 * (2 * random.random() - 1))  # Factor between 0.5 and 1.5

    # Convert back to tensor after augmentations
    data = TF.to_tensor(data)
    
    return data


def read_image(img_path):
    image = Image.open(img_path).convert("RGB")  # Convert to RGB format
    image = to_tensor(image)  # Convert to tensor
    return image


class GditDataset(Dataset):
    def __init__(self, set_type, normalize=None, augment=None, target_transform=None):
        self.set_type = set_type
        self.normalize = normalize
        self.augment = augment
        self.target_transform = target_transform
        # Define directories for images and labels based on set type
        self.images_dir = f"/kaggle/input/gdit-dataset/Dataset/{set_type}/images"
        self.labels_dir = f"/kaggle/input/gdit-dataset/Dataset/{set_type}/labels"

        # Get all image files
        self.image_files = [f for f in os.listdir(self.images_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx, S=S, B=B):
        # Get image path and read the image
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        image = read_image(img_path)
        resize_transform = T.Resize((576, 576))
        image = resize_transform(image)
        
        # Get the corresponding label file (same filename as the image but with .txt extension)
        label_file = os.path.join(self.labels_dir, self.image_files[idx].replace('.jpg', '.txt'))
    
        # Initialize the labels tensor with zeros
        labels = torch.zeros(S, S, B * 5)
    
        # Read the label file
        with open(label_file, 'r') as f:
            for line in f.readlines():
                # Parse the label line
                class_id, cx_norm, cy_norm, w_norm, h_norm = map(float, line.strip().split())
    
                # Calculate the grid cell indices (i, j) where the object center falls
                grid_x = int((cx_norm * img_width) / 64)
                grid_y = int((cy_norm * img_height) / 64)
    

    
                # Add the bounding box data to the grid cell
                for b in range(B):  # Loop over the bounding boxes (B)
                    offset = b * 5  # Offset for the B bounding boxes
                    if labels[grid_y, grid_x, offset + 4] == 0:  # Check if this cell is empty
                        labels[grid_y, grid_x, offset:offset + 5] = torch.tensor([cx_norm, cy_norm, w_norm, h_norm, 1.0])
                        break  # Stop after assigning the bounding box
    
        # Apply optional transformations
        if self.augment:
            image = self.augment(image)
        if self.normalize:
            image = self.normalize(image)
        img_c, img_w, img_h = image.shape
        if self.target_transform:
            labels = self.target_transform(labels, img_w, img_h)
    
        return image, labels



In [None]:
def bbox_to_coords(bbox):
    cx, cy, w, h, confidence = bbox[..., 0], bbox[..., 1], bbox[..., 2], bbox[..., 3], bbox[..., 4]
    x1 = cx - w / 2
    y1 = cy - h / 2
    x2 = cx + w / 2
    y2 = cy + h / 2
    return torch.stack([x1, y1, x2, y2, confidence], dim=-1)

# Function to plot an image with bounding boxes, grid, and a center dot
def visualize_image_with_boxes(image, boxes, S=9, B=1):
    # Convert the tensor to a PIL image if needed
    if isinstance(image, torch.Tensor):
        image = T.ToPILImage()(image)
    
    img_width, img_height = image.size  # Get image dimensions

    # Create a plot
    fig, ax = plt.subplots(1, figsize=(8, 8))
    ax.imshow(image)

    # Draw the 9x9 grid
    cell_size = 64
    for i in range(S + 1):
        # Vertical lines
        ax.plot([i * cell_size, i * cell_size], [0, img_height], color="blue", linewidth=1, linestyle="--")
        # Horizontal lines
        ax.plot([0, img_width], [i * cell_size, i * cell_size], color="blue", linewidth=1, linestyle="--")

    # Draw bounding boxes and their center dots
    for i in range(S):
        for j in range(S):
            for box in range(B):
                # Extract bounding box data
                cx_norm, cy_norm, w_norm, h_norm, confidence = boxes[i, j, box * 5: box * 5 + 5]
                if confidence > 0:  # Only draw boxes with non-zero confidence
                    # De-normalize the bounding box coordinates
                    cx = cx_norm * img_width
                    cy = cy_norm * img_height
                    w = w_norm * img_width
                    h = h_norm * img_height

                    # Convert bbox center format to corner format
                    x1, y1, x2, y2, confidence = bbox_to_coords(
                        torch.stack([cx, cy, w, h, confidence], dim=-1)
                    )
                    
                    # Draw the rectangle
                    rect = patches.Rectangle(
                        (x1, y1), w, h, linewidth=2, edgecolor='r', facecolor='none'
                    )
                    ax.add_patch(rect)
                    
                    # Plot the center dot
                    ax.scatter(cx, cy, color="red", s=5)  # Yellow dot, size 5

    plt.show()

# Example dataset and visualization
dataset = GditDataset(set_type='train')

# Visualize the first 10 images with their bounding boxes
for i in range(10):
    image, boxes = dataset[i]
    visualize_image_with_boxes(image, boxes)


In [None]:
class YOLOv1(nn.Module):
    def __init__(self):
        super().__init__()
        #Each box have 5 value (cx, cy, w, h, confidence).
        self.depth = 5 * B
        layers = [
            # Probe(0, forward=lambda x: print('#' * 5 + ' Start ' + '#' * 5)),
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),                   # Conv 1:
            nn.LeakyReLU(negative_slope=0.1),
            # Probe('conv1', forward=probe_dist),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 192, kernel_size=3, padding=1),                           # Conv 2:
            nn.LeakyReLU(negative_slope=0.1),
            # Probe('conv2', forward=probe_dist),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(192, 128, kernel_size=1),                                     # Conv 3: 
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(256, 256, kernel_size=1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.1),
            # Probe('conv3', forward=probe_dist),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ]

        for i in range(4):                                                          # Conv 4
            layers += [
                nn.Conv2d(512, 256, kernel_size=1),
                nn.Conv2d(256, 512, kernel_size=3, padding=1),
                nn.LeakyReLU(negative_slope=0.1)
            ]
        layers += [
            nn.Conv2d(512, 512, kernel_size=1),
            nn.Conv2d(512, 1024, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.1),
            # Probe('conv4', forward=probe_dist),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ]

        for _ in range(2):                                                          # Conv 5
            layers += [
                nn.Conv2d(1024, 512, kernel_size=1),
                nn.Conv2d(512, 1024, kernel_size=3, padding=1),
                nn.LeakyReLU(negative_slope=0.1)
            ]
        layers += [
            nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(negative_slope=0.1),
            # Probe('conv5', forward=probe_dist),
        ]

        for _ in range(2):                                                          # Conv 6
            layers += [
                nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
                nn.LeakyReLU(negative_slope=0.1)
            ]
        # layers.append(Probe('conv6', forward=probe_dist))

        layers += [
            nn.Flatten(),
            nn.Linear(S * S * 1024, 4096),                            # Linear 1
            nn.Dropout(),
            nn.LeakyReLU(negative_slope=0.1),
            # Probe('linear1', forward=probe_dist),
            nn.Linear(4096, S * S * self.depth),                      # Linear 2
            # Probe('linear2', forward=probe_dist),
        ]

        self.model = nn.Sequential(*layers) #Unpack layers into arguments

    def forward(self, x):
        return torch.reshape(
            self.model.forward(x),
            (x.size(dim=0), S, S, self.depth)
        )

In [None]:
class SumSquaredErrorLoss(nn.Module):
    def __init__(self, S=9, B=1, lambda_coord=5, lambda_noobj=0.5):
        """
        Custom loss function for YOLO-style detection.
        Args:
            S: Number of grid cells (e.g., 9x9 grid).
            B: Number of bounding boxes per grid cell.
            lambda_coord: Weight for the localization loss.
            lambda_noobj: Weight for the no-object confidence loss.
        """
        super(SumSquaredErrorLoss, self).__init__()
        self.S = S
        self.B = B
        self.lambda_coord = lambda_coord
        self.lambda_noobj = lambda_noobj
        self.mse = nn.MSELoss(reduction="sum")  # Use mean-squared error for all terms

    def forward(self, predictions, ground_truth):
        """
        Compute the loss between predictions and ground truth.
        Args:
            predictions: Tensor of shape (N, S, S, B*5), where N is the batch size.
            ground_truth: Tensor of shape (N, S, S, B*5), where N is the batch size.
        Returns:
            Total loss (scalar).
        """
        # Extract the components
        pred_boxes = predictions[..., :B * 4].reshape(-1, self.S, self.S, self.B, 4)  # (cx, cy, w, h)
        pred_confidence = predictions[..., B * 4:].reshape(-1, self.S, self.S, self.B)  # confidence

        true_boxes = ground_truth[..., :B * 4].reshape(-1, self.S, self.S, self.B, 4)  # (cx, cy, w, h)
        true_confidence = ground_truth[..., B * 4:].reshape(-1, self.S, self.S, self.B)  # confidence

        # Localization Loss (only for cells with objects)
        object_mask = (true_confidence > 0).unsqueeze(-1).expand_as(pred_boxes)   # Mask for cells containing objects, shape: (N, S, S, B, 1)
        localization_loss = self.mse(
            pred_boxes[object_mask],
            true_boxes[object_mask]
        )  # Only consider bounding boxes where objects exist

        # Confidence Loss
        confidence_loss_object = self.mse(
            pred_confidence[true_confidence > 0],
            true_confidence[true_confidence > 0]
        )  # Confidence loss for cells with objects

        confidence_loss_noobject = self.mse(
            pred_confidence[true_confidence == 0],
            true_confidence[true_confidence == 0]
        )  # Confidence loss for cells without objects

        # Total Loss
        total_loss = (
            self.lambda_coord * localization_loss  # Weighted localization loss
            + confidence_loss_object  # Confidence loss for objects
            + self.lambda_noobj * confidence_loss_noobject  # Weighted no-object loss
        )

        return total_loss

In [None]:
wandb.login(key='043c2bf5ed74194832136d7ff5c4ed072d5c00e2')

In [None]:
BATCH_SIZE = 8
EPOCHS = 2
LEARNING_RATE = 1E-4

wandb.finish()  # Ensure no lingering runs
wandb.init(
    project="YOLOv1 on GDIT", 
    config={
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
        "learning_rate": LEARNING_RATE
    },
    reinit=True
)

In [None]:
if __name__ == '__main__':      # Prevent recursive subprocess creation

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.autograd.set_detect_anomaly(True)         # Check for nan loss

    model = YOLOv1()
    model.to(device)
    loss_function = SumSquaredErrorLoss()

    # Adam works better
    # optimizer = torch.optim.SGD(
    #     model.parameters(),
    #     lr=LEARNING_RATE,
    #     momentum=0.9,
    #     weight_decay=5E-4
    # )
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=LEARNING_RATE
    )


    # Load the dataset
    train_set = GditDataset('train', normalize=normalize, augment=augment)
    val_set = GditDataset('valid', normalize=normalize, augment=augment)

    train_loader = DataLoader(
        train_set,
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    val_loader = DataLoader(
        val_set,
        batch_size=BATCH_SIZE,
        drop_last=True
    )

    
    # Create folders
    new_directory = 'checkpoints'
    if not os.path.exists(new_directory):
        os.makedirs(new_directory)
    now = datetime.now()
    time = now.strftime('%M-%H-%d-%m-%Y')
    save_path = f'/kaggle/working/checkpoints/yolov1-{time}'

    
    #####################
    #       Train       #
    #####################
    
    epoch_bar = tqdm(total=EPOCHS, desc="Total Progress")
    best_val_loss = float(1e5)
    
    for epoch in range(EPOCHS):
        model.train()
        # Training Loop
        train_loss = 0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
        
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_function(outputs, labels)
            loss.backward()

            max_grad_norm = 1.0  # Maximum allowed norm
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            
            optimizer.step()
            
            train_loss += loss.item()  # Accumulate loss (already normalized per batch)
        
        train_loss /= len(train_loader)  # Average over all batches
        
        # Validation Loop
        val_loss = 0
        model.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device)
                labels = labels.to(device)
        
                outputs = model(images)
                loss = loss_function(outputs, labels)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)  # Average over all batches
        
        print(f"Epoch [{epoch+1}/{EPOCHS}], Train Loss: {train_loss:.10f}, Val Loss: {val_loss:.10f}")

        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint = { 
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss': val_loss,
            }
            torch.save(checkpoint, save_path)
        
        epoch_bar.update(1)
        if wandb.run and not wandb.run._is_finished:
            wandb.log({'Val_loss': val_loss, 'Train_loss': train_loss})
    epoch_bar.close()


In [None]:
print("Dang beo")

In [None]:
def custom_collate_fn(batch):
    """
    Custom collate function to handle varying numbers of bounding boxes per image.
    Args:
        batch: List of tuples (image, labels).
    Returns:
        A tuple of batched images and a list of corresponding labels.
    """
    images = []
    labels = []

    for image, label in batch:
        images.append(image)
        labels.append(label)

    # Stack images into a single tensor
    images = torch.stack(images, dim=0)

    # Labels remain as a list of tensors
    return images, labels



class GditTestDataset(Dataset):
    def __init__(self, set_type, normalize=None, augment=None, target_transform=None):
        self.set_type = set_type
        self.normalize = normalize
        self.augment = augment
        self.target_transform = target_transform
        # Define directories for images and labels based on set type
        self.images_dir = f"/kaggle/input/gdit-dataset/Dataset/{set_type}/images"
        self.labels_dir = f"/kaggle/input/gdit-dataset/Dataset/{set_type}/labels"

        # Get all image files
        self.image_files = [f for f in os.listdir(self.images_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Get image path and read the image
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        image = read_image(img_path)

        # Get the corresponding label file (same filename as the image but with .txt extension)
        label_file = os.path.join(self.labels_dir, self.image_files[idx].replace('.jpg', '.txt'))

        # Read the label file, ignoring the first number (class ID or unused)
        with open(label_file, 'r') as f:
            labels = []
            for line in f.readlines():
                # Split each line and ignore the first value (0), then map the rest to float
                _,  cx_norm, cy_norm, w_norm, h_norm = map(float, line.strip().split())
                
                labels.append(torch.tensor([cx_norm, cy_norm, w_norm, h_norm]))

        labels = torch.stack(labels) if labels else torch.empty(0, 4)

        # Apply optional transformations
        if self.augment:
            image = self.augment(image)
        if self.normalize:
            image = self.normalize(image)
        img_c, img_w, img_h = image.shape
        if self.target_transform:
            labels = self.target_transform(labels, img_w, img_h)
    
        return image, labels

In [None]:
def compute_iou(box1, box2):
    """
    Compute IoU between two bounding boxes.
    Args:
        box1: Tensor of shape (4), format [x1, y1, x2, y2].
        box2: Tensor of shape (4), format [x1, y1, x2, y2].
    Returns:
        IoU value (float).
    """
    # Intersection coordinates
    x1 = torch.max(box1[0], box2[0])
    y1 = torch.max(box1[1], box2[1])
    x2 = torch.min(box1[2], box2[2])
    y2 = torch.min(box1[3], box2[3])

    # Intersection area
    inter_area = torch.clamp(x2 - x1, min=0) * torch.clamp(y2 - y1, min=0)

    # Union area
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union_area = box1_area + box2_area - inter_area

    return inter_area / union_area if union_area > 0 else 0

def non_max_suppression(boxes, iou_threshold=0.5):
    """
    Perform Non-Maximum Suppression (NMS) on bounding boxes.
    Args:
        boxes: List of bounding boxes [x1, y1, x2, y2, confidence].
        iou_threshold: IoU threshold for suppression.
    Returns:
        Filtered list of bounding boxes after NMS.
    """
    if len(boxes) == 0:
        return []

    # Sort boxes by confidence in descending order
    boxes = sorted(boxes, key=lambda x: x[4], reverse=True)

    nms_boxes = []
    while boxes:
        # Pick the box with the highest confidence
        chosen_box = boxes.pop(0)
        nms_boxes.append(chosen_box)

        # Remove boxes with IoU > threshold
        boxes = [
            box for box in boxes
            if compute_iou(torch.tensor(chosen_box[:4]), torch.tensor(box[:4])) < iou_threshold
        ]

    return nms_boxes


In [None]:
def calculate_ap(recalls, precisions):
    """
    Calculate Average Precision (AP) from recall and precision values.
    Args:
        recalls: Tensor of recall values.
        precisions: Tensor of precision values.
    Returns:
        AP value (float).
    """
    recalls = torch.cat([torch.tensor([0.0]), recalls, torch.tensor([1.0])])
    precisions = torch.cat([torch.tensor([0.0]), precisions, torch.tensor([0.0])])

    # Ensure precision is non-decreasing
    for i in range(len(precisions) - 1, 0, -1):
        precisions[i - 1] = torch.max(precisions[i - 1], precisions[i])

    # Compute AP as area under the curve
    indices = torch.where(recalls[1:] != recalls[:-1])[0]
    ap = torch.sum((recalls[indices + 1] - recalls[indices]) * precisions[indices + 1])
    return ap

def mean_average_precision(pred_boxes, gt_boxes, iou_threshold=0.5):
    """
    Calculate mAP for a single class (since GDIT dataset has only one class).
    Args:
        pred_boxes: List of predicted boxes [batch_idx, x1, y1, x2, y2, confidence, class_id].
        gt_boxes: List of ground truth boxes [batch_idx, x1, y1, x2, y2, class_id].
        iou_threshold: IoU threshold for a prediction to be considered correct.
    Returns:
        mAP score (float).
    """
    # Initialize arrays to track true positives and false positives
    tp = torch.zeros(len(pred_boxes))
    fp = torch.zeros(len(pred_boxes))

    # Track ground truth usage
    gt_used = {}

    # Sort predictions by confidence (highest to lowest)
    pred_boxes.sort(key=lambda x: x[5], reverse=True)

    # Process each predicted box
    for i, pred_box in enumerate(pred_boxes):
        batch_idx = pred_box[0]
        best_iou = 0
        best_gt_idx = None

        # Find corresponding ground truth for the current predicted box
        for j, gt_box in enumerate(gt_boxes):
            if batch_idx != gt_box[0]:  # Ensure same image
                continue

            iou = compute_iou(torch.tensor(pred_box[1:5]), torch.tensor(gt_box[1:5]))
            if iou > best_iou:
                best_iou = iou
                best_gt_idx = j

        # Check if IoU is above threshold and ground truth is not used
        if best_iou > iou_threshold and (batch_idx, best_gt_idx) not in gt_used:
            # True Positive
            tp[i] = 1
            gt_used[(batch_idx, best_gt_idx)] = True
        else:
            # False Positive
            fp[i] = 1

    # Calculate cumulative true positives and false positives
    tp_cumsum = torch.cumsum(tp, dim=0)
    fp_cumsum = torch.cumsum(fp, dim=0)

    # Recall and precision calculations
    recalls = tp_cumsum / len(gt_boxes)
    precisions = tp_cumsum / (tp_cumsum + fp_cumsum)

    # Calculate Average Precision (AP)
    ap = calculate_ap(recalls, precisions)
    
    return ap


In [None]:
def evaluate_model(model, iou_threshold=0.5, nms_threshold=0.5):
    """
    Evaluate YOLO model on a dataset.
    Args:
        model: Trained YOLO model.
        dataloader: Dataloader for the evaluation dataset.
        iou_threshold: IoU threshold for mAP calculation.
        nms_threshold: IoU threshold for NMS.
    Returns:
        mAP score (float).
    """
    model.eval()
    pred_boxes = []
    true_boxes = []

    yolo_test_set = GditDataset('test', normalize=normalize, augment=None)
    yolo_test_loader = DataLoader(
        yolo_test_set,
        batch_size=BATCH_SIZE,
        drop_last=True
    )

    test_set = GditTestDataset('test', normalize=normalize, augment=None)
    test_loader = DataLoader(
        test_set,
        batch_size=BATCH_SIZE,
        drop_last=True,
        collate_fn=custom_collate_fn
    )
    
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(yolo_test_loader):
            # Get model predictions            
            images = images.to(device)
            labels = labels.to(device)
            predictions = model(images)  # Shape: (N, S, S, B*5)
            
            
            # Process each image in the batch
            for idx in range(images.shape[0]):
                # Convert predictions to list of bounding boxes
                image_pred_boxes = []
                for i in range(predictions.shape[1]):  # S
                    for j in range(predictions.shape[2]):  # S
                        for b in range(predictions.shape[3] // 5):  # B
                            box = predictions[idx, i, j, b*5:b*5+5].cpu().numpy()
                            cx, cy, w, h, confidence = box
                            if confidence > 0.5:  # Confidence threshold
                                x1 = (cx - w / 2) * images.shape[3]
                                y1 = (cy - h / 2) * images.shape[2]
                                x2 = (cx + w / 2) * images.shape[3]
                                y2 = (cy + h / 2) * images.shape[2]
                                image_pred_boxes.append([x1, y1, x2, y2, confidence])

                # Apply NMS
                image_pred_boxes = non_max_suppression(image_pred_boxes, nms_threshold)
                for box in image_pred_boxes:
                    pred_boxes.append([batch_idx] + box)
        for images, ground_truth in test_loader:
            for idx in range(images.shape[0]):
                # Add ground-truth boxes
                for gt_box in ground_truth[idx]:
                    cx, cy, w, h = gt_box[:4]
                    x1 = (cx - w / 2) * images.shape[3]
                    y1 = (cy - h / 2) * images.shape[2]
                    x2 = (cx + w / 2) * images.shape[3]
                    y2 = (cy + h / 2) * images.shape[2]
                    true_boxes.append([batch_idx, x1, y1, x2, y2])

    # Compute mAP
    return mean_average_precision(pred_boxes, true_boxes, iou_threshold)


In [None]:
map_score = evaluate_model(model)
print(f"mAP Score: {map_score}")

In [None]:
# def plot_test_images(MODEL_DIR):
#     model.eval()
#     test_set = GditDataset('test', normalize=normalize, augment=False)
#     test_loader = DataLoader(test_set, batch_size=8, shuffle=True)

#     model = YOLOv1()
#     model.eval()
#     checkpoint=""
#     model.load_state_dict(torch.load(checkpoint)))

#     with torch.no_grad():
#         test_loss = 0
#         for images, labels in test_loader:
#             images = images.to(device)
#             labels = labels.to(device)
            
#             labels = labels.squeeze(dim=1).long()
#             outputs = model(images)

#             test_loss += loss_function(outputs, labels).item()

#     print(f'Test_loss : {test_loss / len(test_loader)}')

In [None]:
# plot_test_images(MODEL_DIR='models/yolo_v1/08_19_2022/08_42_58')