In [1]:
import torch
import timm
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor, Normalize, Resize, Compose
from sklearn.model_selection import train_test_split
from torchvision.datasets import ImageFolder
from collections import defaultdict
from torch.utils.data import random_split
import time
from torchinfo import summary
import numpy as np
import pandas as pd


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import random

class RandomOneTransform:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        # Randomly choose one transformation to apply
        transform = random.choice(self.transforms)
        return transform(img)

In [3]:
def save_model(model, filename="model_state_dict.pth"):
  torch.save(model, filename)
  torch.save(model.state_dict(), f"s_{filename}.pth")
    

In [4]:

def validate(device, model, val_loader):
    model.eval()
    num_classes = len(val_loader.dataset.classes)  # Assuming dataset classes are accessible like this
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=int)

    with torch.no_grad():
        correct = 0
        total = 0
        class_correct = {}
        class_total = {}
        false_positives = {}
        false_negatives = {}

        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            for label, prediction in zip(labels, predicted):
                confusion_matrix[label.item(), prediction.item()] += 1  # Update the confusion matrix
                if label == prediction:
                    class_correct[label.item()] = class_correct.get(label.item(), 0) + 1
                else:
                    false_negatives[label.item()] = false_negatives.get(label.item(), 0) + 1
                    false_positives[prediction.item()] = false_positives.get(prediction.item(), 0) + 1

                class_total[label.item()] = class_total.get(label.item(), 0) + 1

        # Calculate precision and recall
        precision_list = []
        recall_list = []

        for class_id in class_total.keys():
            tp = class_correct.get(class_id, 0)
            fp = false_positives.get(class_id, 0)
            fn = false_negatives.get(class_id, 0)

            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0

            precision_list.append(precision)
            recall_list.append(recall)

        # Calculate overall precision, recall
        overall_precision = sum(precision_list) / len(precision_list) if len(precision_list) > 0 else 0
        overall_recall = sum(recall_list) / len(recall_list) if len(recall_list) > 0 else 0
        accuracy = 100 * correct / total

        print(f'Accuracy: {accuracy:.2f}%')
        print(f'Precision: {overall_precision:.2f}')
        print(f'Recall: {overall_recall:.2f}')

        return accuracy, overall_precision, overall_recall, confusion_matrix
    

In [5]:
def train(device, model, train_loader, criterion, optimizer, num_epochs):
    model.to(device).float()
    model.train()
    train_losses = []
    start = time.time()
    
    for epoch in range(num_epochs):
        print(f'Start epoch {epoch+1}/{num_epochs}')
        running_loss = 0.0
        train_correct = 0
        train_total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        train_accuracy = train_correct / train_total
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {train_accuracy:.2f}')

    end = time.time()
    computation_time = end - start 
    print(f'Training completed in {(end - start):.2f} seconds')
    print(f'Training accuracy: {train_accuracy:.2f}')
    print('--------------------------------')

    return train_accuracy, train_losses, computation_time


In [6]:
# Specify the device for training
device = "mps" if torch.backends.mps.is_available() else "cpu"


# Load the pre-trained MobileNetV3 model
 
model_name = 'shufflenet_v2_x1_0'
learning_rate = 0.001
batch_size = 64
is_pretrained = True
is_pretrained_str = "pretrained" if is_pretrained else "not_pretrained"
file_name = f"{model_name}-lr_{learning_rate}-batch_{batch_size}-pretrained_{is_pretrained_str}"

model = torch.hub.load('pytorch/vision:v0.10.0', 'shufflenet_v2_x1_0', pretrained=True)
    
num_classes = 8

for param in model.parameters():
    param.requires_grad = False

model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
model.fc.requires_grad = True


# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr= learning_rate)

# Define path to PlantVillage dataset
dataset_path_training = '../dataset-tomatoes/train'

# Load the PlantVillage dataset with appropriate transforms

data_transforms = transforms.Compose([
    RandomOneTransform([
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.5),
        transforms.ColorJitter(contrast=0.5),
        transforms.ColorJitter(saturation=0.5),
        transforms.RandomRotation(45),
        transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), ratio=(0.75, 1.33))
    ]),
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

data_transforms_validation_test = transforms.Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Ebable training mode

# Fine-tune the model
num_epochs = 25

for name, param in model.named_parameters():
    if name.startswith('conv'):
        param.requires_grad = True

print("Training on:", device, ", pretrained:", is_pretrained_str)
print(f"Number of parameters in the model: {sum(p.numel() for p in model.parameters())}")
print(f"Number of trainable parameters in the model: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(total_params)

for name, param in model.named_parameters():
    print(name, param.requires_grad)

Training on: mps , pretrained: pretrained
Number of parameters in the model: 1261804
Number of trainable parameters in the model: 486080
1261804
conv1.0.weight True
conv1.1.weight True
conv1.1.bias True
stage2.0.branch1.0.weight False
stage2.0.branch1.1.weight False
stage2.0.branch1.1.bias False
stage2.0.branch1.2.weight False
stage2.0.branch1.3.weight False
stage2.0.branch1.3.bias False
stage2.0.branch2.0.weight False
stage2.0.branch2.1.weight False
stage2.0.branch2.1.bias False
stage2.0.branch2.3.weight False
stage2.0.branch2.4.weight False
stage2.0.branch2.4.bias False
stage2.0.branch2.5.weight False
stage2.0.branch2.6.weight False
stage2.0.branch2.6.bias False
stage2.1.branch2.0.weight False
stage2.1.branch2.1.weight False
stage2.1.branch2.1.bias False
stage2.1.branch2.3.weight False
stage2.1.branch2.4.weight False
stage2.1.branch2.4.bias False
stage2.1.branch2.5.weight False
stage2.1.branch2.6.weight False
stage2.1.branch2.6.bias False
stage2.2.branch2.0.weight False
stage2.2.bran

Using cache found in /Users/lorenzoperinello/.cache/torch/hub/pytorch_vision_v0.10.0


In [7]:
train_dataset = ImageFolder(root=dataset_path_training, transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Run training and validation
train_accuracy, train_losses, computation_time = train(device, model, train_loader, criterion, optimizer, num_epochs=25)

save_model(model, file_name)

Start epoch 1/25
Epoch 1/25, Loss: 0.7303, Accuracy: 0.78
Start epoch 2/25
Epoch 2/25, Loss: 0.2190, Accuracy: 0.93
Start epoch 3/25
Epoch 3/25, Loss: 0.1650, Accuracy: 0.95
Start epoch 4/25
Epoch 4/25, Loss: 0.1401, Accuracy: 0.95
Start epoch 5/25
Epoch 5/25, Loss: 0.1170, Accuracy: 0.96
Start epoch 6/25
Epoch 6/25, Loss: 0.1054, Accuracy: 0.97
Start epoch 7/25
Epoch 7/25, Loss: 0.1064, Accuracy: 0.97
Start epoch 8/25
Epoch 8/25, Loss: 0.0930, Accuracy: 0.97
Start epoch 9/25
Epoch 9/25, Loss: 0.0770, Accuracy: 0.98
Start epoch 10/25
Epoch 10/25, Loss: 0.0744, Accuracy: 0.98
Start epoch 11/25
Epoch 11/25, Loss: 0.0757, Accuracy: 0.98
Start epoch 12/25
Epoch 12/25, Loss: 0.0664, Accuracy: 0.98
Start epoch 13/25
Epoch 13/25, Loss: 0.0679, Accuracy: 0.98
Start epoch 14/25
Epoch 14/25, Loss: 0.0728, Accuracy: 0.98
Start epoch 15/25
Epoch 15/25, Loss: 0.0711, Accuracy: 0.98
Start epoch 16/25
Epoch 16/25, Loss: 0.0640, Accuracy: 0.98
Start epoch 17/25
Epoch 17/25, Loss: 0.0586, Accuracy: 0.9

In [8]:

validation_dataset = ImageFolder(root='../dataset-tomatoes/validation', transform=data_transforms_validation_test)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)

try:
    validation_accuracy, validation_precision, validation_recall, confusion_matrix = validate(device, model, validation_loader)
    print(f"Validation Accuracy: {validation_accuracy}%")
    print(f"Validation Precision: {validation_precision}")
    print(f"Validation Recall: {validation_recall}")
    print(f"Confusion Matrix:\n{confusion_matrix}")
except RuntimeError as e:
    print(f"Runtime error: {e}")

Accuracy: 97.44%
Precision: 0.96
Recall: 0.96
Validation Accuracy: 97.43875278396436%
Validation Precision: 0.9641523085455737
Validation Recall: 0.9644471327255658
Confusion Matrix:
[[224   0   3   0   7   0   1   0]
 [  4 115   2   1   9   0   0   0]
 [  0   2 205   1   0   0   2   0]
 [  0   0   0 118   1   0   0   1]
 [  3   1   1   2 202   1   0   0]
 [  0   0   0   0   0 623   1   0]
 [  0   1   0   0   0   0  64   1]
 [  0   0   0   0   0   0   1 199]]


In [9]:

test_dataset = ImageFolder(root='../dataset-tomatoes/test', transform=data_transforms_validation_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

test_accuracy, test_precision, test_recall, test_matrix = validate(device, model, validation_loader)



Accuracy: 97.44%
Precision: 0.96
Recall: 0.96


In [10]:


csv_path = '../results.csv'
# append to a csv model learning_rate	batch	accuracy (Tr)	precision (Tr)	accuracy (Va)	precision (Va)	accuracy (Te)	precision (Te)	time (s)
# check if the file exists

total_params = sum(p.numel() for p in model.parameters())
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

df = pd.read_csv(csv_path)
row = pd.DataFrame({
    'model': [model_name],
    'params': [total_params],
    'tr_params': [total_trainable_params],
    'learning_rate': [learning_rate],
    'batch': [batch_size],
    'accuracy_(Tr)': [train_accuracy],
    'accuracy_(Va)': [validation_accuracy],
    'precision_(Va)': [validation_precision],
    'recall_(Va)': [validation_recall],
    'accuracy_(Te)': [test_accuracy],
    'precision_(Te)': [test_precision],
    'recall_(Te)': [test_recall],
    'time_(s)': [computation_time]
}, index=[0])
print(df)

# add the row to the dataframe csv file
df = pd.concat([df, row], ignore_index=False)


df.to_csv(csv_path, index=False)





                          model     params  tr_params  learning_rate  batch  \
0          vit_tiny_patch16_224    5525960     149576          0.001     64   
1                         vgg16  134293320   16814088          0.001     64   
2  swin_tiny_patch4_window7_224   27525506    2366216          0.001     64   

   accuracy_(Tr)  accuracy_(Va)  precision_(Va)  recall_(Va)  accuracy_(Te)  \
0       0.921827      92.873051        0.910795     0.892480      92.873051   
1       0.979658      95.322940        0.941256     0.925700      95.322940   
2       0.978756      96.380846        0.954280     0.947698      96.380846   

   precision_(Te)  recall_(Te)     time_(s)  
0        0.910795     0.892480  1558.915665  
1        0.941256     0.925700  4195.398219  
2        0.954280     0.947698  4775.044521  
