In [1]:
import warnings
warnings.filterwarnings("ignore")
import os
import calendar
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from SSA import ssa_trend

In [2]:
styles = '''
<style>
.label_style {   
    font-weight: normal;
    color: black;
    font-size: 12px;
}
.block_header_style {   
    font-weight: bold !important;
    color: #3e8e41 !important;
    font-size: 14px !important;
    min-width: max-content !important;
}
.data_input_style input { background-color:white !important; }
.box_style{
    width:auto;
    border: 1px solid black !important;
    height: auto;
    background-color:#E5E8E8 !important;
    margin: 0px 0px 10px 0px !important;
    padding: 2px 10px 2px 10px !important;
}
</style>'''

In [3]:
os.chdir(os.path.join(os.getcwd(),'data'))
obs_data = pd.read_csv(f"station_clim_data.csv")

In [105]:
month_abbrs = [calendar.month_abbr[num] for num in range(1,13)]
cmip_data = {
    f"{abbr}": {'temp':pd.read_csv(f'tas_{abbr}.csv'),
                'prec':pd.read_csv(f'pr_{abbr}.csv'),
                'prec_era':pd.read_csv(f'pr_{abbr}_original.csv'),
                'evapo':pd.read_csv(f'evspsbl_{abbr}.csv')}
    for abbr in month_abbrs
}
cmip_data = {
    **cmip_data,
    **{season:{'temp':pd.read_csv(f'tas_{season}.csv'),
               'prec':pd.read_csv(f'pr_{season}.csv'),
               'prec_era':pd.read_csv(f'pr_{season}_original.csv'),
               'evapo':pd.read_csv(f'evspsbl_{season}.csv')}
       for season in ['winter', 'spring', 'summer', 'fall', 'warm', 'cold', 'annual']
    }
}

era_data = {
    f"{abbr}": {'temp':pd.read_csv(f'era_t2m_{abbr}.csv'),
                'prec':pd.read_csv(f'era_tp_{abbr}.csv'),
                'prec_era':pd.read_csv(f'era_tp_{abbr}_original.csv'),
                'evapo':pd.read_csv(f'era_e_{abbr}.csv')}
    for abbr in month_abbrs
}
era_data = {
    **era_data,
    **{season:{'temp':pd.read_csv(f'era_t2m_{season}.csv'),
               'prec':pd.read_csv(f'era_tp_{season}.csv'),
               'prec_era':pd.read_csv(f'era_tp_{season}_original.csv'),
               'evapo':pd.read_csv(f'era_e_{season}.csv')}
       for season in ['winter', 'spring', 'summer', 'fall', 'warm', 'cold', 'annual']
    }
}

In [141]:
class GUI(widgets.VBox):
    
    def __init__(self):        
        super().__init__()       
        
        control_panel = self.create_control_panel()
        control_panel.add_class('label_style') 
        control_panel.add_class('data_input_style')        
       
        self.compute()
        self.variable_change()
        
        indicators = self.create_indicator_panel()        
        chart = self.create_chart()
        
        self.children  = [widgets.HTML(styles), control_panel, indicators, chart]       
             
        self.plot_graph()        
        
        self.variable.observe(self.on_variable, "value")
        self.season.observe(self.on_season, "value")
        self.region.observe(self.on_region, "value")
        self.experiment.observe(self.on_experiment, "value")
        self.base_period.observe(self.on_base_period, "value")
        self.future_period.observe(self.on_future_period, "value")
    
    def create_control_panel(self):
        
        self.variable = widgets.Dropdown(
            options=[("Temperature", "temp"), ("Precipitation","prec"),
                     ("Precipitation (ERA5)","prec_era"), ("Evaporation (ERA5)","evapo")],
            value="temp",                    
            disabled=False,
            layout = widgets.Layout(width = "130px")
        )
        w_variable = widgets.VBox(
            [
                widgets.Label("Variable"),
                self.variable
            ],            
        )
        periods = list(map(lambda s: s.title() + ' period' if s in ['cold','warm'] else s.title(), cmip_data.keys()))        
        self.season = widgets.Dropdown(
            options=[(x,y) for x, y in zip(periods, cmip_data.keys())],
            value="annual",                    
            disabled=False,
            layout = widgets.Layout(width = "130px")
        )
        w_season = widgets.VBox(
            [
                widgets.Label("Season"),
                self.season
            ],            
        ) 
        self.region = widgets.Dropdown(
            options=["Brest","Vitebsk","Gomel","Grodno","Minsk","Mogilev"],
            value="Brest",                    
            disabled=False,
            layout = widgets.Layout(width = "130px")
        )
        w_region = widgets.VBox(
            [
                widgets.Label("Region"),
                self.region
            ],            
        ) 
        self.experiment = widgets.Dropdown(
            options=['SSP1-2.6', 'SSP2-4.5', 'SSP3-7.0', 'SSP5-8.5'],
            value='SSP1-2.6',                    
            disabled=False,
            layout = widgets.Layout(width = "130px")
        )
        w_experiment = widgets.VBox(
            [
                widgets.Label("Experiment"),
                self.experiment
            ],            
        )
        
        options = [(f"{n}-{n+29}", [n, n+29]) for n in [1961, 1981, 1991]]
        self.base_period = widgets.Dropdown(
            options=options,
            value=[1961,1990],                    
            disabled=False,
            layout = widgets.Layout(width = "130px")
        )
        w_base = widgets.VBox(
            [
                widgets.Label("Base Period"),
                self.base_period
            ],            
        )
        
        options = [(f"{n}-{n+19}", [n, n+19]) for n in [2021, 2041, 2061, 2081]]
        self.future_period = widgets.Dropdown(
            options=options,
            value=[2041,2060],                    
            disabled=False,
            layout = widgets.Layout(width = "130px")
        )
        w_future = widgets.VBox(
            [
                widgets.Label("Future Period"),
                self.future_period
            ],            
        )
        
        panel = widgets.HBox(
            [
                w_variable,               
                w_season,
                w_region,
                w_experiment,
                w_base,
                w_future
            ],
            layout = widgets.Layout(display = 'flex', justify_content = 'space-around', width = '90%')
        ).add_class('box_style')
        return panel
    
    def create_chart(self):        
        self.graph = widgets.Output(
            layout = widgets.Layout(
                display='flex',
                flex_flow='row',
                justify_content='center'))
        chart = widgets.VBox(
            [self.graph,],
            layout = widgets.Layout(display = 'flex', justify_content = 'center', width = '90%')
        )
        return chart
        
    
    def on_variable(self, change):
        self.variable.value = change.new
        self.compute()
        self.variable_change()
        self.update_indicators()
        self.plot_graph()
        
    def on_season(self, change):
        self.season.value = change.new
        self.compute()
        self.variable_change()
        self.update_indicators()
        self.plot_graph()
    
    def on_region(self, change):
        self.region.value = change.new
        self.compute()
        self.variable_change()
        self.update_indicators()
        self.plot_graph()   
    
    def on_experiment(self, change):
        self.experiment.value = change.new
        self.compute()
        self.variable_change()
        self.update_indicators()
        self.plot_graph()   
    
    def on_base_period(self, change):
        self.base_period.value = change.new
        self.variable_change()
        self.update_indicators()
    
    def on_future_period(self, change):
        self.future_period.value = change.new
        self.variable_change()
        self.update_indicators()
            
    def plot_graph(self): 
        self.graph.clear_output()
        with self.graph:                     
            self.plot_requested_data()
    
    def create_indicator_panel(self):        
        
        fig = make_subplots(
            rows = 1, cols = 6,
            specs=[
                [None, {"type": "indicator"}, {"type": "indicator"}, {"type": "indicator"}, None, None],
            ],           
        )
        fig.add_trace(
            go.Indicator(
                mode="number",
                value=self.base_value,
                number = {'suffix': u' \N{DEGREE SIGN}C','font':{'size':20}},                
                title="<span style='font-size:3em'>Climatic norm</span>",
            ),
            row=1, col=2
        )
        fig.add_trace(
            go.Indicator(
                mode="number+delta",
                value=self.future_value,
                number = {'suffix': u' \N{DEGREE SIGN}C','font':{'size':20}},
                delta = {'reference': self.base_value, 'font':{'size':20}},
                title="<span style='font-size:3em'>CMIP6 original</span>",
            ),
            row=1, col=3
        )
        fig.add_trace(
            go.Indicator(
                mode="number+delta",
                value=self.future_corrected_value,
                number = {'suffix': u' \N{DEGREE SIGN}C','font':{'size':20}},
                delta = {'reference': self.base_value,'font':{'size':20}},
                title="<span style='font-size:3em'>CMIP6 corrected</span>",
            ),
            row=1, col=4
        )
        
        fig.update_layout(
#             paper_bgcolor="lightgray",
            height = 110,
            margin=dict(l=0, r=0, t=10, b=0)
        )
        self.indicators = go.FigureWidget(fig)
        
        indicators = widgets.VBox(
            [self.indicators,],
            layout = widgets.Layout(display = 'flex', justify_content = 'center', width = '90%')
        )
        return indicators
    
    def update_indicators(self):
        suffix = u' \N{DEGREE SIGN}C' if self.variable.value == 'temp' else " mm"
        self.indicators.data[0].value = self.base_value        
        self.indicators.data[0].number['suffix'] = suffix
        
        self.indicators.data[1].value = self.future_value
        self.indicators.data[1].delta['reference'] = self.base_value
        self.indicators.data[1].number['suffix'] = suffix
        
        self.indicators.data[2].value = self.future_corrected_value
        self.indicators.data[2].delta['reference'] = self.base_value
        self.indicators.data[2].number['suffix'] = suffix
        
    
    def plot_requested_data(self):
        fig, ax = plt.subplots(1, 1, figsize = (9, 4))
        fig.set_dpi(650)
        real_data = 'ERA5' if self.variable.value in ['prec_era', 'evapo'] else 'Observation'
        ax.plot(self.era_years, self.era_variable, color = '#D95319', alpha = 0.5, label=real_data)
        ax.plot(self.era_years, self.era_trend, color='#D95319', label = f'{real_data} non-linear trend')
        ax.plot(self.hist_years, self.hist_50, color = 'gray', label='Historical 50th quantile')
        ax.fill_between(self.hist_years, self.hist_90, self.hist_10,
                        alpha=0.3, color='gray',
                        label='10th and 90th quantile range')
        ax.plot(self.hist_corr_years, self.hist_corr, color='k', alpha = 0.75,
                label = 'Historical with correction')
        ax.plot(self.proj_years, self.proj_50, color = '#0072BD', label=f'{self.experiment.value} 50th quantile')
        ax.fill_between(self.proj_years, self.proj_90, self.proj_10,
                        alpha=0.2, color='#0072BD',
                        label=f'{self.experiment.value} 10th and 90th quantile range')
        ax.plot(self.proj_years, self.proj_corr, color = '#77AC30',
                label=f'{self.experiment.value} with correction')
        period = self.season.value        
        if period in month_abbrs:
            num = month_abbrs.index(period)+1
            period = calendar.month_name[num]
        if period in ['warm','cold']:
            period += ' period'
        period = period.title()
        if self.variable.value == 'temp':
            ax.set_title(f'{real_data} and Modelled Average {period} Temperature in {self.region.value}')
            ax.set_ylabel('Temperature, $^\circ$C')            
        else:
            param = 'Precipitation' if self.variable.value in ['prec', 'prec_era'] else 'Evapotranspiration'                       
            ax.set_title(f'{real_data} and Modelled {period} {param} in {self.region.value}')
            ax.set_ylabel(f'{param}, mm')
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles, labels, fontsize = 8, loc="upper left", bbox_to_anchor=(1,1))
        ax.tick_params(axis='both', which='major')
        ax.grid(linestyle='--') 
        plt.show()
        
    def compute(self):
        variable = self.variable.value
        season = self.season.value
        region = self.region.value
        experiment = self.experiment.value
        
        eravarmap = {'temp':f't2m_{season}', 'prec':f'tp_{season}',
                     'prec_era':f'tp_{season}', 'evapo':f'e_{season}'}
        cmipvarmap = {'temp':'tas', 'prec':'pr',
                      'prec_era':'pr', 'evapo':'evspsbl'}

        cmip = cmip_data[season][variable]
        cmipvar = cmipvarmap[variable]
        cmip_proj = cmip[(cmip['region']==region) & (cmip['experiment']==experiment)]
        self.proj_years = cmip_proj['year']
        cmip_hist = cmip[(cmip['region']==region) & (cmip['experiment']=='Historical')]
        self.hist_years = cmip_hist['year']
        
        era = era_data[season][variable]  
        # era = era[(era['region']==region) & (era['year']<=cmip_hist.iloc[-1]['year'])].dropna()
        era = era[era['region']==region].dropna()
        self.era_years = era['year']
        eravar = eravarmap[variable]        
        self.era_variable = era[eravar]
        self.era_trend, _ = ssa_trend(era[eravar].values, L = 20, n_components = 1, n_forcast = 0)    

        self.hist_50 = cmip_hist[f"{cmipvar}_50"]
        self.hist_10 = cmip_hist[f"{cmipvar}_10"]
        self.hist_90 = cmip_hist[f"{cmipvar}_90"]
        ids = cmip_hist[cmip_hist['year'] >= era.iloc[0]['year']].index
        self.hist_corr = cmip_hist.loc[ids,f"{cmipvar}_corr"]   
        self.hist_corr_years = cmip_hist.loc[ids,'year']

        self.proj_50 = cmip_proj[f"{cmipvar}_50"]
        self.proj_10 = cmip_proj[f"{cmipvar}_10"]
        self.proj_90 = cmip_proj[f"{cmipvar}_90"]
        self.proj_corr = cmip_proj[f"{cmipvar}_corr"]

        self.era_df = era[['year',eravar]].rename(columns={eravar:'variable'})
        self.cmip_df = cmip_proj[['year',f"{cmipvar}_50"]].rename(columns={f"{cmipvar}_50":'variable'})
        self.cmip_corr_df = cmip_proj[['year',f"{cmipvar}_corr"]].rename(columns={f"{cmipvar}_corr":'variable'})
    
    def variable_change(self):
        df = self.era_df[self.era_df['year'].between(*self.base_period.value)]
        self.base_value = df['variable'].mean()
        
        df = self.cmip_df[self.cmip_df['year'].between(*self.future_period.value)]
        self.future_value = df['variable'].mean()
        
        df = self.cmip_corr_df[self.cmip_corr_df['year'].between(*self.future_period.value)]
        self.future_corrected_value = df['variable'].mean()        

In [142]:
GUI()

GUI(children=(HTML(value='\n<style>\n.label_style {   \n    font-weight: normal;\n    color: black;\n    font-…