# 6.2 Staging Data

## S3 Credentials

In [3]:
import os

sc._jsc.hadoopConfiguration().set("fs.s3n.awsAccessKeyId", os.environ['AWS_ACCESS_KEY'])
sc._jsc.hadoopConfiguration().set("fs.s3n.awsSecretAccessKey", os.environ['AWS_SECRET_KEY'])

In [4]:
games_df = spark.read.csv(
  path="s3://dsp-ch6-00/csv/games-expand.csv", 
  header=True,
  inferSchema=True
)
display(games_df)

G1,G2,G3,G4,G5,G6,G7,G8,G9,G10,label
0,0,0,1,0,0,0,0,0,0,0
0,0,0,0,1,0,0,0,0,0,0
0,0,1,0,0,0,0,0,0,0,0
0,0,1,0,0,1,1,0,0,1,1
0,0,1,0,1,1,0,1,1,0,1
1,0,1,0,1,0,0,0,0,0,0
0,0,1,0,0,0,0,0,0,0,0
1,0,1,0,1,0,0,0,0,0,0
1,1,0,1,0,1,1,1,0,0,0
0,0,1,0,0,0,0,0,0,0,0


# 6.3 A PySpark Primer

In [6]:
stats_df = spark.read.csv(
  path="s3://dsp-ch6-00/csv/game_skater_stats.csv", 
  header=True,
  inferSchema=True
)
display(stats_df)

## Persisting Dataframes

In [8]:
# AVRO is good for streaming, using here more as an introduction to different cloud file formats 

# AVRO write
avro_path = "s3://dsp-ch6-00/avro/game_skater_stats/"
stats_df.write.mode('overwrite').format("com.databricks.spark.avro").save(avro_path)

# AVRO read 
avro_df = sqlContext.read.format("com.databricks.spark.avro").load(avro_path)
display(avro_df)

In [9]:
# parquet out
parquet_path = "s3a://dsp-ch6-00/games-parquet/"
avro_df.write.mode('overwrite').parquet(parquet_path)

# parquet in
parquet_df = sqlContext.read.parquet(parquet_path)
display(parquet_df)

In [10]:
# orc out
orc_path = "s3a://dsp-ch6-00/games-orc/"
parquet_df.write.mode('overwrite').orc(orc_path)

# orc in
orc_df = sqlContext.read.orc(orc_path)
display(orc_df)

In [11]:
# CSV out
csv_path = "s3a://dsp-ch6-00/games-csv-out/"
orc_df.coalesce(1).write.mode('overwrite').format("com.databricks.spark.csv").option("header", "true").save(csv_path)
  
# and CSV to finish the round trip 
csv_df = spark.read.csv(path=csv_path, 
                        header=True, 
                        inferSchema=True)
display(csv_df)

## Converting Data Frames

In [13]:
stats_pd = stats_df.toPandas()
display(stats_pd)

In [14]:
stats_df = sqlContext.createDataFrame(stats_pd)

In [15]:
import databricks.koalas as ks

stats_ks = stats_df.to_koalas()
stats_df = stats_ks.to_spark()

print(stats_ks['timeOnIce'].mean())
print(stats_ks.iloc[:1, 1:2])

#993.6149113898216

#   player_id
#0    8467412


In [16]:
# spark -> koalas -> pandas -> spark
stats_df = sqlContext.createDataFrame(stats_df.to_koalas().toPandas())

# spark -> pandas -> koalas -> spark
stats_df = ks.from_pandas(stats_df.toPandas()).to_spark()

## Transforming Data

In [18]:
# summary stats 
stats_df = spark.read.csv(
  path="s3://dsp-ch6-00/csv/game_skater_stats.csv", 
  header=True, 
  inferSchema=True
)

stats_df.createOrReplaceTempView("stats")

new_df = spark.sql("""
  SELECT
      player_id,
      COUNT(game_id) AS games,
      SUM(goals) AS goals
  FROM
      stats
  GROUP BY
      player_id
  ORDER BY
      goals DESC
  LIMIT 5
""")

display(new_df)

In [19]:
display(spark.sql("""
  SELECT
      CAST(goals/shots * 50 AS INT)/50.0 AS goals_per_shot,
      COUNT(player_id) AS players 
  FROM (
    SELECT
        player_id, 
        SUM(shots) AS shots, 
        SUM(goals) AS goals
    FROM
        stats
    GROUP BY
        player_id
    HAVING
      goals >= 5
  )  
  GROUP BY
      goals_per_shot
  ORDER BY
      goals_per_shot
"""))

In [20]:
from pyspark.sql.functions import lit

copy_df = (
  stats_df
  
  # droping columns
  .drop('game_id', 'player_id')
  
  # selection columns
  .select('assists', 'goals', 'shots')
  
  # adding columns
  .withColumn("league", lit('NHL'))
)

display(copy_df)

In [21]:
# join
df = (
  stats_df
  .select('game_id', 'player_id')
  .withColumn("league", lit('NHL'))
  .join(other=stats_df,
        on=['game_id', 'player_id'], 
        how='inner')
)
display(df)

In [22]:
# group by 
summary_df = (
  stats_df
  .groupby("player_id")
  .agg({'timeOnIce': 'avg', 'goals': 'sum'})
)
display(summary_df)

## Pandas UDFs

In [24]:
# Sample data for a player 
sample_pd = spark.sql("""
  SELECT
      *
  FROM
      stats
  WHERE
      player_id = 8471214
""").toPandas()

# Import python libraries 
from scipy.optimize import leastsq
import numpy as np

# Define a function to fit
def fit(params, x, y):
    return (y - (params[0] + x * params[1] ))
  
# Fit the curve and show the results 
result = leastsq(fit, [1, 0], args=(sample_pd['shots'], sample_pd['hits']))

print(result)

In [25]:
# Load necessary libraries
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import *
import pandas as pd

# Create the schema for the resulting data frame
schema = StructType([StructField('id', LongType(), True),
                     StructField('p0', DoubleType(), True),
                     StructField('p1', DoubleType(), True)])

# Define the UDF, input and outputs are Pandas DFs
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def analyze_player(sample_pd):
    # return empty params in not enough data
    if (len(sample_pd['shots']) <= 1):
        return pd.DataFrame({'id': [sample_pd['player_id'][0]], 
                             'p0': [0], 'p1': [0]})
    # Perform curve fitting     
    result = leastsq(fit, [1, 0], 
                     args=(sample_pd['shots'], sample_pd['hits']))
    # Return the parameters as a Pandas DF 
    return pd.DataFrame({'id': [sample_pd['player_id'][0]], 
                         'p0': [result[0][0]], 'p1': [result[0][1]]})

# perform the UDF and show the results 
player_df = stats_df.groupby('player_id').apply(analyze_player)
display(player_df)

# 6.4 MLlib Batch Pipeline

In [27]:
games_df = spark.read.csv(
  path="s3://dsp-ch6-00/csv/games-expand.csv", 
  header=True, 
  inferSchema=True
)
games_df.createOrReplaceTempView("games")

games_df = spark.sql("""
    SELECT
        *,
        ROW_NUMBER() OVER (ORDER BY RAND(42)) AS user_id,
        CASE WHEN RAND(42) > 0.7 THEN 1 ELSE 0 END AS test
    FROM
        games
""")

display(games_df)

In [28]:
# train_df = games_df.filter("test == 0")
# test_df = games_df.filter("test == 1")
# print("Train", train_df.count())
# print("Test", test_df.count())

In [29]:
train_df, test_df = games_df.randomSplit(weights=[0.8, 0.2], seed=42)
print("Train", train_df.count())
print("Test", test_df.count())

## Vector Columns

In [31]:
train_df.schema.names[:10]

In [32]:
from pyspark.ml.feature import VectorAssembler

# create a vector representation
assembler = VectorAssembler(
  inputCols=train_df.schema.names[:10],
  outputCol="features"
)

train_vec = (assembler
             .transform(train_df)
             .select('label', 'features'))
test_vec = (assembler
            .transform(test_df)
            .select('label', 'features', 'user_id'))

display(test_vec)

## Model Application

In [34]:
from pyspark.ml.classification import LogisticRegression

# specify the columns for the model
lr = LogisticRegression(featuresCol='features', labelCol='label')

# fit on training data
model = lr.fit(train_vec)

# predict on test data 
pred_df = model.transform(test_vec)

# show predictions 
display(pred_df)

In [35]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# calculate performance metrics 
roc = BinaryClassificationEvaluator().evaluate(pred_df)
print(roc)

In [36]:
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType

# split out the array into a column 
get_second_element = udf(lambda v: float(v[1]), FloatType())
pred_df = pred_df.select(
  "*", get_second_element("probability").alias("propensity")
)
display(pred_df)

In [37]:
# save results to S3
results_df = pred_df.select("user_id", "propensity")
results_path = "s3a://dsp-ch6-00/game-predictions/"
results_df.write.mode('overwrite').parquet(results_path)

display(sqlContext.read.parquet(results_path))

In [38]:
# plot the predictions 
pred_df.createOrReplaceTempView("pred_df")

plot_df = spark.sql("""
  SELECT
      CAST(propensity*100 AS INT)/100 AS propensity, 
      label,
      COUNT(user_id) AS users
  FROM
      pred_df 
  GROUP BY
      --CAST(propensity*100 AS INT)/100, label
      1, 2
  ORDER BY
      --CAST(propensity*100 AS INT)/100, label
      1, 2  
""")

# table output
display(plot_df)

# 6.5 Distributed Deep Learning

## Model Training

In [41]:
from sklearn.model_selection import train_test_split

# build model on the driver node 
train_pd = train_df.toPandas()

x_train_val, x_test, y_train_val, y_test = train_test_split(
    train_pd.iloc[:, :10], 
    train_pd['label'], 
    test_size=0.33,
    random_state=42,
    stratify=train_pd['label']
)

x_train, x_val, y_train, y_val = train_test_split(
    x_train_val, 
    y_train_val, 
    test_size=0.2,
    random_state=42,
    stratify=y_train_val
)

In [42]:
import tensorflow as tf
import tensorflow.keras as K

def get_model():
    model = K.models.Sequential()
    model.add(K.layers.Dense(64, activation='relu', input_shape=(10,)))
    model.add(K.layers.Dropout(0.1))
    model.add(K.layers.Dense(64, activation='relu'))
    model.add(K.layers.Dropout(0.1))
    model.add(K.layers.Dense(64, activation='relu'))
    model.add(K.layers.Dense(1, activation='sigmoid'))
    return model

# define the network structure
model = get_model()
    
# compile and fit the model    
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[K.metrics.AUC(name='auc')])
history = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=100, batch_size=100, verbose=0)

In [43]:
import matplotlib.pyplot as plt

auc = history.history['auc']
val_auc = history.history['val_auc']
epochs = range(1, len(auc) + 1)

plt.figure(figsize=(12, 6))
plt.rcParams.update({'font.size': 20})
plt.plot(epochs, auc, 'bo', label='Training AUC')
plt.plot(epochs, val_auc, 'b', label='Validation AUC')
plt.legend()
plt.show()

In [44]:
import matplotlib.pyplot as plt

loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)

plt.figure(figsize=(12, 6))
plt.rcParams.update({'font.size': 20})
plt.plot(epochs, loss, 'bo', label='Training Loss')
plt.plot(epochs, val_loss, 'b', label='Validation Loss')
plt.legend()
plt.show()

## Model Application

In [46]:
# set up partitioning for the train data frame
test_df.createOrReplaceTempView("test_df")

partitioned_df = spark.sql("""
  SELECT
      *, 
      CAST(RAND()*100 AS INT) AS partition_id
  FROM
      test_df
""")

display(partitioned_df)

In [47]:
# https://docs.databricks.com/applications/machine-learning/model-inference/resnet-model-inference-keras.html'

from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import *

schema = StructType([StructField('user_id', LongType(), True),
                     StructField('propensity', DoubleType(), True)])

bc_model_weights = sc.broadcast(model.get_weights())

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def apply_keras(pd_df):
  
    model = get_model()
    model.set_weights(bc_model_weights.value)
    
    pd_df['propensity'] = model.predict(pd_df.iloc[:, :10])
    return pd_df[['user_id', 'propensity']]

results_df = partitioned_df.groupby('partition_id').apply(apply_keras)
display(results_df)

# 6.6 Distributed Feature Engineering

In [49]:
plays_pd = (
  spark.read.csv(
    path="s3://dsp-ch6-00/csv/game_plays.csv",
    inferSchema=True,
    header=True,
    nullValue='NA'
  )
  .drop('secondaryType', 'periodType', 'dateTime', 'rink_side')
  .fillna(0)
  .filter("rand() < 0.003")
  .toPandas()
)

plays_pd.shape

## Feature Generation

In [51]:
import featuretools as ft
from featuretools import Feature

es = ft.EntitySet(id="plays")
es = es.entity_from_dataframe(
  entity_id="plays",
  dataframe=plays_pd, 
  index="play_id", 
  variable_types={
    "event": ft.variable_types.Categorical,
    "description": ft.variable_types.Categorical
  }
)

features = list(map(Feature, es["plays"].variables))

encoded, defs = ft.encode_features(plays_pd, features, top_n=10)
encoded.reset_index(drop=True, inplace=True)
encoded.head()

Unnamed: 0,play_id,game_id,play_num,team_id_for,team_id_against,x,y,period,periodTime,periodTimeRemaining,goals_away,goals_home,st_x,st_y,event = Faceoff,event = Shot,event = Hit,event = Stoppage,event = Blocked Shot,event = Missed Shot,event = Giveaway,event = Takeaway,event = Penalty,event = Goal,event is unknown,description = Goalie Stopped,description = Icing,description = Offside,description = Puck in Netting,description = Period Ready,description = Period Start,description = Period Official,description = Puck Frozen,description = Puck in Crowd,description = Puck in Benches,description is unknown
0,2011030221_280,2011030221,280,4,1,-69,22,3,738,462,3,3,69,-22,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True
1,2011030111_31,2011030111,31,3,9,-72,14,1,338,862,0,0,72,-14,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True
2,2011030111_34,2011030111,34,3,9,-76,-5,1,345,855,0,0,76,5,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True
3,2011030222_234,2011030222,234,1,4,-73,14,3,223,977,1,1,-73,14,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True
4,2011030224_165,2011030224,165,1,4,-80,5,2,574,626,2,2,80,-5,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True


In [52]:
es = ft.EntitySet(id="plays")
es = es.entity_from_dataframe(entity_id="plays", 
                              dataframe=encoded, 
                              index="play_id")

es = es.normalize_entity(base_entity_id="plays", 
                         new_entity_id="games", 
                         index="game_id")

features, transform = ft.dfs(entityset=es, 
                             target_entity="games",
                             max_depth=2)

features.fillna(0, inplace=True)

features.reset_index(inplace=True)

In [53]:
features.columns = features.columns.str.replace("[(). =]", "")
schema = sqlContext.createDataFrame(features).schema
features.columns

## Feature Application

In [55]:
plays_df = (
  spark.read.csv(
    path="s3://dsp-ch6-00/csv/game_plays.csv", 
    header=True, 
    inferSchema=True,
    nullValue='NA'
  )
  .drop('secondaryType', 'periodType', 'dateTime', 'rink_side')
  .fillna(0)
)

In [56]:
# bucket IDs 
plays_df.createOrReplaceTempView("plays_df")
plays_df = spark.sql("""
  SELECT
      *, 
      ABS(HASH(game_id)) % 1000 AS partition_id 
  FROM 
      plays_df 
""")

In [57]:
# https://habr.com/ru/company/leroy_merlin/blog/511792/ - featuretools in action
from pyspark.sql.functions import pandas_udf, PandasUDFType

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def generate_features(plays_pd):

    # Encoding
    es = ft.EntitySet(id="plays")
    es = es.entity_from_dataframe(
      entity_id="plays",
      dataframe=plays_pd, 
      index="play_id",
      variable_types={
        "event": ft.variable_types.Categorical,
        "description": ft.variable_types.Categorical
      }
    )
    encoded_features = ft.calculate_feature_matrix(defs, es)    
    encoded_features.reset_index(drop=True, inplace=True)
  
    # Aggregation
    es = ft.EntitySet(id="plays")
    es = es.entity_from_dataframe(entity_id="plays", 
                                  dataframe=encoded_features, 
                                  index="play_id")
    es = es.normalize_entity(base_entity_id="plays", 
                             new_entity_id="games", 
                             index="game_id")
    generated = ft.calculate_feature_matrix(transform, es)
    generated.fillna(0, inplace=True)
    generated.reset_index(inplace=True)
    
    # Renaming
    generated.columns = generated.columns.str.replace("[(). =]", "")
    
    return generated 
  
features_df = plays_df.groupby('partition_id').apply(generate_features)
display(features_df)

game_id,COUNTplays,MAXplaysgoals_away,MAXplaysgoals_home,MAXplaysperiod,MAXplaysperiodTime,MAXplaysperiodTimeRemaining,MAXplaysplay_num,MAXplaysst_x,MAXplaysst_y,MAXplaysteam_id_against,MAXplaysteam_id_for,MAXplaysx,MAXplaysy,MEANplaysgoals_away,MEANplaysgoals_home,MEANplaysperiod,MEANplaysperiodTime,MEANplaysperiodTimeRemaining,MEANplaysplay_num,MEANplaysst_x,MEANplaysst_y,MEANplaysteam_id_against,MEANplaysteam_id_for,MEANplaysx,MEANplaysy,MINplaysgoals_away,MINplaysgoals_home,MINplaysperiod,MINplaysperiodTime,MINplaysperiodTimeRemaining,MINplaysplay_num,MINplaysst_x,MINplaysst_y,MINplaysteam_id_against,MINplaysteam_id_for,MINplaysx,MINplaysy,PERCENT_TRUEplaysdescriptionGoalieStopped,PERCENT_TRUEplaysdescriptionIcing,PERCENT_TRUEplaysdescriptionOffside,PERCENT_TRUEplaysdescriptionPeriodOfficial,PERCENT_TRUEplaysdescriptionPeriodReady,PERCENT_TRUEplaysdescriptionPeriodStart,PERCENT_TRUEplaysdescriptionPuckFrozen,PERCENT_TRUEplaysdescriptionPuckinBenches,PERCENT_TRUEplaysdescriptionPuckinCrowd,PERCENT_TRUEplaysdescriptionPuckinNetting,PERCENT_TRUEplaysdescriptionisunknown,PERCENT_TRUEplayseventBlockedShot,PERCENT_TRUEplayseventFaceoff,PERCENT_TRUEplayseventGiveaway,PERCENT_TRUEplayseventGoal,PERCENT_TRUEplayseventHit,PERCENT_TRUEplayseventMissedShot,PERCENT_TRUEplayseventPenalty,PERCENT_TRUEplayseventShot,PERCENT_TRUEplayseventStoppage,PERCENT_TRUEplayseventTakeaway,PERCENT_TRUEplayseventisunknown,SKEWplaysgoals_away,SKEWplaysgoals_home,SKEWplaysperiod,SKEWplaysperiodTime,SKEWplaysperiodTimeRemaining,SKEWplaysplay_num,SKEWplaysst_x,SKEWplaysst_y,SKEWplaysteam_id_against,SKEWplaysteam_id_for,SKEWplaysx,SKEWplaysy,STDplaysgoals_away,STDplaysgoals_home,STDplaysperiod,STDplaysperiodTime,STDplaysperiodTimeRemaining,STDplaysplay_num,STDplaysst_x,STDplaysst_y,STDplaysteam_id_against,STDplaysteam_id_for,STDplaysx,STDplaysy,SUMplaysgoals_away,SUMplaysgoals_home,SUMplaysperiod,SUMplaysperiodTime,SUMplaysperiodTimeRemaining,SUMplaysplay_num,SUMplaysst_x,SUMplaysst_y,SUMplaysteam_id_against,SUMplaysteam_id_for,SUMplaysx,SUMplaysy
2014030321,400,1,4,3,1200,1200,400,99,41,24,24,99,42,0.3325,1.5825,1.9575,605.3,594.7,200.5,10.7525,-0.7675,15.7,16.2,-1.6025,-0.2225,0,0,1,0,0,1,-99,-42,0,0,-98,-41,0.0475,0.0375,0.0225,0.0075,0.0075,0.0075,0.0,0.0125,0.01,0.02,0.8275,0.0775,0.1975,0.07,0.0125,0.195,0.08,0.01,0.1375,0.1675,0.0175,0.035,0.7137674166876996,0.0992073266925701,0.0787454884646239,-0.0160046351043167,0.0160046351043167,0.0,-0.2694975224149191,0.0336141174971594,-0.8815389221678194,-0.9333517632315638,0.1216143479611207,0.0564877742463343,0.47169905660283,0.9672293323993196,0.8199845649259451,355.61973157251737,355.6197315725174,115.61430130683084,60.28425767194653,20.548455935991477,8.686693652739304,8.917331270760396,61.21701807508278,20.561613351316147,133,633,783,242120,237880,80200,4301,-307,6280,6480,-641,-89
2012020451,338,2,0,3,1200,1200,338,99,42,26,26,99,41,0.3550295857988165,0.0,2.0088757396449703,619.189349112426,580.810650887574,169.5,10.976331360946746,1.3757396449704142,20.5207100591716,20.67159763313609,-6.153846153846154,1.955621301775148,0,0,1,0,0,1,-99,-41,0,0,-99,-42,0.0532544378698224,0.029585798816568,0.0236686390532544,0.0088757396449704,0.0088757396449704,0.0088757396449704,0.0118343195266272,0.0118343195266272,0.0059171597633136,0.0118343195266272,0.8254437869822485,0.0680473372781065,0.1804733727810651,0.0532544378698224,0.0059171597633136,0.2071005917159763,0.0857988165680473,0.014792899408284,0.1745562130177514,0.150887573964497,0.0177514792899408,0.0414201183431952,1.687683557787693,0.0,-0.0157760788215841,-0.0759059044502532,0.0759059044502532,0.0,-0.2954456347607327,0.0664390013278368,-1.5609615324900206,-1.5611409034761996,0.1512708530531762,-0.0150291315833425,0.7255335695516889,0.0,0.7912229440080853,362.5603065835543,362.56030658355417,97.7164264594239,58.601349237839344,21.016263383452284,10.0376671940779,10.11133091488264,59.30407484961338,20.970116813402225,120,0,679,209286,196314,57291,3710,465,6936,6987,-2080,661
2017020624,307,0,2,3,1200,1200,307,98,41,29,29,98,41,0.0,0.4364820846905538,1.925081433224756,605.1172638436482,594.8827361563518,154.0,9.094462540716613,0.8045602605863192,19.224755700325733,19.537459283387623,-2.6970684039087947,1.228013029315961,0,0,1,0,0,1,-98,-41,0,0,-97,-40,0.0651465798045602,0.0390879478827361,0.013029315960912,0.009771986970684,0.009771986970684,0.009771986970684,0.019543973941368,0.003257328990228,0.003257328990228,0.0260586319218241,0.8013029315960912,0.0879478827361563,0.2149837133550488,0.0260586319218241,0.006514657980456,0.1042345276872964,0.0586319218241042,0.0260586319218241,0.2084690553745928,0.1791530944625407,0.0423452768729641,0.0456026058631921,0.0,0.818735206441671,0.1386183415443986,-0.0180980318075557,0.0180980318075557,0.0,-0.2971533149041141,-0.0745347238189895,-0.961976949942234,-0.9793483715782376,0.0905031331045042,-0.0563573919861285,0.0,0.5586848488181041,0.8150481620240944,348.7959203422409,348.7959203422409,88.7674865402117,56.00665986852362,19.622605479464987,10.950496430785478,11.11030835276853,56.6782810289556,19.60059015117502,0,134,591,185771,182629,47278,2792,247,5902,5998,-828,377
2017020171,375,4,3,4,1200,1200,375,99,40,53,53,99,41,2.232,0.5546666666666666,2.1813333333333333,571.6266666666667,546.7733333333333,188.0,7.336,0.544,22.888,22.104,-2.288,-2.341333333333333,0,0,1,0,0,1,-98,-41,0,0,-98,-41,0.0453333333333333,0.024,0.0213333333333333,0.0106666666666666,0.0106666666666666,0.0106666666666666,0.0266666666666666,0.0026666666666666,0.008,0.024,0.816,0.0746666666666666,0.192,0.0773333333333333,0.0186666666666666,0.16,0.0693333333333333,0.0186666666666666,0.1546666666666666,0.1626666666666666,0.024,0.048,-0.7950797414462947,1.7400561097292973,0.2639432120196857,0.0839872402179541,0.2018361621216212,0.0,-0.2165588455233464,-0.0874708162820165,0.3874474617957769,0.4558198331105207,0.0748382976664662,0.0498248598211206,0.7436289109624858,0.9373836231986336,0.9529988319082758,369.8820007436877,361.50175380427294,108.397416943394,56.62853119633729,21.00833957148452,24.800341553625238,24.609296969361058,56.42578050751583,20.88421795794123,837,208,818,214360,205040,70500,2751,204,8583,8289,-858,-878
2017020177,333,2,3,3,1200,1200,333,98,42,28,28,98,42,0.8378378378378378,0.9129129129129128,2.03003003003003,622.4714714714714,577.5285285285286,167.0,1.78978978978979,-0.2942942942942942,14.09009009009009,16.036036036036037,0.1021021021021021,-1.987987987987988,0,0,1,0,0,1,-98,-41,0,0,-98,-42,0.045045045045045,0.045045045045045,0.018018018018018,0.009009009009009,0.009009009009009,0.009009009009009,0.006006006006006,0.012012012012012,0.015015015015015,0.018018018018018,0.8138138138138138,0.1381381381381381,0.1921921921921922,0.051051051051051,0.015015015015015,0.093093093093093,0.069069069069069,0.012012012012012,0.1561561561561561,0.1651651651651651,0.066066066066066,0.042042042042042,-0.660458942706818,0.5852284668935415,-0.0571396954188085,-0.0797027955367811,0.0797027955367811,0.0,-0.1354230286871961,0.0984148205686004,0.2535575288530706,-0.1070405115858599,-0.0505430371333458,0.069831651693454,0.4500443192474239,0.9607892000700822,0.8461459394824927,352.4894212673032,352.4894212673032,96.27304918823336,55.957257927052815,20.22651539880734,10.735777548467208,11.429138354082944,55.985866561159895,20.13044448182901,279,304,676,207283,192317,55611,596,-98,4692,5340,34,-662
2018020078,287,0,3,3,1200,1200,287,97,39,25,25,98,41,0.0,1.9442508710801396,2.0034843205574915,574.8885017421603,625.1114982578397,144.0,12.777003484320558,-0.6759581881533101,11.156794425087108,9.317073170731708,-0.2822299651567944,0.89198606271777,0,0,1,0,0,1,-98,-42,0,0,-97,-42,0.0522648083623693,0.048780487804878,0.0104529616724738,0.0104529616724738,0.0104529616724738,0.0104529616724738,0.0,0.0034843205574912,0.0104529616724738,0.0209059233449477,0.8222996515679443,0.048780487804878,0.2125435540069686,0.0313588850174216,0.0104529616724738,0.0975609756097561,0.0940766550522648,0.0452961672473867,0.1916376306620209,0.1637630662020905,0.0557491289198606,0.048780487804878,0.0,-0.3715189847699592,-0.0065527240502726,0.0770320283064118,-0.0770320283064118,0.0,-0.3905267193841678,-0.0500370786933667,0.2729724221099911,0.604059525667239,0.0146873906905835,-0.0745273228124472,0.0,1.015798308020795,0.8341414825227097,365.65886712219896,365.658867122199,82.99397568498571,52.922087085441085,19.327221292967508,12.100676778575204,11.671516092916752,54.44713029909906,19.318425891073428,0,558,575,164993,179407,41328,3667,-194,3202,2674,-81,256
2013020720,310,0,1,3,1200,1200,310,99,41,27,27,99,41,0.0,0.7741935483870968,1.9,581.8709677419355,618.1290322580645,155.5,12.264516129032256,1.4548387096774194,20.135483870967743,20.18709677419355,-8.264516129032257,-0.3774193548387097,0,0,1,0,0,1,-99,-41,0,0,-97,-41,0.0548387096774193,0.0129032258064516,0.0096774193548387,0.0096774193548387,0.0096774193548387,0.0096774193548387,0.0032258064516129,0.0064516129032258,0.0096774193548387,0.0161290322580645,0.8580645161290322,0.0741935483870967,0.1870967741935484,0.0354838709677419,0.0032258064516129,0.1354838709677419,0.1129032258064516,0.064516129032258,0.164516129032258,0.1483870967741935,0.0290322580645161,0.0451612903225806,0.0,-1.317964312323889,0.1788986318775755,0.0546330187229056,-0.0546330187229056,0.0,-0.3456055803261321,-0.075205758030353,-1.437671345991127,-1.4386728704068297,0.2112647340657866,-0.1740608636887282,0.0,0.4187883137555024,0.7880581188627533,362.63253057169845,362.63253057169845,89.63351307779176,57.16660230998536,20.7759560232352,10.042677487818397,10.0675949100149,57.88265972370896,20.823564720951367,0,240,589,180380,191620,48205,3802,451,6242,6258,-2562,-117
2016021139,291,5,3,3,1200,1200,291,96,40,21,21,98,40,2.9759450171821307,1.5945017182130583,2.0309278350515463,638.9656357388316,561.0343642611684,146.0,4.972508591065292,0.4295532646048109,14.505154639175258,14.56701030927835,-13.446735395189004,-3.59106529209622,0,0,1,0,0,1,-98,-40,0,0,-97,-40,0.0618556701030927,0.0171821305841924,0.0103092783505154,0.0103092783505154,0.0103092783505154,0.0103092783505154,0.0171821305841924,0.0,0.0068728522336769,0.0171821305841924,0.8384879725085911,0.1099656357388316,0.1924398625429553,0.0549828178694158,0.0274914089347079,0.0721649484536082,0.0618556701030927,0.0206185567010309,0.2130584192439862,0.1443298969072164,0.0549828178694158,0.0481099656357388,-0.5848769739828282,0.1108586371531564,-0.0575545984851103,-0.0842194291106631,0.0842194291106631,0.0,-0.1970950525232189,-0.0726670238187052,-1.1048228745529685,-1.1106420158652397,0.4061654566393765,0.0653879499927081,1.3605673450493885,1.0572430033348463,0.8236203709622039,363.4334623746776,363.4334623746776,84.1486779456457,55.72839863148858,19.551462229638485,7.589635007361671,7.617910678125162,54.30494752333811,19.22250268697848,866,464,591,185939,163261,42486,1447,125,4221,4239,-3913,-1045
2017020515,327,3,5,3,1200,1200,327,97,40,28,28,98,40,1.6788990825688073,2.400611620795107,2.1131498470948014,601.6391437308869,598.3608562691131,164.0,3.889908256880734,1.2813455657492354,20.238532110091743,19.6697247706422,5.743119266055046,0.382262996941896,0,0,1,0,0,1,-98,-38,0,0,-97,-38,0.0428134556574923,0.0336391437308868,0.0275229357798165,0.0091743119266055,0.0091743119266055,0.0091743119266055,0.0030581039755351,0.0152905198776758,0.0061162079510703,0.0214067278287461,0.8226299694189603,0.1070336391437308,0.2079510703363914,0.0856269113149847,0.0244648318042813,0.1162079510703364,0.073394495412844,0.018348623853211,0.128440366972477,0.1590214067278287,0.036697247706422,0.0428134556574923,-0.16873472773634,-0.5685981742018927,-0.2109489575225345,-0.0396850711787112,0.0396850711787112,0.0,-0.0662266617017816,0.0228116463018756,-1.262138251280417,-1.2405382300624197,-0.0818271613523521,-0.021927566200502,0.8883782765112335,1.4969497387948736,0.8155194356173296,354.6126782729329,354.61267827293267,94.54099639838792,58.124945720596,19.257387731283256,10.535565605318292,10.258668379369782,57.97070399373011,19.29630254208436,549,785,691,196736,195664,53628,1272,419,6618,6432,1878,125
2016020381,270,1,4,3,1200,1200,270,97,41,53,53,97,41,0.8629629629629629,1.6333333333333333,2.011111111111111,597.0370370370371,602.9629629629629,135.5,12.681481481481482,0.0407407407407407,34.72222222222222,31.788888888888888,4.837037037037037,0.7666666666666667,0,0,1,0,0,1,-97,-40,0,0,-96,-40,0.0259259259259259,0.0333333333333333,0.0148148148148148,0.0111111111111111,0.0111111111111111,0.0111111111111111,0.0259259259259259,0.0037037037037037,0.0037037037037037,0.0074074074074074,0.8518518518518519,0.0666666666666666,0.1888888888888888,0.0518518518518518,0.0185185185185185,0.1333333333333333,0.0888888888888888,0.0407407407407407,0.174074074074074,0.137037037037037,0.0444444444444444,0.0555555555555555,-2.122757953996236,0.3718422370600307,-0.0213151466357389,0.0619061371811773,-0.0619061371811773,0.0,-0.3254429936560654,-0.0335134648081591,-0.656353642540441,-0.4295691378506323,-0.1523829500614367,-0.0351480841012711,0.3445250468642865,1.1547541942208448,0.8556974299935969,372.7920856772641,372.7920856772641,78.08649050892222,55.866696339400974,21.282185496880075,19.90928933274164,18.72903702161766,57.08784171332469,21.26835966116287,233,441,543,161200,162800,36585,3424,11,9375,8583,1306,207


# 6.7 GCP Model Pipeline

## BigQuery Export

## GCP Credentials

In [61]:
# https://docs.databricks.com/data/data-sources/google/bigquery.html
# just specify GOOGLE_APPLICATION_CREDENTIALS and you'll get access to GCS and BigQuery

creds_file = os.environ['GOOGLE_APPLICATION_CREDENTIALS']
creds = sc.textFile(os.environ['GCP_CREDENTIALS_S3_PATH'])

with open(creds_file, 'w') as file:
    for line in creds.take(100):
        file.write(line + "\n")

## Model Pipeline

In [63]:
table = f"{os.environ['GCP_PROJECT_ID']}.dsp_demo.natality"
natality_df = (spark.read.format("bigquery")
               .option("table", table)
               .load())
display(natality_df)

source_year,year,month,day,wday,state,is_male,child_race,weight_pounds,plurality,apgar_1min,apgar_5min,mother_residence_state,mother_race,mother_age,gestation_weeks,lmp,mother_married,mother_birth_state,cigarette_use,cigarettes_per_day,alcohol_use,drinks_per_week,weight_gain_pounds,born_alive_alive,born_alive_dead,born_dead,ever_born,father_race,father_age,record_weight
2002,2002,11,,5.0,IN,False,9.0,7.5618555866,1.0,99.0,9.0,IN,1.0,24,39.0,02012002,True,Mexico,,,False,,24.0,1.0,0.0,0.0,2.0,1.0,26,1
1978,1978,6,26.0,,NY,True,1.0,7.936641432,1.0,9.0,10.0,NY,1.0,24,38.0,09301987,,NY,,,,,,0.0,0.0,0.0,1.0,1.0,23,1
1974,1974,2,2.0,,NC,True,2.0,7.31273323054,1.0,,,NC,2.0,18,38.0,05111943,True,NC,,,,,,0.0,0.0,0.0,1.0,2.0,20,2
1974,1974,2,23.0,,WI,False,1.0,5.37486994756,1.0,,,WI,1.0,24,,88881948,True,WI,,,,,,2.0,0.0,0.0,3.0,1.0,24,2
2004,2004,11,,3.0,CT,False,,6.8122838958,1.0,,9.0,CT,1.0,39,39.0,02992004,True,YY,False,,False,,10.0,1.0,0.0,1.0,2.0,1.0,41,1
1977,1977,3,31.0,,TX,True,1.0,8.87581066812,1.0,,,TX,1.0,18,,88881978,,TX,,,,,,0.0,0.0,0.0,1.0,1.0,20,1
1997,1997,1,,2.0,SC,True,9.0,7.3744626639,1.0,99.0,9.0,SC,2.0,35,40.0,03261996,True,SC,,,False,,16.0,3.0,0.0,0.0,4.0,2.0,29,1
1981,1981,4,17.0,,MA,False,1.0,7.06361087448,1.0,8.0,9.0,MA,1.0,26,42.0,06241910,True,CT,,,,,,2.0,0.0,0.0,3.0,1.0,27,1
2005,2005,3,,3.0,,True,,6.75055446244,1.0,,9.0,,1.0,30,39.0,06202004,True,,False,,False,,25.0,0.0,0.0,1.0,1.0,1.0,35,1
1999,1999,3,,2.0,IL,False,9.0,6.6248909731,1.0,99.0,9.0,IL,2.0,15,39.0,06271998,False,IL,,,False,,20.0,0.0,0.0,0.0,1.0,9.0,99,1


In [64]:
natality_df.createOrReplaceTempView("natality_df")

natality_df = spark.sql("""
    SELECT 
        year,
        plurality, 
        apgar_5min,
        mother_age, 
        father_age,    
        gestation_weeks, 
        ever_born,
        CASE WHEN mother_married = true THEN 1 ELSE 0 END AS mother_married,
        weight_pounds AS weight,
        CASE WHEN RAND() < 0.5 THEN 1 ELSE 0 END AS test
    FROM
        natality_df       
""").fillna(0)

train_df = natality_df.filter("test == 0")
test_df = natality_df.filter("test == 1")
print("Train", train_df.count())
print("Test", test_df.count())

display(natality_df)

year,plurality,apgar_5min,mother_age,father_age,gestation_weeks,ever_born,mother_married,weight,test
2002,1,9,24,26,39,2,1,7.5618555866,1
1978,1,10,24,23,38,1,0,7.936641432,1
1974,1,0,18,20,38,1,1,7.31273323054,1
1974,1,0,24,24,0,3,1,5.37486994756,0
2004,1,9,39,41,39,2,1,6.8122838958,0
1977,1,0,18,20,0,1,0,8.87581066812,1
1997,1,9,35,29,40,4,1,7.3744626639,0
1981,1,9,26,27,42,3,1,7.06361087448,1
2005,1,9,30,35,39,1,1,6.75055446244,1
1999,1,9,15,99,39,1,0,6.6248909731,1


In [65]:
from pyspark.ml.feature import VectorAssembler

# create a vector representation
assembler = VectorAssembler(
  inputCols= train_df.schema.names[:8],
  outputCol="features"
)

train_vec = assembler.transform(train_df).select('weight', 'features')
test_vec = assembler.transform(test_df).select('weight', 'features')

display(test_vec)

weight,features
7.5618555866,"List(1, 8, List(), List(2002.0, 1.0, 9.0, 24.0, 26.0, 39.0, 2.0, 1.0))"
7.936641432,"List(1, 8, List(), List(1978.0, 1.0, 10.0, 24.0, 23.0, 38.0, 1.0, 0.0))"
7.31273323054,"List(1, 8, List(), List(1974.0, 1.0, 0.0, 18.0, 20.0, 38.0, 1.0, 1.0))"
8.87581066812,"List(1, 8, List(), List(1977.0, 1.0, 0.0, 18.0, 20.0, 0.0, 1.0, 0.0))"
7.06361087448,"List(1, 8, List(), List(1981.0, 1.0, 9.0, 26.0, 27.0, 42.0, 3.0, 1.0))"
6.75055446244,"List(1, 8, List(), List(2005.0, 1.0, 9.0, 30.0, 35.0, 39.0, 1.0, 1.0))"
6.6248909731,"List(1, 8, List(), List(1999.0, 1.0, 9.0, 15.0, 99.0, 39.0, 1.0, 0.0))"
6.062712205,"List(1, 8, List(), List(1987.0, 1.0, 99.0, 34.0, 36.0, 39.0, 2.0, 1.0))"
5.93704871566,"List(1, 8, List(), List(1994.0, 1.0, 7.0, 26.0, 33.0, 38.0, 3.0, 0.0))"
4.7509617461,"List(1, 8, List(), List(1980.0, 1.0, 9.0, 25.0, 24.0, 32.0, 1.0, 1.0))"


In [66]:
from pyspark.ml.tuning import ParamGridBuilder 
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.tuning import CrossValidator
from pyspark.ml.evaluation import RegressionEvaluator

folds = 3
rf_trees = [50, 100]
rf_depth = [4, 5]

rf = RandomForestRegressor(featuresCol='features', labelCol='weight')

param_grid = (ParamGridBuilder()
              .addGrid(rf.numTrees, rf_trees)
              .addGrid(rf.maxDepth, rf_depth)
              .build())
crossval = CrossValidator(
  estimator=rf, 
  estimatorParamMaps=param_grid,
  evaluator=RegressionEvaluator(labelCol='weight'),
  numFolds=folds
)      
rf_model = crossval.fit(train_vec)
   
preds_df = rf_model.transform(test_vec).select("weight", "prediction")
display(preds_df)

weight,prediction
7.5618555866,7.688853486585796
7.936641432,7.21488497776436
7.31273323054,7.013529378960051
8.87581066812,6.453919369939452
7.06361087448,7.736195720249044
6.75055446244,7.663675622808915
6.6248909731,6.991829372177746
6.062712205,7.710976626890192
5.93704871566,7.273825817965194
4.7509617461,5.8517969484955055


In [67]:
import time

out_path = "gs://dsp_model_store_00/natality/preds-{time}/".format(time=int(time.time()*1000))
preds_df.write.mode('overwrite').format("avro").save(out_path)

print(out_path)
# display(spark.read.format("avro").load(out_path))