In [2]:
import cv2
import torch
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from tqdm import tqdm
import random
from matplotlib import pyplot as plt
import matplotlib.patches as patches
import os

In [3]:
class ImageFolder(Dataset):
  def __init__(self, root_dir, transform=None):
    super(ImageFolder, self).__init__()
    self.data = []
    self.root_dir = root_dir
    self.transform = transform
    self.class_names = os.listdir(root_dir)

    for index, name in enumerate(self.class_names):
      files = os.listdir(os.path.join(root_dir, name))
      self.data +=list(zip(files, [index]*len(files)))
  
  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, index):
    img_file, label = self.data[index]
    root_and_dir = os.path.join(self.root_dir, self.class_names[label])  # Here label indicate which folder the first one or the second one 
    image = np.array(Image.open(os.path.join(root_and_dir,img_file)))

    if self.transform is not None:
      augmentations = self.transform(image = image)
      image = augmentations["image"]


    return image, label


In [4]:
transform = A.Compose([
    A.Resize(width=1920, height=720),
    A.RandomCrop(width=1280, height = 720),
    A.Rotate(limit=45, p=0.9, border_mode = cv2.BORDER_CONSTANT),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.1),
    A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p = 0.9),
    A.OneOf(
        [
         A.Blur(blur_limit=3, p=0.5),
         A.ColorJitter(p=0.5),
        ],
        p=1.0,
    ),
    # this normalization is different from pytorch default normalization
    A.Normalize(
        mean = [0,0,0],
        std = [1,1,1],
        max_pixel_value=255,

    ),
    ToTensorV2(),
])

In [5]:
dataset = ImageFolder(root_dir="/content/images/cat_dogs", transform = transform)
for x, y in dataset:
  print(x.shape)

torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
torch.Size([3, 720, 1280])
