In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from cnn import SimpleCNN
import datetime

In [2]:
def get_mean_std():
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize images to 224x224
        transforms.ToTensor(),  # Convert images to tensors
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
    ])
    dataset = datasets.ImageFolder(root='archive', transform=transform)
    data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0

    for data, _ in data_loader:
        channels_sum += torch.mean(data, dim=[0, 2, 3])
        channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3])
        num_batches += 1

    mean = channels_sum / num_batches
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

mean, std = get_mean_std()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip with 50% probability
    transforms.RandomRotation(degrees=15),   # Random rotation within a range of 15 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Random changes in brightness, contrast, saturation, and hue
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

train_dataset = datasets.ImageFolder(root='archive', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = datasets.ImageFolder(root='testC', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)


n_class = len(train_dataset.classes)
model = SimpleCNN(num_classes=n_class)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

train_loss = []
train_accuracy = []


test_loss = []
test_accuracy = []

In [3]:
print("Training/Testing started...")
EPOCHS = 30
for epoch in range(EPOCHS):
    train_running_loss = 0.0
    train_correct_predictions = 0
    train_total_predictions = 0
    
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_running_loss += loss.item()
        
        _, predicted = torch.max(outputs, 1)
        train_correct_predictions += (predicted == labels).sum().item()
        train_total_predictions += labels.size(0)
    
    accuracy = train_correct_predictions / train_total_predictions * 100
    train_accuracy.append(accuracy)
    train_loss.append(train_running_loss)

    model.eval()
    with torch.no_grad():
        test_running_loss = 0.0
        test_correct_predictions = 0
        test_total_predictions = 0

        for inputs, labels in test_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_running_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            test_correct_predictions += (predicted == labels).sum().item()
            test_total_predictions += labels.size(0)
            
        test_loss.append(test_running_loss)
        accuracy = test_correct_predictions / test_total_predictions * 100
        print(f"Epoch: {epoch} Train Loss: {train_running_loss} Train Accuracy: {accuracy:.2f}%") 
        print(f"Epoch: {epoch} Test Loss: {test_running_loss} Test Accuracy: {accuracy:.2f}%\nxx")
        test_accuracy.append(accuracy)

Training/Testing started...
Epoch: 0 Train Loss: 57.37014329433441 Train Accuracy: 19.11%
Epoch: 0 Test Loss: 51.91836190223694 Test Accuracy: 19.11%

Epoch: 1 Train Loss: 48.07400929927826 Train Accuracy: 32.33%
Epoch: 1 Test Loss: 46.82590627670288 Test Accuracy: 32.33%



KeyboardInterrupt: 

In [None]:
# Creating subplots
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
# Loss subplot
axs[0].plot(range(EPOCHS), train_loss, label='Train Loss')
axs[0].plot(range(EPOCHS), test_loss, label='Test Loss')
axs[0].set_title('Loss')
axs[0].set_xlabel('EPOCHS')
axs[0].set_ylabel('Loss')
axs[0].legend()
axs[0].grid(True)

# Accuracy subplot
axs[1].plot(range(EPOCHS), train_accuracy, label='Train Accuracy')
axs[1].plot(range(EPOCHS), test_accuracy, label='Test Accuracy')
axs[1].set_title('Accuracy')
axs[1].set_xlabel('EPOCHS')
axs[1].set_ylabel('Accuracy')
axs[1].legend()
axs[1].grid(True)

# Adjust layout
plt.tight_layout()
plt.show()

In [None]:
torch.save(model, f"{datetime.date.today()}_model.pth")