<a href="https://colab.research.google.com/github/lorenzopaoria/Smoking-detection-and-distance-analysis/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Train a model for sigarette, smoker and non smoker detection

In [None]:
import torch
import torchvision
import psutil
import os
import sys
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from torch.utils.data import DataLoader, Dataset
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision import transforms
from PIL import Image
import numpy as np
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
import cv2
import pandas as pd

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
yolov10_path = 'content/drive/MyDrive/yolov10'

In [None]:
class DualDetectionDataset(Dataset):
    def __init__(self, coco_annotation_file, image_dir, transform=None):
        self.coco = COCO(coco_annotation_file)
        self.image_dir = image_dir
        self.transform = transform if transform is not None else transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Get cigarette and person categories separately
        self.cigarette_ids = self.coco.getCatIds(catNms=['cigarette'])
        self.person_ids = self.coco.getCatIds(catNms=['smoker', 'nonSmoker'])
        
        self.image_ids = list(set(
            self.coco.getImgIds(catIds=self.cigarette_ids) + 
            self.coco.getImgIds(catIds=self.person_ids)
        ))

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        image_path = f"{self.image_dir}/{img_info['file_name']}"
        image = Image.open(image_path).convert("RGB")
        
        # Get annotations
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        annotations = self.coco.loadAnns(ann_ids)
        
        cigarette_boxes, cigarette_labels = [], []
        person_boxes, person_labels = [], []
        
        for ann in annotations:
            cat_name = self.coco.loadCats(ann['category_id'])[0]['name']
            x, y, w, h = ann['bbox']
            
            if cat_name == 'cigarette':
                cigarette_boxes.append([x, y, x + w, y + h])
                cigarette_labels.append(1)
            elif cat_name in ['smoker', 'nonSmoker']:
                person_boxes.append([x, y, x + w, y + h])
                person_labels.append(1 if cat_name == 'smoker' else 2)
        
        # Transform image for FRCNN
        image_tensor = self.transform(image)
        
        # Prepare YOLO format - normalized coordinates
        yolo_image = np.array(image)
        h, w = yolo_image.shape[:2]
        yolo_boxes = []
        for box in person_boxes:
            x1, y1, x2, y2 = box
            # Normalize coordinates for YOLO
            yolo_boxes.append([
                (x1 + x2) / (2 * w),  # center x
                (y1 + y2) / (2 * h),  # center y
                (x2 - x1) / w,        # width
                (y2 - y1) / h         # height
            ])
        
        frcnn_target = {
            'boxes': torch.as_tensor(cigarette_boxes, dtype=torch.float32),
            'labels': torch.as_tensor(cigarette_labels, dtype=torch.int64),
            'image_id': torch.tensor([img_id])
        }
        
        yolo_target = {
            'boxes': torch.as_tensor(yolo_boxes, dtype=torch.float32),
            'labels': torch.as_tensor(person_labels, dtype=torch.int64),
            'image_id': torch.tensor([img_id]),
            'orig_size': (h, w)
        }
        
        return {
            'image_tensor': image_tensor,
            'yolo_image': yolo_image,
            'image_path': image_path,
            'frcnn_target': frcnn_target,
            'yolo_target': yolo_target
        }

AP (Average Precision) a vari livelli di IoU,
AR (Average Recall) per varie quantità di detections,
mAP (mean Average Precision).

In [None]:
class DualDetector:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.frcnn = fasterrcnn_resnet50_fpn(weights='DEFAULT')
        in_features = self.frcnn.roi_heads.box_predictor.cls_score.in_features
        self.frcnn.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, 2)
        self.frcnn = self.frcnn.to(self.device)
        
        self.yolo = YOLO('yolov8x.pt')

    def train(self, dataset, num_epochs=10, val_dataset=None):
        frcnn_optimizer = torch.optim.Adam(self.frcnn.parameters(), lr=0.0005)
        scaler = torch.amp.GradScaler() if torch.cuda.is_available() else None
        
        data_loader = DataLoader(
            dataset,
            batch_size=8,
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=2,
            pin_memory=True
        )
        
        best_frcnn_map = 0
        best_yolo_map = 0
        
        for epoch in range(num_epochs):
            self.frcnn.train()
            total_loss = 0
            
            print(f"\nEpoch {epoch+1}/{num_epochs}")
            print("Training Faster R-CNN for cigarette detection...")
            
            # Train Faster R-CNN
            for batch in tqdm(data_loader, desc="FRCNN Training"):
                images = [item['image_tensor'].to(self.device) for item in batch]
                frcnn_targets = [{k: v.to(self.device) for k, v in item['frcnn_target'].items()} for item in batch]
                
                frcnn_optimizer.zero_grad()
                
                if scaler is not None:
                    with torch.amp.autocast(device_type='cuda'):
                        loss_dict = self.frcnn(images, frcnn_targets)
                        frcnn_loss = sum(loss for loss in loss_dict.values())
                    
                    scaler.scale(frcnn_loss).backward()
                    scaler.step(frcnn_optimizer)
                    scaler.update()
                else:
                    loss_dict = self.frcnn(images, frcnn_targets)
                    frcnn_loss = sum(loss for loss in loss_dict.values())
                    frcnn_loss.backward()
                    frcnn_optimizer.step()
                
                total_loss += frcnn_loss.item()
            
            print("\nTraining YOLOv8 for person detection...")
            # Train YOLOv8
            yolo_results = self.yolo.train(
                data=dataset.coco.filename,  # your COCO format dataset
                epochs=1,
                imgsz=640,
                batch=8,
                device=self.device,
                project='pth_person_detect',
                name=f'epoch_{epoch+1}'
            )
            
            # Validation
            if val_dataset:
                print("\nValidating models...")
                frcnn_map, yolo_map = self.evaluate(val_dataset)
                
                # Save FRCNN if it's the best so far
                if frcnn_map > best_frcnn_map:
                    best_frcnn_map = frcnn_map
                    checkpoint_path = os.path.join('pth_cigarette_detect', f'best_frcnn_model.pth')
                    torch.save({
                        'epoch': epoch + 1,
                        'model_state_dict': self.frcnn.state_dict(),
                        'optimizer_state_dict': frcnn_optimizer.state_dict(),
                        'map': frcnn_map,
                    }, checkpoint_path)
                    print(f"New best FRCNN model saved with mAP: {frcnn_map:.4f}")
                
                # YOLOv8 handles its own best model saving
            
            # Save regular checkpoints
            frcnn_checkpoint_path = os.path.join('pth_cigarette_detect', f'frcnn_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': self.frcnn.state_dict(),
                'optimizer_state_dict': frcnn_optimizer.state_dict(),
                'loss': total_loss / len(data_loader),
            }, frcnn_checkpoint_path)
            
            print(f"\nEpoch {epoch+1} complete:")
            print(f"FRCNN Loss: {total_loss/len(data_loader):.4f}")
            print(f"FRCNN checkpoint saved: {frcnn_checkpoint_path}")
            print(f"YOLO results saved in: pth_person_detect/epoch_{epoch+1}")

    def evaluate(self, dataset):
        self.frcnn.eval()
        coco_dt_frcnn = []
        coco_dt_yolo = []
        
        with torch.no_grad():
            for idx in tqdm(range(len(dataset)), desc="Evaluating"):
                data = dataset[idx]
                
                # FRCNN predictions
                image_tensor = data['image_tensor'].unsqueeze(0).to(self.device)
                frcnn_predictions = self.frcnn(image_tensor)
                
                # YOLO predictions
                yolo_image = data['yolo_image']
                yolo_results = self.yolo.predict(yolo_image)
                
                image_id = data['frcnn_target']['image_id'].item()
                
                # Process FRCNN predictions (cigarettes)
                for box, score, label in zip(frcnn_predictions[0]['boxes'], 
                                        frcnn_predictions[0]['scores'],
                                        frcnn_predictions[0]['labels']):
                    if score > 0.5:
                        x1, y1, x2, y2 = box.tolist()
                        coco_dt_frcnn.append({
                            'image_id': image_id,
                            'category_id': 1,  # cigarette
                            'bbox': [x1, y1, x2-x1, y2-y1],
                            'score': score.item()
                        })
                
                # Process YOLO predictions (persons)
                for result in yolo_results:
                    boxes = result.boxes
                    for box in boxes:
                        if box.conf > 0.5:
                            x1, y1, x2, y2 = box.xyxy[0].tolist()
                            coco_dt_yolo.append({
                                'image_id': image_id,
                                'category_id': int(box.cls) + 2,  # 2=smoker, 3=nonSmoker
                                'bbox': [x1, y1, x2-x1, y2-y1],
                                'score': float(box.conf)
                            })
        
        # Evaluate FRCNN
        coco_gt = dataset.coco
        frcnn_map = 0
        if coco_dt_frcnn:
            coco_dt = coco_gt.loadRes(coco_dt_frcnn)
            coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
            coco_eval.evaluate()
            coco_eval.accumulate()
            print("\nFaster R-CNN Results (Cigarette Detection):")
            coco_eval.summarize()
            frcnn_map = coco_eval.stats[1]  # mAP@0.5
        
        # Evaluate YOLO
        yolo_map = 0
        if coco_dt_yolo:
            coco_dt = coco_gt.loadRes(coco_dt_yolo)
            coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
            coco_eval.evaluate()
            coco_eval.accumulate()
            print("\nYOLOv8 Results (Person Detection):")
            coco_eval.summarize()
            yolo_map = coco_eval.stats[1]  # mAP@0.5
        
        return frcnn_map, yolo_map

In [None]:
def collate_fn(batch):
    return batch

In [None]:
if __name__ == "__main__":
    # Configurazione percorsi
    train_image_dir = '/content/drive/MyDrive/Photo/train'
    train_coco_annotation_file = '/content/drive/MyDrive/Photo/train/_annotations.coco.json'
    valid_image_dir = '/content/drive/MyDrive/Photo/valid'
    valid_coco_annotation_file = '/content/drive/MyDrive/Photo/valid/_annotations.coco.json'
    
    # Inizializza i dataset
    train_dataset = DualDetectionDataset(train_coco_annotation_file, train_image_dir)
    val_dataset = DualDetectionDataset(valid_coco_annotation_file, valid_image_dir)
    
    # Inizializza e addestra il detector
    detector = DualDetector()
    detector.train(train_dataset, num_epochs=10, val_dataset=val_dataset)