This notebook is based on scikit-learn notebook: 
[Effect of transforming the targets in regression model](https://scikit-learn.org/stable/auto_examples/compose/plot_transformed_target.html)


In [1]:
from typing import Dict, Union

import numpy as np
import pandas as pd

import plotly
import plotly.express as px

from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression
from sklearn.compose import TransformedTargetRegressor
from sklearn.model_selection import train_test_split
from sklearn import metrics

from sklearn.tree import DecisionTreeRegressor

In [2]:
class PredictedAccuracy:
    """Metrics and plots to evaluate the accuracy of regression models"""

    def __init__(self, y_series: pd.Series, yhat_series: Union[np.ndarray, pd.Series]):
        """
        Args:
          y_series (pd.Series): values of true y
          yhat_series (Union[np.ndarray, pd.Series]): values of predicted y
        """
        if isinstance(yhat_series, np.ndarray):
            yhat_series = pd.Series(yhat_series, name=f"predicted {y_series.name}")
            yhat_series.index = y_series.index

        self.y_series = y_series
        self.yhat_series = yhat_series

    @staticmethod
    def regression_accuracy_metrics(y: pd.Series, yhat: pd.Series) -> Dict[str, float]:
        """Metrics to evaluate the accuracy of regression models.

        Ref: https://www.datatechnotes.com/2019/10/accuracy-check-in-python-mae-mse-rmse-r.html

        Args:
          y (np.ndarray): values of true y
          yhat (np.ndarray): values of predicted y

        Return:
          Dict[str, float]: dictionary metric name / metric value 
        """
        metrics_dict = {
            "MAE": metrics.mean_absolute_error(y, yhat),
            "r2": metrics.r2_score(y, yhat),
        }
        return metrics_dict

    def metrics(self):
        """Compute and store the metrics"""
        return PredictedAccuracy.regression_accuracy_metrics(
            self.y_series, self.yhat_series
        )

    def pretty_metrics(self, decimals: int = 2, separation_string: str = ", ") -> str:
        """Pretty print the metrics.
        
        Args:
          decimals (int): decimal digits to print. Default: 2
          serparation_string (str): text between two consecutive metrics. Default: ", " 

        Return:
          str: text with metric values
        
        """
        return separation_string.join(
            [
                f"{k}: {round(v, decimals):.{decimals}f}"
                for k, v in self.metrics().items()
            ]
        )

    def plot_scatter(self, main_title: str="Actual vs predicted") -> plotly.graph_objs.Figure:
        """Scatterplot to compare actual vs predicted y
        
        Args:
          main_title (str): Plot title. Default: "Actual vs predicted y"

        Return:
          plotly.graph_objs.Figure: actual vs predicted plot
        """
        y_max = self.y_series.max()
        y_min = self.y_series.min()
        x_max = self.yhat_series.max()
        x_min = self.yhat_series.min()
        x_max_min = x_max - x_min
        y_max_min = y_max - y_min
        x_padding = 0.1 * x_max_min
        y_padding = 0.1 * y_max_min
        axis_min = min(x_min - x_padding, y_min - y_padding)
        axis_max = max(x_max + x_padding, y_max + y_padding)

        scatter = px.scatter(
            pd.DataFrame([self.yhat_series, self.y_series]).T,
            x=self.yhat_series.name,
            y=self.y_series.name,
            title=f"{main_title} {self.y_series.name}<br>{self.pretty_metrics()}",
        )
        scatter.add_shape(
            # Line reference to the axes
            type="line",
            xref="x",
            yref="y",
            x0=axis_min,
            y0=axis_min,
            x1=axis_max,
            y1=axis_max,
            line=dict(color="LightSeaGreen", width=1),
        )
        return scatter


## Parameters

- Tip: It's useful to define all your parameters in one cell. (See [papermill](https://papermill.readthedocs.io))
- Homework: Use papermill to run this notebook and load the variables of the next cell from a yaml file.

In [3]:
RANDOM_STATE = 42

# Generate Data

In [4]:
raw_X, raw_y = make_regression(n_samples=10000, noise=100, random_state=RANDOM_STATE)
raw_y = np.expm1((raw_y + abs(raw_y.min())) / 200)
raw_y_trans = np.log1p(raw_y)

In [5]:
X = pd.DataFrame(raw_X, columns=['f'+str(i) for i in range(raw_X.shape[1])])  # new-version: f-strings
y = pd.Series(raw_y, name="y")
y_trans = pd.Series(raw_y_trans, name="y_trans")

# Basic EDA

In [6]:
X.head()

Unnamed: 0,f0,f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13,f14,f15,f16,f17,f18,f19,f20,f21,f22,f23,f24,f25,f26,f27,f28,f29,f30,f31,f32,f33,f34,f35,f36,f37,f38,f39,...,f60,f61,f62,f63,f64,f65,f66,f67,f68,f69,f70,f71,f72,f73,f74,f75,f76,f77,f78,f79,f80,f81,f82,f83,f84,f85,f86,f87,f88,f89,f90,f91,f92,f93,f94,f95,f96,f97,f98,f99
0,0.557844,-1.724489,-1.167184,2.224393,-1.246964,-0.23574,-1.443182,-1.384978,0.457306,-2.102986,1.62792,-0.202725,0.880303,-1.360783,1.232872,-0.805989,-0.859401,0.041562,-1.489334,-1.368615,0.098582,-1.589133,0.106932,0.797858,1.082711,-0.505796,0.179485,-0.095469,0.467696,2.382619,0.523375,1.047347,-0.933485,0.593943,-1.558754,-0.589063,-0.447264,-0.039089,-0.917758,-1.922051,...,0.235055,-0.569273,2.381935,0.160941,-0.011087,0.597486,-0.059568,-0.209217,-0.456469,-0.243482,0.403876,0.241691,1.573838,0.161304,0.177474,-0.80232,-2.072791,-0.408156,-0.691763,-1.289532,2.093388,-1.975659,1.133847,-0.918091,0.037253,-0.392052,0.756057,-1.67955,1.915273,0.128236,-3.014441,0.654831,0.236957,-1.658472,-0.269015,-1.617438,0.54608,-0.092814,0.884266,-0.544326
1,-1.619158,-0.002266,1.744685,-1.441394,-0.523619,-0.355589,-0.691155,-0.531119,1.407624,-1.560975,0.252974,-0.508065,0.082912,-0.495421,-1.648922,-0.224903,-1.365242,-0.299377,2.080383,0.472105,-1.962253,-0.193108,0.468858,-0.542476,1.161629,-0.374601,-0.343934,0.653985,1.448897,1.199646,-0.516866,0.469688,0.326106,1.090064,-0.307865,-1.264109,0.061812,1.042117,-1.481469,1.022299,...,-0.141898,-0.882701,1.845312,1.630183,-0.198079,0.310761,0.477301,0.753367,1.138531,0.793348,0.563286,1.216392,-0.443871,-0.596202,0.274796,-0.740001,-1.099386,1.088729,1.389612,-1.393309,-0.381371,-1.50957,0.674899,-0.47148,-1.222221,-0.835594,-0.675204,-0.945598,-0.091693,0.671523,-0.32055,1.469404,-0.656189,0.299135,0.06588,0.061685,-1.019211,-0.37412,0.003552,-0.798982
2,-0.238216,-0.446124,-3.136928,0.514585,-1.473993,2.439183,-0.09669,0.900328,-0.193925,-0.409793,-0.248506,-0.731891,0.512077,0.021225,-0.875998,-0.884863,-0.581347,-1.076467,1.397625,-0.907689,-1.013732,0.931956,-1.62288,-1.113668,0.147266,-1.041246,0.813178,0.319805,1.377995,-0.591596,0.025667,-0.872351,0.41663,-0.870087,0.238463,1.689889,-0.8621,-1.406294,1.536525,0.991428,...,-0.215998,1.262096,-0.547,0.316772,0.440006,-1.332768,-0.986406,0.037204,-0.47604,-0.610111,-0.53396,0.072295,-0.856795,-0.12244,-0.565727,0.028828,-1.172051,0.227667,-1.038803,-0.525039,0.453016,2.063694,0.806634,-3.061308,-0.359408,0.042112,-0.506296,0.034752,-0.383642,-0.17607,-0.347426,0.044564,-0.398709,-0.396784,0.943146,-0.26125,-1.118885,0.886143,-0.963336,-0.469642
3,-2.459902,0.404984,-0.30269,-0.418191,0.218405,1.197645,0.42898,-0.11825,-0.346083,0.246743,-0.149391,0.777065,0.479185,-0.088405,2.116067,0.406487,1.836861,0.326296,-0.721814,1.472259,-1.20767,-1.128299,-1.182893,0.030918,-0.00852,-0.582973,0.733728,-1.075897,-0.293191,-0.265107,-0.763175,0.706391,2.033413,0.48047,-0.104271,-0.154948,-0.173566,0.962017,1.676903,1.401844,...,-0.012757,-0.989583,0.543935,-0.830348,0.741044,0.669718,-1.678539,-0.38139,1.902972,0.422388,-2.312834,-0.504127,1.410422,-0.576923,-1.7184,0.116155,-1.196769,-0.108577,0.456661,0.160147,-0.316268,-0.719715,0.813096,0.042347,-1.040486,0.211496,0.251726,-0.196126,0.815997,-1.374668,0.975917,0.548807,-1.116362,1.461193,0.003115,0.381408,0.449947,0.089221,0.434039,-0.509651
4,-0.742791,-0.903054,-0.122735,-0.757815,0.284558,-1.663422,0.126213,-0.818931,-1.307882,1.45212,-0.004145,0.397629,0.068839,0.826038,2.897792,-0.882433,1.014298,0.574055,-0.412128,0.103317,0.463745,-0.147511,-1.162399,0.291257,1.864461,1.140567,0.737794,-0.327951,0.358609,-0.175574,-0.48911,1.083852,0.883294,0.941584,0.191363,-0.011513,-0.531249,-0.89149,-0.968602,-0.444561,...,-1.894444,0.078526,2.425649,1.966844,-0.919193,0.365803,0.369484,-0.266691,0.090573,-0.547439,0.128097,-1.000311,-1.028841,1.905664,-0.306231,-1.335083,-0.92031,-0.88128,-0.711862,0.456489,0.943848,0.946261,0.292784,1.393807,-1.210293,-0.942415,0.496517,0.845756,-0.456596,0.864807,0.23613,0.079278,0.900468,2.060264,0.225111,-2.029287,0.428644,-1.626833,1.539008,0.925264


In [7]:
isinstance(px.histogram(y, x="y"), plotly.graph_objs.Figure)

True

In [8]:
type(px.histogram(y, x="y"))

plotly.graph_objs._figure.Figure

In [9]:
px.histogram(y_trans, x="y_trans")

# Prediction

In [10]:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=RANDOM_STATE)

### Linear Regression (no previous tranformation)

In [11]:
# Use linear model
regr = LinearRegression()
regr.fit(X_train, y_train)
y_test_pred = regr.predict(X_test)

In [12]:
pa = PredictedAccuracy(y_series=y_test, yhat_series=y_test_pred)
pa.plot_scatter()

### Linear Regression (y transformed)

In [13]:
regr_trans = TransformedTargetRegressor(regressor=LinearRegression(),
                                        func=np.log1p,
                                        inverse_func=np.expm1)
regr_trans.fit(X_train, y_train)
y_test_trans_pred = regr_trans.predict(X_test)

In [14]:
pa_trans = PredictedAccuracy(y_series=y_test, yhat_series=y_test_trans_pred)
pa_trans.plot_scatter()

# Bonus/Homework

- Decision Trees
  - In the previous example, we explain/predict `y` by using a linear regression model.
  - Can we use a Decision Tree model?
    - Is the model affected by transormation?
      - train a model to explain/predict y by using a DecisionTree
      - Transform `y` first, and them explain/predict y transformed by using a DecisionTree
        - Did you notice a huge improvement in the metrics?
        - why?
- Reproducibility/replicability is crucial in Data Analysis. It's important therefore to report python/package versions. Moreover, it will be helpful for debugging purposes.
  - Install and use `watermark` to report the python/package versions at the end of this notebook

# Watermark

In [15]:
!pip install watermark

Collecting watermark
  Downloading https://files.pythonhosted.org/packages/60/fe/3ed83b6122e70dce6fe269dfd763103c333f168bf91037add73ea4fe81c2/watermark-2.0.2-py2.py3-none-any.whl
Installing collected packages: watermark
Successfully installed watermark-2.0.2


In [16]:
%load_ext watermark
%watermark

2021-02-01T10:03:11+00:00

CPython 3.6.9
IPython 5.5.0

compiler   : GCC 8.4.0
system     : Linux
release    : 4.19.112+
machine    : x86_64
processor  : x86_64
CPU cores  : 2
interpreter: 64bit


In [17]:
%watermark --iversions

plotly 4.4.1
numpy  1.19.5
pandas 1.1.5

