In [6]:
import os
import torch
import pandas as pd
import ast
import torchvision
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from torchvision.io import read_image
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
from torchvision.transforms import v2 as T
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.utils import draw_bounding_boxes
import utils
from engine import train_one_epoch, evaluate

def box_area(boxes):
    """
    Computes the area of a set of bounding boxes
    """
    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])

def box_iou(boxes1, boxes2):
    """
    Return intersection-over-union (Jaccard index) between two sets of boxes.
    
    Args:
        boxes1 (Tensor[N, 4]): first set of boxes
        boxes2 (Tensor[M, 4]): second set of boxes
        
    Returns:
        iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values
    """
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    # Calculate intersection coordinates
    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # left-top coordinates
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # right-bottom coordinates

    # Calculate intersection area
    wh = (rb - lt).clamp(min=0)  # width-height
    inter = wh[:, :, 0] * wh[:, :, 1]  # intersection area

    # Calculate union area
    union = area1[:, None] + area2 - inter

    # Calculate IoU
    iou = inter / union
    return iou    

def save_model(model, working_dir):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    model_path = os.path.join(working_dir, f'model_{timestamp}.pth')
    torch.save(model.state_dict(), model_path)
    return model_path

def plot_metrics(metrics, metric_name, working_dir):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    plt.figure(figsize=(10, 6))
    plt.plot(metrics)
    plt.title(f'Model {metric_name} over epochs')
    plt.xlabel('Epochs')
    plt.ylabel(metric_name)
    plt.grid(True)
    plot_path = os.path.join(working_dir, f'{metric_name.lower()}_{timestamp}.png')
    plt.savefig(plot_path)
    plt.close()

def plot_roc_curve(model, data_loader_test, device, working_dir, num_classes):
    model.eval()
    # Initialize dictionaries to store scores and labels for each class
    class_scores = {i: [] for i in range(1, num_classes)}  # Skip background class (0)
    class_labels = {i: [] for i in range(1, num_classes)}
    
    with torch.no_grad():
        for images, targets in data_loader_test:
            images = list(img.to(device) for img in images)
            predictions = model(images)
            
            for prediction, target in zip(predictions, targets):
                pred_scores = prediction['scores'].cpu().numpy()
                pred_labels = prediction['labels'].cpu().numpy()
                target_labels = target['labels'].cpu().numpy()
                
                # For each class
                for class_id in range(1, num_classes):
                    # Get predictions for this class
                    class_mask_pred = pred_labels == class_id
                    class_mask_target = target_labels == class_id
                    
                    # Add a score for each target instance
                    for target_idx in range(sum(class_mask_target)):
                        class_labels[class_id].append(1)
                        if sum(class_mask_pred) > 0:
                            # Use maximum score for this class as the prediction
                            class_scores[class_id].append(
                                max(pred_scores[class_mask_pred]) if any(class_mask_pred) else 0.0
                            )
                        else:
                            class_scores[class_id].append(0.0)
                    
                    # Add negative examples from other classes
                    other_classes = pred_labels[pred_labels != class_id]
                    other_scores = pred_scores[pred_labels != class_id]
                    for score in other_scores:
                        class_labels[class_id].append(0)
                        class_scores[class_id].append(score)

    # Plot ROC curves
    plt.figure(figsize=(10, 8))
    
    # Define class names
    class_names = {1: 'person', 2: 'cat', 3: 'dog'}
    
    # Plot ROC curve for each class
    for class_id in range(1, num_classes):
        if len(class_scores[class_id]) > 0:  # Only plot if we have predictions
            fpr, tpr, _ = roc_curve(
                class_labels[class_id], 
                class_scores[class_id]
            )
            roc_auc = auc(fpr, tpr)
            plt.plot(
                fpr, 
                tpr, 
                lw=2, 
                label=f'{class_names[class_id]} (AUC = {roc_auc:.2f})'
            )
    
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve per Class')
    plt.legend(loc="lower right")
    plt.grid(True)
    
    # Save the plot
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    plot_path = os.path.join(working_dir, f'roc_curve_{timestamp}.png')
    plt.savefig(plot_path)
    plt.close()

def calculate_map(model, data_loader_test, device):
    model.eval()
    all_maps = []
    
    with torch.no_grad():
        for images, targets in data_loader_test:
            images = list(img.to(device) for img in images)
            predictions = model(images)
            
            for prediction, target in zip(predictions, targets):
                pred_boxes = prediction['boxes'].cpu()
                pred_scores = prediction['scores'].cpu()
                pred_labels = prediction['labels'].cpu()
                
                target_boxes = target['boxes'].cpu()
                target_labels = target['labels'].cpu()
                
                if len(pred_boxes) > 0 and len(target_boxes) > 0:
                    # Sort predictions by score
                    scores_sorted, indices = torch.sort(pred_scores, descending=True)
                    pred_boxes = pred_boxes[indices]
                    pred_labels = pred_labels[indices]
                    
                    # Calculate IoU between predicted and target boxes
                    ious = box_iou(pred_boxes, target_boxes)  # Now using local box_iou function
                    
                    # Calculate AP for each class
                    for class_id in range(1, num_classes):
                        class_mask_pred = pred_labels == class_id
                        class_mask_target = target_labels == class_id
                        
                        if class_mask_target.sum() > 0:
                            class_ious = ious[class_mask_pred][:, class_mask_target]
                            if len(class_ious) > 0:
                                # For each ground truth box, get the max IoU
                                max_ious, _ = class_ious.max(dim=0)
                                # Calculate AP as the proportion of ground truth boxes 
                                # that have a matching prediction (IoU > 0.5)
                                ap = (max_ious > 0.5).float().mean()
                                all_maps.append(ap.item())
    
    return np.mean(all_maps) if all_maps else 0.0

def get_transform(train):
    transforms = []
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

class CocoSubsetDataset(torch.utils.data.Dataset):
    def __init__(self, csv_path, images_path, transforms=None):
        self.transforms = transforms
        self.images_path = images_path
        # Load CSV data
        self.df = pd.read_csv(csv_path)
        # Convert string representation of bbox to list
        self.df['bbox'] = self.df['bbox'].apply(ast.literal_eval)
        
        # Create category to id mapping
        self.category_to_id = {
            'person': 1,
            'cat': 2,
            'dog': 3
        }
        
    def __getitem__(self, idx):
        # Get all annotations for this image
        img_annots = self.df[self.df['image_id'] == self.df['image_id'].unique()[idx]]
        
        # Load image
        img_path = os.path.join(self.images_path, img_annots['image'].iloc[0])
        img = read_image(img_path)
        
        # If image is grayscale, convert to RGB
        if img.shape[0] == 1:
            img = img.repeat(3, 1, 1)
        
        # Get boxes and labels
        boxes = []
        labels = []
        
        for _, row in img_annots.iterrows():
            # Convert COCO bbox [x,y,width,height] to [x1,y1,x2,y2]
            bbox = row['bbox']
            boxes.append([
                bbox[0],
                bbox[1],
                bbox[0] + bbox[2],
                bbox[1] + bbox[3]
            ])
            labels.append(self.category_to_id[row['label']])
        
        # Convert to tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        
        # Prepare target dict
        img = tv_tensors.Image(img)
        target = {}
        target["boxes"] = tv_tensors.BoundingBoxes(boxes, format="XYXY", canvas_size=F.get_size(img))
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)
            
        return img, target
    
    def __len__(self):
        return len(self.df['image_id'].unique())

def get_model_detection(num_classes):
    # Load a pre-trained model
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
    
    # Get number of input features
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    return model

# Training setup
TRAIN_IMAGES_PATH = 'D:/Download/JDownloader/MSCOCO/images/train2017'
FILTERED_PATH = 'D:/Projetos/pythonlib/filtered-coco-dataset'
WORKING_DIR = 'D:/Projetos/pythonlib/working'
CSV_PATH = os.path.join(FILTERED_PATH, 'filtered_coco.csv')

# Number of classes (background + person + cat + dog)
num_classes = 4

# Create train and test datasets
dataset = CocoSubsetDataset(CSV_PATH, TRAIN_IMAGES_PATH, get_transform(train=True))
dataset_test = CocoSubsetDataset(CSV_PATH, TRAIN_IMAGES_PATH, get_transform(train=False))

# Split dataset
indices = torch.randperm(len(dataset)).tolist()
train_size = int(0.8 * len(dataset))  # 80% for training
dataset = torch.utils.data.Subset(dataset, indices[:train_size])
dataset_test = torch.utils.data.Subset(dataset_test, indices[train_size:])

# Create data loaders
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=utils.collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=1,
    shuffle=False,
    collate_fn=utils.collate_fn
)

# Initialize model, optimizer and scheduler
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = get_model_detection(num_classes)
model.to(device)

# Optimize all parameters
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

# Training loop
num_epochs = 3

accuracies = []
losses = []
maps = []

for epoch in range(num_epochs):
    # Train for one epoch and collect metrics
    metric_logger = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    
    # Store loss
    epoch_loss = metric_logger.meters['loss'].global_avg
    losses.append(epoch_loss)
    
    # Calculate and store mAP
    epoch_map = calculate_map(model, data_loader_test, device)
    maps.append(epoch_map)
    
    # Calculate and store accuracy (using mAP as accuracy metric for object detection)
    accuracies.append(epoch_map)
    
    # Update learning rate
    lr_scheduler.step()
    
    # Evaluate on test dataset
    #evaluate(model, data_loader_test, device=device)

# Save model
model_path = save_model(model, WORKING_DIR)
print(f"Model saved to: {model_path}")

# Generate and save plots
plot_metrics(accuracies, 'Accuracy', WORKING_DIR)
plot_metrics(losses, 'Loss', WORKING_DIR)
plot_metrics(maps, 'mAP', WORKING_DIR)
plot_roc_curve(model, data_loader_test, device, WORKING_DIR, num_classes)

print("Training and plotting completed!")

# Visualization function
def visualize_predictions(model, image_path, device, transform):
    image = read_image(image_path)
    
    # Transform image
    image_transformed = transform(image)
    image_transformed = image_transformed[:3, ...].to(device)
    
    # Get predictions
    model.eval()
    with torch.no_grad():
        predictions = model([image_transformed])
        pred = predictions[0]
    
    # Convert image for visualization
    image = image[:3, ...]  # Remove alpha channel if present
    image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
    
    # Draw bounding boxes
    pred_scores = pred["scores"] > 0.5
    pred_boxes = pred["boxes"][pred_scores].long()
    pred_labels = pred["labels"][pred_scores]
    
    # Convert numeric labels to text
    id_to_label = {1: 'person', 2: 'cat', 3: 'dog'}
    pred_label_texts = [f"{id_to_label[label.item()]}: {score:.2f}" 
                       for label, score in zip(pred_labels, pred["scores"][pred_scores])]
    
    # Draw boxes
    output_image = draw_bounding_boxes(
        image, 
        pred_boxes,
        pred_label_texts,
        colors="red"
    )
    
    return output_image

  with torch.cuda.amp.autocast(enabled=scaler is not None):


Epoch: [0]  [   0/1142]  eta: 0:16:13  lr: 0.000010  loss: 1.8931 (1.8931)  loss_classifier: 1.5760 (1.5760)  loss_box_reg: 0.2338 (0.2338)  loss_objectness: 0.0742 (0.0742)  loss_rpn_box_reg: 0.0091 (0.0091)  time: 0.8522  data: 0.0070  max mem: 4952
Epoch: [0]  [  10/1142]  eta: 0:16:37  lr: 0.000060  loss: 1.6158 (1.6368)  loss_classifier: 1.4720 (1.4440)  loss_box_reg: 0.1529 (0.1410)  loss_objectness: 0.0517 (0.0475)  loss_rpn_box_reg: 0.0032 (0.0043)  time: 0.8813  data: 0.0091  max mem: 5062
Epoch: [0]  [  20/1142]  eta: 0:15:41  lr: 0.000110  loss: 1.4043 (1.3614)  loss_classifier: 1.1727 (1.1662)  loss_box_reg: 0.1395 (0.1512)  loss_objectness: 0.0265 (0.0391)  loss_rpn_box_reg: 0.0027 (0.0050)  time: 0.8380  data: 0.0084  max mem: 5062
Epoch: [0]  [  30/1142]  eta: 0:15:30  lr: 0.000160  loss: 0.7375 (1.0831)  loss_classifier: 0.5213 (0.8946)  loss_box_reg: 0.1395 (0.1467)  loss_objectness: 0.0259 (0.0361)  loss_rpn_box_reg: 0.0036 (0.0057)  time: 0.8118  data: 0.0070  max me