In [14]:
import os
import glob
import numpy as np
import xml.etree.ElementTree as ET
from PIL import Image
from torch.utils.data import Dataset
import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn, FasterRCNN_ResNet50_FPN_Weights
from torchvision.transforms import functional as F

# Load the pre-trained Faster R-CNN model
def load_model():
    model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT)
    model.eval()
    return model

# Define the dataset class
class MaskedFaceTestDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.imgs = sorted(glob.glob(os.path.join(root, '*.png')))
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        img_path = self.imgs[index]
        annotation_path = img_path.replace('.png', '.xml')
        img = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        true_counts = self.parse_annotation(annotation_path)
        return img, true_counts, img_path

    def parse_annotation(self, annotation_path):
        tree = ET.parse(annotation_path)
        root = tree.getroot()
        counts = np.zeros(3, dtype=int)  # [with_mask, without_mask, mask_weared_incorrect]
        class_map = {'with_mask': 0, 'without_mask': 1, 'mask_weared_incorrect': 2}
        for member in root.findall('object'):
            class_name = member.find('name').text
            if class_name in class_map:
                class_id = class_map[class_name]
                counts[class_id] += 1
        return counts

# Prediction function using the Faster R-CNN model
def get_predictions_for_image(img, model):
    transform = F.to_tensor
    img_tensor = transform(img).unsqueeze_(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    img_tensor = img_tensor.to(device)

    with torch.no_grad():
        predictions = model(img_tensor)

    pred_classes = predictions[0]['labels'].cpu().numpy()
    pred_scores = predictions[0]['scores'].cpu().numpy()
    threshold = 0.8
    filtered_predictions = [(pred_classes[i], pred_scores[i]) for i, score in enumerate(pred_scores) if score > threshold]

    counts = np.zeros(3, dtype=int)
    for cls, _ in filtered_predictions:
        if cls in [1, 2, 3]:
            counts[cls-1] += 1

    return counts

# Function to count masks and calculate MAPE
def count_masks(dataset, model):
    counts = []
    mape_scores = []

    for img, true_counts, img_path in dataset:
        predicted_counts = get_predictions_for_image(img, model)

        mape_score = np.mean([
            np.abs(tc - pc) / max(tc, 1) for tc, pc in zip(true_counts, predicted_counts)
        ]) * 100
        counts.append(predicted_counts)
        mape_scores.append(mape_score)

    counts_array = np.array(counts)
    mean_mape = np.mean(mape_scores)
    return counts_array, mean_mape


model = load_model()
root_dir = "/content/drive/MyDrive/val"
dataset = MaskedFaceTestDataset(root=root_dir)
counts_array, mean_mape = count_masks(dataset, model)

print(f"Counts per class: {counts_array}")
print(f"Mean Absolute Percentage Error (MAPE): {mean_mape:.2f}%")


Counts per class: [[15  0  0]
 [10  0  0]
 [ 2  0  0]
 [ 3  0  0]
 [10  0  0]
 [10  0  0]
 [ 4  0  0]
 [14  0  0]
 [ 1  0  1]
 [ 6  0  0]
 [ 9  0  0]
 [ 1  0  0]
 [ 6  0  0]
 [ 2  0  0]
 [ 7  0  0]
 [13  0  0]
 [19  0  1]
 [ 1  0  0]
 [ 6  0  0]
 [10  0  0]
 [12  0  0]
 [ 1  0  0]
 [ 1  0  0]
 [ 6  0  0]
 [13  0  0]
 [ 1  0  0]
 [11  0  0]
 [23  0  0]
 [ 2  0  0]
 [13  0  0]
 [ 1  0  0]
 [13  0  0]
 [25  0  0]
 [ 2  0  0]
 [ 1  0  0]
 [13  0  0]
 [ 2  0  1]
 [ 3  0  1]
 [30  0  0]
 [14  0  0]
 [ 6  0  0]
 [ 1  0  0]
 [ 1  0  0]
 [10  0  0]
 [ 1  0  0]
 [ 1  0  0]
 [12  0  0]
 [ 5  0  0]
 [ 1  0  0]
 [10  0  0]
 [16  0  0]
 [11  0  0]
 [13  0  0]
 [ 5  0  2]
 [12  0  0]
 [ 8  0  0]
 [ 3  0  0]
 [ 1  0  0]
 [ 7  0  0]
 [ 4  0  0]
 [ 7  0  0]
 [ 1  0  0]
 [ 5  0  0]
 [ 1  0  0]
 [ 2  0  0]
 [ 9  0  0]
 [ 1  0  0]
 [ 4  0  0]
 [ 2  0  0]
 [ 4  0  0]
 [ 3  0  0]
 [ 9  0  0]
 [ 1  0  0]
 [16  0  0]
 [13  0  0]
 [11  0  0]
 [ 1  0  0]
 [ 1  0  0]
 [ 2  0  0]
 [ 1  0  0]
 [10  0  0]
 [ 4  0  0