In [None]:
import os
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader, Subset
from torchvision import transforms as T
import torchvision.transforms.functional as F
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random
from PIL import Image
from tqdm import tqdm
import cv2

# Read random image

In [None]:
# Read folder
root_folder = ""
train_folder = os.path.join(root_folder + "train2017")
image_files = [i for i in os.listdir(train_folder) if i.endswith(('.jpg', '.jpeg', '.png'))]

# Pick random 6 images
random_images = random.sample(image_files, 6)

plt.figure(figsize=(15,5))

for i, img_name in enumerate(random_images):
    img_path = os.path.join(train_folder, img_name)
    img = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    plt.subplot(2, 3, i+1)
    plt.imshow(img_rgb)
    plt.axis('off')
    plt.title(img_name)
plt.show()


# Prepare dataset

In [None]:
class CocoDetectionDataset(CocoDetection):
    def __init__(self, img_folder, ann_file, train=False):
        super(CocoDetectionDataset, self).__init__(img_folder, ann_file)
        self.train = train
    def __getitem__(self, idx):
        img, target = super(CocoDetectionDataset, self).__getitem__(idx)
        
        # Lấy kích thước ảnh gốc trước khi chuyển thành Tensor
        image_width, image_height = img.size

        # Kiểm tra nếu target rỗng (ảnh không có đối tượng nào)
        if not target:
            # Tạo các tensor rỗng với đúng shape
            formatted_target = {
                "boxes": torch.zeros((0, 4), dtype=torch.float32),
                "labels": torch.zeros(0, dtype=torch.int64),
                "image_id": torch.tensor([self.ids[idx]]),
                "area": torch.zeros(0, dtype=torch.float32),
                "iscrowd": torch.zeros(0, dtype=torch.int64)
            }
            # Chỉ cần chuyển ảnh sang tensor và trả về
            img = F.to_tensor(img)
            return img, formatted_target

        # -- Phần xử lý target như cũ --
        image_id = self.ids[idx]
        boxes = [obj['bbox'] for obj in target]
        boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
        boxes[:, 2:] += boxes[:, :2]
        
        labels = torch.as_tensor([obj['category_id'] for obj in target], dtype=torch.int64)
        image_id = torch.tensor([image_id])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(target),), dtype=torch.int64)

        formatted_target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": image_id,
            "area": area,
            "iscrowd": iscrowd
        }
        
        # -- Logic Augmentation được tích hợp ở đây --
        if self.train:
            # Random horizontal flip với xác suất 50%
            if random.random() < 0.5:
                # Lật ảnh
                img = F.hflip(img)
                
                # Lật các bounding box
                bbox = formatted_target["boxes"]
                # Tọa độ x_min mới = chiều rộng ảnh - x_max cũ
                # Tọa độ x_max mới = chiều rộng ảnh - x_min cũ
                bbox[:, [0, 2]] = image_width - bbox[:, [2, 0]]
                formatted_target["boxes"] = bbox

        # Cuối cùng, chuyển ảnh sang Tensor
        img = F.to_tensor(img)

        return img, formatted_target

In [None]:
def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms) if transforms else None


# --- 3. Model Definition ---
def get_model(num_classes):
    # load a model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT", pretrain = True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

# --- 4. Collate Function for Dataloader ---
def collate_fn(batch):
    return tuple(zip(*batch))

# Training

In [None]:
if __name__ == "__main__":

    train_img = os.path.join(root_folder, 'train2017')
    train_ann = os.path.join(root_folder, '')

    val_img = os.path.join(root_folder, 'val2017')
    val_ann = os.path.join(root_folder, '')

    dataset = CocoDetectionDataset(train_img, train_ann, train=True)
    dataset_val = CocoDetectionDataset(val_img, val_ann, train=False)

    data_loader = DataLoader(
        dataset, batch_size=8, shuffle=True, num_workers=4, collate_fn=collate_fn
    )

    data_loader_val = DataLoader(
        dataset_val, batch_size=4, shuffle= False, num_workers=4, collate_fn=collate_fn
    )

    device = torch.device('cuda')if torch.cuda.is_available() else torch.device('cpu')
    num_classes = 91 + 1
    model = get_model(num_classes)
    model.to(device)

    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, weight_decay=0.0005)

    num_epochs = 30
    patience = 5
    best_val_loss = float('inf')
    best_model_path = '/kaggle/working/fasterrcnn_best_model.pth'
    
    print("Bắt đầu quá trình training...")
    for epoch in range(num_epochs):
        # --- Training ---
        model.train()
        total_train_loss = 0
        train_loop = tqdm(data_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]")
        
        for images, targets in train_loop:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
    
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
    
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
            current_loss = losses.item()
            total_train_loss += current_loss
            
            # Cập nhật thông tin loss lên thanh tiến trình
            train_loop.set_postfix(loss=f"{current_loss:.4f}")
        
        avg_train_loss = total_train_loss / len(data_loader)
    
        # --- Validation ---
        model.eval()
        total_val_loss = 0
        # Tạo đối tượng tqdm cho validation
        val_loop = tqdm(data_loader_val, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]")
        
        with torch.no_grad():
            for images, targets in val_loop:
                images = list(image.to(device) for image in images)
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
                
                loss_dict_val = model(images, targets)
                losses_val = sum(loss for loss in loss_dict_val.values())
                
                current_val_loss = losses_val.item()
                total_val_loss += current_val_loss
                
                # Cập nhật thông tin val_loss lên thanh tiến trình
                val_loop.set_postfix(val_loss=f"{current_val_loss:.4f}")
                
        avg_val_loss = total_val_loss / len(data_loader_val)
        
        # In kết quả trung bình của epoch
        print(f"\nEpoch #{epoch+1} Summary: Avg Train Loss: {avg_train_loss:.4f} | Avg Val Loss: {avg_val_loss:.4f}")
    
        # --- Logic Early Stopping và lưu mô hình tốt nhất ---
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), best_model_path)
            print(f"-> Validation loss cải thiện. Đã lưu mô hình tốt nhất vào '{best_model_path}'\n")
        else:
            epochs_no_improve += 1
            print(f"-> Validation loss không cải thiện. Đã {epochs_no_improve}/{patience} epoch không cải thiện.\n")
    
        if epochs_no_improve >= patience:
            print(f"Early stopping! Dừng training vì validation loss không cải thiện trong {patience} epoch.")
            break
    
    print("\nHoàn tất training!")