In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
from pyspark.sql.functions import split, when, lit, row_number, udf, col

import random

from pyspark.sql.window import Window

In [None]:
spark = SparkSession.builder.appName("BeeHive").getOrCreate()

In [None]:
headers = ['remove', 'Bee ID', 'remove1', 'DaughtersEfficiencyScore', 'remove2', 'Father SIZE', 'Father TYPE', 'remove3', 'X', 'Y', 'Z']
doubles = ['DaughtersEfficiencyScore', 'X', 'Y', 'Z']
integers = ['Father SIZE']

def struct_field(header, doubles, integers):
    
    if header in doubles:
        return StructField(header, DoubleType())
    
    if header in integers:
        return StructField(header, IntegerType())
    
    return StructField(header, StringType())

fields = [struct_field(header, doubles, integers) for header in headers]
schema = StructType(fields)

In [None]:
df = spark.read.schema(schema).csv('/Users/daniel/dev/duds/pep-data/src/jupyter/BeeHiveTestData.csv')
df.show(5)

In [None]:
cols_to_keep = [x for x in df.columns if 'remove' not in x]
df = df.select(*cols_to_keep)

df.show(5)

In [None]:
df_cleaned = df
df_cleaned = df_cleaned.withColumn('Cycle', split(df_cleaned['Bee ID'], '_').getItem(0))\
                .withColumn('Cycle ID', split(df_cleaned['Bee ID'], '_').getItem(1))

df_cleaned.show(5)

In [None]:
# cycles = sorted([i[0] for i in df_cleaned.select('Cycle').distinct().collect()])
# print(cycles)
#
# def father_cycle(cycle):
#     cycle_index = cycles.index(cycle)
#     n=3
#     father_cyc = None
#
#     if cycle_index == 0:
#         return father_cyc
#
#     if cycle_index > n:
#         father_cyc = random.randint(cycle_index-n, cycle_index-1)
#
#     elif cycle_index <= n:
#         father_cyc = random.randint(0, cycle_index-1)
#
#     return cycles[father_cyc]

In [None]:
# from pyspark.sql.functions import udf, col
#
# convertUDF = udf(lambda z: father_cycle(z))
#
# df_cleaned.withColumn("ParentCycle", convertUDF(col('Cycle'))).show()

In [None]:
w = Window().orderBy('Cycle')
df_cleaned = df_cleaned.withColumn('Continuous ID', row_number().over(w))

df_cleaned.show()

In [None]:
#TODO cycles are not orderd ( order as strings not ints), lower cycles have higher Continuous ID (cycle 2 have higher Continuous ID then 10)
continuous_min_id_per_cycle = {key : value for key, value  in df_cleaned.groupBy('Cycle').min('Continuous ID').collect()}

continuous_min_id_per_cycle

In [None]:
# add int(i[0]) so cycles will be sorted as integers and not string
cycles = sorted([i[0] for i in df_cleaned.select('Cycle').distinct().collect()])

cycles

In [None]:
#change func according to changes above
def assert_parent_bee_id(cycle):
    n = 3
    cycle_index = cycles.index(cycle)

    if  not cycle_index :
        return None

    min_cycle_index = 0

    if cycle_index > n:
        min_cycle_index = cycle_index - n
        
    start = continuous_min_id_per_cycle[cycles[min_cycle_index]]
    end = continuous_min_id_per_cycle[cycles[cycle_index]] - 1
    
    parent_bee_id = random.randint(start, end)

    return parent_bee_id

In [None]:
convertUDF = udf(lambda z: assert_parent_bee_id(z))
df_cleaned = df_cleaned\
        .withColumn("Parent Temp", convertUDF(col('Cycle')))\
        .cache()

df_cleaned.show()

In [None]:
from treelib import Node, Tree


root = df_cleaned.filter(df_cleaned['Continuous ID'] == 1).collect()[0]

working_list = [root]
visited_list = []

tree = Tree()
tree.create_node(root['Bee ID'], root['Bee ID'])

while working_list:
    parent = working_list.pop()
    kids = df_cleaned.filter(df_cleaned['Parent Temp'] == parent['Continuous ID'])

    for k in kids.collect():
        working_list.append(k)
        tree.create_node(k['Bee ID'], k['Bee ID'], parent['Bee ID'] )

# tree.create_node("Harry", "Harry")  # No parent means its the root node
# tree.create_node("Jane",  "jane"   , parent="Harry")


tree.show()

In [None]:
# from operator import add
def test2(a,b):
    return a+b

spark.sparkContext.parallelize([1, 2, 3, 4, 5]).fold(1, test2)

In [None]:
def test1(tree_dict, row):

    # for r in row:
    #     if r[10] == 1:
    #         tree_dict.append(row)

    # for column in tree_dict:
    #     if column['Continuous ID'] == row[0]:
    #         tree_dict.append(row)

    # for column in tree_dict:
    #     #tree_dict[column].append(row)
    #     for r in list(row):
    #         print(r)
    #     print("-------")

    # for column in tree_dict:
    #     print(type(column))
    #     # if column['Continuous ID'] == list(row.keys())[0][10]:
    #     #     print(row)
    # print('---------')

    # print(type(row))
    # print(row[10])
    # #print((list(row))[0][10])

    return tree_dict

root = df_cleaned.head()
df_cleaned.rdd.fold([],test1 )

In [None]:
d = {1:['a'], 2: ['b']}
d[1].append('c')
print(d)