In [54]:
#collapse-hide

####### PACKAGES
import glob
import random

import numpy as np
import pandas as pd
from datasets import load_dataset

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import cv2
from pandas.core.common import flatten
from PIL import Image

from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline


####### PARAMS

device      = torch.device('cpu') 
num_workers = 4
image_size  = 32 
batch_size  = 32
data_path   = 'datas/train'

In [31]:
class WayangKulit(Dataset):
    def __init__(self, dataset_path, transform):
        self.dataset_path = dataset_path
        self.transform = transform
        self.dataset = load_dataset(path=self.dataset_path, name='wayang_kulit')

        self.train_image_paths = []
        self.classes = []

        for data_path in glob.glob(self.dataset_path + '/*'):
            self.classes.append(data_path.split('/')[-1])
            self.train_image_paths.append(glob.glob(data_path + '/*'))

        self.train_image_paths = list(flatten(self.train_image_paths))
        random.shuffle(self.train_image_paths)

    def __len__(self):
        return len(self.dataset['train'])

    def __getitem__(self, idx):
        image_filepath = self.train_image_paths[idx]

        image = cv2.imread(image_filepath)

        if self.transform:
            image = self.transform(image = image)['image']

        idx_to_class = {i: j for i, j in enumerate(self.classes)}
        class_to_idx = {value: key for key, value in idx_to_class.items()}

        label = image_filepath.split('/')[-2]
        label = class_to_idx[label]

        return image, label

In [32]:
augs = A.Compose([A.Resize(height = image_size, 
                           width  = image_size),
                  A.Normalize(mean = (0, 0, 0),
                              std  = (1, 1, 1)),
                  ToTensorV2()])

In [60]:
####### EXAMINE SAMPLE BATCH

# dataset
image_dataset = WayangKulit(data_path, transform=augs)

# data loader
image_loader = DataLoader(image_dataset, 
                          batch_size  = 32, 
                          shuffle     = False, 
                          num_workers = 0,
                          pin_memory  = True)

Resolving data files:   0%|          | 0/1350 [00:00<?, ?it/s]

In [61]:
####### COMPUTE MEAN / STD

# placeholders
psum    = torch.tensor([0.0, 0.0, 0.0])
psum_sq = torch.tensor([0.0, 0.0, 0.0])

# loop through images
for inputs in tqdm(image_loader):
    psum    += inputs.sum(axis        = [0, 2, 3])
    psum_sq += (inputs ** 2).sum(axis = [0, 2, 3])

  0%|                                                                                                                                                                               | 0/43 [00:00<?, ?it/s]


AttributeError: 'list' object has no attribute 'sum'