In [1]:
# Packages
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [2]:

# Define dataset class
class SwatDataset(Dataset):
    ''' Dataset class generator on SWaT dataset.
    
    Args:
        - path: <str> preprocessed dataset numpy file path
        - feature_idx: <list<int>> choose features you want to use by index
        - start_idx: <int> choose period you want to use by index
        - end_idx: <int> choose period you want to use by index
        - windows_size: <int> history length you want to use
        - sliding: <int> history window moving step
    '''

    def __init__(self, path,
                 feature_idx: list,
                 start_idx: int, 
                 end_idx: int, 
                 windows_size: int,
                 sliding:int=1):
        data = np.load(path, allow_pickle=True).take(feature_idx, axis=1)[start_idx:end_idx]
        self.data = data
        self.windows_size = windows_size
        self.sliding = sliding

    def __len__(self):
        return int((self.data.shape[0] - self.windows_size) / self.sliding) - 1

    def __getitem__(self, index):
        '''
        Returns:
            input: <np.array> [num_feature, windows_size]
            output: <np.array> [num_feature]
        '''
        start = index * self.sliding
        end = index * self.sliding + self.windows_size
        return self.data[start:end, :], self.data[end + 1, :]

In [3]:
# Simple test the dataset
# For convinent, we already transfer the data file to numpy file
# with the shape of [449919, 53]
dataset = SwatDataset(
    path='../data/swat-2015-data.npy',
    feature_idx=[0, 1, 2], 
    start_idx=1000,
    end_idx=9000, 
    windows_size=100,
    sliding=100
)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Get a data sample
inp, tar = dataset.__getitem__(0)
batch_inp, batch_tar = next(iter(dataloader))

print(f'Dataset Length: {len(dataset)}')
print(f'Input shape: {inp.shape}')
print(f'Target shape: {tar.shape}')
print(f'Batched Input shape: {batch_inp.size()}')
print(f'Batched Target shape: {batch_tar.size()}')

Dataset Length: 78
Input shape: (100, 3)
Target shape: (3,)
Batched Input shape: torch.Size([32, 100, 3])
Batched Target shape: torch.Size([32, 3])
