# Torch GLRT Confidence Intervals

Use GLRT and attribution regularization to calculate CIs for attributions of a Torch model.

Demonstrated here with an OLS model.

The GLRT results should not match bootstrap and analytic 

In [1]:
import numpy as np
import torch
from bootstrapCoefficients import bootstrapCis
from analyticLinearRegressionCIs import analyticLinearCis
from glrtTorch import glrtTorchCis
from DataGeneration import default_data
from sklearn.linear_model import LinearRegression
from torch_linear import TorchLinear

In [2]:
X, y = default_data()

## Sklearn model
Bootstrap and analytic results

In [4]:
LR = LinearRegression()
LR.fit(X, y)
print("Coefficients:", LR.coef_)

print("Bootstrapping")
lcb_LR, ucb_LR = bootstrapCis(LinearRegression, X, y, alpha=0.05, replicates=1000)
print("Lower bounds:", lcb_LR, "\nUpper bounds:", ucb_LR)

print("Analytic solution")
lcb_LR_a, ucb_LR_a = analyticLinearCis(LR, X, y, alpha=0.05)
print("Lower bounds:", lcb_LR_a, "\nUpper bounds:", ucb_LR_a)

Coefficients: [0.99906024 1.02556399]
Bootstrapping
Lower bounds: [0.85987423 0.86049457] 
Upper bounds: [1.13705651 1.17785711]
Analytic solution
Lower bounds: [0.85619486 0.86764879] 
Upper bounds: [1.14192561 1.18347918]


# Torch model
Bootstrap and analytic results

In [4]:
TL = TorchLinear(lr=0.003)
TL.fit(X,y)
# Takes ~4min
print("Bootstrapping")
lcb_TL, ucb_TL = bootstrapCis(lambda:TorchLinear(lr=0.003), X=X, y=y, alpha=0.05, replicates=1000)
print("Lower bounds:", lcb_TL, "\nUpper bounds:", ucb_TL)

print("Analytic solution")
lcb_TL_a, ucb_TL_a = analyticLinearCis(TL, X, y, alpha=0.05)
print("Lower bounds:", lcb_TL_a, "\nUpper bounds:", ucb_TL_a)

Bootstrapping
Lower bounds: [0.80575109 0.88677126] 
Upper bounds: [1.06291306 1.20051801]
Analytic solution
Lower bounds: [0.79458542 0.87648867] 
Upper bounds: [1.08140473 1.18232514]


GLRT Results

In [None]:
TL = TorchLinear(lr=0.003)
TL.fit(X,y)
print("GLRT")
lcb_GLRT, ucb_GLRT, lcb_Reults, ucb_Results = glrtTorchCis(lambda:TorchLinear(lr=0.003), X=X, y=y, alpha=0.05,search_kwargs={'lmbds':np.logspace(-10,10,101)})
print("Lower bounds:", lcb_GLRT, "\nUpper bounds:", ucb_GLRT)

GLRT
