In [1]:
import os
import cv2
import torch
import torchvision.transforms as transforms
from PIL import Image
from models import shufflenetv2, nasnet_mobile_onfire
from tqdm import tqdm


In [3]:
def load_model(model_name, weight_path, device):
    if model_name == "shufflenetonfire":
        model = shufflenetv2.shufflenet_v2_x0_5(pretrained=False, layers=[4, 8, 4],
                                                output_channels=[24, 48, 96, 192, 64], num_classes=1)
    elif model_name == "nasnetonfire":
        model = nasnet_mobile_onfire.nasnetamobile(num_classes=1, pretrained=False)
    else:
        raise ValueError("Invalid model name. Choose 'shufflenetonfire' or 'nasnetonfire'.")
    model.load_state_dict(torch.load(weight_path, map_location=device))
    model.eval()
    model.to(device)
    return model

def get_transform(model_name):
    if model_name == 'shufflenetonfire':
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])
    elif model_name == 'nasnetonfire':
        return transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

def preprocess_image(image_path, transform, device):
    image = cv2.imread(image_path)
    image = cv2.resize(image, (224, 224), cv2.INTER_AREA)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = Image.fromarray(image)
    image = transform(image).float().unsqueeze(0).to(device)
    return image

def detect_fire(image_path, model, transform, device):
    image_tensor = preprocess_image(image_path, transform, device)
    output = model(image_tensor)
    prediction = torch.round(torch.sigmoid(output)).item()
    return "Fire" if prediction == 0 else "No Fire"

In [6]:
# --- Set Parameters ---
model_name = "shufflenetonfire"            # or "nasnetonfire"
weight_path = "weights/shufflenet_ff.pt"      # adjust path if needed
folder_path = "data/test"        # update folder path
ground_truth = "Fire"                      # set to "Fire" or "No Fire"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = load_model(model_name, weight_path, device)
transform = get_transform(model_name)

In [7]:
false_count = 0
total_images = 0
valid_extensions = (".jpg", ".jpeg", ".png", ".bmp")

# Iterate over files with a progress bar
for filename in tqdm(os.listdir(folder_path), desc="Processing images"):
    if filename.lower().endswith(valid_extensions):
        total_images += 1
        image_path = os.path.join(folder_path, filename)
        result = detect_fire(image_path, model, transform, device)
        # Count false positives (for "No Fire" ground truth) or false negatives (for "Fire")
        if (ground_truth == "No Fire" and result == "Fire") or (ground_truth == "Fire" and result == "No Fire"):
            false_count += 1

# Calculate correct classifications and accuracy
correct_count = total_images - false_count
accuracy = (correct_count / total_images * 100) if total_images > 0 else 0

print(f"Total images tested: {total_images}")
print(f"Correct classifications: {correct_count}")
if ground_truth == "No Fire":
    print(f"False Positives: {false_count}")
else:
    print(f"False Negatives: {false_count}")
print(f"Accuracy: {accuracy:.2f}%")







Processing images: 100%|██████████| 1/1 [00:00<00:00, 10.24it/s]

Total images tested: 1
Correct classifications: 1
False Negatives: 0
Accuracy: 100.00%



