In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from collections import defaultdict
from shutil import copyfile

In [2]:
data_dir = "Dataset"
batch_size = 32
num_epochs = 10
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [4]:
full_dataset = datasets.ImageFolder(root=os.path.join(data_dir), transform=transform)

class_names = full_dataset.classes
print("Class Names:", class_names)

Class Names: ['Apple__Apple_scab', 'Apple__Black_rot', 'Apple__Cedar_apple_rust', 'Apple__healthy', 'Background_without_leaves', 'Bean__Blight', 'Bean__Healthy', 'Bean__Mosaic_Virus', 'Bean__Rust', 'Blueberry__healthy', 'Cherry__Powdery_mildew', 'Cherry__healthy', 'Corn__Cercospora_leaf_spot Gray_leaf_spot', 'Corn__Common_rust', 'Corn__Northern_Leaf_Blight', 'Corn__healthy', 'Cowpea__Bacterial_wilt', 'Cowpea__Healthy', 'Cowpea__Mosaic_virus', 'Cowpea__Septoria_leaf_spot', 'Grape__Black_rot', 'Grape__Esca_(Black_Measles)', 'Grape__Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape__healthy', 'Orange__Haunglongbing_(Citrus_greening)', 'Peach__Bacterial_spot', 'Peach__healthy', 'Pepper__bell__Bacterial_spot', 'Pepper__bell__healthy', 'Potato__Early_blight', 'Potato__Late_blight', 'Potato__healthy', 'Raspberry__healthy', 'Soybean__healthy', 'Squash__Powdery_mildew', 'Strawberry__Leaf_scorch', 'Strawberry__healthy', 'Tomato__Bacterial_spot', 'Tomato__Early_blight', 'Tomato__Late_blight', 'Tomato__

In [5]:
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size
train_data, val_data, test_data = random_split(full_dataset, [train_size, val_size, test_size])

In [6]:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [7]:
# Define the output directory for exported images
output_dir = "exported_test_images"
os.makedirs(output_dir, exist_ok=True)

# Get the mapping of classes from the dataset
class_names = full_dataset.classes  # Retrieves class names from the dataset
class_counts = defaultdict(int)  # To track the number of images per class

# Loop through the original dataset to export only 5 images per class
for image_path, label in full_dataset.samples:  # .samples contains (path, class_label)
    if class_counts[label] < 5:  # Limit to 5 images per class
        # Get the class name for the label
        class_name = class_names[label]
        
        # Create a class-specific directory in the output directory
        class_dir = os.path.join(output_dir, class_name)
        os.makedirs(class_dir, exist_ok=True)
        
        # Copy the original image to the new directory
        output_path = os.path.join(class_dir, os.path.basename(image_path))
        copyfile(image_path, output_path)
        
        # Increment the count for the class
        class_counts[label] += 1

    # Stop if all classes have 5 images
    if len(class_counts) == len(class_names) and all(count >= 5 for count in class_counts.values()):
        break

print(f"Exported 5 images per class to {output_dir}")

Exported 5 images per class to exported_test_images
