In [6]:
import holidays
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
import pywt

class DatasetPreprocess:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def wavelet_denoising(self, signal, wavelet='db5', threshold=0.04):
        # Perform discrete wavelet transform (DWT)
        coeffs = pywt.wavedec(signal, wavelet)

        coeffs_thresholded = [pywt.threshold(c, threshold, mode='soft') for c in coeffs]

        # Reconstruct the signal from coefficients without thresholding
        reconstructed_signal = pywt.waverec(coeffs_thresholded, wavelet)
        return reconstructed_signal
    
    def refine_ev_charging_values(self, reconstructed):
        counts, bin_edges = np.histogram(reconstructed['ev_car'], bins='doane')

        sorted_indices = np.argsort(counts)[::-1]

        # Get the index of the third maximum value
        third_max_index = sorted_indices[1]

        # Determine the bin corresponding to the third maximum value
        charging_bin_start = bin_edges[third_max_index]

        reconstructed['ev_car'] = np.where(reconstructed['ev_car'] >= charging_bin_start, reconstructed['ev_car'], 0)
        return reconstructed
    
    def process_faulty_values(self, dt):
        # Replacing the total_power_consumption values by the sum of the individual appliances that have smaller value than actually consumed
        dt['total_power_consumption'] = np.where(dt['total_power_consumption'] < dt['total_usage'], dt['total_usage'], dt['total_power_consumption'])
        return dt
    
    def preprocess_data(self, data):
        # Perform wavelet denoising
        data['total_power_consumption'] = self.wavelet_denoising(data['total_power_consumption'])
        data = self.extract_features(data)
        # Refine the electric vehicle charging values
        data = self.refine_ev_charging_values(data)

        data = self.process_faulty_values(data)

        return data
    
    def extract_features(self, df):
        # Extracting the features from the data
        df['local_15min'] = pd.to_datetime(df['local_15min'], utc=True)
        df['Month']=df['local_15min'].dt.month
        # df['Quarter_of_year']=df['local_15min'].dt.quarter
        df['Day_of_week']=df['local_15min'].dt.day_of_week + 1
        df['Day']=df['local_15min'].dt.day
        df['Hour']=df['local_15min'].dt.hour
        df['holiday'] = df['local_15min'].apply(lambda x: 1 if x in holidays.US() else 0)
        df['Mean'] = df['total_power_consumption'].expanding().mean()
        return df