In [None]:
class WindowGenerator():
    def __init__(self, input_width, label_width, shift, x_data, y_data):
        self.x_data = x_data
        self.y_data = y_data
        
        self.input_width = input_width
        self.label_width = label_width
        self.shift = shift
        self.total_window_size = input_width + shift
        
        self.input_slice = slice(0, input_width)
        self.label_slice = slice(input_width, self.total_window_size)
        
    def __repr__(self):
        return '\n'.join([
            f'Total window size: {self.total_window_size}',
            f'Input slice: {self.input_slice}',
            f'Label slice: {self.label_slice}'
        ])
    
    def make_dataset(self):
        data = tf.keras.utils.timeseries_dataset_from_array(
            data = self.x_data,
            targets = None,
            sequence_length = self.total_window_size,
            sequence_stride = 1,
            shuffle = True,
            batch_size = None
        )
        
        labels = tf.keras.utils.timeseries_dataset_from_array(
            data = self.y_data,
            sequence_length = self.total_window_size,
            sequence_stride = 1,
            shuffle = True,
            batch_size = None
        )
        
        dataset = tf.data.Dataset.zip((data, labels)).map(self.split_window)
        return dataset
    
    def split_window(self, features, labels):
        inputs = features[:, self_input_slice, :]
        labels = labels[:, self.label_slice, :]
        
        return inputs, labels
    
    def create_4d_windows(self):
        num_samples, num_days, num_features = self.x_data.shape
        num_targets = self.y_data.shape[2]
        num_windows = num_days - self.total_window_size + 1
        
        x_windows = np.zeros((num_samples, num_windows, self.input_width, num_features))
        y_windows = np.zeros((num_samples, num_windows, self.input_width, num_targets))
        
        for sample_idx in range(num_samples):
            for window_idx in range(num_windows):
                x_windows[sample_idx, window_idx] = self.x_data[sample_idx, window_idx:window_idx + self.input_width]
                y_windows[sample_idx, window_idx] = self.y_data[sample_idx, window_idx + self.input_width:window_idx + self.input_width + self.label_width]
                
        return x_windows, y_windows