In [1]:
!pip install torchvision --index-url https://download.pytorch.org/whl/cu118

Looking in indexes: https://download.pytorch.org/whl/cu118


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models, transforms
from sklearn.metrics import precision_score, recall_score, f1_score
from PIL import Image
import os
from collections import Counter

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [5]:
num_classes = 26  # Number of classes in the dataset
num_epochs = 50
batch_size = 32
learning_rate = 0.001

In [6]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [9]:
# Data directories
data_dir = '/content/drive/MyDrive/malevis_train_val_224x224'  # Replacing with the dataset path
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')

In [10]:
image_datasets = {
    'train': datasets.ImageFolder(train_dir, transform=data_transforms['train']),
    'val': datasets.ImageFolder(val_dir, transform=data_transforms['val'])
}

In [11]:
dataloaders = {
    'train': DataLoader(image_datasets['train'], batch_size=batch_size, shuffle=True, num_workers=4),
    'val': DataLoader(image_datasets['val'], batch_size=batch_size, shuffle=False, num_workers=4)
}

# Load pre-trained AlexNet model
model = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)

Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:01<00:00, 212MB/s]


In [12]:
# Freeze all the layers except the last fully connected layers
for param in model.parameters():
    param.requires_grad = False

In [13]:
# Modify the classifier to match the number of classes in the dataset and unfreeze some layers
model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)

# Unfreeze the last convolutional block for fine-tuning
for param in model.features[10:].parameters():
    param.requires_grad = True

In [14]:
# Move model to device
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [15]:
import time

# Initialize variables to store the best metrics
best_metrics = {
    'epoch': -1,
    'accuracy': 0.0,
    'precision': 0.0,
    'recall': 0.0,
    'f1_score': 0.0,
    'loss': float('inf')
}

start_time = time.time()

# Training and validation
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    print('-' * 10)

    # Each epoch has a training and validation phase
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train()  # Set model to training mode
        else:
            model.eval()   # Set model to evaluate mode

        running_loss = 0.0
        running_corrects = 0
        all_labels = []
        all_preds = []

        # Iterate over data
        for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward
            with torch.set_grad_enabled(phase == 'train'):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # Backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

        epoch_loss = running_loss / len(image_datasets[phase])
        epoch_acc = running_corrects.double() / len(image_datasets[phase])

        # Calculate precision, recall, and F1 score
        epoch_precision = precision_score(all_labels, all_preds, average='weighted')
        epoch_recall = recall_score(all_labels, all_preds, average='weighted')
        epoch_f1 = f1_score(all_labels, all_preds, average='weighted')

        print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        print(f'{phase} Precision: {epoch_precision:.4f} Recall: {epoch_recall:.4f} F1 Score: {epoch_f1:.4f}')

        # Store the best metrics based on validation accuracy
        if phase == 'val' and epoch_acc > best_metrics['accuracy']:
            best_metrics = {
                'epoch': epoch + 1,
                'accuracy': epoch_acc.item(),
                'precision': epoch_precision,
                'recall': epoch_recall,
                'f1_score': epoch_f1,
                'loss': epoch_loss
            }

elapsed_time = time.time() - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time % 60)

print('Training complete')
print(f'Training complete in {elapsed_mins}m {elapsed_secs}s')
print(f'Best val Acc: {best_metrics["accuracy"]:.4f}')
print(f'Best val Loss: {best_metrics["loss"]:.4f}')
print(f'Best val Precision: {best_metrics["precision"]:.4f}')
print(f'Best val Recall: {best_metrics["recall"]:.4f}')
print(f'Best val F1 Score: {best_metrics["f1_score"]:.4f}')

Epoch 1/50
----------
train Loss: 0.7921 Acc: 0.7887
train Precision: 0.7829 Recall: 0.7887 F1 Score: 0.7848
val Loss: 1.0443 Acc: 0.7749
val Precision: 0.8396 Recall: 0.7749 F1 Score: 0.7738
Epoch 2/50
----------
train Loss: 0.4061 Acc: 0.8827
train Precision: 0.8822 Recall: 0.8827 F1 Score: 0.8821
val Loss: 0.9875 Acc: 0.7616
val Precision: 0.8495 Recall: 0.7616 F1 Score: 0.7584
Epoch 3/50
----------
train Loss: 0.3193 Acc: 0.9068
train Precision: 0.9067 Recall: 0.9068 F1 Score: 0.9066
val Loss: 1.0520 Acc: 0.7981
val Precision: 0.8588 Recall: 0.7981 F1 Score: 0.7986
Epoch 4/50
----------
train Loss: 0.2733 Acc: 0.9182
train Precision: 0.9182 Recall: 0.9182 F1 Score: 0.9181
val Loss: 0.8557 Acc: 0.8137
val Precision: 0.8672 Recall: 0.8137 F1 Score: 0.8140
Epoch 5/50
----------
train Loss: 0.2475 Acc: 0.9273
train Precision: 0.9269 Recall: 0.9273 F1 Score: 0.9270
val Loss: 0.9262 Acc: 0.7938
val Precision: 0.8597 Recall: 0.7938 F1 Score: 0.7916
Epoch 6/50
----------
train Loss: 0.2095

In [16]:
from PIL import Image

# Prediction function
def predict_image(image_path, model, class_names):
    model.eval()

    # Image transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # Load image and apply transformations
    image = Image.open(image_path).convert('RGB')  # Convert to RGB to ensure 3 channels
    image = transform(image).unsqueeze(0)  # Add batch dimension

    # Move image to device
    image = image.to(device)

    # Make prediction
    with torch.no_grad():
        outputs = model(image)
        _, preds = torch.max(outputs, 1)

    # Get the class name
    predicted_class = class_names[preds.item()]
    return predicted_class

# Example usage
class_names = image_datasets['train'].classes  # Get class names from the training dataset
image_path = '/content/drive/MyDrive/malevis_train_val_224x224/val/Neoreklami/c0fa7411c52094305ef8130c86eeafaf072111d4resized_image.png'  # Replace with the path to your image
predicted_class = predict_image(image_path, model, class_names)
print(f'Predicted class: {predicted_class}')

Predicted class: Neoreklami
