In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from collections import Counter

data_dir = r"C:\Users\himan\OneDrive - Dundalk Institute of Technology\y5 sem3\dataset\train"
annotations_file = os.path.join(data_dir, '_annotations.csv')

#load annotations
df = pd.read_csv(annotations_file)

print("First few rows:")
print(df.head())
print("\nClass distribution:")
print(df['class'].value_counts()) 

df['class'].value_counts().plot(kind='bar', title='Class Distribution')
plt.xlabel('Class')
plt.ylabel('Count')
plt.show()

# Custom Dataset
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])  
        image = Image.open(img_name).convert("RGB")
        label = self.img_labels.iloc[idx, 1]  

        if self.transform:
            image = self.transform(image)
        return image, label

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

dataset = CustomImageDataset(annotations_file=annotations_file, img_dir=data_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

import torchvision
images, labels = next(iter(dataloader))
img_grid = torchvision.utils.make_grid(images[:16], nrow=4)
plt.figure(figsize=(8, 8))
plt.imshow(img_grid.permute(1, 2, 0))
plt.title("Sample Images")
plt.axis('off')
plt.show()

#check for class imbalance
#label_counts = Counter([label for _, label in dataset])
#print("Class distribution:", label_counts)

In [None]:
augmented_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

dataset = CustomImageDataset(annotations_file=annotations_file, img_dir=data_dir, transform=augmented_transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

images, labels = next(iter(dataloader))
img_grid = torchvision.utils.make_grid(images[:16], nrow=4)
plt.figure(figsize=(8, 8))
plt.imshow(img_grid.permute(1, 2, 0))
plt.title("Sample Images")
plt.axis('off')
plt.show()