In [None]:
import math
import os
import pickle
import time
import shutil
import sys
import uuid
from collections import defaultdict, Counter
from itertools import combinations, product
from datetime import timedelta, date
from glob import glob

import leidenalg as la
import igraph as ig
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import psutil
from pyspark.sql import functions as sf, types as st
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.storagelevel import StorageLevel

import settings as s

%load_ext autoreload
%autoreload 2

In [None]:
sns.set_theme(style="white", context="talk")

In [None]:
if (
    sys.version_info.major, 
    sys.version_info.minor, 
    sys.version_info.micro,
) != (3, 11, 8):
    raise EnvironmentError("Only runs efficiently on Python 3.11.8 | conda 24.1.2 | Apple M3 Pro")

In [None]:
config = [
    ("spark.driver.memory", "16g"),
    ("spark.worker.memory", "16g"),
    ("spark.driver.maxResultSize", "16g"),
]
spark = SparkSession.builder.appName("testing").config(conf=SparkConf().setAll(config)).getOrCreate()

In [None]:
start_script = time.time()

In [None]:
%%time

MAX_TRANSACTIONS_PER_ACCOUNT = 5_000

data = spark.read.parquet(s.STAGED_DATA_LOCATION)

#### [START] Seed selection ####
data = data.where(sf.col("source") != sf.col("target")) 
data = data.where(sf.col("format").isin(["ACH", "Wire", "Bitcoin"]))
#### [END] Seed selection ####

In [None]:
def aggregate_edges(data_input):
    data_aggregated = data_input.groupby(["source", "target"]).agg(
        sf.sum("source_amount").alias("source_amount"),
        sf.sum("target_amount").alias("target_amount"),
    ).toPandas()
    
    source_totals = data_aggregated.groupby(
        "source"
    ).agg({"source_amount": "sum"})["source_amount"].to_dict()
    target_totals = data_aggregated.groupby(
        "target"
    ).agg({"target_amount": "sum"})["target_amount"].to_dict()
    
    data_aggregated.loc[:, "total_sent_by_source"] = data_aggregated.loc[:, "source"].apply(
        lambda x: source_totals[x]
    )
    data_aggregated.loc[:, "total_received_by_target"] = data_aggregated.loc[:, "target"].apply(
        lambda x: target_totals[x]
    )
    data_aggregated.loc[:, "weight"] = data_aggregated.apply(
        lambda x: (
            (x["source_amount"] / x["total_sent_by_source"]) +
            (x["target_amount"] / x["total_received_by_target"])
        ),
        axis=1
    )
    data_aggregated.loc[:, "source"] = data_aggregated["source"].str.slice(0, 8)
    data_aggregated.loc[:, "target"] = data_aggregated["target"].str.slice(0, 8)
    filter_self = data_aggregated["source"] != data_aggregated["target"]
    data_aggregated = data_aggregated.loc[filter_self, :].reset_index(drop=True)
    return data_aggregated.loc[:, ["source", "target", "weight"]]

In [None]:
%%time

edges = aggregate_edges(data)
graph = ig.Graph.DataFrame(edges, use_vids=False, directed=True)
nodes = [x["name"] for x in graph.vs()]

In [None]:
def get_processes(ids):
    processes = []
    for process in psutil.process_iter():
        cmdline = []
        try:
            cmdline = process.cmdline()
        except Exception as error:
            pass
        if ids.intersection(cmdline):
            processes.append(process)
    return processes

In [None]:
%%time

NUMBER_OF_PROCESSES = 10

shutil.rmtree("staging", ignore_errors=True)
os.mkdir("staging")
chunks = np.array_split(nodes, NUMBER_OF_PROCESSES)

filename = "graph.pickle"
with open(filename, "wb") as f:
    pickle.dump(graph, f, protocol=pickle.HIGHEST_PROTOCOL)

filename = "nodes.pickle"
with open(filename, "wb") as f:
    pickle.dump(chunks, f, protocol=pickle.HIGHEST_PROTOCOL)

process_ids = set()
process_name = "communities.py"
for chunk_number in range(NUMBER_OF_PROCESSES):
    process_id = str(uuid.uuid4())
    process_ids = process_ids.union({process_id})
    os.system(f"{sys.executable} {process_name} {chunk_number} {process_id} &")

while get_processes(process_ids):
    time.sleep(5)

In [None]:
for proc in get_processes(process_ids):
    try:
        proc.kill()
    except psutil.NoSuchProcess:
        pass

communities = []
for filename in glob("./staging/*.pickle"):
    with open(filename, "rb") as f:
        communities += pickle.load(f)

filename = "communities.pickle"
with open(filename, "wb") as f:
    pickle.dump(communities, f)
communities = [x[1] for x in communities]

In [None]:
len(communities)

In [None]:
sizes = [len(x) for x in communities]
round(np.mean(sizes)), round(np.median(sizes)), round(np.max(sizes)), sum(sizes)

In [None]:
sns.set_theme(rc={"figure.figsize":(12.7, 7.27)})
sns.histplot(data=pd.DataFrame(sizes, columns=["Size"]), x="Size", kde=True)

In [None]:
sns.set_theme(rc={"figure.figsize":(10.7, 5.27)})
sns.boxplot(x=sizes)

In [None]:
flows = pd.read_parquet("flows.parquet")
flow_stats = pd.read_parquet("flow_stats.parquet")

In [None]:
%%time

search_hash = defaultdict(list)
for index, community in enumerate(communities):
    for node in community:
        search_hash[node].append(index)

In [None]:
%%time

percentages = []
start = time.time()
for index, (group, grouped) in enumerate(flows.groupby("id")):
    flow_nodes = set(grouped["source"]).union(grouped["target"])
    size = len(flow_nodes)
    matches = []
    perc = 0
    for node in flow_nodes:
        for i in search_hash[node]:
            try:
                matched_size = len(set(communities[i]).intersection(flow_nodes))
            except KeyError:
                continue
            perc = matched_size / size
            matches.append((node, perc))
            if perc == 1:
                break
        if perc == 1:
            break
    matched_node_comm, perc = sorted(matches, reverse=True, key=lambda x: x[1])[0]
    stats = flow_stats.loc[flow_stats["id"] == group, :].iloc[0].to_dict()
    stats["score"] = perc
    stats["matched_node_comm"] = matched_node_comm
    percentages.append(dict(stats))
    if not (index % 2_000):
        print(index, round(time.time() - start))
        start = time.time()

percentages = pd.DataFrame(percentages)

In [None]:
round(percentages["score"].mean(), 2) * 100

In [None]:
round(percentages[percentages["score"] == 1].shape[0] / percentages.shape[0], 2) * 100

In [None]:
percentages.groupby("type").agg({"score": "mean"}).sort_values("score").plot.bar()

In [None]:
filter_ = percentages["number_components"] == 1
percentages_scope = percentages.loc[filter_, :].reset_index(drop=True)
percentages_scope.groupby("type").agg({"score": "mean"}).sort_values("score").plot.bar()

In [None]:
percentages_scope[percentages_scope["score"] < 1].shape[0]

In [None]:
percentages_scope[percentages_scope["score"] < 1].groupby("sub_type")["type"].count()

In [None]:
%%time

columns = [
    "transaction_id", "source", "target", 
    sf.col("timestamp").alias("ts"),
    sf.col("source_amount").alias("amount"),
    sf.col("source_currency").alias("currency"),
]
transactions = data.where(
    sf.col("source_currency") == sf.col("target_currency")
).select(*columns).toPandas()
transactions.loc[:, "source"] = transactions["source"].str.slice(0, 8)
transactions.loc[:, "target"] = transactions["target"].str.slice(0, 8)
transactions = transactions.loc[transactions["source"] != transactions["target"], :].reset_index(drop=True)
transactions.loc[:, "edge"] = transactions.apply(
    lambda x: tuple(sorted([x["source"], x["target"]])), axis=1
)
transactions.loc[:, "timestamp"] = (
    transactions["ts"].astype(np.int64) / int(1e6)
).astype(np.int64)
del transactions["ts"]
transactions.set_index("edge", inplace=True)

In [None]:
%%time

location = "transactions_communities"
shutil.rmtree(location, ignore_errors=True)
os.mkdir(location)

number_of_chunks = int(np.ceil(len(communities) / 100_000))
chunks = np.array_split(communities, number_of_chunks)
for index, chunk in enumerate(chunks):
    edge_combinations = pd.DataFrame([
        x for y in
        [product([i], combinations(sorted(x), 2)) for i, x in enumerate(chunk)]
        for x in y
    ], columns=["id", "edge"]).set_index("edge")
    edge_combinations.join(transactions, how="inner").to_parquet(
        f"{location}/part-{index}.parquet", index=False
    )
    print(index)

In [None]:
transactions_communities = spark.read.parquet(location)
transactions_communities.count()

In [None]:
delta = round(time.time() - start_script)
print(f"Script executed in {timedelta(seconds=delta)}")