In [1]:
from pyspark import SparkContext, SQLContext
from pyspark.sql.functions import udf, lit, col, when, avg, countDistinct, year, month
from pyspark.sql import Window, DataFrame
from pyspark.sql.types import IntegerType

from pyspark.ml.pipeline import Pipeline, Transformer
import numpy as np
import pandas as pd
import pickle
import dill
import codecs

import importlib

In [2]:
sc=SparkContext(appName='jlg')
sqlcontext=SQLContext(sc)

In [3]:
# create dumb pyspark dataframe

X1 = np.random.rand(1000).reshape(-1,1)
X2 = np.random.rand(1000).reshape(-1,1)
X3 = np.random.rand(1000).reshape(-1,1)
X4 = np.random.rand(1000).reshape(-1,1)
Y = X1*2 + X2*4 + X3*2 + X4*1

m = np.hstack([X1,X2,X3,X4,Y])
dataset = pd.DataFrame(m)
dataset.columns = ['X1','X2','X3','X4','Label']

dataset.to_csv('data/foo.csv', index=False)

df = sqlcontext.createDataFrame(dataset,schema=["F1", "F2", "F3", "F4", "Label"])

df.show(5)

+------------------+-------------------+-------------------+-------------------+-----------------+
|                F1|                 F2|                 F3|                 F4|            Label|
+------------------+-------------------+-------------------+-------------------+-----------------+
|0.5205162467747544|0.37892804297373994|0.24542089328813999| 0.4846753984143245|3.532261850435073|
|  0.71781738252966| 0.8611885956954083|0.19367754242737611|  0.623370484839035| 5.89111471753474|
|0.1516094297235261| 0.9886550399625831| 0.6741350311317537| 0.7634411717429604|6.369550253303853|
|0.6650589318917749| 0.9345007492864892| 0.7694777876551938|0.22692208146915405|6.833998517709048|
|0.6933133498352892| 0.9020107922484703| 0.8321117333942878|   0.52579363720328|7.184686972656315|
+------------------+-------------------+-------------------+-------------------+-----------------+
only showing top 5 rows



In [4]:
# create custome transfomer

def Linear_Scaler(params):
    """
    A custom Transformer which scale the value up
    """
    foo = udf(lambda x: x*2)
    
    context = params['context']
    df = context.read.csv('data/foo.csv', header='true', inferSchema = 'true')
    
    alpha = params['alpha'] 
    inputCol = params['inputCol'] 
    outputCol = params['outputCol']
    
    # do transform
    tmp = df.withColumn(outputCol, df[inputCol]*alpha)
    return tmp

    

In [5]:
foo = Linear_Scaler({'context':sqlcontext,'inputCol':'X1','outputCol':'F1','alpha':2.0})
foo.show(3)


+------------------+-------------------+-------------------+------------------+-----------------+------------------+
|                X1|                 X2|                 X3|                X4|            Label|                F1|
+------------------+-------------------+-------------------+------------------+-----------------+------------------+
|0.5205162467747544|0.37892804297373994|0.24542089328813999|0.4846753984143245|3.532261850435073|1.0410324935495088|
|  0.71781738252966| 0.8611885956954083|0.19367754242737611| 0.623370484839035| 5.89111471753474|  1.43563476505932|
|0.1516094297235261| 0.9886550399625831| 0.6741350311317537|0.7634411717429604|6.369550253303853|0.3032188594470522|
+------------------+-------------------+-------------------+------------------+-----------------+------------------+
only showing top 3 rows



In [69]:
%load_ext autoreload
%autoreload 2

from src.core.store import Store
from src.core.feature import Feature

store = Store('store_config.json')
f = Feature('foo_scaler')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
== Store Initialized: ==
{'store_name': "Kai's Feature Store", 'root_dir': '/Users/kai/repository/nebula/storage', 'book_keeper': {'type': 'default', 'params': {'folder_name': 'catalog', 'file_name': 'catalog.nbl'}}, 'writers': {'default': {'folder_name': 'features'}}, 'readers': {'default': {'folder_name': 'features'}}, 'serializers': {'default': {}}, 'deserializers': {'default': {}}}


In [70]:
store.register(f, Linear_Scaler)

In [71]:
store.list_features()

foo_scaler 	 046cc688-6857-4c3b-b50c-443db103eae1


In [77]:
p = store.checkout('046cc688-6857-4c3b-b50c-443db103eae1')

In [78]:
p({'context':sqlcontext,'inputCol':'X1','outputCol':'F1','alpha':2.0}).show(3)

+------------------+-------------------+-------------------+------------------+-----------------+------------------+
|                X1|                 X2|                 X3|                X4|            Label|                F1|
+------------------+-------------------+-------------------+------------------+-----------------+------------------+
|0.5205162467747544|0.37892804297373994|0.24542089328813999|0.4846753984143245|3.532261850435073|1.0410324935495088|
|  0.71781738252966| 0.8611885956954083|0.19367754242737611| 0.623370484839035| 5.89111471753474|  1.43563476505932|
|0.1516094297235261| 0.9886550399625831| 0.6741350311317537|0.7634411717429604|6.369550253303853|0.3032188594470522|
+------------------+-------------------+-------------------+------------------+-----------------+------------------+
only showing top 3 rows



In [13]:
store.config['deserializer']['type']

'default'