## Minibatch k-means

In [2]:
from sklearn.datasets import fetch_kddcup99

In [3]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import col, udf, array, min as smin, lit, count, isnan, when, sum as ssum
from pyspark.sql.types import IntegerType, FloatType
from pyspark.ml.feature import MinMaxScaler, VectorAssembler

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [4]:
spark = SparkSession.builder \
    .master("spark://spark-master:7077") \
    .appName("k-meaner") \
    .config("spark.executor.memory", "1024m") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .config("spark.sql.execution.arrow.pyspark.fallback.enabled", "false") \
    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("ERROR")

#clear old data if rerunning
spark.catalog.clearCache() 
for (id, rdd) in sc._jsc.getPersistentRDDs().items():
    rdd.unpersist()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/07/11 09:22:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


### Data preprocessing

In [5]:
kdd = spark.createDataFrame(fetch_kddcup99(as_frame=True)["frame"])
kdd = kdd.drop("protocol_type")
kdd = kdd.drop("service") 
kdd = kdd.drop("flag")
kdd.createOrReplaceTempView("kdd_table")

for c in kdd.columns:
    try:
        limits = spark.sql("SELECT min(" + c + "), max(" + c + ") FROM kdd_table").collect()
        mn, mx = limits[0][0], limits[0][1]
        
        if mn==mx:
            continue
        
        
        kdd = kdd.withColumn(c, (col(c)-mn)/(mx-mn) )
    except Exception as e:
        print(e)

kdd = kdd.persist()

  [(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
                                                                                

unsupported operand type(s) for -: 'bytearray' and 'bytearray'


### Clusters

In [11]:
res = spark.sql("SELECT labels, count(1) FROM kdd_table GROUP BY labels").collect()
attacks = np.array([r[0] for r in res])
counts = np.array([r[1] for r in res])
attack_sort = np.argsort(counts)
attacks = attacks[attack_sort]
counts = counts[attack_sort]

  attacks = np.array([r[0] for r in res])


### Lloyd algorithm

In [12]:
index_udf = udf(lambda row:attacks.tolist().index(row) )
kdd = kdd.withColumn("labels", index_udf(col("labels")).cast("int"))

ignored_cols = ["labels", "centr", "cost", "p"]

kdd = kdd.withColumn("centr", lit(-1))
kdd = kdd.withColumn("cost", lit(0.))
kdd = kdd.withColumn("p", lit(0.))
ncols = len(kdd.columns) - len(ignored_cols)

In [7]:
def dist(x,y):
    try:
        #return (((np.array(x)-y)**2).sum())**0.5
        return np.linalg.norm(np.array(x)-np.array(y))
    except Exception as e:
        print(e)
        return ncols

def argcomp(comp,func, arr, *params):
    res = [func(arr[i], *params) for i in range(len(arr))]
    #print(res)
    return res.index(comp(res))

#udist=udf(dist, FloatType())

#distance_udf = udf(lambda x,y:  np.linalg.norm(x-y), FloatType())

In [8]:
###################################### actual k-means 
def kmeans(data, centers, max_iter=10, weighted=False, local_centr=False):
    count = 0
    if weighted and "w" not in ignored_cols:
        ignored_cols.append("w")
    while True:
        if local_centr:
            cc = centers
            local_centr = False
        else:
            cc = centers.collect()
        #print(cc)
        
        argmindist_udf = udf(lambda row: argcomp(min,dist,cc,row), IntegerType())
        
        data = data.withColumn("centr", argmindist_udf(array([c for c in data.columns if c not in ignored_cols])))
        
        if not weighted:
            newcenters = data.groupBy("centr").mean().select(*[col("avg("+c+")") for c in data.columns if c not in ignored_cols])
        else:
            #newcenters = data.groupBy("centr").agg(ssum(array([c for c in data.columns if c not in ignored_cols]) * col("w"))/ssum(col("w")))
            newcenters = data.rdd.map(
                lambda x: (x["centr"],(x[[c for c in data.columns if c not in ignored_cols]],x["w"]))
            ).reduceByKey(
                lambda x,y: ((x[0]*x[1]+y[0]*y[1])/x[1]+y[1], x[1]+y[1])
            ).map(lambda x: x[1][0]).collect()
        
        centers = newcenters
        
        count+=1
        if count>max_iter:#(newcenters-centers).mean() < 0.01:
            break
            
    
    return data, centers

### K-means mini batch
Since the idea is to never directly look at the whole dataset we will draw a random sample from the data to act as the starting centers. We will then compute new center using k means on a mini batch of the original data, then the new centers for the following iteration will be computed as old_c+l.rate * new_c until convergence or max iterations

In [18]:
def mini_b(data,batch_size,l_rate,max_it,n_cent):
    #first centroids initialization
    initialCentroids = data.sample(batch_size).limit(n_cent).select(*[col(c) for c in data.columns if c not in ignored_cols])
    initialCentroids = np.array(initialCentroids.collect())
    
    #centroids update
    
    for _ in range(max_it):
        miniBatch = data.sample(batch_size)
        
        _,newCentroids = kmeans(miniBatch,initialCentroids, local_centr=True, max_iter=2)
        
        newCentroids = np.array(newCentroids.collect())
        
        #sort centroids to minimize reciprocal distance by finding the optimal permutation on newCentroids
        
        distances = np.linalg.norm(initialCentroids[:, np.newaxis] - newCentroids, axis=2)
        permutation = np.argmin(np.sum(distances, axis=0))
        newCentroids = newCentroids[permutation]
        
        initialCentroids = (1-l_rate) * initialCentroids + newCentroids * l_rate

    return initialCentroids

In [20]:
newcenters = mini_b(kdd,0.2,0.05,10,len(attacks))

                                                                                

In [21]:
newcenters

array([[2.12813540e-06, 3.88276581e-07, 6.35528231e-04, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 5.74969954e-05, 0.00000000e+00,
        9.99952876e-01, 1.13271891e-06, 1.80429783e-04, 0.00000000e+00,
        0.00000000e+00, 1.65407906e-05, 0.00000000e+00, 6.32633936e-05,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.47430719e-02,
        1.58971684e-02, 3.34553619e-04, 3.47966401e-04, 1.86421510e-04,
        1.96102369e-04, 9.99753859e-01, 4.92281411e-04, 2.60241204e-03,
        1.73229664e-01, 5.06785918e-01, 9.97211558e-01, 1.58312026e-03,
        2.41876338e-02, 1.38478492e-02, 4.76740852e-04, 4.07541513e-04,
        5.61560255e-04, 3.60326099e-04],
       [2.12813540e-06, 3.83959030e-07, 6.24843692e-04, 0.00000000e+00,
        0.00000000e+00, 0.00000000e+00, 5.74969954e-05, 0.00000000e+00,
        9.99952876e-01, 1.13271891e-06, 1.80429783e-04, 0.00000000e+00,
        0.00000000e+00, 1.65407906e-05, 0.00000000e+00, 6.32633936e-05,
        0.00000000e+00,