In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
from tqdm import tqdm

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load datasets
def load_datasets(train_dir, test_dir, transform):
    train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
    test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    return train_loader, test_loader

# Define the CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)  # Adjusted for input size of 64
        self.fc2 = nn.Linear(128, 2)  # Binary classification: output size 2 for using CrossEntropyLoss
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 16 * 16)  # Adjusted flatten layer to match output of last conv layer
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Training the model with tqdm visualization
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    model.train()
    model.to(device)
    for epoch in range(num_epochs):
        # Adding tqdm progress bar
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit='batch') as pbar:
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                pbar.set_postfix(loss=loss.item())
                pbar.update()

# Custom metric calculation for accuracy and false alarms
'''def evaluate_model(model, test_loader):
    model.eval()
    model.to(device)
    total = correct = false_alarms = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            # Counting false alarms: Predicted as H (class 0) when actually N (class 1)
            false_alarms += ((predicted == 0) & (labels == 1)).sum().item()

    accuracy = correct / total * 100
    return accuracy, false_alarms'''



'def evaluate_model(model, test_loader):\n    model.eval()\n    model.to(device)\n    total = correct = false_alarms = 0\n    with torch.no_grad():\n        for images, labels in test_loader:\n            images, labels = images.to(device), labels.to(device)\n            outputs = model(images)\n            _, predicted = torch.max(outputs, 1)\n            correct += (predicted == labels).sum().item()\n            total += labels.size(0)\n            # Counting false alarms: Predicted as H (class 0) when actually N (class 1)\n            false_alarms += ((predicted == 0) & (labels == 1)).sum().item()\n\n    accuracy = correct / total * 100\n    return accuracy, false_alarms'

In [18]:
def evaluate_model(model, test_loader):
    model.eval()
    model.to(device)
    total_hotspots = correct_hotspots = false_alarms = 0
    # 使用tqdm显示评估进度
    with torch.no_grad(), tqdm(total=len(test_loader), desc='Evaluating', unit='batch') as pbar:
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            # 更新正确预测的热点计数
            correct_hotspots += ((predicted == 0) & (labels == 0)).sum().item()
            # 更新总热点计数
            total_hotspots += (labels == 0).sum().item()
            # 计算误报：预测为H类（热点，class 0）但实际为N类（无热点，class 1）
            false_alarms += ((predicted == 0) & (labels == 1)).sum().item()
            pbar.update()

    hotspot_accuracy = (correct_hotspots / total_hotspots * 100) if total_hotspots > 0 else 0
    return hotspot_accuracy, false_alarms

In [19]:
train_loader, test_loader = load_datasets(r'dataset/train', r'dataset/test', transform)
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_model(model, train_loader, criterion, optimizer)


Epoch 1/10: 100%|██████████| 572/572 [06:28<00:00,  1.47batch/s, loss=0.0188]  
Epoch 2/10: 100%|██████████| 572/572 [06:30<00:00,  1.47batch/s, loss=0.024]   
Epoch 3/10: 100%|██████████| 572/572 [06:28<00:00,  1.47batch/s, loss=0.0146]  
Epoch 4/10: 100%|██████████| 572/572 [06:29<00:00,  1.47batch/s, loss=0.0148]  
Epoch 5/10: 100%|██████████| 572/572 [06:30<00:00,  1.46batch/s, loss=0.00696] 
Epoch 6/10: 100%|██████████| 572/572 [06:29<00:00,  1.47batch/s, loss=0.0191]  
Epoch 7/10: 100%|██████████| 572/572 [06:30<00:00,  1.46batch/s, loss=0.0121]  
Epoch 8/10: 100%|██████████| 572/572 [06:29<00:00,  1.47batch/s, loss=0.0133]  
Epoch 9/10: 100%|██████████| 572/572 [06:32<00:00,  1.46batch/s, loss=0.00245] 
Epoch 10/10: 100%|██████████| 572/572 [06:32<00:00,  1.46batch/s, loss=0.00181] 


In [21]:
accuracy, false_alarms = evaluate_model(model, test_loader)
print(accuracy, false_alarms)
# 指定文件路径
file_path = 'results.txt'
# 打开文件以追加模式写入
with open(file_path, 'a') as file:
    # 将变量写入文件，每个变量占一行
    file.write(f"Accuracy: {accuracy}\n")
    file.write(f"False Alarms: {false_alarms}\n")


数据已追加到文件中。
