In [22]:


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

import os

In [5]:
# Define your custom dataset class
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        self.data = ImageFolder(data_dir, transform=transform)
        self.classes = self.data.classes

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img, label = self.data[idx]
        return img, label

In [24]:
# Set up data transformations
transform = transforms.Compose([
    transforms.Resize((640, 640)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [25]:
data_dir = '../Roadviewer/Dataset'
# Load your custom dataset
train_data = CustomDataset(data_dir + '/train', transform=transform)
test_data = CustomDataset(data_dir + '/test', transform=transform)
val_data = CustomDataset(data_dir + '/valid', transform=transform)


In [26]:
# Create data loaders
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [27]:
# Define ResNet model
model = resnet18(pretrained=True)
num_classes = len(train_data.classes)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Set up optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


In [28]:
# Training loop
num_epochs = 100
save_interval = 10
best_val_accuracy = 0.0
best_epoch = 0
best_f1_socre = 0
save_path = r'../Roadviewer/TrainingRes'

for epoch in range(num_epochs):
    train_loss = 0.0
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0.0
    
    correct = 0
    total = 0
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            true_labels.extend(labels.tolist())
            predicted_labels.extend(predicted.tolist())

    val_accuracy = 100 * correct / total
    precision = precision_score(true_labels, predicted_labels, average='binary')
    recall = recall_score(true_labels, predicted_labels, average='binary')
    f1 = f1_score(true_labels, predicted_labels, average='binary')

    print(f"Epoch [{epoch+1}/{num_epochs}]: Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")

    # Save best model checkpoint
    if f1 > best_f1_socre:
        best_f1_socre = f1
        best_epoch = epoch
        torch.save(model.state_dict(), os.path.join(save_path,'resnet_custom_model_best.pth'))


    torch.save(model.state_dict(), os.path.join(save_path,f'last_{epoch}epoch.pth'))
    print(f"Saved model checkpoint at epoch {epoch+1} to {save_path}")

# Final save after all epochs
torch.save(model.state_dict(), 'resnet_custom_model_final.pth')
print(f"Best model achieved at epoch {best_epoch+1} with validation accuracy {best_val_accuracy:.2f}%")

Epoch [1/100]: Val Loss: 3.4774, Val Acc: 81.68%
Precision: 0.6364, Recall: 0.4565, F1-score: 0.5316
Saved model checkpoint at epoch 1 to ../Roadviewer/TrainingRes
Epoch [2/100]: Val Loss: 3.4349, Val Acc: 81.19%
Precision: 0.6429, Recall: 0.3913, F1-score: 0.4865
Saved model checkpoint at epoch 2 to ../Roadviewer/TrainingRes


KeyboardInterrupt: 