In [1]:
import os, cv2
import numpy as np
from matplotlib import pyplot as plt

from utils import *

In [2]:
class DataLoader():
    def __init__(self, path, batch_size=16, num_val=100):
        self.path = path
        self.batch_size = batch_size
        self.img_list = [os.path.join(self.path, i) for i in sorted(os.listdir(self.path)) 
                         if i[-3:].lower() in ['jpg', 'png', 'jpeg', 'bmp']] 
        self.train_list = self.img_list[:-100].copy()
        self.val_list = self.img_list[:-100].copy()
        
    def train_generator(self, shuffle=True):
        while 1:
            if shuffle : 
                np.random.shuffle(self.train_list)
            for i in range(0, len(self.train_list), self.batch_size):
                if i==len(self.train_list)//self.batch_size*self.batch_size:
                    tmp_train_list = self.train_list[i:]
                else : tmp_train_list = self.train_list[i:i+self.batch_size]
                
                y_data = [img_read(img, 256) for img in tmp_train_list]
                x_data = [cv2.resize(img, None, fx=0.5, fy=0.5) for img in y_data]
                
                yield np.array(x_data), np.array(y_data), tmp_train_list
    def validation_generator(self, shuffle=True):
        while 1:
            if shuffle : 
                np.random.shuffle(self.val_list)
            for i in range(0, len(self.val_list), self.batch_size):
                if i==len(self.val_list)//self.batch_size*self.batch_size:
                    tmp_train_list = self.val_list[i:]
                else : tmp_train_list = self.val_list[i:i+self.batch_size]
                
                y_data = [img_read(img, 256) for img in tmp_train_list]
                x_data = [cv2.resize(img, None, fx=0.5, fy=0.5) for img in y_data]
                
                yield np.array(x_data), np.array(y_data)

In [3]:
data_loader = DataLoader('./data/val2017/', 16, 200)

In [4]:
train_gen = data_loader.train_generator()

In [26]:
img, lab, img_list = next(train_gen)
print(img.shape)
print(lab.shape)

(16, 128, 128, 3)
(16, 256, 256, 3)
