In [1]:
import sys
sys.path.append('../model_analysis')
from model_utils import *
from run_grid_search import load_json
import pandas as pd 
import numpy as np 
import os
from typing import Tuple
from tqdm import tqdm
import seaborn as sns
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.metrics import r2_score
sns.set(rc={"figure.figsize": (20, 10)})

In [2]:
FEATURES = COLUMNS
IGNORE = ["knn", "svr", "ensamble2", "ensamble3"]

In [3]:
def df_to_train_data(df: pd.DataFrame) -> Tuple[np.array, np.array]:
    return df.drop(columns=["paredao", "nome", "rejeicao"], axis=1).to_numpy(), df.drop(columns=df.columns[:-1], axis=1)

def run_evaluation() -> pd.DataFrame:

    model_infos_df = pd.DataFrame()

    for model_name in MODELS.keys():
        if model_name in IGNORE: continue

        model = MODELS[model_name]
        norm = NORMALIZE[model_name]
        params = PARAMETERS[model_name]
        feat = FEATURES

        reg = model(**params)

        data_df = get_data(feat, normalize=norm)
        cols = data_df.columns
        X, y = df_to_train_data(data_df)
        y = np.ravel(y)
        
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
        reg.fit(X_train, y_train)
        r2 = r2_score(y_test, reg.predict(X_test))

        importances = reg.coef_ if model_name not in ["random_forest", "ada_boost"] else reg.feature_importances_.reshape(-1,)
        index = importances.argmax() 
        print(index, model_name)
        most_important = cols[index]

        model_infos_df = model_infos_df.append({
                "model": model_name,
                "r2": r2,
                "most_important_feature": most_important
            },
            ignore_index=True
        )
    return model_infos_df

In [4]:
infos = run_evaluation()

13 linear_regression
16 ada_boost
16 random_forest
8 lasso
7 ridge
8 elastic_net
8 sgd


In [5]:
infos

Unnamed: 0,model,most_important_feature,r2
0,linear_regression,retweets,0.291353
1,ada_boost,fora,0.482884
2,random_forest,fora,0.630162
3,lasso,negativos_global_pct,0.13879
4,ridge,neutros_global_pct,0.351417
5,elastic_net,negativos_global_pct,0.127405
6,sgd,negativos_global_pct,0.272415
