In [None]:
import os
import sys

import pyspark


spark_version = pyspark.__version__
if not spark_version.startswith("3"):
    raise EnvironmentError(
        f"Can only execute this notebook on the kernel with Spark 3+ installed | Found: {spark_version}"
    )
sys.path.append(os.path.abspath(".."))

In [None]:
!/home/PyPIconfig.sh
!pip install mlopstools==1.0.37
!pip install pyspark==3.1.1
# Potentially need to restart kernel after executing this cell to make package available on kernel

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import time
import uuid
from collections import defaultdict, Counter
from datetime import timedelta, datetime

import boto3
import pyspark
import numpy as np
import pandas as pd
from graphframes import GraphFrame
from graphframes.lib import AggregateMessages as AM  # noqa
from pyspark.sql import Window
from pyspark.sql import functions as sf
from pyspark.sql import types as st
from pyspark.ml.fpm import FPGrowth

from inference.src.settings import MIN_TRX_DATE, MAX_TRX_DATE, get_logger
from inference.jobs.utils import setup_job


if pyspark.__version__ != "3.1.1":
    raise EnvironmentError("PySpark 3.1.1 not yet loaded | Please restart the kernel")

LOGGER = get_logger(__name__)

In [None]:
# Spark session setup

_, spark = setup_job("exploration", str(uuid.uuid4()), str(uuid.uuid4()))

region = "eu-west-1"

sts_client = boto3.client("sts", region_name=region, endpoint_url=f"https://sts.{region}.amazonaws.com")
sagemaker_client = boto3.client("sagemaker")
domain_id = sagemaker_client.list_domains()["Domains"][0]["DomainId"]
execution_role_arn = sagemaker_client.describe_domain(DomainId=domain_id)["DefaultUserSettings"]["ExecutionRole"]
credentials = sts_client.assume_role(RoleArn=execution_role_arn, RoleSessionName="sagemaker-pyspark")

spark._jsc.hadoopConfiguration().set(  # noqa
    "fs.s3a.aws.credentials.provider", "org.apache.hadoop.fs.s3a.TemporaryAWSCredentialsProvider"
)
spark.sparkContext._jsc.hadoopConfiguration().set(  # noqa
    "fs.s3a.access.key", credentials["Credentials"]["AccessKeyId"]
)
spark.sparkContext._jsc.hadoopConfiguration().set(  # noqa
    "fs.s3a.secret.key", credentials["Credentials"]["SecretAccessKey"]
)
spark.sparkContext._jsc.hadoopConfiguration().set(  # noqa
    "fs.s3a.session.token", credentials["Credentials"]["SessionToken"]
)
spark.sparkContext.setLogLevel("ERROR")

In [None]:
BUCKET = "tmnl-prod-data-scientist-sagemaker-data-intermediate"

In [None]:
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-input"
location_s3 = "s3" + location.lstrip("s3a")

In [None]:
available_dates = !aws s3 ls {location_s3}/
available_dates = sorted(
    set([x.strip().replace("PRE transaction_date=", "").replace("/", "")[:10] for x in available_dates])
)[:-1]
LOGGER.info(len(available_dates), available_dates[0], available_dates[-1])
available_dates = set(available_dates)

In [None]:
HORIZON = 21  # days
MAX_CENTRALITY = 1000
MIN_TOTAL_TRANSACTION_AMOUNT = 500

In [None]:
def filter_data(dataframe):
    window_source = Window.partitionBy("source")
    window_target = Window.partitionBy("target")
    return (
        dataframe.withColumn("total_out", sf.sum("amount").over(window_source))
        .withColumn("total_in", sf.sum("amount").over(window_target))
        .withColumn("count_out", sf.count("source").over(window_source))
        .withColumn("count_in", sf.count("target").over(window_target))
        .where(
            (
                (sf.col("total_out") >= MIN_TOTAL_TRANSACTION_AMOUNT)
                | (sf.col("total_in") >= MIN_TOTAL_TRANSACTION_AMOUNT)
            )
            & (sf.col("count_out") < MAX_CENTRALITY)
            & (sf.col("count_in") < MAX_CENTRALITY)
        )
        .drop("total_out", "total_in", "count_out", "count_in")
    )


def rename_columns(dataframe, names):
    for name, new_name in names.items():
        dataframe = dataframe.withColumnRenamed(name, new_name)
    return dataframe


def max_timestamp(dt):
    year, month, date = dt.split("-")
    return (datetime(int(year), int(month), int(date)) + timedelta(days=1)).timestamp()

In [None]:
left_columns = {x.name: f"{x.name}_left" for x in spark.read.parquet(location).schema}

In [None]:
start_time = time.time()
location_output = f"s3a://{BUCKET}/community-detection/exploration/ftm-joins/"
dates = [str(x.date()) for x in sorted(pd.date_range(MIN_TRX_DATE, MAX_TRX_DATE))]
for transaction_date in dates:
    if not ({transaction_date}.intersection(available_dates)):
        raise Exception(f"[NotAvailable] {transaction_date}")
    start_index = dates.index(transaction_date)
    end_index = start_index + HORIZON + 1
    right_dates = dates[start_index:end_index]
    right_dates = sorted(available_dates.intersection(right_dates))
    if not right_dates:
        raise Exception(f"[NotAvailable] {dates[start_index:end_index]}")
    right_location = [f"{location}/transaction_date={x}" for x in right_dates]
    right = filter_data(spark.read.parquet(*right_location)).cache()
    _ = right.count()
    left = rename_columns(right.where(right.transaction_timestamp < max_timestamp(transaction_date)), left_columns)
    join = left.join(right, left.target_left == right.source, "inner")
    join = join.withColumn("delta", join.transaction_timestamp - join.transaction_timestamp_left)
    join = join.where(join.delta > -1)
    join.write.parquet(f"{location_output}date={transaction_date}", mode="overwrite")
    LOGGER.info(f"[{transaction_date}] Ran in {timedelta(seconds=round(time.time() - start_time))}")
    start_time = time.time()
    right.unpersist()
    del right
    break

In [None]:
data = spark.read.parquet(location_output)

# TODO: This is to avoid cycles in rare instances -> Implement a better solution
data = data.where(data.transaction_timestamp > data.transaction_timestamp_left)

data = data.withColumnRenamed("date", "transaction_date_left").withColumn(
    "transaction_date", sf.from_unixtime("transaction_timestamp").cast(st.DateType())
)

location_nodes_1 = f"s3a://{BUCKET}/community-detection/exploration/ftm-node-1"
location_nodes_2 = f"s3a://{BUCKET}/community-detection/exploration/ftm-node-2"
node_columns = ["id", "source", "target", "transaction_date", "transaction_timestamp", "amount"]
nodes_1 = (
    (
        data.select(
            sf.col("id_left").alias("id"),
            sf.col("source_left").alias("source"),
            sf.col("target_left").alias("target"),
            sf.col("transaction_timestamp_left").alias("transaction_timestamp"),
            sf.col("amount_left").alias("amount"),
            sf.col("transaction_date_left").alias("transaction_date"),
        )
    )
    .select(*node_columns)
    .drop_duplicates(subset=["id"])
)
nodes_1.write.mode("overwrite").parquet(location_nodes_1)
nodes_2 = data.select(*node_columns).drop_duplicates(subset=["id"])
nodes_2.write.mode("overwrite").parquet(location_nodes_2)
nodes_1 = spark.read.parquet(location_nodes_1)
nodes_2 = spark.read.parquet(location_nodes_2)
nodes = nodes_1.union(nodes_2).drop_duplicates(subset=["id"])

edges = data.select(
    sf.col("id_left").alias("src"),
    sf.col("id").alias("dst"),
    sf.col("transaction_date_left").alias("src_date"),
    sf.col("transaction_date").alias("dst_date"),
    "delta",
)

In [None]:
%%time

nodes_location = f"s3a://{BUCKET}/community-detection/exploration/ftm-nodes/"
edges_location = f"s3a://{BUCKET}/community-detection/exploration/ftm-edges/"

nodes.repartition("transaction_date").write.partitionBy("transaction_date").mode("overwrite").parquet(nodes_location)
partition_by = ["src_date", "dst_date"]
edges.repartition(*partition_by).write.partitionBy(*partition_by).mode("overwrite").parquet(edges_location)

In [None]:
def pattern_for(hops):
    last = "x0"
    pattern_constructed = ""
    for x in range(hops):
        current = f"x{x + 1}"
        edge = f"e{x}"
        pattern_constructed += f"({last}) - [{edge}] -> ({current}); "
        last = str(current)
    return pattern_constructed.strip(" ;")


def select_columns(hops, result):
    columns = [sf.col(f"x{x}.account").alias(f"x{x}") for x in range(hops)]
    return result.select(*columns)

In [None]:
# This is important | Without setting the checkpoint directory, GraphFrames will fail
spark.sparkContext.setCheckpointDir(".")

nodes = spark.read.parquet(nodes_location)
edges = spark.read.parquet(edges_location)

graph = GraphFrame(nodes, edges)

In [None]:
HOPS = 5

pattern = pattern_for(HOPS)
results = select_columns(HOPS, graph.find(pattern)).toPandas()

In [None]:
array_sum = sf.udf(lambda x: int(np.array(x, dtype=int).sum()), st.IntegerType())
pattern = "(x1) - [e1] -> (x2)"
number_of_days_allowed = 7 * 24 * 60 * 60

In [None]:
# Type 1 flows

results_t1 = (
    graph.find(pattern)
    .select("x1", "x2", "e1.delta")
    .where(sf.col("e1.delta") < number_of_days_allowed)
    .drop("e1.delta")
    .groupby("x2.source")
    .agg(
        sf.collect_set("x1").alias("left"),
        sf.collect_set("x2").alias("right"),
        sf.countDistinct("x1.source").alias("sources"),
        sf.countDistinct("x2.target").alias("targets"),
    )
    .where((sf.col("sources") > 2) & (sf.col("targets") > 2))
    .withColumn("total", array_sum(sf.col("right.amount")))
    .where(sf.col("total") > 10000)
    .withColumnRenamed("x2.source", "middle")
    .withColumn("start_transactions", sf.size("left"))
    .withColumn("end_transactions", sf.size("right"))
)
location_t1 = f"s3a://{BUCKET}/community-detection/exploration/ftm-results-type-1"
results_t1.write.mode("overwrite").parquet(location_t1)

In [None]:
# Type 2 flows

results_t2 = (
    graph.find(pattern)
    .select("x1", "x2", "e1.delta")
    .where(sf.col("e1.delta") < number_of_days_allowed)
    .drop("e1.delta")
    .groupby("x1.source", "x2.target")
    .agg(
        sf.collect_set("x1").alias("left"),
        sf.collect_set("x2").alias("right"),
        sf.countDistinct("x1.target").alias("intermediaries"),
    )
    .where(sf.col("intermediaries") > 2)
    .withColumn("total", array_sum(sf.col("right.amount")))
    .where(sf.col("total") > 10000)
    .withColumn("start_transactions", sf.size("left"))
    .withColumn("end_transactions", sf.size("right"))
)
location_t2 = f"s3a://{BUCKET}/community-detection/exploration/ftm-results-type-2"
results_t2.write.mode("overwrite").parquet(location_t2)

In [None]:
results_t2 = results_t2.withColumn("percentage_transferred", results_t2.total / array_sum("left.amount"))

In [None]:
# 3-hop networks

pattern = "(x1) - [e1] -> (x2); (x2) - [e2] -> (x3)"

results_h3 = graph.find(pattern)
result_h3 = results_h3.where((results_h3.e1.delta + results_h3.e2.delta) < number_of_days_allowed).select(
    sf.array(
        sf.concat(sf.lit("s-"), "x1.source"),
        sf.concat(sf.lit("m-"), "x1.target"),
        sf.concat(sf.lit("m-"), "x2.target"),
        sf.concat(sf.lit("t-"), "x3.target"),
    ).alias("items")
)
location_h3 = f"s3a://{BUCKET}/community-detection/exploration/ftm-results-3-hops-input"
result_h3.write.mode("overwrite").parquet(location_h3)

fp = FPGrowth(minSupport=0.000001, minConfidence=0.5)
fpm = fp.fit(result_h3)
location_model = f"s3a://{BUCKET}/community-detection/exploration/ftm-fpm-model"
fpm.write().overwrite().save(location_model)

rules = fpm.associationRules.sort("antecedent", "consequent").toPandas()

In [None]:
location_fpm_t0 = f"s3a://{BUCKET}/community-detection/exploration/ftm-input-fpm-t0"
location_fpm_t1 = f"s3a://{BUCKET}/community-detection/exploration/ftm-input-fpm-t1"
location_fpm_t2 = f"s3a://{BUCKET}/community-detection/exploration/ftm-input-fpm-t2"

l0 = f"s3a://{BUCKET}/community-detection/exploration/ftm-input-fpm-pruned-t0"
l1 = f"s3a://{BUCKET}/community-detection/exploration/ftm-input-fpm-pruned-t1"
l2 = f"s3a://{BUCKET}/community-detection/exploration/ftm-input-fpm-pruned-t2"

In [None]:
# FPM All Connections

message = sf.array(
    sf.concat(AM.src["source"], sf.lit("-"), AM.src["target"]),
    sf.concat(AM.src["target"], sf.lit("-"), AM.dst["target"]),
)
t0 = (
    graph.aggregateMessages(sf.collect_set(AM.msg).alias("items"), sendToDst=message)
    .select("id", "items")
    .repartition(1024, "id")
    .cache()
)
LOGGER.info(f"{t0.count():,} nodes processed")

schema = st.StructType(
    [
        st.StructField("left", st.StringType(), nullable=False),
        st.StructField("right", st.StringType(), nullable=False),
        st.StructField("degree", st.IntegerType(), nullable=False),
        st.StructField("id", st.IntegerType(), nullable=False),
    ]
)


@sf.pandas_udf(schema, sf.PandasUDFType.GROUPED_MAP)
def unpivot(input_data):
    row = input_data.iloc[0]
    result = pd.DataFrame(row["items"].tolist(), columns=["left", "right"])
    result.loc[:, "degree"] = result.shape[0]
    result.loc[:, "id"] = row["id"]
    return result


t0.groupby("id").apply(unpivot).write.mode("overwrite").parquet(location_fpm_t0)

In [None]:
# Edges to keep -> Top 20% most frequent

schema = st.StructType(
    [
        st.StructField("src_left", st.StringType(), nullable=False),
        st.StructField("dst_left", st.StringType(), nullable=False),
        st.StructField("src_right", st.StringType(), nullable=False),
        st.StructField("dst_right", st.StringType(), nullable=False),
    ]
)


@sf.pandas_udf(schema, sf.PandasUDFType.GROUPED_MAP)
def select_edges(input_data):
    columns = ["src_left", "dst_left", "src_right", "dst_right"]
    src_right, dst_right = input_data.iloc[0]["right"].split("-")
    result = pd.DataFrame.from_dict(Counter(input_data["left"]), "index", columns=["percentile_rank"]).reset_index()
    result.loc[:, "percentile_rank"] = result.loc[:, "percentile_rank"].rank(pct=True)
    result = result.loc[result.percentile_rank > 0.7999, :]
    if result.empty:
        return pd.DataFrame(columns=columns)
    result.loc[:, "src_right"] = src_right
    result.loc[:, "dst_right"] = dst_right
    left_side = np.array(result.loc[:, "index"].str.split("-").tolist())
    result.loc[:, "src_left"] = left_side[:, 0]
    result.loc[:, "dst_left"] = left_side[:, 1]
    return result.loc[:, columns]


t0 = spark.read.parquet(location_fpm_t0)
t0.select("left", "right").groupby("right").apply(select_edges).write.mode("overwrite").parquet(l0)

In [None]:
# FPM "At least" 60% of the amount carried forward

message = sf.array(
    sf.concat(AM.src["source"], sf.lit("-"), AM.src["target"]),
    sf.concat(AM.src["target"], sf.lit("-"), AM.dst["target"]),
    AM.src["amount"],
    AM.dst["amount"],
)


@sf.udf(st.ArrayType(st.ArrayType(st.StringType())))
def filter_connections(rows):
    forwarded = int(rows[0][3])
    threshold = forwarded * 0.5999
    connections = defaultdict(int)
    for row in rows:
        connections[row[0]] += int(row[2])
    right_node = rows[0][1]
    return [(k, right_node) for k, v in connections.items() if v > threshold]


t1 = (
    graph.aggregateMessages(sf.collect_list(AM.msg).alias("items_raw"), sendToDst=message)
    .withColumn("items", filter_connections(sf.col("items_raw")))
    .select("id", "items")
    .where(sf.col("items") != sf.array())
    .repartition(1024, "id")
    .cache()
)
LOGGER.info(f"{t1.count():,} nodes processed")

t1.groupby("id").apply(unpivot).write.mode("overwrite").parquet(location_fpm_t1)

t1 = spark.read.parquet(location_fpm_t1)
t1.select("left", "right").groupby("right").apply(select_edges).write.mode("overwrite").parquet(l1)

In [None]:
# FPM "At most" 40% of the amount carried forward


@sf.udf(st.ArrayType(st.ArrayType(st.StringType())))
def filter_connections(rows):
    forwarded = int(rows[0][3])
    threshold = forwarded * 0.4001
    connections = defaultdict(int)
    for row in rows:
        connections[row[0]] += int(row[2])
    right_node = rows[0][1]
    return [(k, right_node) for k, v in connections.items() if v < threshold]


t2 = (
    graph.aggregateMessages(sf.collect_list(AM.msg).alias("items_raw"), sendToDst=message)
    .withColumn("items", filter_connections(sf.col("items_raw")))
    .select("id", "items")
    .where(sf.col("items") != sf.array())
    .repartition(1024, "id")
    .cache()
)
LOGGER.info(f"{t2.count():,} nodes processed")

t2.groupby("id").apply(unpivot).write.mode("overwrite").parquet(location_fpm_t2)

t2 = spark.read.parquet(location_fpm_t2)
t2.select("left", "right").groupby("right").apply(select_edges).write.mode("overwrite").parquet(l2)

In [None]:
# `Connections` graph / community detection

pattern = "(x1) - [e1] -> (x2)"
nodes_days = HORIZON + 1
for start_date in pd.date_range("2021-04-01", freq="d", periods=365):
    start_date = start_date.date()
    start_time = time.time()
    nodes_dates = [str(x.date()) for x in sorted(pd.date_range(start_date, periods=nodes_days, freq="d"))]
    nodes_dates = [x for x in nodes_dates if x <= MIN_TRX_DATE]
    nodes_locations = [f"{nodes_location}transaction_date={x}/" for x in nodes_dates]
    day_edges = spark.read.parquet(f"{edges_location}src_date={start_date}/")
    day_nodes = spark.read.parquet(*nodes_locations)
    graph = GraphFrame(day_nodes, day_edges)
    graph.find(pattern).select(
        sf.col("x1.source").alias("start"),
        sf.col("x1.target").alias("middle"),
        sf.col("x2.target").alias("end"),
    ).dropDuplicates().write.mode("overwrite").parquet(
        f"s3a://{BUCKET}/community-detection/exploration/ftm-fpm-input/date={start_date}"
    )
    LOGGER.info(f"[{start_date}] Ran in {timedelta(seconds=round(time.time() - start_time))}")

schema = st.StructType(
    [
        st.StructField("src", st.StringType(), nullable=False),
        st.StructField("dst", st.StringType(), nullable=False),
        st.StructField("weight", st.FloatType(), nullable=False),
    ]
)


@sf.pandas_udf(schema, sf.PandasUDFType.GROUPED_MAP)
def create_connection_edges(input_data):
    end = input_data.iloc[0]["end"]
    input_data = (
        input_data.groupby("start").agg({"middle": "first", "end": "count", "dst_count_per_src": "first"}).reset_index()
    )
    input_data.loc[:, "weight"] = input_data.loc[:, "end"] / input_data.loc[:, "dst_count_per_src"]
    input_data.loc[:, "src"] = input_data.start + "-" + input_data.middle
    input_data.loc[:, "dst"] = input_data.middle + f"-{end}"
    return input_data.loc[:, ["src", "dst", "weight"]]


edges_connections = spark.read.parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-fpm-input/")
dst_counts = edges_connections.groupby("start", "middle").agg(sf.count("end").alias("dst_count_per_src")).cache()
LOGGER.info(f"{dst_counts.count():,} `sources` found")
dst_counts.write.mode("overwrite").parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-temp-1")

edges_connections = (
    dst_counts.alias("left")
    .join(
        edges_connections.alias("right"),
        (dst_counts.start == edges_connections.start) & (dst_counts.middle == edges_connections.middle),
        "inner",
    )
    .select(sf.col("left.start"), sf.col("left.middle"), sf.col("right.end"), "dst_count_per_src")
    .cache()
)
LOGGER.info(f"{edges_connections.count():,} `edges` found")
edges_connections.write.mode("overwrite").parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-temp-2")

location = f"s3a://{BUCKET}/community-detection/exploration/ftm-connection-edges"
edges_connections.groupby("middle", "end").apply(create_connection_edges).write.mode("overwrite").parquet(location)

edges_connections = spark.read.parquet(location)
LOGGER.info(f"{edges_connections.count():,} `edges` left")

nodes_connections = (
    edges_connections.select(sf.col("src").alias("id")).union(edges_connections.select(sf.col("dst").alias("id")))
).dropDuplicates()
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-connection-nodes"
nodes_connections.write.mode("overwrite").parquet(location)

nodes_connections = spark.read.parquet(location)
graph = GraphFrame(nodes_connections, edges_connections)

communities = graph.labelPropagation(maxIter=1)
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-connection-communities"
communities.select("id", "label").write.mode("overwrite").parquet(location)

In [None]:
# Processing a window

start_date = "2021-04-01"
edges_days = 21
nodes_days = int(edges_days * 2)

nodes_dates = [str(x.date()) for x in sorted(pd.date_range(start_date, periods=nodes_days, freq="d"))]
nodes_locations = [f"{nodes_location}transaction_date={x}/" for x in nodes_dates]

edges_locations = []
for src_date in [str(x.date()) for x in sorted(pd.date_range(start_date, periods=edges_days, freq="d"))]:
    dst_dates = [str(x.date()) for x in sorted(pd.date_range(src_date, periods=edges_days, freq="d"))]
    for dst_date in dst_dates:
        edges_locations.append(f"{edges_location}src_date={src_date}/dst_date={dst_date}")

nodes = spark.read.parquet(*nodes_locations)
edges = spark.read.parquet(*edges_locations)

graph = GraphFrame(nodes, edges)
pattern = "(x1) - [e1] -> (x2)"

connection_edges = spark.read.parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-connection-edges")

window_connections = (
    graph.find(pattern)
    .select(
        sf.concat(sf.col("x1.source"), sf.lit("-"), sf.col("x1.target")).alias("src"),
        sf.concat(sf.col("x2.source"), sf.lit("-"), sf.col("x2.target")).alias("dst"),
    )
    .dropDuplicates()
    .cache()
)
LOGGER.info(f"{window_connections.count():,} `connections` found")

widow_connection_edges = window_connections.join(connection_edges, ["src", "dst"], "inner")

location = f"s3a://{BUCKET}/community-detection/exploration/ftm-window-connection-edges"
widow_connection_edges.write.mode("overwrite").parquet(location)
widow_connection_edges = spark.read.parquet(location)
LOGGER.info(f"{widow_connection_edges.count():,} `edges` found")

window_connection_nodes = (
    widow_connection_edges.select(sf.col("src").alias("id")).union(
        widow_connection_edges.select(sf.col("dst").alias("id"))
    )
).dropDuplicates()
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-window-connection-nodes"
window_connection_nodes.write.mode("overwrite").parquet(location)

LOGGER.info("Getting connected components")
location = f"s3a://{BUCKET}/community-detection/exploration/connected-components"
graph.connectedComponents().write.mode("overwrite").parquet(location)

LOGGER.info("Loading connected components")
cc = spark.read.parquet(location)
count = cc.select("component").distinct().count()
LOGGER.info(f"Found {count:,} connected components")

import igraph as ig
import leidenalg as la


MIN_WEIGHT = 0.000999

data = pd.DataFrame()
for x in range(20):  # Correct this number based on number of files
    loc = f"s3a://{BUCKET}/community-detection/exploration/ftm-window-connection-edges/part-{f'{x}'.zfill(4)}*"
    chunk = spark.read.parquet(loc)
    data = data.append(chunk.where(chunk.weight > MIN_WEIGHT).toPandas(), ignore_index=True)
    LOGGER.info(f"Processed chunk number {x} | {data.shape[0]:,}")

data = pd.read_parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-window-connection-edges/")

spark.catalog.clearCache()

graph = ig.Graph.DataFrame(data, use_vids=False, directed=True)
LOGGER.info("Graph Loaded")
communities = la.find_partition(
    graph, la.ModularityVertexPartition, weights="weight", n_iterations=5, max_comm_size=100
)
LOGGER.info("Communities Detected")
communities_output = graph.get_vertex_dataframe()
communities_output.loc[:, "label"] = communities.membership
cluster_graph = communities.cluster_graph()
mapping = dict(zip(cluster_graph.get_vertex_dataframe().index, cluster_graph.clusters().membership))
communities_output.loc[:, "label_cluster"] = communities_output.loc[:, "label"].apply(mapping.get)
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-window-communities"
spark.createDataFrame(communities_output).write.mode("overwrite").parquet(location)

# TODO: Instead of using the weights (or weight filter), keep "all" edges b/w nodes
# ...for the LPA community detection on Spark
nodes = spark.read.parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-window-connection-nodes")
edges = spark.read.parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-window-connection-edges")
edges = edges.where(edges.weight > MIN_WEIGHT).cache()
LOGGER.info(f"Running LPA for {edges.count():,} edges")
graph = GraphFrame(nodes, edges)
communities = graph.labelPropagation(maxIter=5)
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-window-lpa-communities"
communities.select("id", "label").write.mode("overwrite").parquet(location)

In [None]:
# Combine `connection communities` with flows

location = f"s3a://{BUCKET}/community-detection/exploration/ftm-window-lpa-communities"
communities = spark.read.parquet(location)

nodes = spark.read.parquet(*nodes_locations)
results = communities.join(
    nodes.withColumnRenamed("id", "id_actual"),
    communities.id == sf.concat(sf.col("source"), sf.lit("-"), sf.col("target")),
    "inner",
)
results = results.drop("id").withColumnRenamed("id_actual", "id")
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-window-communities-m1"
results.write.mode("overwrite").parquet(location)

location = f"s3a://{BUCKET}/community-detection/exploration/ftm-window-communities"
communities = spark.read.parquet(location)
results = communities.join(
    nodes,
    communities.name == sf.concat(sf.col("source"), sf.lit("-"), sf.col("target")),
    "inner",
)
results = results.drop("name")
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-window-communities-m2"
results.write.mode("overwrite").parquet(location)

communities = spark.read.parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-window-communities-m2")
schema = st.StructType(
    [
        st.StructField("label", st.IntegerType(), nullable=False),
        st.StructField("label_cluster", st.IntegerType(), nullable=False),
        st.StructField("dispensers", st.IntegerType(), nullable=False),
        st.StructField("intermediates", st.IntegerType(), nullable=False),
        st.StructField("sinks", st.IntegerType(), nullable=False),
        st.StructField("forwarded", st.IntegerType(), nullable=False),
        st.StructField("received", st.IntegerType(), nullable=False),
        st.StructField("percentage_forwarded", st.FloatType(), nullable=False),
        st.StructField("diameter", st.IntegerType(), nullable=False),
        st.StructField("components", st.IntegerType(), nullable=False),
    ]
)


@sf.pandas_udf(schema, sf.PandasUDFType.GROUPED_MAP)
def community_summary(input_data):
    first = input_data.iloc[0].to_dict()
    sources = set(input_data.loc[:, "source"])
    targets = set(input_data.loc[:, "target"])
    dispensers = sources.difference(targets)
    intermediates = sources.intersection(targets)
    sinks = targets.difference(sources)
    result = pd.DataFrame([first["label"]], columns=["label"])
    result.loc[:, "label_cluster"] = first["label_cluster"]
    result.loc[:, "dispensers"] = len(dispensers)
    result.loc[:, "intermediates"] = len(intermediates)
    result.loc[:, "sinks"] = len(sinks)
    input_data.loc[:, "is_dispenser"] = input_data.loc[:, "source"].isin(dispensers)
    input_data.loc[:, "is_sink"] = input_data.loc[:, "target"].isin(sinks)
    result.loc[:, "forwarded"] = sum(input_data.loc[input_data["is_dispenser"], "amount"])
    result.loc[:, "received"] = sum(input_data.loc[input_data["is_sink"], "amount"]) or 1
    result.loc[:, "percentage_forwarded"] = result.loc[:, "received"] / result.loc[:, "forwarded"]
    columns = ["source", "target", "transaction_timestamp", "amount"]
    graph_flow = ig.Graph.DataFrame(input_data.loc[:, columns], use_vids=False, directed=True)
    result.loc[:, "diameter"] = graph_flow.diameter()
    result.loc[:, "components"] = len(graph_flow.clusters(mode="weak").sizes())
    return result


window_communities_summary = communities.groupby("label").apply(community_summary)
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-window-communities-summary-m2"
window_communities_summary.write.mode("overwrite").parquet(location)

In [None]:
# Pruning

selected = spark.read.parquet(l0)

pattern = "(x1) - [e1] -> (x2)"
original = graph.find(pattern)
join_on = (
    (original.x1.source == selected.src_left)
    & (original.x1.target == selected.dst_left)
    & (original.x2.source == selected.src_right)
    & (original.x2.target == selected.dst_right)
)
joined = original.join(selected, join_on, "inner").select(sf.col("e1.src").alias("src"), sf.col("e1.dst").alias("dst"))
output_location = f"s3a://{BUCKET}/community-detection/exploration/ftm-t0-edges"
joined.write.mode("overwrite").parquet(output_location)

edges = spark.read.parquet(output_location)

LOGGER.info(f"{edges.count():,} edges selected!")

nodes = edges.select(sf.col("src").alias("id")).union(edges.select(sf.col("dst").alias("id"))).dropDuplicates()
output_location = f"s3a://{BUCKET}/community-detection/exploration/ftm-t0-nodes"
nodes.write.mode("overwrite").parquet(output_location)

In [None]:
# Community detection

nodes = spark.read.parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-t0-nodes")
edges = spark.read.parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-t0-edges")

graph = GraphFrame(nodes, edges)

communities = graph.labelPropagation(maxIter=1)
output_location = f"s3a://{BUCKET}/community-detection/exploration/ftm-t0-communities"
communities.select("id", "label").write.mode("overwrite").parquet(output_location)

In [None]:
# Single node (all) flows

node = "x"
ids = nodes.where((nodes.source == node) | (nodes.target == node)).select("id").toPandas()["id"].tolist()
LOGGER.info(f"Selected nodes = {len(ids):,}")
pattern = "(x0) - [e0] -> (x1); (x1) - [e1] -> (x2); (x2) - [e2] -> (x3); (x3) - [e3] -> (x4)"
for hop in range(5):
    flows = graph.find(pattern).where(sf.col(f"x{hop}.id").isin(ids))
    flows = flows.select(
        flows.x0.source.alias("a"),
        flows.x0.target.alias("b"),
        flows.x1.target.alias("c"),
        flows.x2.target.alias("d"),
        flows.x3.target.alias("e"),
        flows.x4.target.alias("f"),
    ).dropDuplicates()
    output_location = f"s3a://{BUCKET}/community-detection/exploration/ftm-single-node-flows/hop={hop}"
    flows.write.mode("overwrite").parquet(output_location)
    LOGGER.info(f"Processed hop #{hop}")

In [None]:
# Personalised PageRank

nodes = spark.read.parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-connection-nodes")
edges = spark.read.parquet(f"s3a://{BUCKET}/community-detection/exploration/ftm-connection-edges")
edges_reversed = edges.select(
    edges.dst.alias("src"),
    edges.src.alias("dst"),
)

node = "x"
nodes = nodes.select("id", sf.split(nodes.id, "-", 0).alias("ids"))
source_ids = nodes.where((nodes.ids.getItem(0) == node) | (nodes.ids.getItem(1) == node)).select("id").toPandas()
source_ids = source_ids["id"].tolist()

LOGGER.info(f"`sources` count = {len(source_ids)}")

graph_forward = GraphFrame(nodes, edges)
graph_backward = GraphFrame(nodes, edges_reversed)

ppr_forward = graph_forward.parallelPersonalizedPageRank(resetProbability=0.15, sourceIds=source_ids, maxIter=1)
ppr_forward = ppr_forward.where(ppr_forward.weight > 0.4)
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-ppr-forward"
ppr_forward.edges.select("src", "dst", "weight").write.mode("overwrite").parquet(location)

ppr_backward = graph_backward.parallelPersonalizedPageRank(resetProbability=0.15, sourceIds=source_ids, maxIter=1)
ppr_backward = ppr_backward.where(ppr_backward.weight > 0.4)
location = f"s3a://{BUCKET}/community-detection/exploration/ftm-ppr-backward"
ppr_backward.edges.select("src", "dst", "weight").write.mode("overwrite").parquet(location)