# Prediction 
The class `models` takes inputs,
   - DataFrame
   - specified parent hexagon column
   - parent column name as string
   - child column name as string
   - minimum number of child hexagons for parent inclusion in modeling,
 
and provides methods to run and display results from the lasso and random forest models implemented.

In [16]:
import seaborn as sns
import pandas as pd
import numpy as np
import h3
from sklearn.linear_model import LassoLarsIC
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error
pd.set_option('display.max_columns', None)
import warnings
warnings.filterwarnings('ignore')

In [83]:
df = pd.read_csv('model10.csv')

In [84]:
example_inst = models(df = df,
                      parent = df.h6,
                      parent_str = 'h6',
                      child_str ='h10', 
                      numrows_cond = 2000)

In [79]:
class models:
    def __init__(self, df, parent, parent_str, child_str ,numrows_cond=0):
        self.parent = parent
        self.data = df
        self.numrows_cond = numrows_cond
        self.hexagons = parent.unique() # number of unique parent hexagons in the data
        
    def hexagons_used(self):
        df=self.data
        hex_list = []
        for h in self.hexagons:
            d=df.loc[self.parent==h, :]
            if d.shape[0]>=self.numrows_cond:
                hex_list.append(h)
        return hex_list
    
    def fit_forest(self):
        df = self.data
        model_info = {} 
        hex_list = self.hexagons_used()
        for h in hex_list:
            hdict = {}
            d=df.loc[self.parent==h, :]
            X = d.iloc[:,5:]
            y = d[['alerts']]
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state=0)
            forest = RandomForestRegressor(min_samples_split=8).fit(X_train, y_train)
            R2F = forest.score(X_test, y_test)
            hdict['R2'] = R2F
            hdict['forest'] = forest
            model_info[h] = hdict
        return model_info
    
    def display_forest(self):
        model_info = self.fit_forest()
        for h in model_info:
            print('')
            print('For Hexagon: ' + h)
            print('-------------------------------------------')
            print("R^2: {}".format(model_info[h]['R2']))
            print('-------------------------------------------')
            print('')
            print('')
            
    def fit_lasso(self):
        df = self.data
        model_info = {} 
        plots = []
        hex_list = self.hexagons_used()
        for h in hex_list:
            hdict = {}
            d=df.loc[self.parent==h, :]
            X = d.iloc[:,5:]
            y = d[['alerts']]
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state=0)
            lasso = LassoLarsIC(criterion='aic', normalize=True).fit(X_train, y_train)
            hdict['X_test'] = X_test
            hdict['X_train'] = X_train
            hdict['y_train'] = y_train
            hdict['y_pred'] = lasso.predict(X_test)
            hdict['y_test'] = y_test
            hdict['lasso'] = lasso
            hdict['nonzero'] = X.columns[lasso.coef_!=0]
            
            model_info[h] = hdict
        return model_info

    def display_lasso(self):
        model_info = self.fit_lasso()
        i=0
        for h in model_info:
            print('')
            print('For Hexagon: ' + h)
            print('-------------------------------------------')
            print("R^2: {}".format(model_info[h]['lasso'].score(model_info[h]['X_test'],
                                                                model_info[h]['y_test'])))
            print('')
            print(len(list(model_info[h]['nonzero'])), 'non-zero column(s):')
            print(list(model_info[h]['nonzero']))
            print('')
            print('Feature Importance')
            ax = plots[i]
            ax.show()
            print('')   
                        