In [23]:
import os
import shutil
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

from ray import train, tune
from ray.tune.schedulers import ASHAScheduler
import time

In [24]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")

In [25]:
SEED = 42

In [26]:
image_paths = []
labels = []

classes = {v: i for i, v in enumerate(sorted(os.listdir('Vehicles/')))}

for root, _, files in os.walk('Vehicles/'):
    label = os.path.basename(root)
    if label in classes:
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                image_paths.append(os.path.join(root, file))
                labels.append(label)

X_train, X_test, y_train, y_test = train_test_split(
    image_paths, labels, test_size=0.2, random_state=SEED, stratify=labels
)

for split_name, paths, split_labels in [('train', X_train, y_train), 
                                         ('test', X_test, y_test)]:
    for path, label in zip(paths, split_labels):
        dest_dir = os.path.join('Vehicles_split', split_name, label)
        os.makedirs(dest_dir, exist_ok=True)
        shutil.copy2(path, os.path.join(dest_dir, os.path.basename(path)))

In [27]:
class VehiclesDataset:

    def __init__(self, data_path: str, transforms=None):
        self.data_path = data_path
        self.transforms = transforms
        
        self.classes = {v: i for i, v in enumerate(sorted(os.listdir(data_path)))}
        
        self.image_paths = []
        for root, _, files in os.walk(data_path):
            label = os.path.basename(root)
            if label in self.classes:
                for file in files:
                    if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                        self.image_paths.append((os.path.join(root, file), self.classes[label]))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path, label = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transforms:
            image = self.transforms(image)
        return image, label

In [28]:
def get_mean_and_std(dataloader):
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for data, _ in dataloader:
        channels_sum += torch.mean(data, dim=[0,2,3])
        channels_squared_sum += torch.mean(data**2, dim=[0,2,3])
        num_batches += 1
    
    mean = channels_sum / num_batches

    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

dataset = VehiclesDataset('Vehicles/', transforms=transforms.Compose([transforms.Resize((100, 100)), transforms.ToTensor()]))
loader = DataLoader(dataset, batch_size=16, num_workers=0, shuffle=False)

MEAN, STD = get_mean_and_std(loader)



In [30]:
train_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])