In [2]:
import os

import torch
from torch import nn

In [3]:
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt

import albumentations as A
from albumentations.pytorch import ToTensorV2

import utils as ul

In [11]:
class ImageFolder(nn.Module):
    def __init__(self, root_dir: str, transform=None, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.data = []

        self.root_dir = root_dir
        self.transform = transform

        self.class_names = os.listdir(root_dir)

        for idx, name in enumerate(self.class_names):
            files = os.listdir(os.path.join(root_dir, name))

            self.data += list(zip(files, [idx] * len(files)))
    
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int):
        img_file, label = self.data[idx]

        root_and_dir = os.path.join(self.root_dir, self.class_names[label])
        img = cv.cvtColor(cv.imread(os.path.join(root_and_dir, img_file)), cv.COLOR_BGR2RGB)

        if self.transform:
            augs = self.transform(image=img)

            img = augs['image']
        
        return img, label

In [12]:
transform = A.Compose([
    A.Resize(width=1920, height=1080),
    A.RandomCrop(width=1280, height=720),
    A.Rotate(limit=40, p=.9, border_mode=cv.BORDER_CONSTANT),

    A.HorizontalFlip(p=.5),
    A.VerticalFlip(p=.1),

    A.RGBShift(
        r_shift_limit=25,
        g_shift_limit=25,
        b_shift_limit=25,

        p=.9
    ),

    A.OneOf([  # In 100% of cases one these transformations will be chosen
        A.Blur(blur_limit=3, p=.5),
        A.ColorJitter(p=.5)
    ], p=1.),

    A.Normalize(
        mean=[0., 0., 0.],
        std=[1., 1., 1.],

        max_pixel_value=255
    ),

    ToTensorV2(),
])

In [13]:
dataset = ImageFolder(root_dir='./dataset', transform=transform)
for x, y in dataset:
    print(f'{x.shape = }')

x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
x.shape = torch.Size([3, 720, 1280])
