In [7]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import os

In [8]:
def get_dataloaders(data_dir, batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    train_dataset = datasets.ImageFolder(os.path.join(data_dir, "train"), transform)
    val_dataset   = datasets.ImageFolder(os.path.join(data_dir, "val"), transform)
    test_dataset  = datasets.ImageFolder(os.path.join(data_dir, "test"), transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader, len(train_dataset.classes)

In [9]:
data_dir = r"D:\STAI_PROJECT\github\Dangerous-Farm-Insects-Classification-main\data\processed\farm_insects\splits" #changable
train_loader, val_loader, test_loader, num_classes = get_dataloaders(data_dir)

In [10]:
model = models.googlenet(weights=models.GoogLeNet_Weights.IMAGENET1K_V1)
model.aux_logits = False  # disable auxiliary classifiers for simplicity

# Replace the final layer
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

In [11]:
for param in model.parameters():
    param.requires_grad = False

# Unfreeze only final layer
for param in model.fc.parameters():
    param.requires_grad = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, momentum=0.9)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    print("-" * 10)

    for phase, loader in [('train', train_loader), ('val', val_loader)]:
        if phase == 'train':
            model.train()
        else:
            model.eval()

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)

                # GoogLeNet can return tuple if aux_logits=True
                if isinstance(outputs, tuple):
                    outputs = outputs[0]

                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(loader.dataset)
        epoch_acc = running_corrects.double() / len(loader.dataset)

        print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

Epoch 1/10
----------
train Loss: 2.7068 Acc: 0.0833
val Loss: 2.6100 Acc: 0.1937
Epoch 2/10
----------
train Loss: 2.5253 Acc: 0.2448
val Loss: 2.4491 Acc: 0.3676
Epoch 3/10
----------
train Loss: 2.3482 Acc: 0.4054
val Loss: 2.3010 Acc: 0.4308
Epoch 4/10
----------
train Loss: 2.1896 Acc: 0.4757
val Loss: 2.1783 Acc: 0.4941
Epoch 5/10
----------
train Loss: 2.0511 Acc: 0.5302
val Loss: 2.0641 Acc: 0.5217
Epoch 6/10
----------
train Loss: 1.9481 Acc: 0.5768
val Loss: 1.9628 Acc: 0.5534
Epoch 7/10
----------
train Loss: 1.8471 Acc: 0.6065
val Loss: 1.8878 Acc: 0.5573
Epoch 8/10
----------
train Loss: 1.7716 Acc: 0.6095
val Loss: 1.8176 Acc: 0.5573
Epoch 9/10
----------
train Loss: 1.6925 Acc: 0.6412
val Loss: 1.7626 Acc: 0.5692
Epoch 10/10
----------
train Loss: 1.6323 Acc: 0.6264
val Loss: 1.6933 Acc: 0.6008


In [12]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 62.34%


In [13]:
torch.save(model.state_dict(), "googlenet_finetuned_last.pth")