In [None]:
import math
import os
import pickle
import time
import shutil
import sys
import uuid
from collections import defaultdict, Counter
from datetime import timedelta, date
from glob import glob

import leidenalg as la
import igraph as ig
import numpy as np
import pandas as pd
import psutil
from pyspark.sql import functions as sf, types as st
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.storagelevel import StorageLevel
from graphframes import GraphFrame

import settings as s

%load_ext autoreload
%autoreload 2

In [None]:
if (
    sys.version_info.major, 
    sys.version_info.minor, 
    sys.version_info.micro,
) != (3, 11, 8):
    raise EnvironmentError("Only runs efficiently on Python 3.11.7 | conda 24.1.2 | Apple M3 Pro")

In [None]:
config = [
    ("spark.jars.packages", "graphframes:graphframes:0.8.3-spark3.5-s_2.12"),
    ("spark.driver.memory", "8g"),
    ("spark.worker.memory", "8g"),
]
spark = SparkSession.builder.appName("testing").config(conf=SparkConf().setAll(config)).getOrCreate()

In [None]:
start_script = time.time()

In [None]:
TEMPORAL_CURRENCY_LIMITS = {
    "btc": 1,
    "gbp": 50,
    "eur": 100,
    "usd": 100,
    "cad": 100,
    "aud": 100,
    "chf": 100,
    "sar": 100,
    "ils": 100,
    "cny": 1_000,
    "inr": 5_000,
    "rub": 5_000,
    "brl": 5_000,
    "jpy": 10_000,
    "mxn": 10_000,
}

In [None]:
def left_column(column):
    return f"{column}_left"

def update_source_target(input_data):
    return input_data.withColumn(
        "source", sf.substring("source", 0, 8)
    ).withColumn(
        "target", sf.substring("target", 0, 8)
    )

In [None]:
%%time

data = spark.read.parquet(s.STAGED_DATA_LOCATION)
data = data.where(sf.col("source") != sf.col("target"))
data = data.where(sf.col("format").isin(["ACH", "Bitcoin"]))
data_currency = data.where(sf.lit(False))
for currency, limit in TEMPORAL_CURRENCY_LIMITS.items():
    data_currency = data_currency.union(
        data.where(
            (sf.col("source_currency") == currency) &
            (sf.col("target_currency") == currency)
        ).where(sf.col("source_amount") >= limit)
    )
data = data_currency.repartition(11, "source", "target")
data = data.select(
    "transaction_id", "timestamp", "source", "target", 
    sf.col("source_currency").alias("currency"), sf.col("source_amount").alias("amount"),
)
data = update_source_target(data).persist(storageLevel=StorageLevel.DISK_ONLY)
print(f"\n{data.count():,}\n")
dates = sorted(
    data.select(sf.col("timestamp").astype(st.DateType()).alias("x")).distinct().toPandas()["x"]
)

In [None]:
LOCATION = "staging-temporal"

In [None]:
%%time

WINDOW_SIZE = 35
L1_AMOUNT = 1000
MAX_ALLOWED_DIFF_L1 = 0.15
MAX_ALLOWED_DIFF_L2 = 0.25

shutil.rmtree(LOCATION, ignore_errors=True)
os.mkdir(LOCATION)
start = time.time()
for index, date in enumerate(dates):
    start_date = date - timedelta(days=WINDOW_SIZE)
    end_date = date + timedelta(days=WINDOW_SIZE)
    day = data.where(sf.col("timestamp").astype(st.DateType()) == date)
    day = day.select(*[sf.col(x).alias(left_column(x)) for x in day.columns])
    window = data.where(
        (sf.col("timestamp").astype(st.DateType()) >= start_date) &
        (sf.col("timestamp").astype(st.DateType()) <= end_date)
    )
    joined = day.join(
        window, 
        on=(
            (sf.col("source") == sf.col(left_column("target")))
        ),
        how="inner"
    )
    join_diff_currency = joined.where(sf.col("currency") != sf.col(left_column("currency"))).withColumn(
        "diff", sf.lit(0)
    )
    join_same_currency = joined.where(sf.col("currency") == sf.col(left_column("currency"))).withColumn(
        "diff", 
        sf.abs(sf.col("amount") - sf.col(left_column("amount"))) / (sf.col("amount") + sf.col(left_column("amount")))
    )
    join_same_currency_l1 = join_same_currency.where(
        (sf.col("amount") <= L1_AMOUNT) | (sf.col(left_column("amount")) <= L1_AMOUNT)
    )
    join_same_currency_l2 = join_same_currency.where(
        (sf.col("amount") > L1_AMOUNT) & (sf.col(left_column("amount")) > L1_AMOUNT)
    )
    combined = join_diff_currency.union(
        join_same_currency_l1.where(sf.col("diff") <= MAX_ALLOWED_DIFF_L1)
    ).union(
        join_same_currency_l2.where(sf.col("diff") <= MAX_ALLOWED_DIFF_L2)
    ).select(
        sf.col(left_column("transaction_id")).alias("src"),
        sf.col("transaction_id").alias("dst"),
        (
            sf.unix_timestamp(sf.col("timestamp")) - 
            sf.unix_timestamp(sf.col(left_column("timestamp")))
        ).alias("delta"),
        "diff",
    )
    combined.write.parquet(f"{LOCATION}/start={date}")
    if not (index % 30):
        print(date, round(time.time() - start))
        start = time.time()

In [None]:
edges_temporal = spark.read.parquet(LOCATION)
nodes_temporal = data.withColumnRenamed("transaction_id", "id")
graph_temporal = GraphFrame(nodes_temporal, edges_temporal)
graph_temporal = graph_temporal.dropIsolatedVertices()

In [None]:
location = "cyclic_transactions.parquet"
cyclic_transactions = graph_temporal.vertices.toPandas()
cyclic_transactions.to_parquet(location)
cyclic_transactions = pd.read_parquet(location)

In [None]:
round((cyclic_transactions.shape[0] / data.count()) * 100, 2)

In [None]:
flows = pd.read_parquet("flows.parquet")
flow_stats = pd.read_parquet("flow_stats.parquet")

In [None]:
trx_ids = cyclic_transactions["id"].unique().tolist()
flows_trx_ids = flows[
    (flows["type"].isin(["CYCLE", "RANDOM"])) &
    (flows["sub_type"] != " Max 1 hops")
]["transaction_id"].unique().tolist()

In [None]:
round((len(set(trx_ids).intersection(flows_trx_ids)) / len(flows_trx_ids)) * 100, 2)

In [None]:
not_found = set(flows_trx_ids) - set(trx_ids).intersection(flows_trx_ids)
len(not_found)

In [None]:
not_found_flows = flows[flows["transaction_id"].isin(not_found)]

In [None]:
def get_motif(number_of_hops):
    motif = ""
    for hop in range(number_of_hops):
        motif += f"(x{hop})-[e{hop}]->(x{hop + 1});"
    return motif.rstrip(";")

In [None]:
%%time

motifs = graph_temporal.find(get_motif(2))
motif_edges = motifs.select("x0.source")
motif_edges.count()

In [None]:
delta = round(time.time() - start_script)
print(f"Script executed in {timedelta(seconds=delta)}")

In [None]:
35_178_540