In [1]:
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler


def objective_function(model, X, y):
    pred = (np.dot(X, model[:-1]) + model[-1] >= 0).astype(int)
    mse = np.mean(np.subtract(y, pred) ** 2)
    return mse

def bat(X, y):
    #set params
    dim = X.shape[1]+1
    num_bats = 70
    num_gens = 100
    Lbound = -5
    Ubound = 5
    Qmin = 0
    Qmax = (Ubound-Lbound)/num_bats
    pulse_rate = 0.1
    loudness=0.9
    
    positions = np.random.uniform(Lbound, Ubound, (num_bats, dim))
    velocities = np.zeros((num_bats, dim))
    fitness = np.apply_along_axis(objective_function, 1, positions, X, y)
    
    gbest_position = positions[np.argmin(fitness)]
    gbest_fitness = np.min(fitness)

    for iteration in range(num_gens):
        for i in range(num_bats):
            #generate new solutions by adjusting frequency
            freq = Qmin + (Qmax - Qmin) * np.random.rand()
            velocities[i] += (positions[i] - gbest_position) * freq
            new_position = positions[i] + velocities[i]

            
            if np.random.rand() < pulse_rate:
                #generate a location solution around best solution
                new_position = gbest_position + 0.001 * np.random.randn(dim)
             
            new_fitness = objective_function(new_position, X, y)
            
            #if new fitness < fitness[i] and rand < loudness
            if new_fitness < fitness[i] and np.random.rand() < loudness:
                #accept new solutions
                positions[i] = new_position
                fitness[i] = new_fitness
            #update new global bests
            if new_fitness < gbest_fitness:
                gbest_position = new_position.copy()
                gbest_fitness = new_fitness

    return gbest_position

def predict(model, X):
    pred = (np.dot(X, model[:-1]) + model[-1] >= 0).astype(int)
    return pred

def run(file_name):
    spark = SparkSession.builder \
            .appName("Bat Algorithm with Spark") \
            .getOrCreate()
    sc = spark.sparkContext

    #read data
    df = spark.read.csv(file_name, header=True, inferSchema=True)
    X = np.array(df.select(df.columns[:-1]).collect())
    y = np.array(df.select(df.columns[-1]).collect()).flatten()
   
    y = LabelEncoder().fit_transform(y)  #transform y values to ints
    X = StandardScaler().fit_transform(X) #scale X values
   
 
    #Create an RDD of (feature, label) pairs
    data_rdd = sc.parallelize(list(zip(X, y)))

    #Firefly algorithm applied to partitions
    def bat_partition(partition):
        partition_list = list(partition)
        if len(partition_list) == 0:
            return []
        X_partition, y_partition = zip(*partition_list)
        X_partition = np.array(X_partition)
        y_partition = np.array(y_partition)
        return [bat(X_partition, y_partition)]

    # Apply Firefly algorithm to each partition and collect results
    weights = data_rdd.mapPartitions(bat_partition).collect()
    model = [sum(x) / len(weights) for x in zip(*weights)]
    
    y_pred = predict(model,X)
    accuracy = accuracy_score(y, y_pred)
    mse = np.mean(np.subtract(y,y_pred)**2)
    print(f'Accuracy: {accuracy * 100:.2f}%')
    
    spark.stop()

if __name__ == "__main__":
    run('Behavior.csv')


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/21 11:33:58 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/06/21 11:34:19 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 4:>                                                          (0 + 4) / 4]

Accuracy: 96.20%


                                                                                