In [1]:
import numpy as np
import pandas as pd

from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql import functions as f

spark = SparkSession.builder \
    .config('spark.deriver.memory', '8g') \
    .enableHiveSupport() \
    .getOrCreate()

NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark classes ahead of assembly.
22/01/07 17:00:12 WARN Utils: Your hostname, maropus-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.3.3 instead (on interface en0)
22/01/07 17:00:12 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/01/07 17:00:12 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
def display(pdf):
    import altair as alt
    return alt.Chart(pdf).mark_circle().encode(
        alt.X("X", type='quantitative'),
        alt.Y("Y", type='quantitative'))

def display_with_class(pdf):
    import altair as alt
    return alt.Chart(pdf).mark_circle().encode(
        alt.X("X", type='quantitative'),
        alt.Y("Y", type='quantitative'),
        color='c:N')

In [3]:
import numpy as np
import pandas as pd

num_data = 100000

mean = np.array([20.0, 40.0])
cov = np.array([[30.0, 10.0], [10.0, 100.0]])
data = np.random.multivariate_normal(mean, cov, size=num_data)
orig_df = spark.createDataFrame(pd.DataFrame(data, columns=['X', 'Y'])).withColumn('uid', f.expr('monotonically_increasing_id()')).selectExpr('uid', 'X', 'Y')
orig_df.show()

[Stage 0:>                                                          (0 + 1) / 1]

+---+------------------+------------------+
|uid|                 X|                 Y|
+---+------------------+------------------+
|  0|17.543388196101606| 32.50755265068612|
|  1|13.073828699686494| 41.89754991721496|
|  2| 21.74682936358031|29.745709194488285|
|  3|28.921057167736773|53.387357263234115|
|  4|15.778109509327795| 43.77650218791468|
|  5|16.979443829731753| 44.52423933921543|
|  6|19.230007902133103| 31.23080272655929|
|  7|20.095231856946945| 32.89480827058366|
|  8|16.794984085093102|36.376262085817636|
|  9|23.179199541060928|40.261420510591144|
| 10|18.303158090764246|43.596051469265824|
| 11| 13.67300055729404| 36.91341691868316|
| 12|21.287650476524387| 51.87615285319929|
| 13| 24.25560643628013| 51.32024220202169|
| 14|  18.7382341643601| 47.71909548564216|
| 15|23.224830889754585| 53.30655756954307|
| 16| 16.57116461598929| 55.73341364362252|
| 17|11.526821153235327| 48.05389702459854|
| 18|23.535779015301483|28.285447001803476|
| 19| 17.70215424593329|34.59978

                                                                                

In [6]:
display(orig_df.sample(fraction=float(1000 / num_data)).toPandas())

In [34]:
from sklearn.cluster import AgglomerativeClustering

num_cluster = 12
num_train = 4000

X = orig_df.sample(fraction=float(num_train / num_data)).selectExpr('X', 'Y').toPandas()
cls = AgglomerativeClustering(n_clusters=num_cluster).fit(X)

pdf = X.copy(deep=True)
pdf['c'] = cls.labels_
pdf['c'].value_counts()

0     788
6     618
1     423
4     423
2     419
8     416
3     215
7     214
10    180
11    149
5     104
9      68
Name: c, dtype: int64

In [12]:
display_with_class(pdf)

In [14]:
from sklearn.neighbors import KNeighborsClassifier
clf = KNeighborsClassifier(n_neighbors=8)
clf.fit(X, cls.labels_)

In [35]:
broadcasted_clf = spark.sparkContext.broadcast(clf)
    
@f.pandas_udf('int')
def assign_group_id(X: pd.Series, Y: pd.Series) -> pd.Series:
    clf = broadcasted_clf.value
    pdf = pd.DataFrame({'X': X, 'Y': Y})
    return pd.Series(clf.predict(pdf))

In [44]:
df = orig_df.withColumn('c', assign_group_id(f.col('X'), f.col('Y')))
df.show()

[Stage 36:>                                                         (0 + 1) / 1]

+---+------------------+------------------+---+
|uid|                 X|                 Y|  c|
+---+------------------+------------------+---+
|  0|17.543388196101606| 32.50755265068612|  8|
|  1|13.073828699686494| 41.89754991721496|  0|
|  2| 21.74682936358031|29.745709194488285|  4|
|  3|28.921057167736773|53.387357263234115|  9|
|  4|15.778109509327795| 43.77650218791468| 10|
|  5|16.979443829731753| 44.52423933921543|  0|
|  6|19.230007902133103| 31.23080272655929|  4|
|  7|20.095231856946945| 32.89480827058366|  8|
|  8|16.794984085093102|36.376262085817636|  8|
|  9|23.179199541060928|40.261420510591144|  5|
| 10|18.303158090764246|43.596051469265824| 10|
| 11| 13.67300055729404| 36.91341691868316|  2|
| 12|21.287650476524387| 51.87615285319929|  6|
| 13| 24.25560643628013| 51.32024220202169|  6|
| 14|  18.7382341643601| 47.71909548564216|  0|
| 15|23.224830889754585| 53.30655756954307|  6|
| 16| 16.57116461598929| 55.73341364362252|  3|
| 17|11.526821153235327| 48.053897024598

                                                                                

In [38]:
display_with_class(df.sample(fraction=float(1000 / num_data)).toPandas())

                                                                                

In [43]:
df.groupBy('c').count().show()

[Stage 33:>                                                         (0 + 8) / 8]

+---+-----+
|  c|count|
+---+-----+
|  1| 7816|
|  6| 9888|
|  3| 3891|
|  5|11548|
|  9| 3190|
|  4|12278|
|  8|12653|
|  7| 3959|
| 10| 7856|
| 11| 4223|
|  2|10790|
|  0|11908|
+---+-----+



                                                                                

In [47]:
@f.pandas_udf('uid long, X double, Y double, c int', f.PandasUDFType.GROUPED_MAP)
def do_something(pdf: pd.DataFrame) -> pd.DataFrame:
    # Do something here...
    return pdf

df.groupBy('c').apply(do_something).show()



+---+-------------------+------------------+---+
|uid|                  X|                 Y|  c|
+---+-------------------+------------------+---+
| 23| 22.348426781881518|21.809423359401084|  1|
| 33|-1.8273907661046636|17.230763238875507|  1|
| 38| 17.917783148279305|23.955451541126543|  1|
| 64| 13.922828655772381| 22.76557479398088|  1|
| 69|  24.89322313556103| 23.61690326855955|  1|
| 98|  19.61168827442075|19.600320948688047|  1|
|142|  16.16319100791148|14.203552332460617|  1|
|175|  16.45040882397412|  20.6239800929125|  1|
|184| 25.423859432519425|21.902722101360347|  1|
|215|  9.844827607909565|25.918887241944567|  1|
|220|  7.630648935424064|25.057776182481796|  1|
|224|  23.56437776588308|24.382588842371085|  1|
|239| 25.602478427836306|21.817590769884664|  1|
|252| 23.119010116656696|25.592709739328143|  1|
|265|  7.978230127137223|26.867342345002946|  1|
|276| 15.346978228744577| 16.20254709184993|  1|
|285|  22.23504955131264| 25.20706683485073|  1|
|294| 19.71654782149

                                                                                