# Import libs


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pvlib
import json
import os
from pvlib.pvsystem import PVSystem, Array, FixedMount
from pvlib.location import Location
from pvlib.modelchain import ModelChain
from pvlib.temperature import TEMPERATURE_MODEL_PARAMETERS
import plotly.graph_objects as go
import plotly.io as pio
from tqdm import tqdm
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_percentage_error, mean_absolute_error, root_mean_squared_error, r2_score
from sklearn.inspection import permutation_importance
import forestci as fci
from sklearn.model_selection import KFold
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import GridSearchCV
import threading
from sklearn.metrics import make_scorer
import dash
from dash import dcc, html
import plotly.graph_objects as go
from dash.dependencies import Input, Output
import webbrowser
from threading import Timer
import requests
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import datetime
from sklearn.utils.parallel import Parallel, delayed
from sklearn.utils.validation import (
    check_is_fitted,
)
from sklearn.ensemble._base import _partition_estimators

import pickle
import joblib
import os

# Setup


## Setup Directories


In [None]:
pio.renderers.default = "browser"  # render plotly figures in browser

PARENT_DATA_DIR = os.getenv('PARENT_DATA_DIR')
if PARENT_DATA_DIR is None:
    raise ValueError("PARENT_DATA_DIR environment variable is not set")


dataDirpath = PARENT_DATA_DIR + r"\PRiOT\dataExport_1"  # "/Applications/Documents/TM Maxime/dataExport_3400_daily"#
dataCacheDirpath = os.path.join(dataDirpath, "cache")
logsDirpath = "../logs"

if not os.path.exists(logsDirpath):
    os.makedirs(logsDirpath)

if not os.path.exists(dataCacheDirpath):
    os.makedirs(dataCacheDirpath)

## Setup Parameters


In [None]:
useCached = False
forceTrain = True
tuneMaxProductionEstimators = True
random_state = 42

trainingDays = 50
minTestingDays = 10
minMeasurements = trainingDays + minTestingDays

# Functions


## Serializer


In [None]:
# https://scikit-learn.org/stable/model_persistence.html
class ModelSerializer:
    def _save_model(self, model, serial_type, save_params):
        serial_type.dump(model, save_params)

    def _retrieve_model(self, serial_type, retrieve_params):
        return serial_type.load(retrieve_params)


class JoblibSerializer(ModelSerializer):
    def save_model(self, model, save_model_path, filename):
        super()._save_model(model, joblib, os.path.join(save_model_path, filename + ".joblib"))

    def retrieve_model(self, save_model_path, filename):
        return super()._retrieve_model(joblib, os.path.join(save_model_path, filename + '.joblib'))


class PickleSerializer(ModelSerializer):
    def save_model(self, model, save_model_path, filename):
        with open(os.path.join(save_model_path, filename + ".pkl"), 'wb') as f:
            super()._save_model(model, pickle, f)

    def retrieve_model(self, save_model_path, filename):
        with open(os.path.join(save_model_path, filename + ".pkl"), 'rb') as f:
            return super()._retrieve_model(pickle, f)

## Utils


In [None]:
def get_altitude_from_wgs84(longitude, latitude):
    # Convert WGS84 to LV95
    lv95_url = "https://geodesy.geo.admin.ch/reframe/wgs84tolv95"
    params_lv95 = {
        "easting": longitude,
        "northing": latitude,
        "format": "json"
    }

    response_lv95 = requests.get(lv95_url, params=params_lv95)
    if response_lv95.status_code != 200:
        raise Exception("Error converting WGS84 to LV95: " + response_lv95.text)

    lv95_data = response_lv95.json()
    lv95_easting = lv95_data["easting"]
    lv95_northing = lv95_data["northing"]

    # Get altitude from LV95 coordinates
    altitude_url = "https://api3.geo.admin.ch/rest/services/height"
    params_altitude = {
        "easting": lv95_easting,
        "northing": lv95_northing
    }

    response_altitude = requests.get(altitude_url, params=params_altitude)
    if response_altitude.status_code != 200:
        raise Exception("Error retrieving altitude: " + response_altitude.text)

    altitude_data = response_altitude.json()
    altitude = altitude_data["height"]

    return float(altitude)

# Import


## Import metadata


## Import data


In [None]:
class DataHandler:

    def __init__(self, dataDirpath, dataCacheDirpath):
        self.dataDirpath = dataDirpath
        self.dataCacheDirpath = dataCacheDirpath

        self.measures = None
        self.metadata = None

        self.metadataFilepath = os.path.join(dataDirpath, "metadata.json")

    def get_systems_names(self):
        return self.measures.columns

    def load_metadata(self):
        with open(self.metadataFilepath, 'r') as f:
            self.metadata = json.load(f)

        for systemName, systemMetadata in tqdm(self.metadata.items(), desc="Post-processing metadata"):
            # Add altitude to metadata, if not already present (TODO : imporove with multi threading)
            if "loc_altitude" not in systemMetadata['metadata']:
                if "loc_longitude" in systemMetadata['metadata'] and "loc_latitude" in systemMetadata['metadata']:
                    systemMetadata['metadata']["loc_altitude"] = get_altitude_from_wgs84(systemMetadata['metadata']["loc_longitude"], systemMetadata['metadata']["loc_latitude"])

            # Add the default loss to metadata if not already present
            if 'loss' not in systemMetadata['metadata']:
                systemMetadata['metadata']['loss'] = 0

            # Convert key with "modX" in the name (x is the array number) to a dictionary with the array number as key
            keys_to_delete = []
            for key, value in systemMetadata['metadata'].items():
                if 'mod' in key:
                    # Extract the module number
                    array_num = key.split('_')[1][-1]
                    # Remove the module number from the key
                    new_key = '_'.join(key.split('_')[:1] + key.split('_')[2:])
                    # Add the key-value pair to the appropriate module dictionary
                    if 'arrays' not in systemMetadata:
                        systemMetadata['arrays'] = {}
                    if array_num not in systemMetadata['arrays']:
                        systemMetadata['arrays'][array_num] = {}
                    systemMetadata['arrays'][array_num][new_key] = value
                    keys_to_delete.append(key)
            for key in keys_to_delete:
                del systemMetadata['metadata'][key]

        # Save metadata with new format and value
        self.save_metadata()

    def save_metadata(self):
        with open(self.metadataFilepath, 'w') as f:
            json.dump(self.metadata, f, indent=4)

    def load_csv(self):
        measures_dic = {}
        duplicates_list = []
        for filename in tqdm(os.listdir(self.dataDirpath), desc="Loading CSV files"):
            if filename.endswith(".csv"):
                systemName = filename.split('_')[0]
                systemMeasures = pd.read_csv(os.path.join(self.dataDirpath, filename))
                # convert the timestamp to datetime with correct timezone
                systemMeasures['Datetime'] = pd.to_datetime(systemMeasures['Timestamp'], unit='ms', utc=True).dt.tz_convert('Europe/Zurich')
                # Convert the datetime to only the date, as the production is the daily production. The +1h is to manage the saving time. Normally PRiOT exports the data at midnight (local time) for the day after (e.g. the energy for the July 1st is saved at July 1st 00:00 Europe/Zurich). However it seams that the saving time is not always correctly handled, and sometime the export is done at 23:00 the day before (e.g. the energy for the July 1st is saved at June 30th 23:00 Europe/Zurich). This is why we add 1h to the datetime to be sure to have the correct date.
                systemMeasures['Date'] = (systemMeasures['Datetime'] + pd.Timedelta(hours=1)).dt.date
                # Set the date as index
                systemMeasures.set_index('Date', inplace=True)
                # Append in duplicates_list all the rows with duplicated index, for logging purpose
                if len(systemMeasures.index.duplicated(keep=False)):
                    duplicates_list.append(systemMeasures[systemMeasures.index.duplicated(keep=False)])
                # keep only the measure tt_forward_active_energy_total_toDay as a Series
                systemMeasures = systemMeasures['tt_forward_active_energy_total_toDay']
                # Group by the index (Date) and sum the systemMeasures for each date to handle duplicates
                systemMeasures = systemMeasures.groupby('Date').sum()

                measures_dic[systemName] = systemMeasures
        # convert the dictionary of series to a pandas dataframe
        self.measures = pd.DataFrame(measures_dic)
        # Log the duplicates
        duplicates_df = pd.concat(duplicates_list)
        logFilename = os.path.join(logsDirpath, "measureDuplicates.csv")
        print(f"Number of duplicate dates found: {len(duplicates_df)} (see log file {logFilename} for more details)")
        duplicates_df.to_csv(logFilename, index=True)

    def check_integrity(self):
        # Check if the metadata is loaded
        if self.metadata is None:
            raise ValueError("Metadata not loaded. Please load the metadata first.")

        # Check if the measures are loaded
        if self.measures is None:
            raise ValueError("Measures not loaded. Please load the measures first.")

        invalidSystems = []

        for systemName in tqdm(self.get_systems_names(), desc="Checking data integrity"):
            invalidSystem = False

            # Check if the system has measures
            if systemName not in self.measures or self.measures[systemName].count() == 0:
                invalidSystem = True
                print(f"System {systemName} : No measures found")
            # Check if the system has metadata
            if systemName not in self.metadata:
                invalidSystem = True
                print(f"System {systemName} : No metadata found")
            else:
                # Check metadata for the entire system
                systemMetadata = self.metadata[systemName]
                for key in ['loc_latitude', 'loc_longitude', 'loc_altitude', 'pv_kwp']:
                    # test that the key is present
                    if key not in systemMetadata['metadata']:
                        invalidSystem = True
                        print(f"System {systemName} : No '{key}' found")
                    # if present, convert the value to a number, if possible
                    elif not isinstance(systemMetadata['metadata'][key], (int, float)):
                        try:
                            systemMetadata['metadata'][key] = int(systemMetadata['metadata'][key])
                        except ValueError:
                            try:
                                systemMetadata['metadata'][key] = float(systemMetadata['metadata'][key])
                            except ValueError:
                                invalidSystem = True
                                print(f"System {systemName} : The key-value '{key}:{systemMetadata['metadata'][key]}' is not a number")

                # Check metadata for the arrays
                if 'arrays' not in systemMetadata or len(systemMetadata['arrays']) == 0:
                    print(f"System {systemName} : No PV arrays found")
                    invalidSystem = True
                else:
                    for array_num, arrayData in systemMetadata['arrays'].items():
                        for key in ['pv_tilt', 'pv_azimut', 'pv_wp', 'pv_number']:
                            if key not in arrayData:
                                invalidSystem = True
                                print(f"System {systemName} : No '{key}' found for array {array_num}")
                            # test that the value is a number, and convert it if possible
                            elif not isinstance(arrayData[key], (int, float)):
                                try:
                                    arrayData[key] = int(arrayData[key])
                                except ValueError:
                                    try:
                                        arrayData[key] = float(arrayData[key])
                                    except ValueError:
                                        invalidSystem = True
                                        print(f"System {systemName} : The key-value '{key}:{arrayData[key]}' is not a number for array {array_num}")
            if invalidSystem:
                invalidSystems.append(systemName)

        if len(invalidSystems) > 0:
            # remove the invalid systems from the measures
            nbrSystems = len(self.get_systems_names())
            print(f"Number of systems with all the necessary data: {nbrSystems - len(invalidSystems)}/{nbrSystems}")
            self.measures.drop(columns=invalidSystems, inplace=True)

    def get_missing_value(self, sorted=True):
        if sorted:
            # Sort columns by number of missing values
            sorted_columns = self.measures.isnull().sum().sort_values().index
            sorted_measures = self.measures[sorted_columns]

            # Create a boolean DataFrame where True indicates missing values
            missing_values = sorted_measures.isnull()
        else:
            missing_values = self.measures.isnull()
        return missing_values
    
    def create_train_test_set(self, test_size=None, train_size=None, random_state=None, shuffle=True):
    
        measuresTrain, measuresTest = train_test_split(self.measures, test_size=test_size, random_state=random_state, shuffle=shuffle)
        if train_size is not None:
            # Sort the observation by number of missing values in ascending order. 
            # Then, look at the number of missing value of the train_size th element

In [None]:
dataHandler = DataHandler(dataDirpath, dataCacheDirpath)
dataHandler.load_metadata()
dataHandler.load_csv()
dataHandler.check_integrity()