In [None]:
import sys
import os
from  pathlib import Path
import yaml
import lightning.pytorch as pl

import torch
import pickle
torch.set_float32_matmul_precision('high')

import numpy as np
import pandas as pd
sys.path.insert(0, '../src')

from data_tools import WaVoDataModule
from models import WaVoLightningModule
import utility as ut

import plotly
import plotly.graph_objs as go
import ipywidgets as widgets


In [None]:
#Plotting functions

colors = ['rgba(0,0,0,_)',
          'rgba(255,0,0,_)','rgba(0,255,0,_)','rgba(0,0,255,_)',
          'rgba(255,0,255,_)','rgba(255,255,0,_)','rgba(0,255,255,_)',
          'rgba(128,0,0,_)','rgba(0,128,0,_)','rgba(0,0,128,_)',
          'rgba(128,128,0,_)','rgba(128,0,128,_)','rgba(0,128,128,_)',
          'rgba(128,128,128,_)','rgba(192,192,192,_)']

def gc(i,op=1):
    return colors[i % len(colors)].replace('_',str(op))


def plot_normal(df,start,days=60,pred_hours=[1,12,24,48]):
    end_string = str(pd.to_datetime(start)+pd.Timedelta(days=days))
    start = ut.get_start_index(start,df,0)
    end = ut.get_end_index(end_string,df)

    fig = go.Figure()
    fig.add_trace(go.Scatter(name='Gemessen',x=df[start:end].index,y=df[start:end][0], mode='lines',marker=dict(color=colors[0].replace('_','1')),showlegend=True))

    for i,n in enumerate(pred_hours):
        start_temp = start + n
        end_temp = end + n
        #fig.add_trace(go.Scatter(name=f"{n} Stunden", x=df[start_temp:end_temp].index,y=df[start:end][n],mode='lines',marker=dict(color=colors[n % len(colors)].replace('_','1')),showlegend=True,line=dict(dash='dot')))
        fig.add_trace(go.Scatter(name=f"{n} Stunden", x=df[start_temp:end_temp].index,y=df[start:end][n],mode='lines',marker=dict(color=gc(i+1)),showlegend=True,line=dict(dash='dot')))
    return fig

In [None]:
# helper functions predictions/loading
def load_stuff(model_dir):
    with open(model_dir / 'hparams.yaml', 'r') as file:
        #yaml_data = safe_load(file)
        yaml_data = yaml.load(file, Loader=yaml.FullLoader)
        yaml_data['scaler'] = pickle.loads(yaml_data['scaler'])
        #yaml_data
    
    data_module = WaVoDataModule(**yaml_data)
    #print(data_module.hparams)
    data_module.setup(stage='predict')
    df,data_loader,y_true = data_module.get_data_forecast()
    model = WaVoLightningModule.load_from_checkpoint(next((model_dir / 'checkpoints').iterdir()))

    return model, data_loader, y_true

def get_pred(model,data_loader,y_true):
    trainer = pl.Trainer(accelerator='gpu',devices=1,logger=False)
    pred = trainer.predict(model, data_loader)
    pred = np.concatenate(pred)
    
    y_pred = np.concatenate([np.expand_dims(y_true,axis=1),pred],axis=1)
    y_pred = pd.DataFrame(y_pred,index=y_true.index,columns=range(49))
    return y_pred

In [None]:
model_dir = Path('../../models_torch/treia/lightning_logs/version_7/')
model, data_loader, y_true = load_stuff(model_dir)
y_pred = get_pred(model, data_loader, y_true)


In [None]:
fig = plot_normal(y_pred,start="2022-02-15",days=15,pred_hours=[12,24,48])
fig.update_layout(
    template="plotly_white",
    width=1000,
    height=600,
    font_size=15,
    legend=dict(x=0.8, y=1,font=dict(size=20)),
    margin=dict(t=20,b=70,l=60,r=20,),
    xaxis_title="Zeit", 
    yaxis_title="Wasserstand [cm]"
)


## LFU July 23

### Pötrau multiple models

In [None]:
y_pred_list = []
for i in range(1,20):
    model_dir = Path(f'../../models_torch/pötrau/lightning_logs/version_{i}/')
    model, data_loader, y_true = load_stuff(model_dir)
    y_pred_list.append(get_pred(model, data_loader, y_true))

In [None]:
### Pötrau plot all models

In [None]:
start="2022-02-15"
days = 15
pred_hour = 1

df = y_pred_list[0]


end_string = str(pd.to_datetime(start)+pd.Timedelta(days=days))
start = ut.get_start_index(start,df,0)
end = ut.get_end_index(end_string,df)



fig = go.Figure()


fig.add_trace(go.Scatter(name='Gemessen',x=df[start:end].index,y=df[start:end][0], mode='lines',marker=dict(color=colors[0].replace('_','1')),showlegend=True))

n = pred_hour
for i,df in enumerate(y_pred_list):
    start_temp = start + n
    end_temp = end + n
    fig.add_trace(go.Scatter(name=f"{n} Stunden", x=df[start_temp:end_temp].index,y=df[start:end][n],mode='lines',marker=dict(color=gc(i+1)),showlegend=False,line=dict(dash='dot')))

fig.update_layout(
    template="plotly_white",
    width=1000,
    height=600,
    font_size=15,
    legend=dict(x=0.8, y=1,font=dict(size=20)),
    margin=dict(t=20,b=70,l=60,r=20,),
    xaxis_title="Zeit", 
    yaxis_title="Wasserstand [cm]"
)

### Pötrau plot sdv

In [None]:


n = 24

true = y_pred_list[0][0]
ensemble_forecast = pd.concat([df[n] for df in y_pred_list],axis=1)


In [None]:

start="2022-02-15"
days = 15

df = true

end_string = str(pd.to_datetime(start)+pd.Timedelta(days=days))
start = ut.get_start_index(start,df,0)
end = ut.get_end_index(end_string,df)



fig = go.Figure()


fig.add_trace(go.Scatter(name='Gemessen',x=df[start:end].index,y=df[start:end], mode='lines',marker=dict(color=colors[0].replace('_','1')),showlegend=True))

start_temp = start + n
end_temp = end + n

x = df[start_temp:end_temp].index
fig.add_trace(go.Scatter(name=f"{n} Stunden Mittelwert", x=x,y=ensemble_forecast[start:end].mean(axis=1),mode='lines',marker=dict(color=gc(1)),showlegend=True,line=dict(dash='dot')))
#fig.add_trace(go.Scatter(name=f"{n} Stunden Minimum", x=x,y=ensemble_forecast[start:end].min(axis=1),mode='lines',marker=dict(color=gc(2)),showlegend=True,line=dict(dash='dot')))
#fig.add_trace(go.Scatter(name=f"{n} Stunden Maximum", x=x,y=ensemble_forecast[start:end].max(axis=1),mode='lines',marker=dict(color=gc(3)),showlegend=True,line=dict(dash='dot')))
fig.add_trace(go.Scatter(name=f"{n} Stunden std", x=x,y=ensemble_forecast[start:end].mean(axis=1) + ensemble_forecast[start:end].std(axis=1),mode='lines',marker=dict(color=gc(2)),showlegend=True,line=dict(dash='dot')))
fig.add_trace(go.Scatter(name=f"{n} Stunden std", x=x,y=ensemble_forecast[start:end].mean(axis=1) - ensemble_forecast[start:end].std(axis=1),mode='lines',marker=dict(color=gc(2)),showlegend=True,line=dict(dash='dot')))


fig.update_layout(
    template="plotly_white",
    width=1000,
    height=600,
    font_size=15,
    legend=dict(x=0.8, y=1,font=dict(size=20)),
    margin=dict(t=20,b=70,l=60,r=20,),
    xaxis_title="Zeit", 
    yaxis_title="Wasserstand [cm]"
)

### Treia

In [None]:
model_dir = Path('../../models_torch/treia/lightning_logs/version_7/')
model, data_loader, y_true = load_stuff(model_dir)
y_pred = get_pred(model, data_loader, y_true)


In [None]:
fig = plot_normal(y_pred,start="2022-02-15",days=15,pred_hours=[12,24,48])
fig.update_layout(
    template="plotly_white",
    width=1000,
    height=600,
    font_size=15,
    legend=dict(x=0.8, y=1,font=dict(size=20)),
    margin=dict(t=20,b=70,l=60,r=20,),
    xaxis_title="Zeit", 
    yaxis_title="Wasserstand [cm]"
)
fig.write_image("treia_feb_23.pdf")


In [None]:
fig = plot_normal(y_pred,start="2022-05-01",days=90,pred_hours=[12,24,48])
fig.update_layout(
    template="plotly_white",
    width=1000,
    height=600,
    font_size=15,
    legend=dict(x=0.8, y=1,font=dict(size=20)),
    margin=dict(t=20,b=70,l=60,r=20,),
    xaxis_title="Zeit", 
    yaxis_title="Wasserstand [cm]"
)
fig.write_image("treia_summer_22.pdf")


In [None]:
fig = plot_normal(y_pred,start="2014-01-07",days=100000,pred_hours=[12,24,48])
fig.update_layout(xaxis_title="Zeit",yaxis_title="Wasserstand [cm]")
fig.write_html("treia.html")

### Pötrau

In [None]:
#model_dir = Path('../../models_torch/pötrau/lightning_logs/version_13/')
model_dir = Path('../../models_torch/pötrau/lightning_logs/version_8/')
model, data_loader, y_true = load_stuff(model_dir)
y_pred = get_pred(model, data_loader, y_true)


In [None]:
fig = plot_normal(y_pred,start="2022-02-15",days=15,pred_hours=[12,24,48])
fig.update_layout(
    template="plotly_white",
    width=1000,
    height=600,
    font_size=15,
    legend=dict(x=0.8, y=1,font=dict(size=20)),
    margin=dict(t=20,b=70,l=60,r=20,),
    xaxis_title="Zeit", 
    yaxis_title="Wasserstand [cm]"
)
#fig.write_image("pötrau_feb_23.pdf")


In [None]:
fig = plot_normal(y_pred,start="2020-09-01",days=90,pred_hours=[12,24,48])
fig.update_layout(
    template="plotly_white",
    width=1000,
    height=600,
    font_size=15,
    legend=dict(x=0.8, y=1,font=dict(size=20)),
    margin=dict(t=20,b=70,l=60,r=20,),
    xaxis_title="Zeit", 
    yaxis_title="Wasserstand [cm]"
)
fig.write_image("pötrau_fall_21.pdf")
fig

In [None]:
fig = plot_normal(y_pred,start="2014-01-07",days=10000,pred_hours=[12,24,48])
fig.update_layout(xaxis_title="Zeit",yaxis_title="Wasserstand [cm]")
fig.write_html("pötrau.html")

## alt

In [None]:
with open(model_dir / 'hparams.yaml', 'r') as file:
    #yaml_data = safe_load(file)
    yaml_data = yaml.load(file, Loader=yaml.FullLoader)
    yaml_data['scaler'] = pickle.loads(yaml_data['scaler'])
#yaml_data

data_module = WaVoDataModule(**yaml_data)
#print(data_module.hparams)
data_module.setup(stage='predict')
df,data_loader,y_true = data_module.get_data_forecast()
model = WaVoLightningModule.load_from_checkpoint(next((model_dir / 'checkpoints').iterdir()))

