In [1]:
library(caret, quiet = TRUE)
library(base64enc)
library(httr, quiet = TRUE)

library(mlbench)


Attaching package: ‘httr’

The following object is masked from ‘package:caret’:

    progress



# Build a Model

In [2]:
## https://www.machinelearningplus.com/machine-learning/caret-package/
## https://en.wikipedia.org/wiki/Multivariate_adaptive_regression_splines

data(BostonHousing)
BostonHousing$chas = as.numeric(BostonHousing$chas)

set.seed(1960)

dataset = BostonHousing[, -14] 

create_model  =  function() {
    grid = data.frame(degree=2, nprune = 20)
    ctrl = trainControl(method="none")
    model <- train(medv ~ ., data = BostonHousing, method = "earth", trControl = ctrl, tuneGrid = grid)    

    return(model)
}


In [3]:
model = create_model()
# cat(model$feature_names)
# print(model)

Loading required package: earth
Loading required package: plotmo
Loading required package: plotrix
Loading required package: TeachingDemos


In [4]:
pred_labels <- predict(model, BostonHousing[, -14] , type="raw")
df = data.frame(BostonHousing[,14])
names(df) = c("medv")
df$Estimator = pred_labels
df$Error = df$Estimator - df$medv
MAPE = mean(abs(df$Error / df$medv))
summary(df)
MAPE

      medv          Estimator.y           Error.y       
 Min.   : 5.00   Min.   : 1.73657   Min.   :-10.929540  
 1st Qu.:17.02   1st Qu.:17.09398   1st Qu.: -1.494189  
 Median :21.20   Median :21.29315   Median :  0.228028  
 Mean   :22.53   Mean   :22.53281   Mean   :  0.000000  
 3rd Qu.:25.00   3rd Qu.:25.21287   3rd Qu.:  1.647407  
 Max.   :50.00   Max.   :62.31410   Max.   : 12.314102  

# SQL Code Generation

In [5]:

test_ws_sql_gen = function(mod) {
    WS_URL = "https://sklearn2sql.herokuapp.com/model"
    WS_URL = "http://localhost:1888/model"
    model_serialized <- serialize(mod, NULL)
    b64_data = base64encode(model_serialized)
    data = list(Name = "xgboost_test_model", SerializedModel = b64_data , SQLDialect = "postgresql" , Mode="caret")
    r = POST(WS_URL, body = data, encode = "json")
    # print(r)
    content = content(r)
    # print(content)
    lSQL = content$model$SQLGenrationResult[[1]]$SQL # content["model"]["SQLGenrationResult"][0]["SQL"]
    return(lSQL);
}

In [6]:
lModelSQL = test_ws_sql_gen(model)
cat(lModelSQL)


WITH earth_input AS 
(SELECT "ADS"."KEY" AS "KEY", CAST("ADS"."Feature_0" AS FLOAT) AS "Feature_0", CAST("ADS"."Feature_1" AS FLOAT) AS "Feature_1", CAST("ADS"."Feature_2" AS FLOAT) AS "Feature_2", CAST("ADS"."Feature_3" AS FLOAT) AS "Feature_3", CAST("ADS"."Feature_4" AS FLOAT) AS "Feature_4", CAST("ADS"."Feature_5" AS FLOAT) AS "Feature_5", CAST("ADS"."Feature_6" AS FLOAT) AS "Feature_6", CAST("ADS"."Feature_7" AS FLOAT) AS "Feature_7", CAST("ADS"."Feature_8" AS FLOAT) AS "Feature_8", CAST("ADS"."Feature_9" AS FLOAT) AS "Feature_9", CAST("ADS"."Feature_10" AS FLOAT) AS "Feature_10", CAST("ADS"."Feature_11" AS FLOAT) AS "Feature_11", CAST("ADS"."Feature_12" AS FLOAT) AS "Feature_12" 
FROM "INPUT_DATA" AS "ADS"), 
earth_model_cte AS 
(SELECT earth_input."KEY" AS "KEY", -0.7769068137603704 * greatest(earth_input."Feature_12" - 6.12, 0) + 10.58947737074077 * greatest(earth_input."Feature_5" - 6.431, 0) + 29.2750852369069 * greatest(6.431 - earth_input."Feature_5", 0) + -7.262109920559491

# Execute the SQL Code

In [7]:
library(RODBC)
conn = odbcConnect("pgsql", uid="db", pwd="db", case="nochange")
odbcSetAutoCommit(conn , autoCommit = TRUE)

In [8]:
df_sql = dataset
names(df_sql) = sprintf("Feature_%d",0:(ncol(df_sql)-1))
df_sql$KEY = seq.int(nrow(dataset))

sqlDrop(conn , "INPUT_DATA" , errors = FALSE)
sqlSave(conn, df_sql, tablename = "INPUT_DATA", verbose = FALSE)

# df_sql

In [9]:
colnames(df_sql)
# odbcGetInfo(conn)
# sqlTables(conn)

In [10]:
df_sql_out = sqlQuery(conn, lModelSQL)
head(df_sql_out[order(df_sql_out$KEY),])

KEY,Estimator
1,25.83124
2,22.26612
3,34.64964
4,35.18817
5,31.93754
6,25.13966


In [11]:
# df_sql_out

# R RPART Output

In [12]:
estimator  =  predict(model, dataset, type = "raw")
df_r_out = data.frame(estimator)
names(df_r_out) = c("Estimator")

df_r_out$KEY = seq.int(nrow(dataset))
head(df_r_out)


Estimator,KEY
25.83124,1
22.26612,2
34.64964,3
35.18817,4
31.93754,5
25.13966,6


# Compare R and SQL output

In [13]:
df_merge = merge(x = df_r_out, y = df_sql_out, by = "KEY", all = TRUE, , suffixes = c("_1","_2"))
head(df_merge)

KEY,Estimator_1,Estimator_2
1,25.83124,25.83124
2,22.26612,22.26612
3,34.64964,34.64964
4,35.18817,35.18817
5,31.93754,31.93754
6,25.13966,25.13966


In [14]:
df_merge$Error = df_merge$Estimator_1 - df_merge$Estimator_2
df_merge$AbsError = abs(df_merge$Error)
head(df_merge)


KEY,Estimator_1,Estimator_2,Error,AbsError
1,25.83124,25.83124,-1.065814e-14,1.065814e-14
2,22.26612,22.26612,-1.065814e-14,1.065814e-14
3,34.64964,34.64964,-7.105427e-15,7.105427e-15
4,35.18817,35.18817,-7.105427e-15,7.105427e-15
5,31.93754,31.93754,-7.105427e-15,7.105427e-15
6,25.13966,25.13966,-3.552714e-15,3.552714e-15


In [15]:
df_merge_largest_errors = df_merge[df_merge$AbsError > 0.0001,]
df_merge_largest_errors

KEY,Estimator_1,Estimator_2,Error,AbsError


In [16]:
nrow(df_merge_largest_errors)
stopifnot(nrow(df_merge_largest_errors) == 0)


In [17]:
summary(df_sql_out)

      KEY          Estimator     
 Min.   :  1.0   Min.   : 1.737  
 1st Qu.:127.2   1st Qu.:17.094  
 Median :253.5   Median :21.293  
 Mean   :253.5   Mean   :22.533  
 3rd Qu.:379.8   3rd Qu.:25.213  
 Max.   :506.0   Max.   :62.314  

In [18]:
summary(df_r_out)

   Estimator           KEY       
 Min.   : 1.737   Min.   :  1.0  
 1st Qu.:17.094   1st Qu.:127.2  
 Median :21.293   Median :253.5  
 Mean   :22.533   Mean   :253.5  
 3rd Qu.:25.213   3rd Qu.:379.8  
 Max.   :62.314   Max.   :506.0  

In [19]:
summary(df_merge)

      KEY         Estimator_1      Estimator_2         Error           
 Min.   :  1.0   Min.   : 1.737   Min.   : 1.737   Min.   :-5.684e-14  
 1st Qu.:127.2   1st Qu.:17.094   1st Qu.:17.094   1st Qu.:-7.105e-15  
 Median :253.5   Median :21.293   Median :21.293   Median :-7.105e-15  
 Mean   :253.5   Mean   :22.533   Mean   :22.533   Mean   :-7.504e-17  
 3rd Qu.:379.8   3rd Qu.:25.213   3rd Qu.:25.213   3rd Qu.: 0.000e+00  
 Max.   :506.0   Max.   :62.314   Max.   :62.314   Max.   : 1.670e-13  
    AbsError        
 Min.   :0.000e+00  
 1st Qu.:5.329e-15  
 Median :7.105e-15  
 Mean   :1.193e-14  
 3rd Qu.:1.421e-14  
 Max.   :1.670e-13  

In [20]:
model$finalModel

Selected 20 of 27 terms, and 9 of 13 predictors
Termination condition: Reached nk 27
Importance: rm, lstat, ptratio, tax, dis, nox, crim, age, b, zn-unused, ...
Number of terms at each degree of interaction: 1 7 12
GCV 9.015343    RSS 3729.185    GRSq 0.8936296    RSq 0.9126988

In [21]:
model$modelInfo

parameter,class,label
nprune,numeric,#Terms
degree,numeric,Product Degree


In [22]:
earth1 = model$finalModel

In [23]:
earth1$coefficients

Unnamed: 0,y
(Intercept),25.57688
h(lstat-6.12),-0.7769068
h(rm-6.431),10.58948
h(6.431-rm),29.27509
h(rm-6.431)*h(ptratio-18.6),-7.26211
h(rm-6.431)*h(18.6-ptratio),0.4992664
h(tax-305)*h(6.12-lstat),0.01732757
h(305-tax)*h(6.12-lstat),0.02197435
h(0.713-nox)*h(lstat-6.12),2.387959
h(6.431-rm)*h(dis-1.8209),66.81135


In [24]:
earth1$bx

(Intercept),h(lstat-6.12),h(rm-6.431),h(6.431-rm),h(rm-6.431)*h(ptratio-18.6),h(rm-6.431)*h(18.6-ptratio),h(tax-305)*h(6.12-lstat),h(305-tax)*h(6.12-lstat),h(0.713-nox)*h(lstat-6.12),h(6.431-rm)*h(dis-1.8209),h(6.431-rm)*h(1.8209-dis),h(crim-4.42228),h(4.42228-crim),h(dis-1.3567),h(1.3567-dis),h(6.43-rm)*h(dis-1.3567),h(6.431-rm)*h(lstat-19.31),h(4.42228-crim)*h(224-tax),h(98.4-age)*h(b-240.16),h(98.4-age)*h(240.16-b)
1,0.00,0.144,0.000,0.0000,0.4752,0,10.26,0.00000,0.0000000,0,0,4.41596,2.7333,0,0.0000000,0.00000,0.00000,5203.768,0
1,3.02,0.000,0.010,0.0000,0.0000,0,0.00,0.73688,0.0314620,0,0,4.39497,3.6104,0,0.0324936,0.00000,0.00000,3056.430,0
1,0.00,0.754,0.000,0.0000,0.6032,0,131.67,0.00000,0.0000000,0,0,4.39499,3.6104,0,0.0000000,0.00000,0.00000,5694.591,0
1,0.00,0.567,0.000,0.0567,0.0000,0,263.94,0.00000,0.0000000,0,0,4.38991,4.7055,0,0.0000000,0.00000,8.77982,8125.122,0
1,0.00,0.716,0.000,0.0716,0.0000,0,65.57,0.00000,0.0000000,0,0,4.35323,4.7055,0,0.0000000,0.00000,8.70646,6927.908,0
1,0.00,0.000,0.001,0.0000,0.0000,0,75.53,0.00000,0.0042413,0,0,4.39243,4.7055,0,0.0000000,0.00000,8.78486,6112.212,0
1,6.31,0.000,0.419,0.0000,0.0000,0,0.00,1.19259,1.5668924,0,0,4.33399,4.2038,0,1.7571884,0.00000,0.00000,4942.992,0
1,13.03,0.000,0.259,0.0000,0.0000,0,0.00,2.46267,1.0695664,0,0,4.27773,4.5938,0,1.1852004,0.00000,0.00000,360.502,0
1,23.81,0.000,0.800,0.0000,0.0000,0,0.00,4.50009,3.4089600,0,0,4.21104,4.7254,0,3.7755946,8.49600,0.00000,0.000,0
1,10.98,0.000,0.427,0.0000,0.0000,0,0.00,2.07522,2.0373024,0,0,4.25224,5.2354,0,2.2302804,0.00000,0.00000,1831.875,0


In [25]:
earth1$cuts

Unnamed: 0,crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,b,lstat
(Intercept),0.0,0,0,0,0.0,0.0,0.0,0.0,0,0,0.0,0.0,0.0
h(lstat-6.12),0.0,0,0,0,0.0,0.0,0.0,0.0,0,0,0.0,0.0,6.12
h(6.12-lstat),0.0,0,0,0,0.0,0.0,0.0,0.0,0,0,0.0,0.0,6.12
h(rm-6.431),0.0,0,0,0,0.0,6.431,0.0,0.0,0,0,0.0,0.0,0.0
h(6.431-rm),0.0,0,0,0,0.0,6.431,0.0,0.0,0,0,0.0,0.0,0.0
h(rm-6.431)*h(ptratio-18.6),0.0,0,0,0,0.0,6.431,0.0,0.0,0,0,18.6,0.0,0.0
h(rm-6.431)*h(18.6-ptratio),0.0,0,0,0,0.0,6.431,0.0,0.0,0,0,18.6,0.0,0.0
h(tax-305)*h(6.12-lstat),0.0,0,0,0,0.0,0.0,0.0,0.0,0,305,0.0,0.0,6.12
h(305-tax)*h(6.12-lstat),0.0,0,0,0,0.0,0.0,0.0,0.0,0,305,0.0,0.0,6.12
h(nox-0.713)*h(lstat-6.12),0.0,0,0,0,0.713,0.0,0.0,0.0,0,0,0.0,0.0,6.12


In [26]:
earth1$dirs

Unnamed: 0,crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,b,lstat
(Intercept),0,0,0,0,0,0,0,0,0,0,0,0,0
h(lstat-6.12),0,0,0,0,0,0,0,0,0,0,0,0,1
h(6.12-lstat),0,0,0,0,0,0,0,0,0,0,0,0,-1
h(rm-6.431),0,0,0,0,0,1,0,0,0,0,0,0,0
h(6.431-rm),0,0,0,0,0,-1,0,0,0,0,0,0,0
h(rm-6.431)*h(ptratio-18.6),0,0,0,0,0,1,0,0,0,0,1,0,0
h(rm-6.431)*h(18.6-ptratio),0,0,0,0,0,1,0,0,0,0,-1,0,0
h(tax-305)*h(6.12-lstat),0,0,0,0,0,0,0,0,0,1,0,0,-1
h(305-tax)*h(6.12-lstat),0,0,0,0,0,0,0,0,0,-1,0,0,-1
h(nox-0.713)*h(lstat-6.12),0,0,0,0,1,0,0,0,0,0,0,0,1
