In [1]:
import os 
import torch
import shutil
import numpy as np
from pathlib import Path
from glob import glob
from PIL import Image
from torch.utils.data import random_split, Dataset, DataLoader
from torchvision import transforms as T
torch.manual_seed(42)

<torch._C.Generator at 0x7f9666d5ba30>

## Custom Datasets

In [None]:
class MyDataset(Dataset):
    def __init__(self, path, transformations=None):
        #Apply img transformations
        self.transformations = transformations
        #Load all images in a path
        self.img_paths = [im_path for im_path in sorted(glob(f"{path}/*/*"))]
        #Dictionaries and counters for classes and count per class
        self.class_names = {}
        self.class_counts = {}
        count = 0

        for idx, img_path in enumerate(self.img_paths):
            #Get the class name
            class_name = self.get_class(img_path)
            #Check if the class exist already and if not appends it to the class name dict
            if class_name not in self.class_names:
                self.class_names[class_name] = count
                self.class_counts[class_name] = 1
                count += 1
            #If it exist, increase the counter for that class
            else:
                self.class_counts[class_name] += 1

    #Function to get the real label of an img
    def get_class(self, path) -> str:
        """Return the name of the class based on its path"""
        return os.path.dirname(path).split("/")[-1]
    
    def __len__(self):
        return len(self.img_paths)
    
    #Obtain 1 image and its label from te ds
    def __getitem__(self, idx: int):
        img_path = self.img_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.class_names[self.get_class(img_path)]

        #If there are tarnsformation, apply it to return the image with them.
        if self.transformations is not None:
            image = self.transformations(image)
        return image, label
        


In [29]:
dataset = MyDataset("/home/edu_pc/Projects/Aer_City_Img_Classifier/dataset/")
print(len(dataset))

a,b = dataset[0]



8000


## Dataloaders

In [24]:
def create_dataloaders(path, transformations, batch_size, split: list = [0.9, 0.05, 0.05], num_workers: int = 4):
    """Organize and create the dataloaders for train, valid and test using our datset objet. Splits and apply transformations and"""
    dataset = MyDataset(path = path, transformations=transformations)
    #Calculate the len for each split (train, valid, test)
    dataset_len = len(dataset)
    train_len = int(dataset_len * split[0])
    val_len = int(dataset_len * split[1])
    test_len = int(dataset_len * split[2])

    #Create the splits 
    tr_ds, val_ds, test_ds = random_split(dataset, lengths=[train_len, val_len, test_len])

    #Create the dataloaders
    tr_dl = DataLoader(dataset=tr_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers) 
    #Avoid shuffle for replicability in both val and test 
    val_dl = DataLoader(dataset=val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    #test images are evaluated 1 by 1  
    test_dl = DataLoader(dataset=test_ds, batch_size=1, shuffle=False, num_workers=num_workers)  

    #Return dataloaders for each split andthe class_names with its IDs (keys are the name, value the ID)
    return tr_dl, val_dl, test_dl, dataset.class_names





## Initialization dataset and dataloaders for our project

In [None]:
ds_path = "dataset"

#Normalization parameters per channel
mean = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
std = [0.229, 0.224, 0.225]
resize = 224

#Define the transformations using transforms utils
    #Resize to 224x224
    #Tranform to tensor
    #Normalize according typical values of ConvNext, ResNet and so
transforms = T.Compose([
    T.Resize((resize,resize)),
    T.ToTensor(),
    T.Normalize(mean = mean, std = std)]
)

#Create the dataloaders using our function
tr_dl, val_dl, test_dl, classes = create_dataloaders(path=ds_path, transformations=transforms, batch_size=32)


In [None]:
#Check the batches per dataloader
print(len(tr_dl)); print(len(val_dl)); print(len(test_dl)); print(classes)

225
13
400
{'Bridge': 0, 'Commercial': 1, 'Industrial': 2, 'Intersection': 3, 'Landmark': 4, 'Park': 5, 'Parking': 6, 'Playground': 7, 'Residential': 8, 'Stadium': 9}


Yes, the numbers are correct. We have 400 batches in the test_dl because we specified batch size of 1

## Data visualization