In [2]:
import numpy as np
import pandas as pd

from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql import functions as f

spark = SparkSession.builder \
    .config('spark.deriver.memory', '8g') \
    .enableHiveSupport() \
    .getOrCreate()

In [3]:
def display(pdf):
    import altair as alt
    return alt.Chart(pdf).mark_circle().encode(
        alt.X("X", type='quantitative'),
        alt.Y("Y", type='quantitative'))

def display_with_class(pdf):
    import altair as alt
    return alt.Chart(pdf).mark_circle().encode(
        alt.X("X", type='quantitative'),
        alt.Y("Y", type='quantitative'),
        color='c:N')

In [4]:
import numpy as np
import pandas as pd

num_data = 100000

mean = np.array([20.0, 40.0])
cov = np.array([[30.0, 10.0], [10.0, 100.0]])
data = np.random.multivariate_normal(mean, cov, size=num_data)
orig_df = spark.createDataFrame(pd.DataFrame(data, columns=['X', 'Y'])).withColumn('uid', f.expr('monotonically_increasing_id()')).selectExpr('uid', 'X', 'Y')
orig_df.show()

+---+------------------+------------------+
|uid|                 X|                 Y|
+---+------------------+------------------+
|  0|18.279142948878313|  41.1259355156916|
|  1|25.523427332943342|  43.8754414663049|
|  2|  24.2193153319294|21.444139723258886|
|  3|25.170100188357885| 35.72942970897056|
|  4|25.982318733782904|33.839888666004214|
|  5| 21.24485505406872| 40.09851525121294|
|  6|15.047132144325616|  39.8350182184421|
|  7| 13.28339189058167|28.512979520566482|
|  8| 22.99980931949321| 62.76967886252962|
|  9| 19.53868298751242|29.454117936527247|
| 10|21.593564153394034| 53.86814994876163|
| 11|20.369215762180893|26.111074466178746|
| 12| 8.538434107372675| 38.65465185328122|
| 13|14.147794345737683|31.913829892150922|
| 14|14.833223563853558| 45.91755179502277|
| 15|23.385993455854152| 44.48381163985997|
| 16| 18.53627470264322|  45.9689622161429|
| 17| 25.34510376339181|44.361551991827355|
| 18| 23.51955110122912|18.784575180662156|
| 19| 33.17962276561825| 54.9068

In [5]:
display(orig_df.sample(fraction=float(1000 / num_data)).toPandas())

In [6]:
num_cluster = 12
num_train = 4000

X = orig_df.sample(fraction=float(num_train / num_data)).selectExpr('X', 'Y').toPandas()

In [15]:
from size_constrained_clustering import fcm, equal, minmax, shrinkage
cls = fcm.FCM(num_cluster)
cls.fit(X)
pdf = X.copy(deep=True)
pdf['c'] = cls.labels_
pdf['c'].value_counts()

7     458
5     418
6     399
8     374
0     368
9     341
4     316
3     307
11    306
1     234
10    225
2     200
Name: c, dtype: int64

In [16]:
# from sklearn.cluster import AgglomerativeClustering
# cls = AgglomerativeClustering(n_clusters=num_cluster).fit(X)
# pdf = X.copy(deep=True)
# pdf['c'] = cls.labels_
# pdf['c'].value_counts()

In [17]:
display_with_class(pdf)

In [18]:
from sklearn.neighbors import KNeighborsClassifier
clf = KNeighborsClassifier(n_neighbors=8)
clf.fit(X, cls.labels_)

KNeighborsClassifier(n_neighbors=8)

In [19]:
broadcasted_clf = spark.sparkContext.broadcast(clf)
    
@f.pandas_udf('int')
def assign_group_id(X: pd.Series, Y: pd.Series) -> pd.Series:
    clf = broadcasted_clf.value
    pdf = pd.DataFrame({'X': X, 'Y': Y})
    return pd.Series(clf.predict(pdf))

In [20]:
df = orig_df.withColumn('c', assign_group_id(f.col('X'), f.col('Y')))
df.show()

+---+------------------+------------------+---+
|uid|                 X|                 Y|  c|
+---+------------------+------------------+---+
|  0|18.279142948878313|  41.1259355156916|  0|
|  1|25.523427332943342|  43.8754414663049|  6|
|  2|  24.2193153319294|21.444139723258886|  2|
|  3|25.170100188357885| 35.72942970897056|  4|
|  4|25.982318733782904|33.839888666004214|  4|
|  5| 21.24485505406872| 40.09851525121294|  7|
|  6|15.047132144325616|  39.8350182184421|  0|
|  7| 13.28339189058167|28.512979520566482|  9|
|  8| 22.99980931949321| 62.76967886252962|  1|
|  9| 19.53868298751242|29.454117936527247|  8|
| 10|21.593564153394034| 53.86814994876163|  1|
| 11|20.369215762180893|26.111074466178746|  8|
| 12| 8.538434107372675| 38.65465185328122|  0|
| 13|14.147794345737683|31.913829892150922|  9|
| 14|14.833223563853558| 45.91755179502277|  0|
| 15|23.385993455854152| 44.48381163985997|  6|
| 16| 18.53627470264322|  45.9689622161429|  5|
| 17| 25.34510376339181|44.3615519918273

In [21]:
display_with_class(df.sample(fraction=float(1000 / num_data)).toPandas())

In [22]:
df.groupBy('c').count().show()

+---+-----+
|  c|count|
+---+-----+
|  1| 5717|
|  6| 9978|
|  3| 7602|
|  5|10615|
|  9| 8912|
|  4| 8252|
|  8| 9472|
|  7|11601|
| 10| 5835|
| 11| 8001|
|  2| 4304|
|  0| 9711|
+---+-----+



In [23]:
@f.pandas_udf('uid long, X double, Y double, c int', f.PandasUDFType.GROUPED_MAP)
def do_something(pdf: pd.DataFrame) -> pd.DataFrame:
    # Do something here...
    return pdf

df.groupBy('c').apply(do_something).show()



+---+------------------+------------------+---+
|uid|                 X|                 Y|  c|
+---+------------------+------------------+---+
|  8| 22.99980931949321| 62.76967886252962|  1|
| 10|21.593564153394034| 53.86814994876163|  1|
| 30|23.225525454548723| 55.35163768790345|  1|
| 39|19.129975955141052| 59.06133988353858|  1|
| 80|29.013496332791036|61.580655001784116|  1|
| 90|23.575561427588205| 56.81122856275795|  1|
|147|16.923100560422487| 63.65137128158777|  1|
|155|19.369924272875167| 59.77281994464738|  1|
|164|20.701433928173874| 67.23899581264192|  1|
|172|27.484018898779656| 60.41225952472757|  1|
|198|22.752847946932143| 53.82022213774504|  1|
|201| 20.70926009483112|54.969343552249995|  1|
|227| 22.47895129870433| 67.25165887959957|  1|
|230|22.324978380403103|60.194894405834575|  1|
|237|33.765575676221715|58.885274826951864|  1|
|239|21.123854691251445| 55.03799906081855|  1|
|264|22.677490609754468| 59.76679615961652|  1|
|300|26.198905428868454|55.7175982789180