In [2]:
import pandas as pd
from pyspark.sql.types import *
from pyspark.sql import *

In [3]:
def evaluate_model(model_name, model_path, weights_path, validation_data_set, evaluation, input_col='serie', label_col='sales'):
  # Load model
  model = load_model(model_path, weights_path)
  # Load validation data set
  validation = spark.sql("select * from %s" % validation_data_set)
  # Get validation series and labels
  validation_x, validation_y = prepare_collected_data(validation.select(input_col, label_col).collect())
  # Make predictions
  predictions = model.predict(validation_x)
  # Turn predctions into spark dataframe
  df = pd.DataFrame(validation_y, columns=['label'])
  df[label_col] = predictions
  df_predictions = spark.createDataFrame(df)
  # Evaluate model
  rmse = rmse_evaluator.evaluate(df_predictions)
  mse = mse_evaluator.evaluate(df_predictions)
  mae = mae_evaluator.evaluate(df_predictions)
  print("RMSE: %f, MSE: %f, MAE: %f" % (rmse, mse, mae))
  store_model(model_name, rmse, mse, mae, evaluation)

In [4]:
def store_model(model_name, validation_rmse, validation_mse, validation_mae, evaluation):
  data = [Row(model_name, validation_rmse, validation_mse, validation_mae)]

  schema = [
    StructField("Model name", StringType(), True),
    StructField("RMSE", FloatType(), True),
    StructField("MSE", FloatType(), True),
    StructField("MAE", FloatType(), True)]

  new_df = spark.createDataFrame(
    spark.sparkContext.parallelize(data),
    StructType(schema))
  
  evaluation = spark.sql("select * from %s" % evaluation)
  evaluation = evaluation.union(new_df)
  evaluation.write.saveAsTable('model_evaluation_temp', mode='overwrite')
  evaluation = spark.sql("select * from %s" % 'model_evaluation_temp')
  evaluation.write.saveAsTable('model_evaluation', mode='overwrite')

In [5]:
# someData = [
#   Row('model1', 1, 2, 3)
# ]

# someSchema = [
#   StructField("Model_name", StringType(), True),
#   StructField("RMSE", FloatType(), True),
#   StructField("MSE", FloatType(), True),
#   StructField("MAE", FloatType(), True)
# ]

# model_evaluation = spark.createDataFrame(
#   spark.sparkContext.parallelize([]),
#   StructType(someSchema)
# )
# model_evaluation.write.saveAsTable('model_evaluation', mode='overwrite')