In [None]:
import pandas as pd
from typing import Tuple
from sklearn.preprocessing import StandardScaler
import io
import sys
import sympy as sp
import numpy as np
import pandas as pd
import seaborn as sns
from scipy import stats
from typing import Tuple
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, make_scorer
from sklearn.model_selection import (
    train_test_split,
    cross_val_score,
    GridSearchCV,
    RandomizedSearchCV,
)
import boilerplate.util as util

In [2]:
def nmae(y_true, y_pred):
    mae = mean_absolute_error(y_true, y_pred)
    nmae_value = mae / np.mean(y_true)
    return nmae_value
def nmae_fun(y_true, y_pred):
    mae = mean_absolute_error(y_true, y_pred)
    nmae_value = mae / np.mean(y_true)
    return nmae_value

def nrmse_fun(y_true, y_pred):
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    nrmse_value = rmse / np.mean(y_true)
    return nrmse_value

def get_datasets(dataset_name="sinusoid") -> Tuple[pd.DataFrame, pd.DataFrame]:
    source_data = "sinusoid_8h"

    if dataset_name == "mix":
        source_data = "mix_5h"
    if dataset_name == "flashcrowd":
        source_data = "flashcrowd_6h"

    data_log = pd.read_csv(f"assets/data/log_INT_{source_data}.txt", delimiter=",")

    data_log.columns = data_log.columns.str.replace(" ", "")

    data_dash = pd.read_csv(f"assets/data/dash_{source_data}.log", sep=",")

    return data_log, data_dash

def remove_useless_attribute(dataset):
    dataset.drop(columns=dataset.columns[dataset.nunique() == 1], inplace=True)
    return dataset

def remove_outlier_IQR(df):
    Q1 = df.quantile(0.25)
    Q3 = df.quantile(0.75)
    IQR = Q3 - Q1
    df_final = df[~((df < (Q1 - 1.5 * IQR)) | (df > (Q3 + 1.5 * IQR)))]
    return df_final

def change_NaN_to_mean(dataset):
    dataset = dataset.fillna(dataset.mean())
    return dataset

def merge_dataset(data_log, data_dash):

    total = data_log.merge(data_dash, on=['timestamp', 'timestamp'], how='left')
    total = total.dropna()
    
    features = total.iloc[:,1:len(data_log.columns)].values
    
    labels = total['framesDisplayedCalc'].values

    return normalization(features), labels

def normalization(X):
    scaler = StandardScaler().fit(X)
    X = scaler.transform(X)
    return X

In [3]:
def default_random_forest_model(
    features: pd.DataFrame, labels: pd.Series, model_params
):
    X_train, X_validation, y_train, y_validation = train_test_split(
        features,
        labels,
        test_size=model_params.test_size,
        random_state=model_params.random_state,
        shuffle=model_params.shuffle,
    )

   
    X_train_scaled = X_train
    X_validation_scaled = X_validation

    rf_model = RandomForestRegressor(
        n_estimators=model_params.n_estimators,
        max_depth=model_params.max_depth,
        min_samples_split=model_params.min_samples_split,
        min_samples_leaf=model_params.min_samples_leaf,
        bootstrap=model_params.bootstrap,
        verbose=model_params.verbose,
        max_features=model_params.max_features,
        n_jobs=model_params.n_jobs,
        random_state=model_params.random_state,
    )


    nmae_scorer = make_scorer(nmae, greater_is_better=False)

    if(model_params.shuffle):
        kf = KFold(
            n_splits=model_params.n_splits,
            shuffle=model_params.shuffle,
            random_state=model_params.random_state,
        )
    else:
        kf = KFold(
            n_splits=model_params.n_splits,
            shuffle=model_params.shuffle,
        )
    cross_val_scores = cross_val_score(
        rf_model, X_train_scaled, y_train, cv=kf, scoring=nmae_scorer
    )

    avg_cross_val_score = np.mean(cross_val_scores)


    rf_model.fit(X_train_scaled, y_train)

    predictions = rf_model.predict(X_validation_scaled)
    mae_rf = mean_absolute_error(y_validation, predictions)
    nmae_rf = nmae_fun(y_validation, predictions)

    return mae_rf, nmae_rf, rf_model

In [4]:
sinusoid_dash, sinusoid_log = get_datasets()
sinusoid_dash, sinusoid_log = pd.DataFrame(sinusoid_dash), pd.DataFrame(sinusoid_log)
sinusoid_dash = remove_useless_attribute(sinusoid_dash)
sinusoid_dash = remove_outlier_IQR(sinusoid_dash)
sinusoid_dash = change_NaN_to_mean(sinusoid_dash)


flashcrowd_dash, flashcrowd_log = get_datasets("flashcrowd")
flashcrowd_dash, flashcrowd_log = pd.DataFrame(flashcrowd_dash), pd.DataFrame(flashcrowd_log)
flashcrowd_dash = remove_useless_attribute(flashcrowd_dash)
flashcrowd_dash = remove_outlier_IQR(flashcrowd_dash)
flashcrowd_dash = change_NaN_to_mean(flashcrowd_dash)
features, labels = merge_dataset(flashcrowd_dash, flashcrowd_log)


# mix_dash, flashcrowd_log = get_datasets()
# flashcrowd_dash, flashcrowd_log = pd.DataFrame(flashcrowd_dash), pd.DataFrame(flashcrowd_log)
# flashcrowd_dash = remove_useless_attribute(flashcrowd_dash)
# flashcrowd_dash = remove_outlier_IQR(flashcrowd_dash)
# flashcrowd_dash = change_NaN_to_mean(flashcrowd_dash)
# features, labels = merge_dataset(flashcrowd_dash, flashcrowd_log)

In [5]:
best_params = rf_model_params(
    n_estimators=90,
    min_samples_split=5,
    min_samples_leaf=1,
    max_features='sqrt',
    max_depth=44,
    bootstrap=True,
    n_splits=5,
    random_state=42,
    shuffle=False,
    test_size=0.2,
    verbose = 2,
    n_jobs=8
)

mae, nmae, rf_model = default_random_forest_model(features, labels, best_params)

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.


building tree 1 of 90
building tree 2 of 90
building tree 3 of 90
building tree 4 of 90
building tree 5 of 90
building tree 6 of 90
building tree 7 of 90
building tree 8 of 90
building tree 9 of 90
building tree 10 of 90
building tree 11 of 90
building tree 12 of 90
building tree 13 of 90
building tree 14 of 90
building tree 15 of 90
building tree 16 of 90
building tree 17 of 90
building tree 18 of 90
building tree 19 of 90
building tree 20 of 90
building tree 21 of 90
building tree 22 of 90
building tree 23 of 90
building tree 24 of 90
building tree 25 of 90
building tree 26 of 90
building tree 27 of 90
building tree 28 of 90
building tree 29 of 90
building tree 30 of 90
building tree 31 of 90
building tree 32 of 90
building tree 33 of 90
building tree 34 of 90


[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:   13.3s


building tree 35 of 90
building tree 36 of 90
building tree 37 of 90
building tree 38 of 90
building tree 39 of 90
building tree 40 of 90
building tree 41 of 90
building tree 42 of 90
building tree 43 of 90
building tree 44 of 90
building tree 45 of 90
building tree 46 of 90
building tree 47 of 90
building tree 48 of 90
building tree 49 of 90
building tree 50 of 90
building tree 51 of 90
building tree 52 of 90
building tree 53 of 90
building tree 54 of 90
building tree 55 of 90
building tree 56 of 90
building tree 57 of 90
building tree 58 of 90
building tree 59 of 90
building tree 60 of 90
building tree 61 of 90
building tree 62 of 90
building tree 63 of 90
building tree 64 of 90
building tree 65 of 90
building tree 66 of 90
building tree 67 of 90
building tree 68 of 90
building tree 69 of 90
building tree 70 of 90
building tree 71 of 90
building tree 72 of 90
building tree 73 of 90
building tree 74 of 90
building tree 75 of 90
building tree 76 of 90
building tree 77 of 90
building tr

[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:   39.9s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.


building tree 1 of 90building tree 2 of 90

building tree 3 of 90
building tree 4 of 90
building tree 5 of 90
building tree 6 of 90
building tree 7 of 90
building tree 8 of 90
building tree 9 of 90
building tree 10 of 90
building tree 11 of 90
building tree 12 of 90
building tree 13 of 90
building tree 14 of 90
building tree 15 of 90
building tree 16 of 90
building tree 17 of 90
building tree 18 of 90
building tree 19 of 90
building tree 20 of 90
building tree 21 of 90
building tree 22 of 90
building tree 23 of 90
building tree 24 of 90
building tree 25 of 90
building tree 26 of 90
building tree 27 of 90
building tree 28 of 90
building tree 29 of 90
building tree 30 of 90
building tree 31 of 90
building tree 32 of 90
building tree 33 of 90


[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:   12.2s


building tree 34 of 90
building tree 35 of 90
building tree 36 of 90
building tree 37 of 90
building tree 38 of 90
building tree 39 of 90
building tree 40 of 90
building tree 41 of 90
building tree 42 of 90
building tree 43 of 90
building tree 44 of 90
building tree 45 of 90
building tree 46 of 90
building tree 47 of 90
building tree 48 of 90
building tree 49 of 90
building tree 50 of 90
building tree 51 of 90
building tree 52 of 90
building tree 53 of 90
building tree 54 of 90
building tree 55 of 90
building tree 56 of 90
building tree 57 of 90
building tree 58 of 90
building tree 59 of 90
building tree 60 of 90
building tree 61 of 90
building tree 62 of 90
building tree 63 of 90
building tree 64 of 90
building tree 65 of 90
building tree 66 of 90
building tree 67 of 90
building tree 68 of 90
building tree 69 of 90
building tree 70 of 90
building tree 71 of 90
building tree 72 of 90
building tree 73 of 90
building tree 74 of 90
building tree 75 of 90
building tree 76 of 90
building tr

[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:   36.6s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.


building tree 1 of 90
building tree 2 of 90
building tree 3 of 90
building tree 4 of 90
building tree 5 of 90
building tree 6 of 90
building tree 7 of 90
building tree 8 of 90
building tree 9 of 90
building tree 10 of 90
building tree 11 of 90
building tree 12 of 90
building tree 13 of 90
building tree 14 of 90
building tree 15 of 90
building tree 16 of 90
building tree 17 of 90
building tree 18 of 90
building tree 19 of 90
building tree 20 of 90
building tree 21 of 90
building tree 22 of 90
building tree 23 of 90
building tree 24 of 90
building tree 25 of 90
building tree 26 of 90
building tree 27 of 90
building tree 28 of 90
building tree 29 of 90
building tree 30 of 90
building tree 31 of 90
building tree 32 of 90
building tree 33 of 90


[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:   12.4s


building tree 34 of 90
building tree 35 of 90
building tree 36 of 90
building tree 37 of 90
building tree 38 of 90
building tree 39 of 90
building tree 40 of 90
building tree 41 of 90
building tree 42 of 90
building tree 43 of 90
building tree 44 of 90
building tree 45 of 90
building tree 46 of 90
building tree 47 of 90
building tree 48 of 90
building tree 49 of 90
building tree 50 of 90
building tree 51 of 90
building tree 52 of 90
building tree 53 of 90
building tree 54 of 90
building tree 55 of 90
building tree 56 of 90
building tree 57 of 90
building tree 58 of 90
building tree 59 of 90
building tree 60 of 90
building tree 61 of 90
building tree 62 of 90
building tree 63 of 90
building tree 64 of 90
building tree 65 of 90
building tree 66 of 90
building tree 67 of 90
building tree 68 of 90
building tree 69 of 90
building tree 70 of 90
building tree 71 of 90
building tree 72 of 90
building tree 73 of 90
building tree 74 of 90
building tree 75 of 90
building tree 76 of 90
building tr

[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:   37.2s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    0.1s
[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    0.2s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.


building tree 1 of 90
building tree 2 of 90
building tree 3 of 90
building tree 4 of 90
building tree 5 of 90
building tree 6 of 90
building tree 7 of 90
building tree 8 of 90
building tree 9 of 90
building tree 10 of 90
building tree 11 of 90
building tree 12 of 90
building tree 13 of 90
building tree 14 of 90
building tree 15 of 90
building tree 16 of 90
building tree 17 of 90
building tree 18 of 90
building tree 19 of 90
building tree 20 of 90
building tree 21 of 90
building tree 22 of 90
building tree 23 of 90
building tree 24 of 90
building tree 25 of 90
building tree 26 of 90
building tree 27 of 90
building tree 28 of 90
building tree 29 of 90
building tree 30 of 90
building tree 31 of 90
building tree 32 of 90
building tree 33 of 90
building tree 34 of 90


[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:   12.6s


building tree 35 of 90
building tree 36 of 90
building tree 37 of 90
building tree 38 of 90
building tree 39 of 90
building tree 40 of 90
building tree 41 of 90
building tree 42 of 90
building tree 43 of 90
building tree 44 of 90
building tree 45 of 90
building tree 46 of 90
building tree 47 of 90
building tree 48 of 90
building tree 49 of 90
building tree 50 of 90
building tree 51 of 90
building tree 52 of 90
building tree 53 of 90
building tree 54 of 90
building tree 55 of 90
building tree 56 of 90
building tree 57 of 90
building tree 58 of 90
building tree 59 of 90
building tree 60 of 90
building tree 61 of 90
building tree 62 of 90
building tree 63 of 90
building tree 64 of 90
building tree 65 of 90
building tree 66 of 90
building tree 67 of 90
building tree 68 of 90
building tree 69 of 90
building tree 70 of 90
building tree 71 of 90
building tree 72 of 90
building tree 73 of 90
building tree 74 of 90
building tree 75 of 90
building tree 76 of 90
building tree 77 of 90
building tr

[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:   39.8s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.


building tree 1 of 90
building tree 2 of 90
building tree 3 of 90
building tree 4 of 90
building tree 5 of 90
building tree 6 of 90
building tree 7 of 90
building tree 8 of 90
building tree 9 of 90
building tree 10 of 90
building tree 11 of 90
building tree 12 of 90
building tree 13 of 90
building tree 14 of 90
building tree 15 of 90
building tree 16 of 90
building tree 17 of 90
building tree 18 of 90
building tree 19 of 90
building tree 20 of 90
building tree 21 of 90
building tree 22 of 90
building tree 23 of 90
building tree 24 of 90
building tree 25 of 90
building tree 26 of 90
building tree 27 of 90
building tree 28 of 90
building tree 29 of 90
building tree 30 of 90
building tree 31 of 90
building tree 32 of 90
building tree 33 of 90
building tree 34 of 90


[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:   11.6s


building tree 35 of 90
building tree 36 of 90
building tree 37 of 90
building tree 38 of 90
building tree 39 of 90
building tree 40 of 90
building tree 41 of 90
building tree 42 of 90
building tree 43 of 90
building tree 44 of 90
building tree 45 of 90
building tree 46 of 90
building tree 47 of 90
building tree 48 of 90
building tree 49 of 90
building tree 50 of 90
building tree 51 of 90
building tree 52 of 90
building tree 53 of 90
building tree 54 of 90
building tree 55 of 90
building tree 56 of 90
building tree 57 of 90
building tree 58 of 90
building tree 59 of 90
building tree 60 of 90
building tree 61 of 90
building tree 62 of 90
building tree 63 of 90
building tree 64 of 90
building tree 65 of 90
building tree 66 of 90
building tree 67 of 90
building tree 68 of 90
building tree 69 of 90
building tree 70 of 90
building tree 71 of 90
building tree 72 of 90
building tree 73 of 90
building tree 74 of 90
building tree 75 of 90
building tree 76 of 90
building tree 77 of 90
building tr

[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:   37.0s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    0.1s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.


building tree 1 of 90
building tree 2 of 90
building tree 3 of 90
building tree 4 of 90
building tree 5 of 90
building tree 6 of 90
building tree 7 of 90
building tree 8 of 90
building tree 9 of 90
building tree 10 of 90
building tree 11 of 90
building tree 12 of 90
building tree 13 of 90
building tree 14 of 90
building tree 15 of 90
building tree 16 of 90
building tree 17 of 90
building tree 18 of 90
building tree 19 of 90
building tree 20 of 90
building tree 21 of 90
building tree 22 of 90
building tree 23 of 90
building tree 24 of 90
building tree 25 of 90
building tree 26 of 90
building tree 27 of 90
building tree 28 of 90
building tree 29 of 90
building tree 30 of 90
building tree 31 of 90
building tree 32 of 90
building tree 33 of 90
building tree 34 of 90


[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:   15.8s


building tree 35 of 90
building tree 36 of 90
building tree 37 of 90
building tree 38 of 90
building tree 39 of 90
building tree 40 of 90
building tree 41 of 90
building tree 42 of 90
building tree 43 of 90
building tree 44 of 90
building tree 45 of 90
building tree 46 of 90
building tree 47 of 90
building tree 48 of 90
building tree 49 of 90
building tree 50 of 90
building tree 51 of 90
building tree 52 of 90
building tree 53 of 90
building tree 54 of 90
building tree 55 of 90
building tree 56 of 90
building tree 57 of 90
building tree 58 of 90
building tree 59 of 90
building tree 60 of 90
building tree 61 of 90
building tree 62 of 90
building tree 63 of 90
building tree 64 of 90
building tree 65 of 90
building tree 66 of 90
building tree 67 of 90
building tree 68 of 90
building tree 69 of 90
building tree 70 of 90
building tree 71 of 90
building tree 72 of 90
building tree 73 of 90
building tree 74 of 90
building tree 75 of 90
building tree 76 of 90
building tree 77 of 90
building tr

[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:   48.7s finished
[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    0.0s
[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    0.1s finished


In [6]:
print(mae, nmae, rf_model)

18.019752419323876 0.82627170852777 RandomForestRegressor(max_depth=44, max_features='sqrt', min_samples_split=5,
                      n_estimators=90, n_jobs=8, random_state=42, verbose=2)


In [7]:
flashcrowd_dash, flashcrowd_log = get_datasets("flashcrowd")
flashcrowd_dash, flashcrowd_log = pd.DataFrame(flashcrowd_dash), pd.DataFrame(flashcrowd_log)
flashcrowd_dash = remove_useless_attribute(flashcrowd_dash)
flashcrowd_dash = remove_outlier_IQR(flashcrowd_dash)
flashcrowd_dash = change_NaN_to_mean(flashcrowd_dash)
features, labels = merge_dataset(flashcrowd_dash, flashcrowd_log)

In [8]:
predictions = rf_model.predict(features)
mae_rf = mean_absolute_error(labels, predictions)
nmae_rf = nmae_fun(labels, predictions)
print(mae_rf, nmae_rf)

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    0.4s


4.391815521277021 0.2034137416036891


[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    1.0s finished


In [10]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error

def create_modern_prediction_plot(labels, predictions, model_name="Random Forest"):
    """
    Cria visualizações modernas e atrativas para análise de predições
    
    Parameters:
    labels: array-like - valores reais
    predictions: array-like - valores preditos
    model_name: str - nome do modelo
    """
    
    # Calcular métricas
    mae = mean_absolute_error(labels, predictions)
    rmse = np.sqrt(mean_squared_error(labels, predictions))
    r2 = r2_score(labels, predictions)
    nmae = mae / (np.max(labels) - np.min(labels))
    
    # Criar DataFrame para facilitar visualização
    df = pd.DataFrame({
        'Real': labels,
        'Predito': predictions,
        'Erro_Absoluto': np.abs(labels - predictions),
        'Residuos': labels - predictions,
        'Indice': range(len(labels))
    })
    
    # Criar subplot com 2x2 layout
    fig = make_subplots(
        rows=2, cols=2,
        subplot_titles=[
            '🎯 Predito vs Real',
            '📊 Análise de Resíduos', 
            '📈 Distribuição dos Resíduos',
            '⚡ Evolução das Predições'
        ],
        specs=[[{"secondary_y": False}, {"secondary_y": False}],
               [{"secondary_y": False}, {"secondary_y": False}]]
    )
    
    # 1. Scatter Plot: Predito vs Real
    min_val = min(df['Real'].min(), df['Predito'].min())
    max_val = max(df['Real'].max(), df['Predito'].max())
    
    # Pontos principais com gradiente de cores baseado no erro
    fig.add_trace(
        go.Scatter(
            x=df['Real'],
            y=df['Predito'],
            mode='markers',
            name='Predições',
            marker=dict(
                size=8,
                color=df['Erro_Absoluto'],
                colorscale='RdYlGn_r',
                showscale=True,
                colorbar=dict(
                    title="Erro Absoluto",
                    x=0.47,
                    len=0.4
                ),
                opacity=0.8,
                line=dict(width=1, color='white')
            ),
            hovertemplate='<b>Real:</b> %{x:.2f}<br><b>Predito:</b> %{y:.2f}<br><b>Erro:</b> %{marker.color:.2f}<extra></extra>'
        ),
        row=1, col=1
    )
    
    # Linha perfeita
    fig.add_trace(
        go.Scatter(
            x=[min_val, max_val],
            y=[min_val, max_val],
            mode='lines',
            name='Linha Perfeita',
            line=dict(color='#2E86AB', width=3, dash='dash'),
            hoverinfo='skip'
        ),
        row=1, col=1
    )
    
    # 2. Análise de Resíduos
    fig.add_trace(
        go.Scatter(
            x=df['Predito'],
            y=df['Residuos'],
            mode='markers',
            name='Resíduos',
            marker=dict(
                size=8,
                color=np.abs(df['Residuos']),
                colorscale='Viridis',
                opacity=0.7,
                line=dict(width=1, color='white')
            ),
            hovertemplate='<b>Predito:</b> %{x:.2f}<br><b>Resíduo:</b> %{y:.2f}<extra></extra>'
        ),
        row=1, col=2
    )
    
    # Linha zero para resíduos
    fig.add_trace(
        go.Scatter(
            x=[df['Predito'].min(), df['Predito'].max()],
            y=[0, 0],
            mode='lines',
            name='Linha Zero',
            line=dict(color='red', width=2, dash='dash'),
            hoverinfo='skip'
        ),
        row=1, col=2
    )
    
    # 3. Histograma dos Resíduos
    fig.add_trace(
        go.Histogram(
            x=df['Residuos'],
            name='Distribuição',
            nbinsx=30,
            marker=dict(
                color='#A23B72',
                opacity=0.7,
                line=dict(color='white', width=1)
            ),
            hovertemplate='<b>Resíduo:</b> %{x:.2f}<br><b>Frequência:</b> %{y}<extra></extra>'
        ),
        row=2, col=1
    )
    
    # 4. Evolução das Predições
    indices = np.arange(len(df))
    fig.add_trace(
        go.Scatter(
            x=indices,
            y=df['Real'],
            mode='lines+markers',
            name='Valores Reais',
            line=dict(color='#2E86AB', width=2),
            marker=dict(size=4),
            hovertemplate='<b>Índice:</b> %{x}<br><b>Real:</b> %{y:.2f}<extra></extra>'
        ),
        row=2, col=2
    )
    
    fig.add_trace(
        go.Scatter(
            x=indices,
            y=df['Predito'],
            mode='lines+markers',
            name='Predições',
            line=dict(color='#A23B72', width=2),
            marker=dict(size=4),
            hovertemplate='<b>Índice:</b> %{x}<br><b>Predito:</b> %{y:.2f}<extra></extra>'
        ),
        row=2, col=2
    )
    
    # Layout moderno
    fig.update_layout(
        title={
            'text': f'<b>📊 Análise de Performance - {model_name}</b><br><span style="font-size:14px;">MAE: {mae:.3f} | RMSE: {rmse:.3f} | R²: {r2:.3f} | NMAE: {nmae:.3f}</span>',
            'x': 0.5,
            'xanchor': 'center',
            'font': {'size': 20, 'color': '#2E86AB'}
        },
        template='plotly_white',
        height=800,
        showlegend=True,
        legend=dict(
            orientation="v",
            yanchor="top",
            y=1,
            xanchor="left",
            x=1.02
        ),
        font=dict(family="Arial, sans-serif", size=12),
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor='rgba(0,0,0,0)'
    )
    
    # Personalizar eixos
    fig.update_xaxes(title_text="<b>Valores Reais</b>", row=1, col=1, showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
    fig.update_yaxes(title_text="<b>Valores Preditos</b>", row=1, col=1, showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
    
    fig.update_xaxes(title_text="<b>Valores Preditos</b>", row=1, col=2, showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
    fig.update_yaxes(title_text="<b>Resíduos</b>", row=1, col=2, showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
    
    fig.update_xaxes(title_text="<b>Resíduos</b>", row=2, col=1, showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
    fig.update_yaxes(title_text="<b>Frequência</b>", row=2, col=1, showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
    
    fig.update_xaxes(title_text="<b>Índice da Amostra</b>", row=2, col=2, showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
    fig.update_yaxes(title_text="<b>Valores</b>", row=2, col=2, showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
    
    fig.show()
    
    return fig

def create_simple_scatter_plot(labels, predictions):
    """
    Versão simplificada - apenas scatter plot estilizado
    """
    df = pd.DataFrame({
        'Real': labels,
        'Predito': predictions,
        'Erro': np.abs(labels - predictions)
    })
    
    # Métricas
    mae = mean_absolute_error(labels, predictions)
    r2 = r2_score(labels, predictions)
    nmae = mae / (np.max(labels) - np.min(labels))
    
    # Criar figura
    fig = px.scatter(
        df, 
        x='Real', 
        y='Predito',
        color='Erro',
        color_continuous_scale='RdYlGn_r',
        title=f'🎯 Predições vs Valores Reais<br><span style="font-size:14px;">MAE: {mae:.3f} | R²: {r2:.3f} | NMAE: {nmae:.3f}</span>',
        labels={'Real': 'Valores Reais', 'Predito': 'Valores Preditos', 'Erro': 'Erro Absoluto'},
        hover_data={'Erro': ':.3f'}
    )
    
    # Linha perfeita
    min_val = min(df['Real'].min(), df['Predito'].min())
    max_val = max(df['Real'].max(), df['Predito'].max())
    
    fig.add_trace(
        go.Scatter(
            x=[min_val, max_val],
            y=[min_val, max_val],
            mode='lines',
            name='Linha Perfeita',
            line=dict(color='blue', width=3, dash='dash')
        )
    )
    
    # Estilização moderna
    fig.update_layout(
        template='plotly_white',
        height=600,
        title_x=0.5,
        font=dict(family="Arial, sans-serif", size=12),
        plot_bgcolor='rgba(0,0,0,0)',
        paper_bgcolor='rgba(0,0,0,0)'
    )
    
    fig.update_traces(
        marker=dict(size=8, opacity=0.8, line=dict(width=1, color='white'))
    )
    
    fig.show()
    return fig

# Exemplo de uso com seus dados:
# Substitua por seus dados reais
predictions = rf_model.predict(features)
mae_rf = mean_absolute_error(labels, predictions)
nmae_rf = nmae_fun(labels, predictions)

print(f"MAE: {mae_rf:.4f}")
print(f"NMAE: {nmae_rf:.4f}")

# # Criar visualização completa
# fig_completa = create_modern_prediction_plot(labels, predictions, "Random Forest")

# Ou versão simplificada
fig_simples = create_simple_scatter_plot(labels[:5000], predictions[:5000])

# Salvar gráficos (opcional)
# fig_completa.write_html("analise_completa.html")
# fig_simples.write_html("scatter_plot.html")
# fig_completa.write_image("analise_completa.png", width=1200, height=800, scale=2)

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    0.3s


MAE: 4.3918
NMAE: 0.2034


[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    0.9s finished


[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    0.3s
18.828863850194555 0.8720880071940826
[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    0.7s finished

[Parallel(n_jobs=8)]: Using backend ThreadingBackend with 8 concurrent workers.
[Parallel(n_jobs=8)]: Done  25 tasks      | elapsed:    0.4s
4.391815521277021 0.2034137416036891
[Parallel(n_jobs=8)]: Done  90 out of  90 | elapsed:    1.2s finished


