In [10]:
import numpy as np
import os 
import glob
import tensorflow as tf
import random

WORKING_DIR = os.path.join("..", "..")
PROCESSED_PATH = os.path.join(WORKING_DIR, "data", "03_processed")
DATASET_PATH = glob.glob(os.path.join(PROCESSED_PATH,"*"))[0]

class Train_Validation_Generators:
    def __init__(self, dataset_path, view_IDs, train_size, batch_size=32, shuffle=True, RGB=False):
        self.devel_path = os.path.join(dataset_path, "development")    
        self.train_size = train_size
        self.view_IDs = view_IDs.copy()
        self.batch_size = batch_size
        self.RGB = RGB
        self.shuffle = shuffle
        if train_size > 1:
            train_size = 1
        if (train_size == -1):
            train_size = 1.0
        self.validation_size = 1-train_size
        def _read_samples_in_subfolder(zero_or_one):
            fp_list = glob.glob(os.path.join(self.devel_path, str(zero_or_one), "*.npz"))
            return [np.load(f) for f in fp_list]
        pos_samples = _read_samples_in_subfolder(1)
        neg_samples = _read_samples_in_subfolder(0)
        random.shuffle(pos_samples)
        random.shuffle(neg_samples)
        m = len(pos_samples) if len(pos_samples) < len(neg_samples) else len(neg_samples)
        split_pt = int(self.train_size * m)
        self.train_pos_samples, self.train_neg_samples = pos_samples[:split_pt], neg_samples[:split_pt]
        self.valid_pos_samples, self.valid_neg_samples = pos_samples[split_pt:], neg_samples[split_pt:]
    
    def get_train(self):
        return Data_generator(self.train_pos_samples, self.train_neg_samples, view_IDs = self.view_IDs, batch_size = self.batch_size, shuffle = self.shuffle, RGB = self.RGB)
    
    def get_valid(self):
        return Data_generator(self.valid_pos_samples, self.valid_neg_samples, view_IDs = self.view_IDs, batch_size = self.batch_size, shuffle = self.shuffle, RGB = self.RGB)
        
class Data_generator(tf.keras.utils.Sequence):
    def __init__(self, pos_samples, neg_samples, view_IDs, batch_size=32, shuffle=True, RGB=False):
        self.pos_samples = pos_samples.copy()
        self.neg_samples = neg_samples.copy()
        self.view_IDs = view_IDs.copy()
        self.RGB = RGB
        self.keys = ['array_ID{}'.format(id) for id in self.view_IDs]
        if self.RGB:
            self.keys += ['array_IDRGB']
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.pos_in_batch_n = self.batch_size // 2
        self.neg_in_batch_n = self.pos_in_batch_n
        self.on_epoch_end()
    
    def __len__(self):
        m = len(self.pos_samples) if len(self.pos_samples) < len(self.neg_samples) else len(self.neg_samples)
        return int(np.floor(m) / self.batch_size)
    
    def __getitem__(self, index):
        indices = list(range(index * self.pos_in_batch_n, (index + 1) * self.pos_in_batch_n))
        return self._load_data(indices)
    
    def _load_data(self, indices):
        data =  [[] for i in range(len(self.keys))]
        Y = []
        for idx in indices:
            samples = [self.pos_samples[idx], self.neg_samples[idx]]
            for sample in samples:
                [data[i].append(sample[id]) for i, id in enumerate(self.keys)]
                Y.append(sample['one_hot'])
        # Y [B,2] > [B, F, 2]
        Y = np.tile(np.expand_dims(np.array(Y), axis=1), [1, sample['frames'], 1])
        x = [np.array(view) for view in data]
        assert all([len(view) == len(Y) for view in x])
        return x, Y
        
    def on_epoch_end(self):
        if self.shuffle == True:
            random.shuffle(self.pos_samples)
            random.shuffle(self.neg_samples)
          
RGB = 0
generators = Train_Validation_Generators(dataset_path=DATASET_PATH, view_IDs=["121", "122", "123"], train_size=1, batch_size=4, RGB=RGB)
train_generator = generators.get_train()
valid_generator = generators.get_valid()
x, y = train_generator[0]
print(RGB, len(x))

RGB = 1
generators = Train_Validation_Generators(dataset_path=DATASET_PATH, view_IDs=["121", "122", "123"], train_size=1, batch_size=4, RGB=RGB)
train_generator = generators.get_train()
valid_generator = generators.get_valid()
x, y = train_generator[0]
print(RGB, len(x))


0 3
1 4


In [None]:


    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(list_IDs_temp)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            # Store sample
            X[i,] = np.load('data/' + ID + '.npy')

            # Store class
            y[i] = self.labels[ID]

        return X, keras.utils.to_categorical(y, num_classes=self.n_classes)