<a href="https://colab.research.google.com/github/lorenzopaoria/Smoking-detection-and-distance-analysis/blob/main/cigarettes_model_load.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from torchvision import models, transforms
from PIL import Image
from tqdm import tqdm

In [9]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def load_model(model_path):
    model = models.detection.fasterrcnn_resnet50_fpn(weights=models.detection.FasterRCNN_ResNet50_FPN_Weights.COCO_V1)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    num_classes = 2
    model.roi_heads.box_predictor = models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()
    return model

In [None]:
def get_predictions(model, image, threshold=0.5):
    transform = transforms.ToTensor()
    image_tensor = transform(image).unsqueeze(0)
    with torch.no_grad():
        predictions = model(image_tensor)

    boxes = predictions[0]['boxes']
    scores = predictions[0]['scores']
    labels = predictions[0]['labels']

    keep = scores > threshold
    boxes = boxes[keep]
    labels = labels[keep]
    scores = scores[keep]
    return boxes, labels, scores

In [None]:
def draw_boxes(image, boxes, labels, scores):
    # Convert PIL Image to numpy array in RGB format
    image_np = np.array(image)
    
    # Convert RGB to BGR for OpenCV
    image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
    
    for box, label, score in zip(boxes, labels, scores):
        x1, y1, x2, y2 = box.tolist()
        # Draw red bounding box
        cv2.rectangle(image_cv, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2)
        
        # Prepare label text with confidence score
        label_text = f"Cigarette: {score:.2f}"
        
        # Calculate text size and position
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.6
        thickness = 2
        (text_width, text_height), baseline = cv2.getTextSize(label_text, font, font_scale, thickness)
        
        # Draw background rectangle for text
        cv2.rectangle(image_cv, 
                    (int(x1), int(y1) - text_height - 5),
                    (int(x1) + text_width, int(y1)),
                    (0, 0, 255),
                    -1)  # Filled rectangle
        
        # Add white text
        cv2.putText(image_cv,
                    label_text,
                    (int(x1), int(y1) - 5),
                    font,
                    font_scale,
                    (255, 255, 255),  # White color
                    thickness)
    
    return image_cv

In [None]:
def process_images(model_path, images_folder, output_folder):
    model = load_model(model_path)
    os.makedirs(output_folder, exist_ok=True)

    # Get list of image files
    image_files = [f for f in os.listdir(images_folder) 
                if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
    
    # Create progress bar
    for image_name in tqdm(image_files, desc="Processing images"):
        image_path = os.path.join(images_folder, image_name)
        
        # Load image in RGB format
        image = Image.open(image_path).convert("RGB")
        boxes, labels, scores = get_predictions(model, image)

        # Draw boxes and save
        result_image = draw_boxes(image, boxes, labels, scores)
        result_image_path = os.path.join(output_folder, f"output_{image_name}")
        cv2.imwrite(result_image_path, result_image)

In [None]:
def main():
    model_path = '/content/drive/MyDrive/pth_cigarette_detect/fasterrcnn_cigarette_final.pth'
    images_folder = '/content/drive/MyDrive/Photo/test'
    output_folder = '/content/drive/MyDrive/test_trained'

    process_images(model_path, images_folder, output_folder)

if __name__ == "__main__":
    main()