In [2]:
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [34]:
class AlbumentationsDataset(Dataset):
    def __init__(self,file_paths,labels,transform=None):
        self.file_paths = file_paths
        # self.file_paths = os.listdir(file_paths)
        self.labels = labels
        self.transform = transform
        
    def __getitem__(self,index):
        label = self.labels[index]
        file_path = self.file_paths[index]
        
        img = cv2.imread(file_path)
        print(img)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            augmented = self.transform(image=img)
            img = augmented['image']
            
        return img, label
        
    def __len__(self):
        return len(self.file_paths)

In [35]:
albumentation_transform = A.Compose([
    A.Resize(256,256),
    A.RandomCrop(224,224),
    A.HorizontalFlip(),
    A.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ToTensorV2()
])

In [36]:
data = AlbumentationsDataset(
    ["./data/sample_data_01\\train\\dew\\2208.jpg",
    "./data/sample_data_01\\train\\fogsmog\\4075.jpg",
    "./data/sample_data_01\\train\\frost\\3600.jpg"],
    [0,1,2],
    transform = albumentation_transform
)

In [37]:
for i in data:
    print(i)

[[[  2 120  77]
  [  3 121  78]
  [  2 120  77]
  ...
  [  2 105  47]
  [  3 106  48]
  [  2 105  47]]

 [[  0 118  75]
  [  1 119  76]
  [  0 118  75]
  ...
  [  3 106  48]
  [  3 106  48]
  [  1 104  46]]

 [[  0 118  75]
  [  0 118  75]
  [  1 119  76]
  ...
  [  3 106  49]
  [  2 105  48]
  [  1 104  47]]

 ...

 [[  0  49  11]
  [  0  48  10]
  [  0  49  11]
  ...
  [  0 113  56]
  [  0 113  55]
  [  0 113  55]]

 [[  0  49  11]
  [  0  47   9]
  [  0  48  10]
  ...
  [  0 113  56]
  [  0 113  55]
  [  0 113  55]]

 [[  0  49  11]
  [  0  47   9]
  [  0  48  10]
  ...
  [  0 111  55]
  [  0 112  54]
  [  0 112  54]]]
(tensor([[[-1.0562, -1.0048, -1.0048,  ..., -0.7993, -0.8164, -0.7822],
         [-1.0219, -1.0390, -1.0219,  ..., -0.8164, -0.7822, -0.7993],
         [-1.0733, -1.0390, -1.0219,  ..., -0.8164, -0.7822, -0.7650],
         ...,
         [-1.2617, -1.3302, -1.4158,  ..., -2.1008, -2.1008, -2.0837],
         [-1.2788, -1.3473, -1.3987,  ..., -2.1008, -2.1008, -2.1008],
