In [None]:
import os
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.transforms import functional as F
from pycocotools.coco import COCO
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import numpy as np

import matplotlib.pyplot as plt

# Define paths
TRAIN_PATH = 'D:/Download/JDownloader/MSCOCO/images/train2017'
VAL_PATH = 'D:/Download/JDownloader/MSCOCO/images/val2017'
ANNOTATIONS_PATH = 'D:/Download/JDownloader/MSCOCO/annotations'
WORKING_DIR = 'D:/Projetos/Mestrado/2024_Topicos_Esp_Sist_Informacao/ARTIGO_FINAL/object_detection_model_compare/working'

# Define constants
CATEGORIES = ['person', 'cat', 'dog']
NUM_IMAGES_PER_CLASS = 1000
NUM_CLASSES = len(CATEGORIES) + 1  # +1 for background
BATCH_SIZE = 4
NUM_EPOCHS = 10
LEARNING_RATE = 0.005
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
STEP_SIZE = 3
GAMMA = 0.1

# Data Preprocessing
class CocoDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotation, categories, transforms=None):
        self.root = root
        self.coco = COCO(annotation)
        self.transforms = transforms
        self.cat_ids = self.coco.getCatIds(catNms=categories)
        self.img_ids = []
        for cat_id in self.cat_ids:
            self.img_ids.extend(self.coco.getImgIds(catIds=cat_id))
        self.img_ids = list(set(self.img_ids))[:NUM_IMAGES_PER_CLASS * len(categories)]
        self.img_ids = sorted(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        path = os.path.join(self.root, img_info['file_name'])
        image = F.to_tensor(F.pil_to_tensor(F.pil_image_loader(path)))
        ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=self.cat_ids, iscrowd=None)
        anns = self.coco.loadAnns(ann_ids)
        boxes = []
        labels = []
        for ann in anns:
            xmin, ymin, width, height = ann['bbox']
            boxes.append([xmin, ymin, xmin + width, ymin + height])
            labels.append(self.cat_ids.index(ann['category_id']) + 1)
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        target = {'boxes': boxes, 'labels': labels}
        if self.transforms:
            image, target = self.transforms(image, target)
        return image, target

    def __len__(self):
        return len(self.img_ids)

# Load datasets
train_dataset = CocoDataset(TRAIN_PATH, os.path.join(ANNOTATIONS_PATH, 'instances_train2017.json'), CATEGORIES)
val_dataset = CocoDataset(VAL_PATH, os.path.join(ANNOTATIONS_PATH, 'instances_val2017.json'), CATEGORIES)

# Split datasets
train_indices, val_indices = train_test_split(list(range(len(train_dataset))), test_size=0.2, random_state=42)
train_subset = Subset(train_dataset, train_indices)
val_subset = Subset(val_dataset, val_indices)

# Data loaders
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))

# Model Setup
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, NUM_CLASSES)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# Optimizer and learning rate scheduler
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=LEARNING_RATE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

# Training function
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10):
    model.train()
    train_loss = 0
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        train_loss += losses.item()
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
    return train_loss / len(data_loader)

# Validation function
def evaluate(model, data_loader, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, targets in data_loader:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            val_loss += losses.item()
    return val_loss / len(data_loader)

# Training loop
train_losses = []
val_losses = []
for epoch in range(NUM_EPOCHS):
    train_loss = train_one_epoch(model, optimizer, train_loader, device, epoch)
    val_loss = evaluate(model, val_loader, device)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    lr_scheduler.step()
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), os.path.join(WORKING_DIR, 'fasterrcnn_model.pth'))

# Plotting functions
def plot_losses(train_losses, val_losses, save_path):
    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.savefig(save_path)
    plt.close()

# Generate and save plots
plot_losses(train_losses, val_losses, os.path.join(WORKING_DIR, 'loss_plot.png'))

# Note: For ROC curve and mAP, additional code is required to calculate these metrics.
# This script focuses on training and saving the model and loss plots.

loading annotations into memory...
Done (t=7.93s)
creating index...
index created!
loading annotations into memory...
Done (t=0.36s)
creating index...
index created!
