In [1]:
import numpy
import pandas
import pyspark
from sklearn import datasets

from pybda.glm import GLM
from pybda.spark.features import assemble

In [2]:
conf = (pyspark.SparkConf()
         .setMaster("local")
         .set("spark.driver.memory", "1g")
         .set("spark.executor.memory", "1g"))
sc = pyspark.SparkContext(conf=conf)
spark = pyspark.sql.SparkSession(sc)

In [4]:
boston = datasets.load_boston()
features = list(boston.feature_names)
response =  "response"
df = pandas.DataFrame(
  data=numpy.column_stack((boston.data, boston.target[:,numpy.newaxis])),
  columns=features + [response])

In [5]:
df.head()

Unnamed: 0,CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PTRATIO,B,LSTAT,response
0,0.00632,18.0,2.31,0.0,0.538,6.575,65.2,4.09,1.0,296.0,15.3,396.9,4.98,24.0
1,0.02731,0.0,7.07,0.0,0.469,6.421,78.9,4.9671,2.0,242.0,17.8,396.9,9.14,21.6
2,0.02729,0.0,7.07,0.0,0.469,7.185,61.1,4.9671,2.0,242.0,17.8,392.83,4.03,34.7
3,0.03237,0.0,2.18,0.0,0.458,6.998,45.8,6.0622,3.0,222.0,18.7,394.63,2.94,33.4
4,0.06905,0.0,2.18,0.0,0.458,7.147,54.2,6.0622,3.0,222.0,18.7,396.9,5.33,36.2


In [6]:
spark_df = spark.createDataFrame(df)
spark_df = assemble(spark_df, features, True)

In [7]:
spark_df.take(5)

[Row(response=24.0, features=DenseVector([0.0063, 18.0, 2.31, 0.0, 0.538, 6.575, 65.2, 4.09, 1.0, 296.0, 15.3, 396.9, 4.98])),
 Row(response=21.6, features=DenseVector([0.0273, 0.0, 7.07, 0.0, 0.469, 6.421, 78.9, 4.9671, 2.0, 242.0, 17.8, 396.9, 9.14])),
 Row(response=34.7, features=DenseVector([0.0273, 0.0, 7.07, 0.0, 0.469, 7.185, 61.1, 4.9671, 2.0, 242.0, 17.8, 392.83, 4.03])),
 Row(response=33.4, features=DenseVector([0.0324, 0.0, 2.18, 0.0, 0.458, 6.998, 45.8, 6.0622, 3.0, 222.0, 18.7, 394.63, 2.94])),
 Row(response=36.2, features=DenseVector([0.069, 0.0, 2.18, 0.0, 0.458, 7.147, 54.2, 6.0622, 3.0, 222.0, 18.7, 396.9, 5.33]))]

In [8]:
model = GLM(spark, response, features)
fit = model.fit(spark_df)

In [9]:
fit.p_values

array([1.08681010e-03, 7.78109688e-04, 7.38288071e-01, 1.92503033e-03,
       4.24564381e-06, 0.00000000e+00, 9.58229309e-01, 6.01296790e-13,
       5.07052902e-06, 1.11163672e-03, 1.30873090e-12, 5.72859167e-04,
       0.00000000e+00, 3.28337357e-12])

In [10]:
fit.r2

0.7406426641094093

In [11]:
spark.stop()

In [12]:
sc.stop()