In [None]:
from os.path import join
from sklearn.externals import joblib
from keras.models import model_from_json
import json
from IPython.display import display
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display, clear_output, Markdown

out_model_load = widgets.Output()

dict_models = dict()
with open(join(modelDirectory, 'summary.json'), 'r') as summary_file:
    dict_models = json.load(summary_file)
    
selectedModels = tuple()
def load_selectModels(selected_model):
    with out_model_load:
        global selectedModels
        selectedModels = list(selected_model)
        print (selectedModels)
    
def load_show_devices(Source):
    load_device.options = [s for s in list(records.readings[Source]['devices'].keys())]
    load_device.source = Source
    
def load_show_dates(Source):
    load_min_date.value = records.readings[Source]['devices'][load_device.value]['data'].index.min()._short_repr
    load_max_date.value = records.readings[Source]['devices'][load_device.value]['data'].index.max()._short_repr

def loadModel(b):
    with out_model_load:
        clear_output()
            
        if len(selectedModels)>0:
            global loaded_model
            global loaded_params
            global loaded_metrics
            global loaded_features
            global loaded_model_name
            
            if 'ARCHIVE' in selectedModels[0]:
                model_name = selectedModels[0][8:]
                print ('Loading model {} from disk'.format(model_name))
                filename = join(modelDirectory, load_target_drop.value, model_name)
                if load_type_drop.value == "LSTM":
                    # ML Model
                    # Load Model and weights
                    json_file = open(filename + "_model.json", "r")
                    loaded_model_json = json_file.read()
                    json_file.close()
                
                    loaded_model = model_from_json(loaded_model_json)
                    loaded_model.load_weights(filename + "_model.h5")
                elif load_type_drop.value == "OLS" or load_type_drop.value == 'RF' or load_type_drop.value == 'SVR':
                    # OLS, RF, or SVR Model
                    loaded_model = joblib.load(filename + '_model.sav')
                    
                # Load params and metrics
                loaded_params = joblib.load(filename + '_parameters.sav')
                loaded_metrics = joblib.load(filename + '_metrics.sav')
                loaded_features = joblib.load(filename + '_features.sav')
                print ('Model loaded from disk')
            elif 'SESSION' in selectedModels[0]:
                model_name = selectedModels[0][8:]

                test_source = list_tests[list_model_session.index(selectedModels[0])]
                print ('Using model {} from current session'.format(model_name))

                loaded_model = records.readings[test_source]['models'][model_name]['model']
                loaded_params = records.readings[test_source]['models'][model_name]['parameters']
                loaded_metrics = records.readings[test_source]['models'][model_name]['metrics']
                loaded_features = records.readings[test_source]['models'][model_name]['features']
                loaded_ref = records.readings[test_source]['models'][model_name]['reference']
                print ('Model loaded from session')
            display(Markdown('## Model Load'))
            display(Markdown("Loaded " + model_name))
            display(Markdown('**Model Type** (*loaded_model*):' ))
            display(loaded_model)
            display(Markdown('**Model Parameters** (*loaded_params*)'))
            display(loaded_params)
            display(Markdown('**Model Metrics** (*loaded_metrics*)'))
            display(loaded_metrics)
            display(Markdown('**Model Features** (*loaded_features*)'))
            display(loaded_features)
            loaded_model_name = model_name
        else:
            print ('Select one model to load')
    
def load_show_models(target, mtype):
    with out_model_load:
        clear_output()
        global list_tests
        global list_model_session
        list_models = list()
        for item in dict_models[target]:
            if dict_models[target][item] == mtype:
                list_models.append('ARCHIVE_' + item)
        list_tests = list()
        list_model_session = list()
        for reading in records.readings:
            if 'models' in records.readings[reading]:
                for model_name in records.readings[reading]['models']:
                    try:
                        if records.readings[reading]['models'][model_name]['model_type'] == mtype:
                            list_models.append('SESSION_' + model_name)
                            list_tests.append(reading)
                            list_model_session.append('SESSION_' + model_name)
                    except:
                        print ('Could not use model {} from current session. Model is not archived'.format(model_name))
        load_models.options = list(list_models)
        
def load_calculateModel(b):
    with out_model_load:
        
        load_test_name = load_device.source
        load_device_name = load_device.value
        load_prediction_name = load_result_text.value
        
        clear_output()
        # Predict based on choices
        records.predict_channels(load_test_name, load_device_name, loaded_model, loaded_features, loaded_params, 
                         load_type_drop.value, loaded_model_name, load_result_text.value, plot_result.value, load_min_date.value, load_max_date.value, 
                         clean_na = True, clean_na_method = 'fill', target_raster = '1Min')

display(widgets.HTML('<hr><h4>Import Local Models</h4>'))

# Test dropdown
load_test = widgets.Dropdown(options=[k for k in records.readings.keys()], 
                        layout=widgets.Layout(width='400px'),
                        description = 'Test')

load_test_drop = widgets.interactive(load_show_devices, 
                                Source=load_test, 
                                layout=widgets.Layout(width='600px'))

load_type_drop = widgets.Dropdown(options = ['LSTM', 'RF', 'OLS', 'SVR'],
                                  value = 'LSTM',
                                  description = 'Model Type',
                                  layout = widgets.Layout(width='300px'))

load_target_drop = widgets.Dropdown(options = ['ALPHASENSE', 'MICS', 'PMS'],
                                  value = 'MICS',
                                  description = 'Model Target',
                                  layout = widgets.Layout(width='300px'))

load_model_type_drop = widgets.interactive(load_show_models, 
                                target = load_target_drop,
                                mtype = load_type_drop, 
                                layout = widgets.Layout(width='700px'))

load_models = widgets.SelectMultiple(selected_labels = selectedModels, 
                           layout = widgets.Layout(width='700px'))

# Test dropdown
load_test_dd = widgets.Dropdown(options=[k for k in records.readings.keys()], 
                        layout=widgets.Layout(width='400px'),
                        description = 'Test')

load_models_interact = widgets.interactive(load_selectModels,
                                     selected_model = load_models,
                                     model_source= load_test_dd,
                                     layout = widgets.Layout(width='700px'))

load_min_date = widgets.Text(description='Start date:', 
                         layout=widgets.Layout(width='330px'))
load_max_date = widgets.Text(description='End date:', 
                         layout=widgets.Layout(width='330px'))



load_test_drop = widgets.interactive(load_show_devices, 
                                Source=load_test_dd, 
                                layout=widgets.Layout(width='400px'))

# Device dropdown
load_device = widgets.Dropdown(layout=widgets.Layout(width='200px'),
                        description = 'Device')

load_device_drop = widgets.interactive(load_show_dates, 
                                Source=load_device, 
                                layout=widgets.Layout(width='400px'))

# Sensor dropdown
load_result_text = widgets.Text(layout = widgets.Layout(width='300px'),
                               description = 'Result name')

load_calculateButton = widgets.Button(description='Predict channel')
load_calculateButton.on_click(load_calculateModel)
load_device_box = widgets.HBox([load_test_drop, load_device])
plot_result = widgets.Checkbox(value=True, 
                                     description='Plot Result', 
                                     disabled=False, 
                                     layout=widgets.Layout(width='300px'))
calculate_channel_box = widgets.HBox([load_result_text, load_calculateButton, plot_result])
display(load_model_type_drop)
display(load_models)

load_B = widgets.Button(description='Load Model')
load_B.on_click(loadModel)

buttonBox = widgets.VBox([load_B, load_device_box, load_min_date, load_max_date, calculate_channel_box])
display(buttonBox)
display(out_model_load)