In [None]:
from torchvision import datasets, transforms
import os
from torchvision.datasets import ImageFolder
from PIL import Image
import random

In [85]:
# PRINT THE NUMBER OF IMAGES IN THE DATASETS 

dataset_path = './dataset-tomatoes'
train_path = os.path.join(dataset_path, 'train')
val_path = os.path.join(dataset_path, 'validation')
test_path = os.path.join(dataset_path, 'test')


train_dataset = ImageFolder(root=train_path)
val_dataset = ImageFolder(root=val_path)
test_dataset = ImageFolder(root=test_path)

print(f"Number of training images: {len(train_dataset)}")
print(f"Number of validation images: {len(val_dataset)}")
print(f"Number of test images: {len(test_dataset)}")

total = len(train_dataset) + len(val_dataset) + len(test_dataset)
print(f"Total number of images: {total}")



Number of training images: 14404
Number of validation images: 1796
Number of test images: 1796
Total number of images: 17996


In [None]:
path = './dataset-tomatoes'

full_dataset = ImageFolder(root=path)

print(len(full_dataset))

classes = full_dataset.classes

for i in range(len(classes)):
    print(f"Number of images in {classes[i]}: {len(os.listdir(os.path.join(path, classes[i])))}")
    

In [81]:
# === SPLITTING THE DATASET === 

import os
import shutil
import json
from random import shuffle

data_config = '''
{
    "Tomato___Bacterial_spot": {"total": 2406, "training_set": 1930, "validation_set": 235, "test_set": 241},
    "Tomato___Early_blight": {"total": 1214, "training_set": 966, "validation_set": 131, "test_set": 117},
    "Tomato___Late_blight": {"total": 2129, "training_set": 1710, "validation_set": 210, "test_set": 209},
    "Tomato___Leaf_Mold": {"total": 1244, "training_set": 999, "validation_set": 120, "test_set": 125},
    "Tomato___Septoria_leaf_spot": {"total": 2204, "training_set": 1771, "validation_set": 210, "test_set": 223},
    "Tomato___Tomato_mosaic_virus": {"total": 634, "training_set": 509, "validation_set": 66, "test_set": 59},
    "Tomato___Tomato_Yellow_Leaf_Curl_Virus": {"total": 6179, "training_set": 4930, "validation_set": 624, "test_set": 625},
    "Tomato___healthy": {"total": 1986, "training_set": 1589, "validation_set": 200, "test_set": 197}
}
'''
config = json.loads(data_config)

dataset_path = './dataset-tomatoes/'

for set_type in ['train', 'validation', 'test']:
    set_path = os.path.join(dataset_path, set_type)
    if not os.path.exists(set_path):
        os.makedirs(set_path)

for class_name, sets in config.items():
    class_dir = os.path.join(dataset_path, class_name)
    images = os.listdir(class_dir)
    shuffle(images)  

    for set_type in ['train', 'validation', 'test']:
        set_class_path = os.path.join(dataset_path, set_type, class_name)
        if not os.path.exists(set_class_path):
            os.makedirs(set_class_path)

    validation_count = sets['validation_set']
    test_count = sets['test_set']
    training_count = sets['training_set']

    validation_images = images[:validation_count]
    test_images = images[validation_count:validation_count + test_count]
    training_images = images[validation_count + test_count:]

    for image in validation_images:
        shutil.copy(os.path.join(class_dir, image), os.path.join(dataset_path, 'validation', class_name))
    for image in test_images:
        shutil.copy(os.path.join(class_dir, image), os.path.join(dataset_path, 'test', class_name))
    for image in training_images:
        shutil.copy(os.path.join(class_dir, image), os.path.join(dataset_path, 'train', class_name))


In [84]:
# === PRINTING THE NUMBER OF IMAGES IN EACH CLASS ===
import os
path = './dataset-tomatoes/train'
full_dataset = ImageFolder(root=path)

print(len(full_dataset))

classes = full_dataset.classes

for i in range(len(classes)):
    print(f"Number of images in {classes[i]}: {len(os.listdir(os.path.join(path, classes[i])))}")
    

14404
Number of images in Tomato___Bacterial_spot: 1930
Number of images in Tomato___Early_blight: 966
Number of images in Tomato___Late_blight: 1710
Number of images in Tomato___Leaf_Mold: 999
Number of images in Tomato___Septoria_leaf_spot: 1771
Number of images in Tomato___Tomato_Yellow_Leaf_Curl_Virus: 4930
Number of images in Tomato___Tomato_mosaic_virus: 509
Number of images in Tomato___healthy: 1589


In [86]:
# === AUGMENTING THE DATASET ===
import os
import random
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
import json

# Load JSON configuration
data_config = '''
{
    "Tomato___Bacterial_spot": {"total": 2406, "training_set": 1930, "validation_set": 235, "test_set": 241},
    "Tomato___Early_blight": {"total": 1214, "training_set": 966, "validation_set": 131, "test_set": 117},
    "Tomato___Late_blight": {"total": 2129, "training_set": 1710, "validation_set": 210, "test_set": 209},
    "Tomato___Leaf_Mold": {"total": 1244, "training_set": 999, "validation_set": 120, "test_set": 125},
    "Tomato___Septoria_leaf_spot": {"total": 2204, "training_set": 1771, "validation_set": 210, "test_set": 223},
    "Tomato___Tomato_mosaic_virus": {"total": 634, "training_set": 509, "validation_set": 66, "test_set": 59},
    "Tomato___Tomato_Yellow_Leaf_Curl_Virus": {"total": 6179, "training_set": 4930, "validation_set": 624, "test_set": 625},
    "Tomato___healthy": {"total": 1986, "training_set": 1589, "validation_set": 200, "test_set": 197}
}
'''
config = json.loads(data_config)

dataset_path = './dataset-tomatoes/train'
full_dataset = ImageFolder(root=dataset_path)
classes = full_dataset.classes

# Define transformations
transformations = {
    "horizontal_flip": transforms.RandomHorizontalFlip(p=1),
    "color_jitter": transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0.5),
    "contrast_enhancement": transforms.ColorJitter(brightness=0, contrast=1.5, saturation=0, hue=0),
    "saturation_enhancement": transforms.ColorJitter(brightness=0, contrast=0, saturation=1.5, hue=0),
    "brightness_enhancement": transforms.ColorJitter(brightness=1.5, contrast=0, saturation=0, hue=0),
    "rotation": transforms.RandomRotation((45, 45)),
    "resized_crop": transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), ratio=(1.0, 1.0))
}

transformations = list(transformations.values())
print(transformations)

# Augment images for each class to meet training set requirements
for class_name in classes:
    class_folder = os.path.join(dataset_path, class_name)
    images = [os.path.join(class_folder, img) for img in os.listdir(class_folder) if img.lower().endswith(('.jpg', '.jpeg', '.png'))]
    current_count = len(images)
    training_target = config[class_name]['training_set']

    needed_augmentations = training_target - current_count
    
    print(f"Augmenting {class_name} with {needed_augmentations} images")
    print("In total we ll have", current_count + needed_augmentations, "images")
    print("-------------------")
    
    for i in range(needed_augmentations):
        print("Augmenting image", i+1)
        random_image_path = random.choice(images)
        image = Image.open(random_image_path)
        transform = transformations[i % len(transformations)]
        transformed_image = transform(image)
        suffix = transform.__class__.__name__.lower()
        save_path = os.path.join(class_folder, f"img_{suffix}_{i+1}_{class_name}.jpg")
        transformed_image.save(save_path)

print("Augmentation complete!")


[RandomHorizontalFlip(p=1), ColorJitter(brightness=None, contrast=None, saturation=None, hue=(-0.5, 0.5)), ColorJitter(brightness=None, contrast=(0.0, 2.5), saturation=None, hue=None), ColorJitter(brightness=None, contrast=None, saturation=(0.0, 2.5), hue=None), ColorJitter(brightness=(0.0, 2.5), contrast=None, saturation=None, hue=None), RandomRotation(degrees=[45.0, 45.0], interpolation=nearest, expand=False, fill=0), RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0), ratio=(1.0, 1.0), interpolation=bilinear, antialias=warn)]
Augmenting Tomato___Bacterial_spot with 178 images
In total we ll have 1930 images
-------------------
Augmenting image 1
Augmenting image 2
Augmenting image 3
Augmenting image 4
Augmenting image 5
Augmenting image 6
Augmenting image 7
Augmenting image 8
Augmenting image 9
Augmenting image 10
Augmenting image 11
Augmenting image 12
Augmenting image 13
Augmenting image 14
Augmenting image 15
Augmenting image 16
Augmenting image 17
Augmenting image 18
Augmenting