In [1]:
import torch
import timm
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Normalize, Resize, Compose
from sklearn.model_selection import train_test_split
from torchvision.datasets import ImageFolder
from collections import defaultdict
from torch.utils.data import random_split
import time
from torchinfo import summary
import numpy as np
import pandas as pd


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import random

class RandomOneTransform:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        # Randomly choose one transformation to apply
        transform = random.choice(self.transforms)
        return transform(img)

In [None]:
def save_model(model, filename="model_state_dict.pth"):
  torch.save(model, filename)
  torch.save(model.state_dict(), f"s_{filename}.pth")
    

In [None]:

def validate(device, model, val_loader):
    model.eval()
    num_classes = len(val_loader.dataset.classes) 
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=int)

    with torch.no_grad():
        correct = 0
        total = 0
        class_correct = {}
        class_total = {}
        false_positives = {}
        false_negatives = {}

        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for label, prediction in zip(labels, predicted):
                confusion_matrix[label.item(), prediction.item()] += 1  # Update the confusion matrix
                if label == prediction:
                    class_correct[label.item()] = class_correct.get(label.item(), 0) + 1
                else:
                    false_negatives[label.item()] = false_negatives.get(label.item(), 0) + 1
                    false_positives[prediction.item()] = false_positives.get(prediction.item(), 0) + 1

                class_total[label.item()] = class_total.get(label.item(), 0) + 1

        # Calculate precision and recall
        precision_list = []
        recall_list = []

        for class_id in class_total.keys():
            tp = class_correct.get(class_id, 0)
            fp = false_positives.get(class_id, 0)
            fn = false_negatives.get(class_id, 0)

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0

            precision_list.append(precision)
            recall_list.append(recall)

        # Calculate overall precision, recall
        overall_precision = sum(precision_list) / len(precision_list) if len(precision_list) > 0 else 0
        overall_recall = sum(recall_list) / len(recall_list) if len(recall_list) > 0 else 0
        accuracy = 100 * correct / total

        print(f'Accuracy: {accuracy:.2f}%')
        print(f'Precision: {overall_precision:.2f}')
        print(f'Recall: {overall_recall:.2f}')

        return accuracy, overall_precision, overall_recall, confusion_matrix
    

In [None]:
def train(device, model, train_loader, criterion, optimizer, num_epochs):
    model.to(device).float()
    model.train()
    train_losses = []
    start = time.time()
    
    for epoch in range(num_epochs):
        print(f'Start epoch {epoch+1}/{num_epochs}')
        running_loss = 0.0
        train_correct = 0
        train_total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        train_accuracy = train_correct / train_total
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {train_accuracy:.2f}')

    end = time.time()
    computation_time = end - start 
    print(f'Training completed in {(end - start):.2f} seconds')
    print(f'Training accuracy: {train_accuracy:.2f}')
    print('--------------------------------')

    return train_accuracy, train_losses, computation_time


In [4]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
 
model_name = 'shufflenet_v2_x1_0'
learning_rate = 0.001
batch_size = 64
is_pretrained = True
is_pretrained_str = "pretrained" if is_pretrained else "not_pretrained"
file_name = f"{model_name}-lr_{learning_rate}-batch_{batch_size}-pretrained_{is_pretrained_str}"

model = torch.hub.load('pytorch/vision:v0.10.0', 'shufflenet_v2_x1_0', pretrained=True)
    
num_classes = 8

for param in model.parameters():
    param.requires_grad = False

model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.fc.requires_grad = True


criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr= learning_rate)

dataset_path_training = './dataset-tomatoes/train'

data_transforms = transforms.Compose([
    RandomOneTransform([
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.5),
        transforms.ColorJitter(contrast=0.5),
        transforms.ColorJitter(saturation=0.5),
        transforms.RandomRotation(45),
        transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), ratio=(0.75, 1.33))
    ]),
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

data_transforms_validation_test = transforms.Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


num_epochs = 25
print("Training on:", device, ", pretrained:", is_pretrained_str)
print(f"Number of parameters in the model: {sum(p.numel() for p in model.parameters())}")
print(f"Number of trainable parameters in the model: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

summary(model=model, input_size=(1, 3, 224, 224))


Training on: mps , pretrained: pretrained
Number of parameters in the model: 1261804
Number of trainable parameters in the model: 8200


Using cache found in /Users/lorenzoperinello/.cache/torch/hub/pytorch_vision_v0.10.0


Layer (type:depth-idx)                   Output Shape              Param #
ShuffleNetV2                             [1, 8]                    --
├─Sequential: 1-1                        [1, 24, 112, 112]         --
│    └─Conv2d: 2-1                       [1, 24, 112, 112]         (648)
│    └─BatchNorm2d: 2-2                  [1, 24, 112, 112]         (48)
│    └─ReLU: 2-3                         [1, 24, 112, 112]         --
├─MaxPool2d: 1-2                         [1, 24, 56, 56]           --
├─Sequential: 1-3                        [1, 116, 28, 28]          --
│    └─InvertedResidual: 2-4             [1, 116, 28, 28]          --
│    │    └─Sequential: 3-1              [1, 58, 28, 28]           (1,772)
│    │    └─Sequential: 3-2              [1, 58, 28, 28]           (5,626)
│    └─InvertedResidual: 2-5             [1, 116, 28, 28]          --
│    │    └─Sequential: 3-3              [1, 58, 28, 28]           (7,598)
│    └─InvertedResidual: 2-6             [1, 116, 28, 28]        

In [30]:
train_dataset = ImageFolder(root=dataset_path_training, transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

train_accuracy, train_losses, computation_time = train(device, model, train_loader, criterion, optimizer, num_epochs=25)

save_model(model, file_name)

Start epoch 1/25
Epoch 1/25, Loss: 1.6996, Accuracy: 0.38
Start epoch 2/25
Epoch 2/25, Loss: 1.3150, Accuracy: 0.65
Start epoch 3/25
Epoch 3/25, Loss: 1.0809, Accuracy: 0.73
Start epoch 4/25
Epoch 4/25, Loss: 0.9318, Accuracy: 0.77
Start epoch 5/25
Epoch 5/25, Loss: 0.8337, Accuracy: 0.79
Start epoch 6/25
Epoch 6/25, Loss: 0.7491, Accuracy: 0.81
Start epoch 7/25
Epoch 7/25, Loss: 0.6936, Accuracy: 0.83
Start epoch 8/25
Epoch 8/25, Loss: 0.6418, Accuracy: 0.84
Start epoch 9/25
Epoch 9/25, Loss: 0.6036, Accuracy: 0.85
Start epoch 10/25
Epoch 10/25, Loss: 0.5761, Accuracy: 0.86
Start epoch 11/25
Epoch 11/25, Loss: 0.5482, Accuracy: 0.86
Start epoch 12/25
Epoch 12/25, Loss: 0.5286, Accuracy: 0.87
Start epoch 13/25
Epoch 13/25, Loss: 0.5019, Accuracy: 0.87
Start epoch 14/25
Epoch 14/25, Loss: 0.4772, Accuracy: 0.88
Start epoch 15/25
Epoch 15/25, Loss: 0.4682, Accuracy: 0.88
Start epoch 16/25
Epoch 16/25, Loss: 0.4522, Accuracy: 0.88
Start epoch 17/25
Epoch 17/25, Loss: 0.4371, Accuracy: 0.8

In [31]:

validation_dataset = ImageFolder(root='./dataset-tomatoes/validation', transform=data_transforms_validation_test)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)

try:
    validation_accuracy, validation_precision, validation_recall, confusion_matrix = validate(device, model, validation_loader)
    print(f"Validation Accuracy: {validation_accuracy}%")
    print(f"Validation Precision: {validation_precision}")
    print(f"Validation Recall: {validation_recall}")
    print(f"Confusion Matrix:\n{confusion_matrix}")
except RuntimeError as e:
    print(f"Runtime error: {e}")

Accuracy: 92.09%
Precision: 0.91
Recall: 0.88
Validation Accuracy: 92.09354120267261%
Validation Precision: 0.9060383792801197
Validation Recall: 0.8814497618840745
Confusion Matrix:
[[214   1   4   2  10   4   0   0]
 [  9  90   9   3  12   3   2   3]
 [  0   8 188   1   5   5   2   1]
 [  4   1   2 109   1   3   0   0]
 [ 14   3   4   3 180   3   1   2]
 [  2   0   0   0   0 621   0   1]
 [  0   2   1   0   3   6  53   1]
 [  0   1   0   0   0   0   0 199]]


In [32]:

test_dataset = ImageFolder(root='./dataset-tomatoes/test', transform=data_transforms_validation_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

test_accuracy, test_precision, test_recall, test_matrix = validate(device, model, validation_loader)



Accuracy: 92.09%
Precision: 0.91
Recall: 0.88
