In [1]:
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transformations for your dataset
image_width = 64
transform = transforms.Compose([
    transforms.Resize((image_width, image_width)),  # Resize all images to 128x128
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# Load dataset using ImageFolder
dataset_path = 'PlantVillage'
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

# Create a DataLoader
batch_size = 64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Access class-to-index mapping
class_to_idx = dataset.class_to_idx
print("Class to index mapping:", class_to_idx)


Class to index mapping: {'Pepper__bell___Bacterial_spot': 0, 'Pepper__bell___healthy': 1, 'Potato___Early_blight': 2, 'Potato___Late_blight': 3, 'Potato___healthy': 4, 'Tomato_Bacterial_spot': 5, 'Tomato_Early_blight': 6, 'Tomato_Late_blight': 7, 'Tomato_Leaf_Mold': 8, 'Tomato_Septoria_leaf_spot': 9, 'Tomato_Spider_mites_Two_spotted_spider_mite': 10, 'Tomato__Target_Spot': 11, 'Tomato__Tomato_YellowLeaf__Curl_Virus': 12, 'Tomato__Tomato_mosaic_virus': 13, 'Tomato_healthy': 14}


In [None]:
import numpy as np
from collections import Counter
from torch.utils.data import Subset
from torch.utils.data import random_split

# Count instances per class
class_counts = Counter([label for _, label in dataset])
print("Class counts before downsampling:", class_counts)

test_split = 0.2  # 20% for testing
val_split = 0.1   # 10% for validation

test_size = int(test_split * len(dataset))
val_size = int(val_split * len(dataset))
train_size = len(dataset) - test_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=3)