In [None]:
import math
import os
import pickle
import time
import shutil
import sys
import uuid
from collections import defaultdict, Counter
from copy import deepcopy
from datetime import timedelta, date
from glob import glob

import leidenalg as la
import igraph as ig
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import psutil
from pyspark.sql import functions as sf, types as st
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.storagelevel import StorageLevel
from graphframes import GraphFrame

import settings as s

%load_ext autoreload
%autoreload 2

In [None]:
if (
    sys.version_info.major, 
    sys.version_info.minor, 
    sys.version_info.micro,
) != (3, 11, 8):
    raise EnvironmentError("Only runs efficiently on Python 3.11.8 | conda 24.1.2 | Apple M3 Pro")

In [None]:
config = [
    ("spark.jars.packages", "graphframes:graphframes:0.8.3-spark3.5-s_2.12"),
    ("spark.driver.memory", "8g"),
    ("spark.worker.memory", "8g"),
]
spark = SparkSession.builder.appName("testing").config(conf=SparkConf().setAll(config)).getOrCreate()

In [None]:
start_script = time.time()

In [None]:
%%time

data = spark.read.parquet(s.STAGED_DATA_LOCATION)
currency_conversion = data.where(
    sf.col("source_currency") != sf.col("target_currency")
).where((sf.col("source_currency") == "usd") | (sf.col("target_currency") == "usd")).select(
    "timestamp", "source_amount", "target_amount", "source_currency", "target_currency"
)
currencies = set(
    data.select("source_currency").distinct().toPandas()["source_currency"]
).union(
    data.select("target_currency").distinct().toPandas()["target_currency"]
)

currency_rates = {}
for currency in currencies:
    if currency == "usd":
        currency_rates[currency] = np.float64(1)
        continue
    conversion_data = currency_conversion.where(
        sf.col("source_currency") == currency
    ).select(
        "timestamp", (sf.col("target_amount") / sf.col("source_amount")).alias("x")
    ).union(
        currency_conversion.where(
            sf.col("target_currency") == currency
        ).select(
            "timestamp", (sf.col("source_amount") / sf.col("target_amount")).alias("x")
        )
    ).groupby(sf.col("timestamp").astype("date").alias("date")).agg(
        sf.median("x").alias("rate")
    ).sort("date", ascending=True).toPandas()
    currency_rates[currency] = conversion_data["rate"].mean()

In [None]:
def left_column(column):
    return f"{column}_left"

def update_source_target(input_data):
    return input_data.withColumn(
        "source", sf.substring("source", 0, 8)
    ).withColumn(
        "target", sf.substring("target", 0, 8)
    )

In [None]:
%%time

MIN_AMOUNT = 100
MAX_TRANSACTIONS_PER_ACCOUNT = 5_000

data = spark.read.parquet(s.STAGED_DATA_LOCATION)

#### [START] Seed selection ####
source_transactions = data.groupby("source").count().toPandas().sort_values("count", ascending=False)
target_transactions = data.groupby("target").count().toPandas().sort_values("count", ascending=False)
sources_to_remove = set(
    source_transactions[source_transactions["count"] > MAX_TRANSACTIONS_PER_ACCOUNT]["source"]
)
targets_to_remove = set(
    target_transactions[target_transactions["count"] > MAX_TRANSACTIONS_PER_ACCOUNT]["target"]
)
data = data.where(~sf.col("source").isin(sources_to_remove)).where(
    ~sf.col("target").isin(targets_to_remove)
)
data = data.where(sf.col("source") != sf.col("target"))
data = data.where(sf.col("format").isin(["ACH", "Bitcoin"]))
#### [END] Seed selection ####

data = data.where(sf.col("source_currency") == sf.col("target_currency"))
get_currency_rate = sf.udf(lambda x: float(currency_rates[x]), st.FloatType())
data = data.withColumn("amount", get_currency_rate("source_currency") * sf.col("source_amount"))
data = data.select("transaction_id", "timestamp", "source", "target", "amount")
data = update_source_target(data).persist(storageLevel=StorageLevel.DISK_ONLY)
data = data.where(sf.col("amount") >= MIN_AMOUNT)
print(f"\n{data.count():,}\n")
dates = sorted(
    data.select(sf.col("timestamp").astype(st.DateType()).alias("x")).distinct().toPandas()["x"]
)

In [None]:
flows = pd.read_parquet(s.STAGED_CASES_DATA_LOCATION)
filter_ = (
    flows["type"].isin(["cycle", "random"]) & 
    (~flows["sub_type"].isin(["max 1 hops"]))
)
flows = flows.loc[filter_, :].reset_index(drop=True)
flows.loc[:, "amount"] = flows["source_currency"].apply(lambda x: currency_rates[x]) * flows["source_amount"]

In [None]:
LOCATION = "staging-temporal"

In [None]:
# %%time

# WINDOW_SIZE = 35
# MAX_ALLOWED_DIFF = 0.25
# CHRONOLOGICAL = True

# shutil.rmtree(LOCATION, ignore_errors=True)
# os.mkdir(LOCATION)
# start = time.time()
# for index, date in enumerate(dates):
#     start_date = date - timedelta(days=WINDOW_SIZE)
#     if CHRONOLOGICAL:
#         start_date = deepcopy(date)
#     end_date = date + timedelta(days=WINDOW_SIZE)
#     day = data.where(sf.col("timestamp").astype(st.DateType()) == date)
#     day = day.select(*[sf.col(x).alias(left_column(x)) for x in day.columns])
#     window = data.where(
#         (sf.col("timestamp").astype(st.DateType()) >= start_date) &
#         (sf.col("timestamp").astype(st.DateType()) <= end_date)
#     )
#     join_on = sf.col("source") == sf.col(left_column("target"))
#     if CHRONOLOGICAL:
#         join_on = join_on & (sf.col("timestamp") >= sf.col(left_column("timestamp")))
#     joined = day.join(window, on=join_on, how="inner")
#     joined = joined.withColumn(
#         "diff", 
#         sf.abs(sf.col("amount") - sf.col(left_column("amount"))) / (sf.col("amount") + sf.col(left_column("amount")))
#     ).where(sf.col("diff") <= MAX_ALLOWED_DIFF).select(
#         sf.col(left_column("transaction_id")).alias("src"),
#         sf.col("transaction_id").alias("dst"),
#         (
#             sf.unix_timestamp(sf.col("timestamp")) - 
#             sf.unix_timestamp(sf.col(left_column("timestamp")))
#         ).alias("delta"),
#         "diff",
#     )
#     joined.write.parquet(f"{LOCATION}/start={date}")
#     if not (index % 30):
#         print(date, round(time.time() - start))
#         start = time.time()

In [None]:
edges_temporal = spark.read.parquet(LOCATION)
nodes_temporal = data.withColumnRenamed("transaction_id", "id")
graph_temporal = GraphFrame(nodes_temporal, edges_temporal)
graph_temporal = graph_temporal.dropIsolatedVertices()

In [None]:
%%time

location = "cyclic_transactions.parquet"
cyclic_transactions = graph_temporal.vertices.toPandas()
cyclic_transactions.to_parquet(location)
cyclic_transactions = pd.read_parquet(location)

In [None]:
edges_temporal_pd = edges_temporal.toPandas()
edges_temporal_pd.loc[:, "distance"] = (
    np.abs(edges_temporal_pd["delta"] / edges_temporal_pd["delta"].max()) +
    (edges_temporal_pd["diff"] / edges_temporal_pd["diff"].max())
)
graph = ig.Graph.DataFrame(edges_temporal_pd[["src", "dst", "distance"]], use_vids=False, directed=True)

In [None]:
nodes = [x["name"] for x in graph.vs()]
mapping = {index: x for index, x in enumerate(nodes)}

In [None]:
transactions = spark.createDataFrame(pd.DataFrame(nodes, columns=["transaction_id"]))
transactions = transactions.join(data, on="transaction_id", how="left").toPandas()
transactions = transactions.set_index("transaction_id")

In [None]:
cc = graph.connected_components(mode="weak")
cc = sorted([(x, len(x)) for x in cc], reverse=True, key=lambda x: x[1])
cc_sizes = [x[1] for x in cc]
largest_component_size = cc_sizes[0]
second_largest_component_size = cc_sizes[1]
largest_and_second_largest = largest_component_size + second_largest_component_size
between_500_and_3000 = sum([x for x in cc_sizes if (x > 500) and (x < 3000)])
between_100_and_500 = sum([x for x in cc_sizes if (x >= 100) and (x <= 500)])
less_than_100 = sum([x for x in cc_sizes if x < 100])
bplot = (
    ("largest_component", largest_component_size),
    ("largest_2_components", largest_and_second_largest),
    ("between_500_and_3000", between_500_and_3000),
    ("between_100_and_500", between_100_and_500),
    ("less_than_100", less_than_100),
)
bplot = pd.DataFrame([dict(bplot)]).loc[:, [x[0] for x in bplot]].T
bplot = bplot.rename(columns={0: "Count"})
sns.set_theme(rc={"figure.figsize":(8.7, 5.27)})
bplot.plot.barh(color=["C2"])

In [None]:
source_transactions = transactions.reset_index().groupby("source")["transaction_id"].agg(list)
target_transactions = transactions.reset_index().groupby("target")["transaction_id"].agg(list)

In [None]:
%%time

all_cycles = []
start = time.time()
for counter, (index, sources) in enumerate(source_transactions.items()):
    if not (counter % 100_000):
        print(counter, round(time.time() - start))
        start = time.time()
    targets = set(target_transactions.get(index, [])).difference(sources)
    if not targets:
        continue
    neighbors = graph.neighborhood(sources, mode="out", order=12, mindist=1)
    neighbors = [(source, {mapping[x] for x in i}.intersection(targets)) for source, i in zip(sources, neighbors)]
    neighbors = [x for x in neighbors if x[1]]
    if neighbors:
        all_cycles += neighbors

In [None]:
%%time

cycle_neighbours = graph.neighborhood([x[0] for x in all_cycles], mode="out", order=12, mindist=0)
all_cycles_paths = {}
start = time.time()
for index, (source, end_transactions) in enumerate(all_cycles):
    end_transactions = list(end_transactions)
    ind_graph = graph.induced_subgraph(cycle_neighbours[index])
    cycle_paths = ind_graph.get_all_shortest_paths(
        source, to=end_transactions, mode="out", weights="distance"
    )
    mapping_new = {x.index: x["name"] for x in ind_graph.vs()}
    cycle_paths = [[mapping_new[x] for x in i] + [source] for i in cycle_paths]
    all_cycles_paths[source] = cycle_paths
    if not (index % 50_000):
        print(index, round(time.time() - start))
        start = time.time()

In [None]:
%%time

all_cycles_count = [len(x) for x in all_cycles_paths.values()]
all_cycles_flattened = [set(x) for y in all_cycles_paths.values() for x in y]
len(all_cycles_flattened)

In [None]:
sorted(flows[flows["type"].isin(["cycle", "random"])].sub_type.unique())

In [None]:
flows_trx_ids = flows.groupby("id")[]

In [None]:
found_flow = []
not_found_flow = []
all_ids = set(transactions.index)
for i, flow_trx_ids in enumerate(flows_trx_ids):
    if not flow_trx_ids.issubset(all_ids):
        not_found_flow.append(flow_trx_ids)
    else:
        found_flow.append(flow_trx_ids)

In [None]:
len(not_found_flow), len(found_flow), len(flows_trx_ids)

In [None]:
flows_trx_ids = [set(x) for x in flows[(flows["type"] == "cycle")].groupby("id")["transaction_id"].agg(list)]
found = set()
start = time.time()
for i, flow_trx_ids in enumerate(flows_trx_ids):
    s_i = len(i)
    for c in all_cycles_flattened:
        if len(c) < s_i:
            continue
        if flow_trx_ids.issubset(c):
            found.add(i)
            break
    if not (i % 100):
        print(i, round(time.time() - start))
        start = time.time()

In [None]:
delta = round(time.time() - start_script)
print(f"Script executed in {timedelta(seconds=delta)}")