In [3]:
# Load Dataset 

import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class RazorbackDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []

        # Folder paths
        official_dir = os.path.join(root_dir, 'official_logo')
        not_official_dir = os.path.join(root_dir, 'not_official')

        # Load official logo images (label = 1)
        for img_file in os.listdir(official_dir):
            if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                self.images.append(os.path.join(official_dir, img_file))
                self.labels.append(1)

        # Load non-official images (label = 0)
        for img_file in os.listdir(not_official_dir):
            if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                self.images.append(os.path.join(not_official_dir, img_file))
                self.labels.append(0)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


In [4]:
# Pre-Process Dataset w Transformations

from torch.utils.data import DataLoader
import torchvision.transforms as transforms

# Transformations (resize to 500x500 REQUIRED)
data_transforms = transforms.Compose([
    transforms.Resize((500, 500)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Load dataset
dataset = RazorbackDataset(
    root_dir='/Users/eliseeldridge/desktop/GitRepo/DASC-41103-Group-22/Project3/razorback_dataset',
    transform=data_transforms
)

# DataLoader
data_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Verification
print(f"Total images: {len(dataset)}")
print(f"Official logos: {sum(dataset.labels)}")
print(f"Non-official images: {len(dataset) - sum(dataset.labels)}")

Total images: 91
Official logos: 43
Non-official images: 48
