In [1]:
import os
import sys

import pandas as pd
from pyspark.sql import types as st
from pyspark.sql import functions as sf
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.storagelevel import StorageLevel

import settings as s

In [2]:
if (
    sys.version_info.major, 
    sys.version_info.minor, 
    sys.version_info.micro,
) != (3, 9, 19):
    raise EnvironmentError("Only runs efficiently on Python 3.9.19 | conda 24.1.2 | Apple M3 Pro")

In [3]:
config = [
    ("spark.jars.packages", "graphframes:graphframes:0.8.3-spark3.5-s_2.13"),
    ("spark.driver.memory", "8g"),
    ("spark.worker.memory", "8g"),
]
spark = SparkSession.builder.appName("testing").config(conf=SparkConf().setAll(config)).getOrCreate()

:: loading settings :: url = jar:file:/opt/anaconda3/envs/redirect-39/lib/python3.9/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /Users/haseeb.tariq/.ivy2/cache
The jars for the packages stored in: /Users/haseeb.tariq/.ivy2/jars
graphframes#graphframes added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-93bd8962-362a-4288-9f21-6d6187c1764f;1.0
	confs: [default]
	found graphframes#graphframes;0.8.3-spark3.5-s_2.13 in spark-packages
	found org.slf4j#slf4j-api;1.7.16 in central
:: resolution report :: resolve 67ms :: artifacts dl 2ms
	:: modules in use:
	graphframes#graphframes;0.8.3-spark3.5-s_2.13 from spark-packages in [default]
	org.slf4j#slf4j-api;1.7.16 from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	---------------------------------------------------------------------
	|      default     |   2   |   0   |   0   |   0   ||   2   |   0   |
	--------------------

In [4]:
def get_id(value):
    return f"id-{hash(value)}"

In [5]:
%%time

try:
    os.remove(s.STAGED_DATA_CSV_LOCATION)
except FileNotFoundError:
    pass

with open(s.DATA_FILE) as in_file:
    cnt = 0
    lines = ""
    for line in in_file:
        cnt += 1
        if cnt == 1:
            continue
        line = line.strip()
        lines += f"{get_id(line)},{line}\n"
        if not (cnt % 2e7):
            print(cnt)
            with open(s.STAGED_DATA_CSV_LOCATION, "a") as out_file:
                out_file.write(lines)
                lines = ""
if lines:
    lines = lines.strip()
    with open(s.STAGED_DATA_CSV_LOCATION, "a") as out_file:
        out_file.write(lines)
        del lines

20000000
40000000
60000000
80000000
100000000
120000000
140000000
160000000
CPU times: user 1min 14s, sys: 6.52 s, total: 1min 21s
Wall time: 1min 24s


In [6]:
try:
    os.remove(s.STAGED_PATTERNS_CSV_LOCATION)
except FileNotFoundError:
    pass

lines = ""
with open(s.PATTERNS_FILE) as in_file:
    for line in in_file:
        line = line.strip()
        if line[:4].isnumeric():
            lines += f"{get_id(line)},{line}\n"
        else:
            lines += f"{line}\n"

lines = lines.strip()
with open(s.STAGED_PATTERNS_CSV_LOCATION, "a") as out_file:
    out_file.write(lines)
    del lines

In [7]:
schema = st.StructType(
    [
        st.StructField("transaction_id", st.StringType(), False),
        st.StructField("timestamp", st.TimestampType(), False),
        st.StructField("source_bank", st.StringType(), False),
        st.StructField("source", st.StringType(), False),
        st.StructField("target_bank", st.StringType(), False),
        st.StructField("target", st.StringType(), False),
        st.StructField("received_amount", st.FloatType(), False),
        st.StructField("receiving_currency", st.StringType(), False),
        st.StructField("sent_amount", st.FloatType(), False),
        st.StructField("sending_currency", st.StringType(), False),
        st.StructField("format", st.StringType(), False),
        st.StructField("is_laundering", st.IntegerType(), False),
    ]
)
columns = [x.name for x in schema]

In [8]:
with open(s.STAGED_PATTERNS_CSV_LOCATION, "r") as fl:
    patterns = fl.read()

cases = []
case_id = 0
for pattern in patterns.split("\n\n"):
    case_id += 1
    if not pattern.strip():
        continue
    pattern = pattern.split("\n")
    name = pattern.pop(0).split(" - ")[1]
    category, sub_category = name, name
    if ": " in name:
        category, sub_category = name.split(": ")
    pattern.pop()
    case = pd.DataFrame([x.split(",") for x in pattern], columns=columns)
    case.loc[:, "id"] = case_id
    case.loc[:, "type"] = category
    case.loc[:, "sub_type"] = sub_category
    cases.append(case)
cases = pd.concat(cases, ignore_index=True)
cases = spark.createDataFrame(cases)
cases = cases.withColumn("timestamp", sf.to_timestamp("timestamp", s.TIMESTAMP_FORMAT))
cases = cases.select(
    "transaction_id", "id", "type", "sub_type"
)

In [9]:
CURRENCY_MAPPING = {
    "Australian Dollar": "aud",
    "Bitcoin": "btc",
    "Brazil Real": "brl",
    "Canadian Dollar": "cad",
    "Euro": "eur",
    "Mexican Peso": "mxn",
    "Ruble": "rub",
    "Rupee": "inr",
    "Saudi Riyal": "sar",
    "Shekel": "ils",
    "Swiss Franc": "chf",
    "UK Pound": "gbp",
    "US Dollar": "usd",
    "Yen": "jpy",
    "Yuan": "cny"
}

currency_code = sf.udf(lambda x: CURRENCY_MAPPING[x], st.StringType())

In [10]:
%%time

data = spark.read.csv(
    s.STAGED_DATA_CSV_LOCATION, header=False, schema=schema, timestampFormat=s.TIMESTAMP_FORMAT
)
group_by = [
    "timestamp",
    "source_bank",
    "source",
    "target_bank",
    "target",
    "receiving_currency",
    "sending_currency",
    "format",
 ]
data = data.groupby(group_by).agg(
    sf.first("transaction_id").alias("transaction_id"),
    sf.collect_set("transaction_id").alias("transaction_ids"),
    sf.sum("received_amount").alias("received_amount"), 
    sf.sum("sent_amount").alias("sent_amount"),
    sf.max("is_laundering").alias("is_laundering"),
)
data = data.withColumn(
    "source_currency", currency_code(sf.col("sending_currency"))
).withColumn(
    "target_currency", currency_code(sf.col("receiving_currency")),
)
data = data.join(cases, on="transaction_id", how="left").repartition(128, "transaction_id")
data = data.select(
    "transaction_id",
    "transaction_ids",
    "timestamp",
    sf.concat(sf.col("source"), sf.lit("-"), sf.col("source_currency")).alias("source"),
    sf.concat(sf.col("target"), sf.lit("-"), sf.col("target_currency")).alias("target"),
    "source_bank",
    "target_bank",
    "source_currency",
    "target_currency",
    sf.col("sent_amount").alias("source_amount"),
    sf.col("received_amount").alias("target_amount"),
    "format",
    "is_laundering",
).persist(StorageLevel.DISK_ONLY)
data.count()

24/06/23 17:11:15 WARN TaskSetManager: Stage 1 contains a task of very large size (1849 KiB). The maximum recommended task size is 1000 KiB.

CPU times: user 211 ms, sys: 74 ms, total: 285 ms
Wall time: 6min 38s


                                                                                

179504480

In [11]:
data.select(sf.explode("transaction_ids")).count() - data.count()

                                                                                

197599

In [12]:
cases_data = cases.join(
    data.withColumnRenamed("transaction_id", "x").drop(*cases.columns).select(
        sf.explode("transaction_ids").alias("transaction_id"), "*"
    ),
    on="transaction_id",
    how="left",
).drop("is_laundering", "transaction_id", "transaction_ids").withColumnRenamed("x", "transaction_id")
cases_data.toPandas().to_parquet(s.STAGED_CASES_DATA_LOCATION)
cases_data = pd.read_parquet(s.STAGED_CASES_DATA_LOCATION)

24/06/23 17:15:51 WARN TaskSetManager: Stage 30 contains a task of very large size (1849 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

In [13]:
data = data.drop("transaction_ids")
data.write.parquet(s.STAGED_DATA_LOCATION, mode="overwrite")
data = spark.read.parquet(s.STAGED_DATA_LOCATION)

                                                                                

In [24]:
assert data.count() == data.select("transaction_id").distinct().count()

                                                                                

In [14]:
data.count(), cases_data.shape[0], cases_data["transaction_id"].nunique(), cases_data["id"].nunique()

(179504480, 137936, 137933, 16467)