In [60]:
import dash
from dash import Dash, dcc, html, callback, callback_context, dash_table
import dash_bootstrap_components as dbc
from dash_bootstrap_templates import load_figure_template
import dash_daq as daq
from jupyter_dash import JupyterDash
from dash.dependencies import Input, Output, State

import plotly.express as px
import plotly.graph_objects as go

import geopandas as gpd
import pandas as pd

import statsmodels.api as sm
from sklearn.neighbors import KNeighborsRegressor

from Functions import read_data, fit_regression, get_metrics, get_predicted_GeoJSON

In [61]:
# Read the data that will be used
X_train, X_test, y_train, y_test = read_data()
# To avoid reading the geojson everytime we update the map
nuts2 = gpd.read_file('https://gisco-services.ec.europa.eu/distribution/v2/nuts/geojson/NUTS_RG_01M_2016_4326_LEVL_2.geojson')

In [62]:
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.DARKLY])

load_figure_template(["darkly"])
color = 'darkgreen'
font_s = {'font-family' : 'bahnschrift'}

app.layout = html.Div([html.H1('Machine Learning Dashbord', style=font_s),
                       html.Div([''], style = {'height':15, 'background-color':color}),
                       html.Div([html.Div([' '], style = {'width':15}),
                                 html.Div([html.Div([' '], style = {'height':20}),
                                           html.H2('Regression algorithms', style = font_s),
                                           html.H3('Choose a model:', style = font_s),
                                           dcc.Dropdown(['Linear Regression', 'KNN Regression'],
                                                        'Linear Regression',
                                                        id = 'Model',
                                                        style={'font-family' : 'bahnschrift','width':440}),
                                           html.H4('Additional parameters:', style = font_s),
                                           html.Div([html.B('   K=', style = {'font-family' : 'bahnschrift','width':440}),
                                                     daq.NumericInput(min=1,
                                                                      max=30,
                                                                      value=3,
                                                                      style = {'font-family' : 'bahnschrift'},
                                                                      id='K')], 
                                                    style={'display':'flex', 'width':440}),
                                           html.Div([html.B('   YEAR=', style = {'font-family' : 'bahnschrift','width':440}),
                                                     daq.NumericInput(min=1999,
                                                                      max=2018,
                                                                      value=2012,
                                                                      style = {'font-family' : 'bahnschrift'},
                                                                      id='YEAR')], 
                                                    style={'display':'flex', 'width':440}),
                                           html.H5('Metrics:', style = font_s),
                                           dash_table.DataTable(id= 'metrics_table', 
                                                                style_header={'backgroundColor': color,'fontWeight': 'bold'},
                                                                style_cell={'textAlign': 'center', 'backgroundColor':'gray'},
                                                                style_table={'width':440}, cell_selectable = False, 
                                                                style_as_list_view=True)]),
                                 html.Div([' '], style = {'width':15}),
                                 html.Div([' '], style = {'width':20,'background-color':color}),
                                 html.Div([' '], style = {'width':15}),
                                 html.Div([html.Div([''], style = {'height':15}),
                                           html.H3('Crop yield predictions', style = font_s),
                                           dcc.Graph(id="Yield_Pred", style = {'width':1380, 'height' : 480}),
                                           html.Div([''], style = {'height':15}),
                                           html.H3('Error estimations', style = font_s)],
                                           # dcc.Graph(id="error_Pred", style = {'width':1400, 'height' : 450})],
                                         )],
                                style={'display':'flex', 'width':1880, 'height':880, 'overflow':'auto'}),
], style = {'overflow':'auto', 'height':1000})


@app.callback(
    [Output('metrics_table', 'data'),Output('K', 'disabled'), Output('Yield_Pred', 'figure')],
    [Input('Model','value'),Input('K','value'), Input('YEAR','value')]
)

def update_metrics_n_map(model, K, year):
    
    if model == 'Linear Regression':
        K_not_needed = True
        model, y_pred = fit_regression('LR', X_train, X_test, y_train, y_test)
        df = get_predicted_GeoJSON(nuts2, 'LR', year)
        
    elif model == 'KNN Regression':
        K_not_needed = False
        model, y_pred = fit_regression('KNN', X_train, X_test, y_train, y_test, K)
        df = get_predicted_GeoJSON(nuts2, 'KNN', year, K)

    fig = px.choropleth_mapbox(df, 
                               geojson=df.geometry, 
                               locations = df.index,
                               color='est_crop_yield',
                               color_continuous_scale="greens",
                               range_color=(35, 55),
                               mapbox_style="carto-positron",
                               zoom=6, center = {"lat": 51.97, "lon": 5.67}
                              )
    fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})
    
    return get_metrics(y_test,y_pred), K_not_needed, fig
    

if __name__ == '__main__':
    app.run_server(debug=True, port = 8070, )

Dash app running on http://127.0.0.1:8070/
