# India power consumption

### Data split

We consider the `_long_data.csv` file from the dataset (accessed [here](https://kaggle.com/twinkle0705/state-wise-power-consumption-in-india), already sorted by timestamp) to split rows into a training dataset and a held out testing dataset.

In [None]:
import os
import pandas as pd

In [None]:
path = './datasets/india_power_consumption'
data = pd.read_csv(os.path.join(path, 'data.csv')).astype(str)

train_fraction=0.8
train_data = data[:int(data.shape[0]*train_fraction)]
test_data = data[int(data.shape[0]*train_fraction):]

### Training

In [None]:
import mindsdb_native

In [None]:
mdb = mindsdb_native.Predictor(name='pted_usecase2')
params = { 'order': 'Dates', 'target': 'Usage', 'group': 'States', 'window': 5 }

mdb.learn(from_data=train_data, 
          to_predict=params['target'],
          timeseries_settings={
              'order_by':  [params['order']]
              ,'window':    params['window']
              ,'group_by': [params['group']]
          })

## Predict + visualize

In [None]:
# pio.renderers.default = "browser" # turn this on to see plots in separate web browser tabs

def plotter(time, real, predicted, confa, confb, labels, anomalies=None):
    """ We use Plotly to generate forecasting visualizations """
    from ipywidgets import interact, interactive, fixed, interact_manual
    import plotly.graph_objects as go
    import plotly.io as pio
    
    fig = go.Figure()
    
    if confa is not None and confb is not None:
        fig.add_trace(go.Scatter(x=time,y=confa,
                                 name = 'Confidence',
                                 fill=None,
                                 mode='lines',
                                 line = dict(color='#919EA5', width=0 )))

        fig.add_trace(go.Scatter(x=time,y=confb,
                                 name='Confidence',
                                 fill='tonexty',
                                 mode='lines',
                                 line = dict(color='#919EA5', width=0 )))

    fig.add_trace(go.Scatter(x=time, y=real,
                             name='Real',
                             line=dict(color='rgba(0,176,109,1)', width=3)))

    fig.add_trace(go.Scatter(x=time, y=predicted,
                             name='Predicted',
                             showlegend=True,
                             line=dict(color='rgba(103,81,173,1)', width=3)),)

    if anomalies:
        for (t_idx, t), anomaly in zip(enumerate(time), anomalies):
            if anomaly:
                t1 = time[t_idx-1] if t_idx > 0 else t
                t3 = time[t_idx+1] if t_idx < len(time)-1 else t
                fig.add_vrect(x0=t1, x1=t3, line_width=0, opacity=0.25, fillcolor="orange")

    fig.update_layout(
        xaxis=dict(
            showline=True,
            showgrid=True,
            showticklabels=True,
            gridwidth=1,
            gridcolor='rgb(232,232,232)',
            linecolor='rgb(181, 181, 181)',
            linewidth=2,
            ticks='outside',
            tickfont=dict(
                family='Source Sans Pro',
                size=14,
                color='rgb(44, 38, 63)',
            ),
        ),
        yaxis=dict(
            showgrid=True,
            zeroline=True,
            showline=True,
            linecolor='rgb(181, 181, 181)',
            linewidth=2,

            showticklabels=True,
            gridwidth=1,
            gridcolor='rgb(232,232,232)',
            tickfont=dict(
                family='Source Sans Pro',
                size=14,
                color='rgb(44, 38, 63)',
            ),

        ),
        autosize=True,
        showlegend=True,
        plot_bgcolor='white',
        hovermode='x',
        
        font_family="Courier New",
        font_color='rgba(0,176,109,1)',
        title_font_family="Times New Roman",
        title_font_color='rgba(0,176,109,1)',
        legend_title_font_color='rgba(0,176,109,1)',
        
        title=labels['title'],
        xaxis_title=labels['xtitle'],
        yaxis_title=labels['ytitle'],
        legend_title=labels['legend_title'],
    )

    return fig

In [None]:
def plot_call(r, target, order, titles, show_anomaly=False, idx=None, n=1, window=0):
    """
    Calls the plotter using predictor results.
    idx: for t+n predictors, specifies at which point to forecast each test series
    n: number of predictions
    """
    if isinstance(r._data[f'{target}'], list):
        forecasting_window = len(r._data[f'{target}'])
    else:
        forecasting_window = len(r._data[f'{target}'][0])
        
    for key in [f'{params["target"]}', f'__observed_{params["target"]}']:      
        if isinstance(r._data[key][0], list) and (forecasting_window == 1 or idx is None):
            r._data[key] = [p[0] for p in r._data[key]]

    results = pd.DataFrame.from_dict(r._data).sort_values(order)
    time_target = results[order].values
    results = results[window:]
    
    if idx is None:
        real_target = [float(r) for r in results[f'__observed_{target}']]
        pred_target = [p for p in r._data[f'{target}']][window:]
        conf_lower = [c[0] for c in results[f'{target}_confidence_range']]
        conf_upper = [c[1] for c in results[f'{target}_confidence_range']]
    else:
        pred_target = [None for _ in range(idx)] + [p for p in r._data[f'{target}'][idx]]
        real_target = [float(r) for r in results[f'__observed_{target}']][:idx+n]
        conf_lower = None
        conf_upper = None
        
    anomalies = [c for c in results[f'{target}_anomaly']] if show_anomaly else None
    fig = plotter(time_target, real_target, pred_target, conf_lower, conf_upper, labels=titles, anomalies=anomalies)
    fig.show()

In [None]:
def forecast(model, df, params, limit=-1, state=None, show_anomaly=False):
    groups = df[params['group']].unique()
    data = df.iloc[:limit].astype(str)
    advanced_args = {'anomaly_error_rate': 0.01, 'anomaly_cooldown': 1, 'anomaly_detection': show_anomaly}

    r = {group: model.predict(when_data=data[data[params['group']] == group], advanced_args=advanced_args)
         for group in groups 
         if (state is None or group == state)
         and data[data[params['group']] == group].shape[0] > 0}
    
    for group, rr in r.items():
        if (state is None or group == state):
            titles = { 'title': f'MindsDB t+1 forecast for State {group}',
                      'xtitle': 'Date (Unix timestamp)',
                      'ytitle': params['target'],
                      'legend_title': 'Legend'
                     }
            plot_call(rr, params['target'], params['order'], titles, show_anomaly=show_anomaly, window=params['window'])

In [None]:
# Specific state
state = None#'Pondy'  # None predicts for all states
forecast(mdb, test_data, params, state=state, show_anomaly=True)