In [41]:
import yaml
import os
from functions import print_function
import numpy as np
from keras.datasets import mnist, cifar10
from sklearn.model_selection import train_test_split


class StaticLinear:

    def __init__(self, path):
        self.path = path

        print_function('Retrieving yaml parameters')

        with open(self.path, "r") as stream:
            self.load_yaml = yaml.load(stream, yaml.SafeLoader)

        self.general = self.load_yaml['general']

        self.random_state = np.random.RandomState(self.general['seed_value'])


    def load_data(self, X = None, Y = None, use_case_name = None):

        """
        For predefined use cases the data is downloaded and split into training, testing and validation and then saved locally.
        When an X and Y is supplied, the data is sampled, split into training, testing and validation and then saved locally.
        Supply X data and Y data or use the predefined MNIST and CIFAR-10 use cases.
        """

        self.data = self.load_yaml['data']

        use_case = self.data['use_case'].lower()

        if self.data['save_data']['set_save_path']:
            save_path = self.data['save_data']['path'] + '/data/'

        else:
            save_path = os.getcwd() + '/data/'

        if X == None and Y == None and use_case not in ['mnist', 'cifar-10']:
            raise Exception('No predefined use case provided and no data supplied. Please choose a predefined case or supply data.')

        elif X != None and Y != None:
            if use_case_name != None:
                save_path = save_path + use_case_name + '/'

            else:
                save_path = save_path + 'own_use_case' + '/'

            if not os.path.exists(save_path) or self.data['replace'] == True:

                os.makedirs(save_path, exist_ok=True)

                try:
                    X = np.load(self.data['X_path'])
                    Y = np.load(self.data['Y_path'])

                    if self.data['data_sample_size'] != None:
                        g_sample = int(np.floor(self.data['test_set_size']/2))
                        n_sample = self.data['test_set_size'] - g_sample

                        g_i = np.random.randint(0, int(X.shape[0]/2) - 1, g_sample)
                        n_i = np.random.randint(int(X.shape[0]/2), X.shape[0] - 1, n_sample)

                        X = np.concatenate([X[g_i], X[n_i]])
                        Y = np.concatenate([Y[g_i], Y[n_i]])

                        self.X_train, self.X_test, self.Y_train, self.Y_test = train_test_split(X, Y, test_size=self.data['test_set_size'], random_state=self.general['seed_value'], shuffle = True)
                        self.X_test, self.X_val, self.Y_test, self.Y_val = train_test_split(self.X_test, self.Y_test, test_size=0.5, random_state=self.general['seed_value'], shuffle = True)

                        np.save(save_path + 'X_train.npy', self.X_train)
                        np.save(save_path + 'X_test.npy', self.X_test)
                        np.save(save_path + 'X_val.npy', self.X_val)
                        np.save(save_path + 'Y_train.npy', self.Y_train)
                        np.save(save_path + 'Y_test.npy', self.Y_test)
                        np.save(save_path + 'Y_val.npy', self.Y_val)
                        
                    else:
                        self.X_train, self.X_test, self.Y_train, self.Y_test = train_test_split(X, Y, test_size=self.data['test_set_size'], random_state=self.general['seed_value'], shuffle = True)
                        self.X_test, self.X_val, self.Y_test, self.Y_val = train_test_split(self.X_test, self.Y_test, test_size=0.5, random_state=self.general['seed_value'], shuffle = True)

                        np.save(save_path + 'X_train.npy', self.X_train)
                        np.save(save_path + 'X_test.npy', self.X_test)
                        np.save(save_path + 'X_val.npy', self.X_val)
                        np.save(save_path + 'Y_train.npy', self.Y_train)
                        np.save(save_path + 'Y_test.npy', self.Y_test)
                        np.save(save_path + 'Y_val.npy', self.Y_val)

                except:
                    raise Exception(f"Unable to load data from: {self.data['X_path'], self.data['Y_path']}")

        
            else:
                raise Exception('Could not save provided data.')

        elif use_case not in ['mnist', 'cifar-10']:
            raise Exception('Incorrect data use case provided in yaml. Check spelling or choose one of the available use cases: mnist or cifar-10.')

        else:
            print_function(f"Loading data for use case: {use_case.upper()}")

            
            
            if use_case == 'mnist':

                save_path = save_path + 'mnist' + '/'

                try:
                    self.X_train = np.load(save_path + 'X_train.npy')
                    self.X_test = np.load(save_path + 'X_test.npy')
                    self.X_val = np.load(save_path + 'X_val.npy')
                    self.Y_train = np.load(save_path + 'Y_train.npy')
                    self.Y_test = np.load(save_path + 'Y_test.npy')
                    self.Y_val = np.load(save_path + 'Y_val.npy')

                except:
                    print_function(f"MNIST data not found. Downloading data.")

                    (self.X_train, self.Y_train), (self.X_test, self.Y_test) = mnist.load_data()
                    self.X_test, self.X_val, self.Y_test, self.Y_val = train_test_split(self.X_test, self.Y_test, test_size=0.5, random_state=self.general['seed_value'], shuffle = True)

                    if not os.path.exists(save_path) or self.data['replace'] == True:
                        os.makedirs(save_path, exist_ok=True)

                    np.save(save_path + 'X_train.npy', self.X_train)
                    np.save(save_path + 'X_test.npy', self.X_test)
                    np.save(save_path + 'X_val.npy', self.X_val)
                    np.save(save_path + 'Y_train.npy', self.Y_train)
                    np.save(save_path + 'Y_test.npy', self.Y_test)
                    np.save(save_path + 'Y_val.npy', self.Y_val)


            elif use_case == 'cifar-10':

                save_path = save_path + 'cifar10' + '/'

                try:
                    self.X_train = np.load(save_path + 'X_train.npy')
                    self.X_test = np.load(save_path + 'X_test.npy')
                    self.X_val = np.load(save_path + 'X_val.npy')
                    self.Y_train = np.load(save_path + 'Y_train.npy')
                    self.Y_test = np.load(save_path + 'Y_test.npy')
                    self.Y_val = np.load(save_path + 'Y_val.npy')

                except:
                    print_function(f"CIFAR-10 data not found. Downloading data instead.")

                    (self.X_train, self.Y_train), (self.X_test, self.Y_test) = cifar10.load_data()
                    self.X_test, self.X_val, self.Y_test, self.Y_val = train_test_split(self.X_test, self.Y_test, test_size=0.5, random_state=self.general['seed_value'], shuffle = True)

                    if not os.path.exists(save_path) or self.data['replace'] == True:
                        os.makedirs(save_path, exist_ok=True)

                    np.save(save_path + 'X_train.npy', self.X_train)
                    np.save(save_path + 'X_test.npy', self.X_test)
                    np.save(save_path + 'X_val.npy', self.X_val)
                    np.save(save_path + 'Y_train.npy', self.Y_train)
                    np.save(save_path + 'Y_test.npy', self.Y_test)
                    np.save(save_path + 'Y_val.npy', self.Y_val)

In [42]:
data = StaticLinear('parameters.yaml')

Retrieving yaml parameters...
__________________________________________________


In [43]:
data.load_data()

Loading data for use case: MNIST...
__________________________________________________


In [44]:
data.X_train.shape

(60000, 28, 28)