In [2]:
import torch
import timm
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Normalize, Resize, Compose
from sklearn.model_selection import train_test_split
from torchvision.datasets import ImageFolder
from collections import defaultdict
from torch.utils.data import random_split
import time
from torchinfo import summary
import numpy as np
import pandas as pd

In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import random

class RandomOneTransform:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        transform = random.choice(self.transforms)
        return transform(img)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def save_model(model, filename="model_state_dict.pth"):
  torch.save(model, filename)
  torch.save(model.state_dict(), f"s_{filename}.pth")
    

In [3]:

def validate(device, model, val_loader):
    model.eval()
    num_classes = len(val_loader.dataset.classes)  
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=int)

    with torch.no_grad():
        correct = 0
        total = 0
        class_correct = {}
        class_total = {}
        false_positives = {}
        false_negatives = {}

        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for label, prediction in zip(labels, predicted):
                confusion_matrix[label.item(), prediction.item()] += 1  # Update the confusion matrix
                if label == prediction:
                    class_correct[label.item()] = class_correct.get(label.item(), 0) + 1
                else:
                    false_negatives[label.item()] = false_negatives.get(label.item(), 0) + 1
                    false_positives[prediction.item()] = false_positives.get(prediction.item(), 0) + 1

                class_total[label.item()] = class_total.get(label.item(), 0) + 1

        # Calculate precision and recall
        precision_list = []
        recall_list = []

        for class_id in class_total.keys():
            tp = class_correct.get(class_id, 0)
            fp = false_positives.get(class_id, 0)
            fn = false_negatives.get(class_id, 0)

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0

            precision_list.append(precision)
            recall_list.append(recall)

        # Calculate overall precision, recall
        overall_precision = sum(precision_list) / len(precision_list) if len(precision_list) > 0 else 0
        overall_recall = sum(recall_list) / len(recall_list) if len(recall_list) > 0 else 0
        accuracy = 100 * correct / total

        print(f'Accuracy: {accuracy:.2f}%')
        print(f'Precision: {overall_precision:.2f}')
        print(f'Recall: {overall_recall:.2f}')

        return accuracy, overall_precision, overall_recall, confusion_matrix
    

In [5]:
def train(device, model, train_loader, criterion, optimizer, num_epochs):
    model.to(device).float()
    model.train()
    train_losses = []
    start = time.time()
    
    for epoch in range(num_epochs):
        print(f'Start epoch {epoch+1}/{num_epochs}')
        running_loss = 0.0
        train_correct = 0
        train_total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        train_accuracy = train_correct / train_total
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {train_accuracy:.2f}')

    end = time.time()
    computation_time = end - start 
    print(f'Training completed in {(end - start):.2f} seconds')
    print(f'Training accuracy: {train_accuracy:.2f}')
    print('--------------------------------')

    return train_accuracy, train_losses, computation_time


In [6]:
device = "mps" if torch.backends.mps.is_available() else "cpu"


model_name = 'efficientnetv2_rw_s'
learning_rate = 0.001
batch_size = 64
is_pretrained = True
is_pretrained_str = "pretrained" if is_pretrained else "not_pretrained"
file_name = f"{model_name}-lr_{learning_rate}-batch_{batch_size}-pretrained_{is_pretrained_str}"

model = timm.create_model(model_name, pretrained=True).to(device)
for param in model.parameters():
    param.requires_grad = False
    
num_classes = 8
model.classifier = nn.Linear(model.classifier.in_features, num_classes).to(device)
print("Default input size for efficientnetv2_rw_s:", model.default_cfg['input_size'])


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr= learning_rate)
dataset_path_training = '../dataset-tomatoes/train'


data_transforms = transforms.Compose([
    RandomOneTransform([
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.5),
        transforms.ColorJitter(contrast=0.5),
        transforms.ColorJitter(saturation=0.5),
        transforms.RandomRotation(45),
        transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), ratio=(0.75, 1.33))
    ]),
    Resize((288, 288)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

data_transforms_validation_test = transforms.Compose([
    Resize((288, 288)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


num_epochs = 25

print("Training on:", device, ", pretrained:", is_pretrained_str)
print(f"Number of parameters in the model: {sum(p.numel() for p in model.parameters())}")
print(f"Number of trainable parameters in the model: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

summary(model=model, input_size=(1, 3, 224, 224))


Default input size for efficientnetv2_rw_s: (3, 288, 288)
Training on: mps , pretrained: pretrained
Number of parameters in the model: 22162640
Number of trainable parameters in the model: 14344


Layer (type:depth-idx)                        Output Shape              Param #
EfficientNet                                  [1, 8]                    --
├─Conv2d: 1-1                                 [1, 24, 112, 112]         (648)
├─BatchNormAct2d: 1-2                         [1, 24, 112, 112]         48
│    └─Identity: 2-1                          [1, 24, 112, 112]         --
│    └─SiLU: 2-2                              [1, 24, 112, 112]         --
├─Sequential: 1-3                             [1, 272, 7, 7]            --
│    └─Sequential: 2-3                        [1, 24, 112, 112]         --
│    │    └─EdgeResidual: 3-1                 [1, 24, 112, 112]         (5,856)
│    │    └─EdgeResidual: 3-2                 [1, 24, 112, 112]         (5,856)
│    └─Sequential: 2-4                        [1, 48, 56, 56]           --
│    │    └─EdgeResidual: 3-3                 [1, 48, 56, 56]           (25,632)
│    │    └─EdgeResidual: 3-4                 [1, 48, 56, 56]           (92,

In [7]:
train_dataset = ImageFolder(root=dataset_path_training, transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Run training and validation
train_accuracy, train_losses, computation_time = train(device, model, train_loader, criterion, optimizer, num_epochs=25)

save_model(model, file_name)

Start epoch 1/25
Epoch 1/25, Loss: 0.9006, Accuracy: 0.75
Start epoch 2/25
Epoch 2/25, Loss: 0.5051, Accuracy: 0.86
Start epoch 3/25
Epoch 3/25, Loss: 0.4166, Accuracy: 0.88
Start epoch 4/25
Epoch 4/25, Loss: 0.3691, Accuracy: 0.89
Start epoch 5/25
Epoch 5/25, Loss: 0.3455, Accuracy: 0.90
Start epoch 6/25
Epoch 6/25, Loss: 0.3272, Accuracy: 0.90
Start epoch 7/25
Epoch 7/25, Loss: 0.3076, Accuracy: 0.91
Start epoch 8/25
Epoch 8/25, Loss: 0.2960, Accuracy: 0.91
Start epoch 9/25
Epoch 9/25, Loss: 0.2830, Accuracy: 0.91
Start epoch 10/25
Epoch 10/25, Loss: 0.2802, Accuracy: 0.92
Start epoch 11/25
Epoch 11/25, Loss: 0.2665, Accuracy: 0.92
Start epoch 12/25
Epoch 12/25, Loss: 0.2627, Accuracy: 0.92
Start epoch 13/25
Epoch 13/25, Loss: 0.2517, Accuracy: 0.92
Start epoch 14/25
Epoch 14/25, Loss: 0.2601, Accuracy: 0.92
Start epoch 15/25
Epoch 15/25, Loss: 0.2394, Accuracy: 0.93
Start epoch 16/25
Epoch 16/25, Loss: 0.2393, Accuracy: 0.93
Start epoch 17/25
Epoch 17/25, Loss: 0.2467, Accuracy: 0.9

In [9]:
validation_dataset = ImageFolder(root='/Users/lorenzoperinello/Desktop/Uni/VCS/vcs-tomatoes/dataset-tomatoes/validation', transform=data_transforms_validation_test)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)

try:
    validation_accuracy, validation_precision, validation_recall, confusion_matrix = validate(device, model, validation_loader)
    print(f"Validation Accuracy: {validation_accuracy}%")
    print(f"Validation Precision: {validation_precision}")
    print(f"Validation Recall: {validation_recall}")
    print(f"Confusion Matrix:\n{confusion_matrix}")
except RuntimeError as e:
    print(f"Runtime error: {e}")

Accuracy: 92.20%
Precision: 0.89
Recall: 0.89
Validation Accuracy: 92.20489977728285%
Validation Precision: 0.8931997572971996
Validation Recall: 0.891759851677882
Confusion Matrix:
[[213   5   3   3   7   3   1   0]
 [  5  95   9   6  10   0   1   5]
 [  1  13 186   3   2   1   1   3]
 [  3   1   2 113   1   0   0   0]
 [  8  11   2   3 177   1   2   6]
 [  4   2   0   0   0 617   0   1]
 [  2   0   1   2   4   0  56   1]
 [  0   0   0   1   0   0   0 199]]


In [9]:
test_dataset = ImageFolder(root='../dataset-tomatoes/test', transform=data_transforms_validation_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

test_accuracy, test_precision, test_recall, test_matrix = validate(device, model, validation_loader)

Accuracy: 92.32%
Precision: 0.90
Recall: 0.89
