In [6]:
import numpy as np
import pandas as pd
from pyspark.sql import SparkSession

##SPLITS PARTICLES INTO PARTITIONS
class FireflyAlgorithm:
    def __init__(self, n_fireflies=56, max_iter=20, alpha=0.3, beta0=1, gamma=0.04):
        self.n_fireflies = n_fireflies
        self.max_iter = max_iter
        self.alpha = alpha
        self.beta0 = beta0
        self.gamma = gamma
        self.lb = 0 
        self.ub = 100
        self.centroids = {}
        self.points = []

    def objective_function(self, x):
        return np.sum(np.linalg.norm(self.points-x, axis = 1))

    def find_center(self, fireflies):
        global_best = open("global_best", "w+")
        #initialize fireflies
        fireflies = list(map(lambda firefly: list(firefly), fireflies))
        n_fireflies = len(fireflies)
        #for i in range (n_fireflies):
            #print (f"Firefly {i} at: {fireflies[i]}")
        dim = len(fireflies[0])
        
        fitness = np.apply_along_axis(self.objective_function, 1, fireflies)
        
        
        #set arbitrary global best
        best_firefly = fireflies[0]
        best_fitness = fitness[0]
        
        for k in range(self.max_iter):
            k_alpha = self.alpha * (1-k/self.max_iter) # decreases alpha over time
           
            for i in range(n_fireflies):
                for j in range(n_fireflies):
                    ##Here check file for new best
                    lines = global_best.readlines()
                    if len(lines)>0:
                        lines = [0].split(',')
                        firefly = lines[:-1]
                        fitness = lines[-1]
                    if fitness < best_fitness:
                        best_fitness = fitness
                        best_firefly = firefly
                        fireflies[i] = firefly
                    if fitness[j] < fitness[i]:
                        #move firefly
                        r = np.linalg.norm(np.subtract(fireflies[i], fireflies[j])) #distance
                        beta = self.beta0 * np.exp(-self.gamma * r**2) #attractiveness
                        random_factor = k_alpha * (np.random.rand(dim) - 0.5) #randomness
                        #moves firefly based on equation 
                        fireflies[i] += beta * (np.subtract(fireflies[j],fireflies[i])) + random_factor
                        fireflies[i] = np.clip(fireflies[i], self.lb, self.ub) # keeps new loc within range

                        #update fitness
                        fitness[i] = self.objective_function(fireflies[i])
                        #update new best
                    
                        if fitness[i] < best_fitness:
                            #update global best
                            best_firefly = fireflies[i]
                            best_fitness = fitness[i]
                            
                            #update file
                            lines = global_best.readlines()
                            lines[0] = f"{best_firefly},{best_fitness}"
                            print(f"writing: {lines}")
                            file.writelines(lines)
        return best_firefly, best_fitness

    #returns string of classification
    def classify(self, row):
        distances = {}
        for key, points in self.centroids.items():
            coord = np.array(row)
            distances[key]= np.linalg.norm(points-coord)
        cls = min(distances, key = distances.get)
        return cls
    

    def run(self, file_name):
        # Create a SparkSession
        spark = SparkSession.builder \
            .appName("Firefly Algorithm with Spark") \
            .getOrCreate()

        sc = spark.sparkContext
        num_cores = sc.defaultParallelism  #Determine the number of available cores
        self.n_fireflies = max(self.n_fireflies, num_cores) 
        

        # Read the dataset from CSV file into a Spark DataFrame
        df = spark.read.csv(file_name, header=True, inferSchema=True)
        dim = len(df.columns)-1
        
        #split into training and test data
        train, test = df.randomSplit([0.8, 0.2], seed=12345)

        #initialize fireflies
        fireflies = np.random.uniform(self.lb, self.ub, (self.n_fireflies, dim))
        fireflies_rdd = sc.parallelize(fireflies, numSlices=num_cores)
        class_column = train.columns[-1]
        classes = train.select(class_column).distinct().collect()
        
        
        for cls in classes:
            cls = cls[class_column]
            data = train.filter(df[class_column] == cls).drop(class_column).collect()
            
            points = []
            for row in data:
                points.append(list(row))
            self.points = points
            center = fireflies_rdd.mapPartitions(lambda fireflies: [self.find_center(fireflies)]).collect()
            #clean appearance
            center = list(map(lambda point: list(point), center))
            
            #TODO replace code with collect?
            best_centroid = center[0][0]
            best_fitness = center[0][1]
            
            for centroid, fitness in center:
                if fitness < best_fitness:
                    best_fitness = fitness
                    best_centroid = centroid
            self.centroids[cls] = best_centroid
            print (f"Center of class {cls}: {center}")
    
        #test
        accuracy = 0
        count = 0
        for row in test.collect():
            row = list(row)
            cls = self.classify(row[:-1])
            if cls == row[-1]:
                accuracy +=1
            count +=1
        print("Accuracy: ", accuracy/count)
        # Stop the SparkSession
        spark.stop()

if __name__ == "__main__":
    fa = FireflyAlgorithm()
    fa.run("4Cluster2D.csv")

24/06/06 09:55:53 ERROR Executor: Exception in task 1.0 in stage 41.0 (TID 51)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/dahlink/anaconda3/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1247, in main
    process()
  File "/Users/dahlink/anaconda3/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1237, in process
    out_iter = func(split_index, iterator)
  File "/Users/dahlink/anaconda3/lib/python3.10/site-packages/pyspark/rdd.py", line 840, in func
    return f(iterator)
  File "/var/folders/gt/yzvtk3y13fb9m4fm78tk0n2dzbvzg_/T/ipykernel_9815/4076006303.py", line 120, in <lambda>
  File "/var/folders/gt/yzvtk3y13fb9m4fm78tk0n2dzbvzg_/T/ipykernel_9815/4076006303.py", line 48, in find_center
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handle

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 41.0 failed 1 times, most recent failure: Lost task 1.0 in stage 41.0 (TID 51) (dyn134-200.wireless-1725.ndsu.nodak.edu executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/dahlink/anaconda3/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1247, in main
    process()
  File "/Users/dahlink/anaconda3/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1237, in process
    out_iter = func(split_index, iterator)
  File "/Users/dahlink/anaconda3/lib/python3.10/site-packages/pyspark/rdd.py", line 840, in func
    return f(iterator)
  File "/var/folders/gt/yzvtk3y13fb9m4fm78tk0n2dzbvzg_/T/ipykernel_9815/4076006303.py", line 120, in <lambda>
  File "/var/folders/gt/yzvtk3y13fb9m4fm78tk0n2dzbvzg_/T/ipykernel_9815/4076006303.py", line 48, in find_center
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1049)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2438)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
	at java.base/java.lang.Thread.run(Thread.java:832)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2463)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1049)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1048)
	at org.apache.spark.api.python.PythonRDD$.collectAndServe(PythonRDD.scala:195)
	at org.apache.spark.api.python.PythonRDD.collectAndServe(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:564)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:832)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/dahlink/anaconda3/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1247, in main
    process()
  File "/Users/dahlink/anaconda3/lib/python3.10/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1237, in process
    out_iter = func(split_index, iterator)
  File "/Users/dahlink/anaconda3/lib/python3.10/site-packages/pyspark/rdd.py", line 840, in func
    return f(iterator)
  File "/var/folders/gt/yzvtk3y13fb9m4fm78tk0n2dzbvzg_/T/ipykernel_9815/4076006303.py", line 120, in <lambda>
  File "/var/folders/gt/yzvtk3y13fb9m4fm78tk0n2dzbvzg_/T/ipykernel_9815/4076006303.py", line 48, in find_center
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.rdd.RDD.$anonfun$collect$2(RDD.scala:1049)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2438)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
	... 1 more
