In [11]:
package_jar = '../target/spark-data-repair-plugin_2.12_spark3.2_0.1.0-EXPERIMENTAL-with-dependencies.jar'

In [12]:
import numpy as np
import pandas as pd
from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql import functions as f

spark = SparkSession.builder \
    .config('spark.jars', package_jar) \
    .config('spark.deriver.memory', '8g') \
    .enableHiveSupport() \
    .getOrCreate()

# Suppresses user warinig messages in Python
import warnings
warnings.simplefilter("ignore", UserWarning)

# Suppresses `WARN` messages in JVM
spark.sparkContext.setLogLevel("ERROR")

In [13]:
from repair.api import Scavenger
Scavenger().version()

'0.1.0-spark3.2-EXPERIMENTAL'

In [14]:
spark.read.option("header", True).csv("../testdata/hospital.csv").createOrReplaceTempView("hospital")
spark.table('hospital').printSchema()

root
 |-- tid: string (nullable = true)
 |-- ProviderNumber: string (nullable = true)
 |-- HospitalName: string (nullable = true)
 |-- Address1: string (nullable = true)
 |-- Address2: string (nullable = true)
 |-- Address3: string (nullable = true)
 |-- City: string (nullable = true)
 |-- State: string (nullable = true)
 |-- ZipCode: string (nullable = true)
 |-- CountyName: string (nullable = true)
 |-- PhoneNumber: string (nullable = true)
 |-- HospitalType: string (nullable = true)
 |-- HospitalOwner: string (nullable = true)
 |-- EmergencyService: string (nullable = true)
 |-- Condition: string (nullable = true)
 |-- MeasureCode: string (nullable = true)
 |-- MeasureName: string (nullable = true)
 |-- Score: string (nullable = true)
 |-- Sample: string (nullable = true)
 |-- Stateavg: string (nullable = true)



In [15]:
import altair as alt

charts = []
pdf = spark.table('hospital').toPandas()

for c in [c for c in pdf.columns if c != 'tid']:
    charts.append(alt.Chart(pdf).mark_bar().encode(x=alt.X(c), y=alt.Y('count()', axis=alt.Axis(title='freq'))).properties(width=300, height=300))

alt.hconcat(*charts)

In [16]:
spark.read.option("header", True).csv("../bin/testdata/hospital_error_cells.csv").createOrReplaceTempView("hospital_error_cells")
spark.table('hospital_error_cells').groupBy('attribute').count().show()

+----------------+-----+
|       attribute|count|
+----------------+-----+
|     MeasureCode|   29|
|  ProviderNumber|   28|
|    HospitalName|   24|
|      CountyName|   39|
|           State|   26|
|   HospitalOwner|   27|
|          Sample|   31|
|     PhoneNumber|   34|
|     MeasureName|   36|
|         ZipCode|   30|
|           Score|   23|
|    HospitalType|   32|
|EmergencyService|   27|
|       Condition|   32|
|        Stateavg|   27|
|        Address1|   31|
|            City|   33|
+----------------+-----+



In [17]:
from repair.model import RepairModel
model = RepairModel().setTableName('hospital').setRowId('tid').setDiscreteThreshold(100) 
target_columns = ['ProviderNumber', 'HospitalName', 'Address1', 'City', 'State', 'ZipCode', 'CountyName', 'PhoneNumber', 'HospitalType', 'HospitalOwner', 'EmergencyService', 'Condition', 'MeasureCode', 'MeasureName', 'Score', 'Stateavg']
error_cells_df = spark.table('hospital_error_cells')
repair_base_df = model._prepare_repair_base_cells('hospital', error_cells_df, target_columns, 1000, 20)

In [18]:
import altair as alt

charts = []
pdf = repair_base_df.toPandas()
cols = ['ProviderNumber', 'HospitalName', 'Address1', 'City', 'State', 'ZipCode', 'CountyName', 'PhoneNumber', 'HospitalType', 'HospitalOwner', 'EmergencyService', 'Condition', 'MeasureCode', 'MeasureName', 'Score', 'Sample', 'Stateavg']

for c in [c for c in pdf.columns if c != 'tid']:
    charts.append(alt.Chart(pdf).mark_bar().encode(x=alt.X(c), y=alt.Y('count()', axis=alt.Axis(title='freq'))).properties(width=300, height=300))

alt.hconcat(*charts)

In [19]:
target = 'ZipCode'

In [20]:
pdf = repair_base_df.toPandas()

X_test = (pdf[pdf[target].isna()]).drop([target], axis=1).reset_index(drop=True)

pdf = pdf[pdf[target].notna()]
X = pdf.drop(['tid', target], axis=1).reset_index(drop=True)
y = pdf[target].reset_index(drop=True)

In [21]:
import category_encoders as ce
se = ce.OrdinalEncoder(handle_unknown='impute')
X = se.fit_transform(X)
_X_test = se.transform(X_test[X.columns]).copy(deep=True)
X_test = pd.concat([X_test[['tid']], _X_test], axis=1)

In [22]:
import altair as alt

cols = ['ProviderNumber', 'HospitalName', 'Address1', 'City', 'State', 'ZipCode', 'CountyName', 'PhoneNumber', 'HospitalType', 'HospitalOwner', 'EmergencyService', 'Condition', 'MeasureCode', 'MeasureName', 'Score', 'Sample', 'Stateavg']

_y = y.replace(dict(map(lambda v: (v[1], v[0]), enumerate(y.unique()))))
pdf = pd.concat([X, _y], axis=1)

alt.Chart(pdf).mark_circle().encode(
    alt.X(alt.repeat("column"), type='quantitative'),
    alt.Y(alt.repeat("row"), type='quantitative'),
    color=f'{target}:N'
).properties(width=200, height=200).repeat(row=cols, column=cols)

In [23]:
from minepy import MINE

results = []

cols = ['ProviderNumber', 'HospitalName', 'Address1', 'City', 'State', 'ZipCode', 'CountyName', 'PhoneNumber', 'HospitalType', 'HospitalOwner', 'EmergencyService', 'Condition', 'MeasureCode', 'MeasureName', 'Score', 'Sample', 'Stateavg']
cols.remove(target)

mine = MINE(alpha=0.6, c=15, est="mic_approx")

import itertools
for c1, c2 in itertools.combinations(cols, 2):
    mine.compute_score(X[c1], X[c2])
    results.append(((c1, c2), mine.mic()))

print(sorted(results, key=lambda x: x[1], reverse=True)[0:3])

[(('ProviderNumber', 'Address1'), 0.9781291797686797), (('ProviderNumber', 'HospitalName'), 0.9772579845385087), (('HospitalName', 'PhoneNumber'), 0.9768402151203731)]


In [24]:
from minepy import MINE

results = []

_y = y.replace(dict(map(lambda v: (v[1], v[0]), enumerate(y.unique()))))
cols = ['ProviderNumber', 'HospitalName', 'Address1', 'City', 'State', 'ZipCode', 'CountyName', 'PhoneNumber', 'HospitalType', 'HospitalOwner', 'EmergencyService', 'Condition', 'MeasureCode', 'MeasureName', 'Score', 'Sample', 'Stateavg']

mine = MINE(alpha=0.6, c=15, est="mic_approx")

for c in [c for c in cols if c != target]:
    mine.compute_score(_y, X[c])
    results.append(((target, c), mine.mic()))

print(sorted(results, key=lambda x: x[1], reverse=True)[0:3])

[(('ZipCode', 'ProviderNumber'), 0.9998602850914529), (('ZipCode', 'HospitalName'), 0.9998602850914529), (('ZipCode', 'Address1'), 0.9997515888915909)]


In [25]:
from sklearn.ensemble import RandomForestClassifier
from boruta import BorutaPy

# RandomForestClassifier cannot handle NaN correctly
_X = X.fillna(-255.0)

rf = RandomForestClassifier(n_jobs=-1, max_depth=5)
rf.fit(_X, y)
print('SCORE with ALL Features: %1.2f' % rf.score(_X, y))

rf = RandomForestClassifier(n_jobs=-1, max_depth=5)
fs = BorutaPy(rf, n_estimators='auto', random_state=0, perc=80, two_step=False, max_iter=500)
fs.fit(_X.values, y.values)

selected = fs.support_
print('Selected Features: %s' % ','.join(_X.columns[selected]))

X_selected = _X[_X.columns[selected]]
rf = RandomForestClassifier(n_jobs=-1, max_depth=5)
rf.fit(X_selected, y)
print('SCORE with selected Features: %1.2f' % rf.score(X_selected, y))

SCORE with ALL Features: 0.93
Selected Features: ProviderNumber,HospitalName,Address1,City,State,CountyName,PhoneNumber,HospitalOwner,EmergencyService,Score,Sample,Stateavg
SCORE with selected Features: 0.95


In [26]:
# One of non-linear embedding in sklearn
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=0, perplexity=300, n_iter=500)
_X = X_selected.dropna()
_X = tsne.fit_transform(_X)
print('KL divergence: {}'.format(tsne.kl_divergence_))

import altair as alt
_X = pd.DataFrame({'tSNE-X': _X[:, 0], 'tSNE-Y': _X[:, 1], target: y})
alt.Chart(_X).mark_point().encode(x='tSNE-X', y='tSNE-Y', color=f'{target}:N').properties(width=600, height=400).interactive()

KL divergence: 0.1684424877166748


In [27]:
from repair import train
params = {'hp.timeout': '3600', 'hp.no_progress_loss': '30'}
(clf, score), _ = train.build_model(X[X.columns[selected]], y, is_discrete=True, num_class=len(y.unique()), n_jobs=-1, opts=params)
print(f'Score: {score}')

# import lightgbm as lgb
# obj = 'multiclass' if len(y.unique()) > 2 else 'binary'
# clf = lgb.LGBMClassifier(objective=obj, num_leaves=64, min_child_samples=20, max_depth=7)
# clf.fit(X[X.columns[selected]], y)

import json
top_k = 3
probs = clf.predict_proba(X_test[X.columns[selected]])
pmf = map(lambda p: {"classes": clf.classes_.tolist(), "probs": p.tolist()}, probs)
pmf = map(lambda p: json.dumps(p), pmf)
df = spark.createDataFrame(pd.DataFrame({'tid': X_test['tid'], 'pmf': pd.Series(list(pmf))}))
df = df.selectExpr('tid', 'from_json(pmf, "classes array<string>, probs array<double>") pmf')
df = df.selectExpr('tid', 'arrays_zip(pmf.classes, pmf.probs) pmf')
df = df.selectExpr('tid', f'slice(array_sort(pmf, (left, right) -> if(left.`1` < right.`1`, 1, -1)), 1, {top_k}) top_k_pmf')
df = df.selectExpr('tid', f'top_k_pmf[0].`0` `{target}`', 'top_k_pmf')
predicted = df.toPandas()

Score: 0.98989898989899


                                                                                

In [28]:
spark.read.option("header", True).csv("../testdata/hospital_clean.csv").createOrReplaceTempView("hospital_clean")
pdf_clean = spark.table('hospital_clean').where(f'attribute = "{target}"').selectExpr('tid', 'correct_val').toPandas()
result = pd.merge(predicted, pdf_clean, on='tid')
result['is_correct'] = result[target] == result['correct_val']
pd.set_option("display.max_colwidth", 300)
result

Unnamed: 0,tid,ZipCode,top_k_pmf,correct_val,is_correct
0,44,35957,"[(35957, 0.9643642190152275), (35976, 0.0037511336570521052), (35631, 0.0019414713540415861)]",35957,True
1,63,35957,"[(35957, 0.958682814509401), (35976, 0.005340659019176744), (35631, 0.0022468494683604567)]",35957,True
2,70,35631,"[(35631, 0.9717431037727254), (35150, 0.0027634746138696944), (35205, 0.0015281527144411398)]",35631,True
3,93,35631,"[(35631, 0.9303228849611292), (36854, 0.004438165021521895), (35555, 0.0030274421083783926)]",35631,True
4,137,36049,"[(36049, 0.9814816339887623), (99559, 0.0012441618678298658), (36784, 0.0009656955490383377)]",36049,True
5,139,36049,"[(36049, 0.9815205895250855), (99559, 0.00124421124928659), (36784, 0.0009707648218641633)]",36049,True
6,149,35640,"[(35640, 0.9718736117531187), (35609, 0.0012987548229854943), (36201, 0.0011960965557767534)]",35640,True
7,157,35640,"[(35640, 0.9656177108168658), (35968, 0.0027037542700482854), (35609, 0.0014407535007517661)]",35640,True
8,196,35235,"[(35235, 0.926993149109579), (35233, 0.014271018353715867), (35205, 0.012640523308065778)]",35235,True
9,232,35968,"[(35968, 0.9785267950693758), (35007, 0.003499732991825204), (36116, 0.001141813077218991)]",35968,True


In [29]:
print('Accuracy: {}'.format(len(result[result['is_correct']]) / len(result)))

Accuracy: 1.0
