In [1]:
import torch
import os
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split

In [2]:
root_dir = "../forest_fire/Training and Validation"

# image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # resizing for ResNet input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard ResNet normalization
])

dataset = datasets.ImageFolder(root=root_dir, transform=transform)

# splits training data into 3:1 ratio
train_size = int(0.75 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# dataloaders for training
# batch_size = 32
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = models.resnet50(pretrained=True)
num_features = model.fc.in_features
model.fc = torch.nn.Linear(num_features, 2)

# for checking purposes
print(f"Total images: {len(dataset)}")
print(f"Training images: {len(train_dataset)}")
print(f"Validation images: {len(val_dataset)}")




Total images: 1832
Training images: 1374
Validation images: 458


In [3]:
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm

# sets device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# moves the model to device
model = model.to(device)

# loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 20

# training/validation
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0

    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct_train += (predicted == labels).sum().item()
        total_train += labels.size(0)

    avg_train_loss = train_loss / total_train
    train_accuracy = correct_train / total_train


    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)


            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            correct_val += (predicted == labels).sum().item()
            total_val += labels.size(0)

    avg_val_loss = val_loss / total_val
    val_accuracy = correct_val / total_val

    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, "
          f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")


Training Epoch 1/20: 100%|██████████| 86/86 [00:13<00:00,  6.17it/s]
Validating: 100%|██████████| 29/29 [00:02<00:00, 10.52it/s]


Epoch [1/20], Train Loss: 0.2393, Train Accuracy: 0.9178, Val Loss: 0.0837, Val Accuracy: 0.9738


Training Epoch 2/20: 100%|██████████| 86/86 [00:05<00:00, 15.58it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 22.96it/s]


Epoch [2/20], Train Loss: 0.1237, Train Accuracy: 0.9585, Val Loss: 0.2056, Val Accuracy: 0.9541


Training Epoch 3/20: 100%|██████████| 86/86 [00:05<00:00, 16.37it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 23.78it/s]


Epoch [3/20], Train Loss: 0.1211, Train Accuracy: 0.9607, Val Loss: 0.0558, Val Accuracy: 0.9782


Training Epoch 4/20: 100%|██████████| 86/86 [00:05<00:00, 16.16it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 18.56it/s]


Epoch [4/20], Train Loss: 0.1107, Train Accuracy: 0.9680, Val Loss: 0.1096, Val Accuracy: 0.9716


Training Epoch 5/20: 100%|██████████| 86/86 [00:05<00:00, 15.03it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 22.63it/s]


Epoch [5/20], Train Loss: 0.1111, Train Accuracy: 0.9592, Val Loss: 0.0848, Val Accuracy: 0.9651


Training Epoch 6/20: 100%|██████████| 86/86 [00:05<00:00, 16.39it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 22.39it/s]


Epoch [6/20], Train Loss: 0.0682, Train Accuracy: 0.9767, Val Loss: 0.0618, Val Accuracy: 0.9891


Training Epoch 7/20: 100%|██████████| 86/86 [00:05<00:00, 15.68it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 22.13it/s]


Epoch [7/20], Train Loss: 0.0690, Train Accuracy: 0.9825, Val Loss: 0.2350, Val Accuracy: 0.9498


Training Epoch 8/20: 100%|██████████| 86/86 [00:06<00:00, 14.06it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 23.44it/s]


Epoch [8/20], Train Loss: 0.1035, Train Accuracy: 0.9716, Val Loss: 0.0746, Val Accuracy: 0.9891


Training Epoch 9/20: 100%|██████████| 86/86 [00:05<00:00, 16.89it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 24.48it/s]


Epoch [9/20], Train Loss: 0.0830, Train Accuracy: 0.9709, Val Loss: 0.0984, Val Accuracy: 0.9651


Training Epoch 10/20: 100%|██████████| 86/86 [00:05<00:00, 16.51it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 24.73it/s]


Epoch [10/20], Train Loss: 0.1081, Train Accuracy: 0.9636, Val Loss: 0.0448, Val Accuracy: 0.9847


Training Epoch 11/20: 100%|██████████| 86/86 [00:05<00:00, 15.70it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 24.20it/s]


Epoch [11/20], Train Loss: 0.0507, Train Accuracy: 0.9803, Val Loss: 0.0721, Val Accuracy: 0.9847


Training Epoch 12/20: 100%|██████████| 86/86 [00:05<00:00, 15.70it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 22.95it/s]


Epoch [12/20], Train Loss: 0.0597, Train Accuracy: 0.9731, Val Loss: 0.0498, Val Accuracy: 0.9847


Training Epoch 13/20: 100%|██████████| 86/86 [00:05<00:00, 15.76it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 22.60it/s]


Epoch [13/20], Train Loss: 0.0730, Train Accuracy: 0.9803, Val Loss: 0.0614, Val Accuracy: 0.9869


Training Epoch 14/20: 100%|██████████| 86/86 [00:05<00:00, 16.54it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 24.62it/s]


Epoch [14/20], Train Loss: 0.0464, Train Accuracy: 0.9862, Val Loss: 0.0366, Val Accuracy: 0.9869


Training Epoch 15/20: 100%|██████████| 86/86 [00:05<00:00, 17.04it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 21.93it/s]


Epoch [15/20], Train Loss: 0.0358, Train Accuracy: 0.9884, Val Loss: 0.0377, Val Accuracy: 0.9847


Training Epoch 16/20: 100%|██████████| 86/86 [00:06<00:00, 14.32it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 21.95it/s]


Epoch [16/20], Train Loss: 0.0706, Train Accuracy: 0.9760, Val Loss: 0.0969, Val Accuracy: 0.9694


Training Epoch 17/20: 100%|██████████| 86/86 [00:05<00:00, 16.71it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 25.14it/s]


Epoch [17/20], Train Loss: 0.0624, Train Accuracy: 0.9782, Val Loss: 0.0582, Val Accuracy: 0.9891


Training Epoch 18/20: 100%|██████████| 86/86 [00:05<00:00, 15.79it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 24.17it/s]


Epoch [18/20], Train Loss: 0.0216, Train Accuracy: 0.9956, Val Loss: 0.0366, Val Accuracy: 0.9913


Training Epoch 19/20: 100%|██████████| 86/86 [00:04<00:00, 17.27it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 24.66it/s]


Epoch [19/20], Train Loss: 0.0279, Train Accuracy: 0.9891, Val Loss: 0.0412, Val Accuracy: 0.9891


Training Epoch 20/20: 100%|██████████| 86/86 [00:05<00:00, 16.68it/s]
Validating: 100%|██████████| 29/29 [00:01<00:00, 24.47it/s]

Epoch [20/20], Train Loss: 0.0152, Train Accuracy: 0.9956, Val Loss: 0.0453, Val Accuracy: 0.9891





In [4]:
# for saving the model
model_save_path = "wildfire_resnet50.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")


Model saved to wildfire_resnet50.pth


In [None]:
# for loading the model
model_load_path = "wildfire_resnet50.pth"
model.load_state_dict(torch.load(model_load_path))
# model = model.to(device)