<a href="https://colab.research.google.com/github/lorenzopaoria/Smoking-detection-and-distance-analysis/blob/main/model_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Create a model and import weights from pretained model


In [19]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.ops import nms
from torchvision.utils import draw_bounding_boxes
import os
from PIL import Image
from pathlib import Path

In [20]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [21]:
CLASS_NAMES = ["cigarette", "smoker", "nonSmoker"]
NUM_CLASSES = len(CLASS_NAMES) + 1
MODEL_PATH = '/content/drive/MyDrive/pth_epoch/best_model.pth'
TEST_DIR = '/content/drive/MyDrive/Photo/test'
OUTPUT_DIR = '/content/drive/MyDrive/test_trained'

os.makedirs(OUTPUT_DIR, exist_ok=True)

In [22]:
def get_model():
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(Weights=False)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
        in_features, NUM_CLASSES)

    return model

In [23]:
def load_model(model_path):
    """Carica il modello salvato"""
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Il file del modello non esiste: {model_path}")

    print(f"Caricamento del modello da {model_path}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Utilizzo device: {device}")

    try:
        model = get_model()
        checkpoint = torch.load(model_path, map_location=device)

        if "model_state_dict" in checkpoint:
            model.load_state_dict(checkpoint["model_state_dict"])
        else:
            model.load_state_dict(checkpoint)

        model.to(device)
        model.eval()
        print("Modello caricato con successo")
        return model, device

    except Exception as e:
        print(f"Errore nel caricamento del modello: {str(e)}")
        raise

In [24]:
class TestDataset(Dataset):
    def __init__(self, test_dir, transform=None):
        self.test_dir = Path(test_dir)
        if not self.test_dir.exists():
            raise FileNotFoundError(f"La directory {test_dir} non esiste!")

        self.image_files = list(self.test_dir.glob('*.[jp][pn][g]'))
        if not self.image_files:
            raise FileNotFoundError(f"Nessuna immagine trovata in {test_dir}")

        print(f"Trovate {len(self.image_files)} immagini in {test_dir}")

        self.transform = transform or transforms.Compose([
            transforms.Resize((800, 800)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = str(self.image_files[idx])
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, img_path
        except Exception as e:
            print(f"Errore nel caricamento dell'immagine {img_path}: {str(e)}")
            raise

In [25]:
def draw_predictions(image, boxes, labels, scores, output_path):
    """Disegna le bounding box sulle immagini e salva il risultato"""
    try:
        image = image * torch.tensor([0.229, 0.224, 0.225])[:, None, None]
        image = image + torch.tensor([0.485, 0.456, 0.406])[:, None, None]
        image = (image * 255).byte()

        def get_color(label):
            class_name = CLASS_NAMES[int(label)-1]
            if class_name == 'cigarette':
                return 'red'
            elif class_name == 'nonSmoker':
                return 'green'
            elif class_name == 'smoker':
                return 'blue'
            else:
                return 'white'

        colors = [get_color(label) for label in labels]
        labels_text = [f"{CLASS_NAMES[int(l)-1]}: {s:.2f}" for l, s in zip(labels, scores)]

        image_with_boxes = draw_bounding_boxes(
            image,
            boxes,
            labels=labels_text,
            colors=colors,
            width=2
        )

        image_with_boxes = image_with_boxes.permute(1, 2, 0).numpy()
        plt.figure(figsize=(12, 8))
        plt.imshow(image_with_boxes)
        plt.axis('off')
        plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
        plt.close()

        print(f"Salvata immagine con predizioni in {output_path}")
    except Exception as e:
        print(f"Errore nel disegno delle predizioni: {str(e)}")
        raise

In [26]:
def evaluate_model(model, device, test_loader):
    """Valuta il modello e salva le immagini con le predizioni"""
    model.to(device)

    print("Inizio valutazione del modello...")
    with torch.no_grad():
        for i, (images, img_paths) in enumerate(test_loader):
            print(f"\nProcessing batch {i+1}/{len(test_loader)}")

            try:
                images = images.to(device)
                predictions = model(images)

                for j, (prediction, img_path) in enumerate(zip(predictions, img_paths)):
                    boxes = prediction['boxes']
                    scores = prediction['scores']
                    labels = prediction['labels']

                    mask = scores > 0.7
                    boxes = boxes[mask]
                    scores = scores[mask]
                    labels = labels[mask]

                    if len(boxes) > 0:
                        base_name = os.path.basename(img_path)
                        output_filename = f"pred_{base_name}"
                        output_path = os.path.join(OUTPUT_DIR, output_filename)

                        draw_predictions(
                            images[j].cpu(),
                            boxes.cpu(),
                            labels.cpu(),
                            scores.cpu(),
                            output_path
                        )

                        print(f"Processata immagine: {base_name}")
                    else:
                        print(f"Nessuna predizione sopra la soglia per {os.path.basename(img_path)}")
            except Exception as e:
                print(f"Errore nel processing del batch {i+1}: {str(e)}")
                continue

In [27]:
if __name__ == "__main__":
    try:
        print("Inizializzazione...")

        os.makedirs(OUTPUT_DIR, exist_ok=True)

        for path in [os.path.dirname(MODEL_PATH), TEST_DIR]:
            if not os.path.exists(path):
                raise FileNotFoundError(f"Directory non trovata: {path}")

        print("\nCreazione dataset di test...")
        test_dataset = TestDataset(TEST_DIR)
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        print("\nCaricamento modello...")
        model, device = load_model(MODEL_PATH)

        print("\nInizio valutazione...")
        evaluate_model(model, device, test_loader)

        print("\nProcessing completato con successo!")

    except Exception as e:
        print(f"\nErrore durante l'esecuzione: {str(e)}")

Inizializzazione...

Creazione dataset di test...
Trovate 73 immagini in /content/drive/MyDrive/Photo/test

Caricamento modello...
Caricamento del modello da /content/drive/MyDrive/pth_epoch/best_model.pth
Utilizzo device: cuda


  checkpoint = torch.load(model_path, map_location=device)


Modello caricato con successo

Inizio valutazione...
Inizio valutazione del modello...

Processing batch 1/73
Salvata immagine con predizioni in /content/drive/MyDrive/test_trained/pred_20240928_140539_jpg.rf.a3f59ec5ebcc4dd94ed524a3659c45ec.jpg
Processata immagine: 20240928_140539_jpg.rf.a3f59ec5ebcc4dd94ed524a3659c45ec.jpg

Processing batch 2/73
Salvata immagine con predizioni in /content/drive/MyDrive/test_trained/pred_20240928_133303_jpg.rf.6f615ecc8690a9068e521e60ebbc982b.jpg
Processata immagine: 20240928_133303_jpg.rf.6f615ecc8690a9068e521e60ebbc982b.jpg

Processing batch 3/73
Salvata immagine con predizioni in /content/drive/MyDrive/test_trained/pred_20241120_110416_jpg.rf.479214375879a2a91444dbc7997d4af8.jpg
Processata immagine: 20241120_110416_jpg.rf.479214375879a2a91444dbc7997d4af8.jpg

Processing batch 4/73
Salvata immagine con predizioni in /content/drive/MyDrive/test_trained/pred_20240928_122456_jpg.rf.499c83da73f30227543e134377dfc208.jpg
Processata immagine: 20240928_1224