In [140]:
from pygam import LinearGAM, s, f

from datasets import diabetes_data

from plotly.subplots import make_subplots
import plotly.graph_objects as go
from plotly.offline import init_notebook_mode
init_notebook_mode(connected=True)

In [141]:
original_X, original_y, train_X, train_y, test_X, test_y = diabetes_data()

## Spline terms

In [142]:
model = LinearGAM(s(0) + s(1) + s(2) + s(3) + s(4) + s(5) + s(6) + s(7) + s(8) + s(9))
result = model.fit(original_X, original_y)

In [143]:
result.summary()

LinearGAM                                                                                                 
Distribution:                        NormalDist Effective DoF:                                     92.0369
Link Function:                     IdentityLink Log Likelihood:                                 -3893.2146
Number of Samples:                          442 AIC:                                             7972.5028
                                                AICc:                                            8022.7891
                                                GCV:                                             4208.5528
                                                Scale:                                           2668.0171
                                                Pseudo R-Squared:                                   0.6438
Feature Function                  Lambda               Rank         EDoF         P > x        Sig. Code   
s(0)                              [0.

In [144]:
fig = make_subplots(rows=2, cols=5, subplot_titles=original_X.columns)

for i in range(10):
    if i > 4:
        col = i - 5 + 1
    else:
        col = i + 1

    XX = model.generate_X_grid(term=i)

    x = XX[:, i]
    y = model.partial_dependence(term=i, X=XX, width=.95)[0]
    confidence = model.partial_dependence(term=i, X=XX, width=.95)[1]

    lower = confidence[:,0]
    upper = confidence[:,1]

    fig.add_trace(
        go.Scatter(x=x, y=y),
        row=i//5 + 1, col=col)

    fig.add_trace(
        go.Scatter(x=x, y=upper, fillcolor='rgba(0,100,80,0.2)', line=dict(width=0),),
        row=i//5 + 1, col=col)
    
    fig.add_trace(
        go.Scatter(x=x, y=lower, fillcolor='rgba(100,50,80,0.1)', fill='tonexty',line=dict(width=0),),
        row=i//5 + 1, col=col)

fig.update_layout(height=600, width=1000)
fig.update_layout(showlegend=False)
fig.show()

## Feature subset
(from feature selection)

In [146]:
selected_features = ['sex', 'bmi', 'bp', 's1', 's3', 's5']

In [147]:
model = LinearGAM(s(0) + s(1) + s(2) + s(3) + s(4) + s(5))
result = model.gridsearch(original_X[selected_features].values, original_y.values)

100% (11 of 11) |########################| Elapsed Time: 0:00:00 Time:  0:00:00


In [148]:
model.summary()

LinearGAM                                                                                                 
Distribution:                        NormalDist Effective DoF:                                      11.852
Link Function:                     IdentityLink Log Likelihood:                                 -3929.1686
Number of Samples:                          442 AIC:                                             7884.0412
                                                AICc:                                            7884.8728
                                                GCV:                                             3040.4484
                                                Scale:                                           2894.0579
                                                Pseudo R-Squared:                                    0.525
Feature Function                  Lambda               Rank         EDoF         P > x        Sig. Code   
s(0)                              [10

In [157]:
fig = make_subplots(rows=1, cols=6, subplot_titles=original_X.columns)

for i in range(6):
    XX = model.generate_X_grid(term=i)

    x = XX[:, i]
    y = model.partial_dependence(term=i, X=XX, width=.95)[0]
    confidence = model.partial_dependence(term=i, X=XX, width=.95)[1]

    lower = confidence[:,0]
    upper = confidence[:,1]

    fig.add_trace(
        go.Scatter(x=x, y=y),
        row=1, col=i+1)

    fig.add_trace(
        go.Scatter(x=x, y=upper, fillcolor='rgba(0,100,80,0.2)', line=dict(width=0),),
        row=1, col=i+1)
    
    fig.add_trace(
        go.Scatter(x=x, y=lower, fillcolor='rgba(100,50,80,0.1)', fill='tonexty',line=dict(width=0),),
        row=1, col=i+1)

fig.update_layout(height=400, width=1000)
fig.update_layout(showlegend=False)
fig.show()