In [5]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets, models
from torch.utils.data import DataLoader
from torch import nn, optim
import os


In [3]:
# Define transformations for the input data
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet34 expects images of size 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [6]:
# Load train, validation and test dataset
data_dir = os.getcwd()
print("data_dir", data_dir)
train_file = os.path.join(data_dir, "train")
print("train_file", train_file)
val_file = os.path.join(data_dir, "validation")
print("val_file", val_file)
test_file = os.path.join(data_dir, "test")
print("test_file", test_file)


data_dir C:\Users\etson\PycharmProjects\pythonProject
train_file C:\Users\etson\PycharmProjects\pythonProject\train
val_file C:\Users\etson\PycharmProjects\pythonProject\validation
test_file C:\Users\etson\PycharmProjects\pythonProject\test


In [7]:
train_dataset = datasets.ImageFolder(train_file, transform=transform)
valid_dataset = datasets.ImageFolder(val_file, transform=transform)
test_dataset = datasets.ImageFolder(test_file, transform=transform)


In [8]:
# Creating data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [9]:
# Load the pretrained ResNet34 model
model = models.resnet34(pretrained=True)


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to C:\Users\etson/.cache\torch\hub\checkpoints\resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:08<00:00, 10.2MB/s]


In [10]:
# Modify the final fully connected layer for binary classification
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)


In [11]:
# Move the model to the GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [12]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


In [13]:
def evaluate_model(model, data_loader):
    model.eval()  # Set model to evaluation mode
    running_corrects = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)
            total += labels.size(0)

    accuracy = running_corrects.double() / total
    return accuracy


In [14]:
def train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs=5):
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode

        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)

        # Evaluate on validation set
        valid_acc = evaluate_model(model, valid_loader)

        print(f'Epoch {epoch+1}/{num_epochs} Train Loss: {epoch_loss:.4f} Train Acc: {epoch_acc:.4f} Valid Acc: {valid_acc:.4f}')


In [ ]:
# Train the model with validation
train_model(model, criterion, optimizer, train_loader, valid_loader, num_epochs=5)


KeyboardInterrupt: 

In [ ]:
# Evaluate the model on the test set
test_acc = evaluate_model(model, test_loader)
print(f'Test Accuracy: {test_acc:.4f}')
