# Inventory Management Environments

> To be written.

In [None]:
#| default_exp dataloaders.tabular

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

import numpy as np
from abc import ABC, abstractmethod
from typing import Union

from ddopnew.dataloaders.base import BaseDataLoader

In [None]:
#| export
class XYDataLoader(BaseDataLoader):
    
    def __init__(self,
        X: np.ndarray,
        Y: np.ndarray,
        val_index_start: Union[int, None] = None,
        test_index_start: Union[int, None] = None,
    ):
        self.X = X
        self.Y = Y

        self.val_index_start = val_index_start
        self.test_index_start = test_index_start

        if self.val_index_start is not None:
            self.train_index_end = self.val_index_start-1
        elif self.test_index_start is not None:
            self.train_index_end = self.test_index_start-1
        else:
            self.train_index_end = len(Y)-1

        self.dataset_type = "train"

        if len(X.shape) == 1:
            self.X = X.reshape(-1, 1)
        
        if len(Y.shape) == 1:
            self.Y = Y.reshape(-1, 1)

        assert len(X) == len(Y), 'X and Y must have the same length'

        self.num_SKUs = Y.shape[1]

        super().__init__()
    
    def __getitem__(self, idx): 

        if self.dataset_type == "train":
            if idx > self.train_index_end:
                raise IndexError('index out of range')
            idx = idx
        elif self.dataset_type == "val":
            idx = idx + self.val_index_start
            
            if idx >= self.test_index_start:
                raise IndexError('index out of range')
            
        elif self.dataset_type == "test":
            idx = idx + self.test_index_start
            
            if idx >= len(self.X):
                raise IndexError('index out of range')
        
        else:
            raise ValueError('dataset_type not set')

        return self.X[idx], self.Y[idx]

    def __len__(self):
        return len(self.X)
    
    @property
    def X_shape(self):
        return self.X.shape
    
    @property
    def Y_shape(self):
        return self.Y.shape

    @property
    def len_train(self):
        return self.train_index_end+1

    @property
    def len_val(self):
        if self.val_index_start is None:
            raise ValueError('no validation set defined')
        return self.test_index_start-self.val_index_start

    @property
    def len_test(self):
        if self.test_index_start is None:
            raise ValueError('no test set defined')
        return len(self.X)-self.test_index_start

    def val(self):

        if self.val_index_start is None:
            raise ValueError('no validation set defined')
        else:
            self.dataset_type = "val"

        return self
    
    def test(self):

        if self.test_index_start is None:
            raise ValueError('no test set defined')
        else:
            self.dataset_type = "test"

        return self
    
    def train(self):

        self.dataset_type = "train"

        return self

In [None]:
X = np.random.standard_normal((100, 2))
Y = np.random.standard_normal((100, 1))
Y += 2*X[:,0].reshape(-1, 1) + 3*X[:,1].reshape(-1, 1)

dataloader = XYDataLoader(X = X, Y = Y)

sample_X, sample_Y = dataloader[0]
print("sample:", sample_X, sample_Y)
print("sample shape Y:", sample_Y.shape)

print("length:", len(dataloader))

sample: [1.56245182 1.14748893] [6.75542795]
sample shape Y: (1,)
length: 100


In [None]:
X = np.random.standard_normal((10, 2))
Y = np.random.standard_normal((10, 1))
Y += 2*X[:,0].reshape(-1, 1) + 3*X[:,1].reshape(-1, 1)

dataloader = XYDataLoader(X = X, Y = Y, val_index_start=6, test_index_start=8)

sample_X, sample_Y = dataloader[0]

print("length train:", dataloader.len_train, "length val:", dataloader.len_val, "length test:", dataloader.len_test)

print("")
print("### Data from train set ###")
for i in range(dataloader.len_train):
    sample_X, sample_Y = dataloader[i]
    print("idx:", i, "data:", sample_X, sample_Y)

dataloader.val()

print("")
print("### Data from val set ###")
for i in range(dataloader.len_val):
    sample_X, sample_Y = dataloader[i]
    print("idx:", i, "data:", sample_X, sample_Y)

dataloader.test()

print("")
print("### Data from test set ###")
for i in range(dataloader.len_test):
    sample_X, sample_Y = dataloader[i]
    print("idx:", i, "data:", sample_X, sample_Y)

length train: 6 length val: 2 length test: 2

### Data from train set ###
idx: 0 data: [ 0.15150781 -0.21567278] [0.6324219]
idx: 1 data: [0.69960546 0.39280064] [0.62274738]
idx: 2 data: [1.13297111 2.36949139] [9.64112181]
idx: 3 data: [ 0.29039519 -0.24950271] [-1.94061223]
idx: 4 data: [-1.11941705 -0.4216381 ] [-3.81505835]
idx: 5 data: [-0.67085536  1.06870613] [0.46132296]

### Data from val set ###
idx: 0 data: [ 0.50917985 -0.31464408] [0.37783885]
idx: 1 data: [-1.47151837 -0.67879661] [-3.36486836]

### Data from test set ###
idx: 0 data: [ 0.56259617 -1.09672819] [-2.29198518]
idx: 1 data: [-2.45385219 -0.05892748] [-5.09045412]


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()