In [1]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score



In [2]:
# Paths
images_dir = "../dfire_test/images"
labels_dir = "../dfire_test/labels"
model_path = "ff-det_resnet50.pth"  # Update with your model path

# Load dataset
def load_dataset(images_dir, labels_dir):
    dataset = []
    for img_file in os.listdir(images_dir):
        if img_file.endswith(('.jpg', '.jpeg', '.png')):
            img_path = os.path.join(images_dir, img_file)
            label_file = os.path.join(labels_dir, img_file.rsplit('.', 1)[0] + ".txt")
            
            if os.path.exists(label_file):
                # Check if the label file is blank
                with open(label_file, "r") as f:
                    label_content = f.read().strip()
                    label = 1 if label_content else 0  # Fire if not blank, No Fire if blank
                dataset.append((img_path, label))
    return dataset

# Load the dataset
dataset = load_dataset(images_dir, labels_dir)



In [3]:
# Transform for ResNet50
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet50()
model.fc = torch.nn.Linear(model.fc.in_features, 2)  # Assuming binary classification
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()


  model.load_state_dict(torch.load(model_path, map_location=device))


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [6]:

# Test loop
def test_model(dataset, model, transform, device):
    y_true = []
    y_pred = []
    total_read = 0
    total_not_read = 0

    for img_path, label in dataset:
        try:
            image = Image.open(img_path).convert("RGB")
            input_tensor = transform(image).unsqueeze(0).to(device)

            # Predict
            with torch.no_grad():
                output = model(input_tensor)
                pred_label = torch.argmax(output, dim=1).item()

            y_true.append(label)
            y_pred.append(pred_label)
            total_read += 1

        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            total_not_read += 1

    return y_true, y_pred, total_read, total_not_read

# Run the test
y_true, y_pred, total_read, total_not_read = test_model(dataset, model, transform, device)

# Calculate metrics
accuracy = accuracy_score(y_true, y_pred) * 100
precision = precision_score(y_true, y_pred, pos_label=1)
recall = recall_score(y_true, y_pred, pos_label=1)
f1 = f1_score(y_true, y_pred, pos_label=1)

# Display results
print(f"Total images successfully read: {total_read}")
print(f"Total images not read: {total_not_read}")
print(f"Accuracy: {accuracy:.2f}%")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")


Total images successfully read: 4306
Total images not read: 0
Accuracy: 39.46%
Precision: 0.4469
Recall: 0.5602
F1 Score: 0.4972
