In [11]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import os

In [9]:
root_dir = os.getcwd()
data_dir = os.path.join(root_path, "data")
root_dir, data_dir

('/home/douglas.ta/ResiliTree/ResiliTree',
 '/home/douglas.ta/ResiliTree/ResiliTree/data')

In [25]:
transformations = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),      # Flip images horizontally with 50% probability
    transforms.RandomRotation(degrees=15),       # Rotate images randomly within a range of ±15 degrees
    transforms.Resize((200,200)),                # Resizing images to be consistant with each other
    transforms.ToTensor(),                       # Convert to PyTorch tensor
])

In [28]:
class CustomImageDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        # Create a mapping from folder name to label
        self.class_to_idx = {
            "eucalyptus": 0,
            "live_oak": 1,
            "maple": 2,
            "queenpalm": 3,
            "sabel_palm": 4,
            "south_magnolia": 5
            }
        
        # Collect all image paths and their corresponding labels
        for folder_name, label in self.class_to_idx.items():
            folder_path = os.path.join(data_dir, folder_name)
            for img_file in os.listdir(folder_path):
                img_path = os.path.join(folder_path, img_file)
                if img_path.endswith(('.png', '.jpg', '.jpeg', '.jfif', '.webp')):  # Filter for image files
                    self.image_paths.append(img_path)
                    self.labels.append(label)  # Append the label associated with the folder

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]
        
        # Apply transformations if any
        if self.transform:
            image = self.transform(image)
        
        return image, label


In [29]:
# Initialize the dataset
dataset = CustomImageDataset(data_dir=data_dir, transform=transforms)

# Load the dataset into DataLoader
data_loader = DataLoader(dataset, batch_size=5, shuffle=True)
