## Please download the following dataset:
* https://www.kaggle.com/datasets/ealtman2019/ibm-transactions-for-anti-money-laundering-aml
## And extract the files to ./synthetic-data/

In [None]:
import os
import sys

import pandas as pd
from pyspark.sql import types as st
from pyspark.sql import functions as sf
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pyspark.storagelevel import StorageLevel

import settings as s

In [None]:
if (
    sys.version_info.major,
    sys.version_info.minor,
    sys.version_info.micro,
) != (3, 9, 19):
    raise EnvironmentError(
        "Only runs efficiently on Python 3.9.19 (Tested on: Conda 24.1.2 | Apple M3 Pro)"
    )

In [None]:
config = [
    ("spark.driver.memory", "8g"),
    ("spark.worker.memory", "8g"),
]
spark = (
    SparkSession.builder.appName("testing")
    .config(conf=SparkConf().setAll(config))
    .getOrCreate()
)

In [None]:
def get_id(value):
    return f"id-{hash(value)}"

In [None]:
%%time

try:
    os.remove(s.STAGED_DATA_CSV_LOCATION)
except FileNotFoundError:
    pass


mapping = {}
with open(s.DATA_FILE) as in_file:
    cnt = -1
    lines = ""
    for line in in_file:
        cnt += 1
        if cnt == 0:
            continue
        line = line.strip()
        line_id = get_id(line)
        mapping[line_id] = cnt
        lines += f"{cnt},{line}\n"
        if not (cnt % 2e7):
            print(cnt)
            with open(s.STAGED_DATA_CSV_LOCATION, "a") as out_file:
                out_file.write(lines)
                lines = ""
if lines:
    lines = lines.strip()
    with open(s.STAGED_DATA_CSV_LOCATION, "a") as out_file:
        out_file.write(lines)
        del lines

In [None]:
try:
    os.remove(s.STAGED_PATTERNS_CSV_LOCATION)
except FileNotFoundError:
    pass

lines = ""
with open(s.PATTERNS_FILE) as in_file:
    for line in in_file:
        line = line.strip()
        if line[:4].isnumeric():
            line_id = get_id(line)
            cnt = mapping[line_id]
            lines += f"{cnt},{line}\n"
        else:
            lines += f"{line}\n"

lines = lines.strip()
with open(s.STAGED_PATTERNS_CSV_LOCATION, "a") as out_file:
    out_file.write(lines)
    del lines

In [None]:
del mapping

In [None]:
schema = st.StructType(
    [
        st.StructField("transaction_id", st.IntegerType(), False),
        st.StructField("timestamp", st.TimestampType(), False),
        st.StructField("source_bank", st.StringType(), False),
        st.StructField("source", st.StringType(), False),
        st.StructField("target_bank", st.StringType(), False),
        st.StructField("target", st.StringType(), False),
        st.StructField("received_amount", st.FloatType(), False),
        st.StructField("receiving_currency", st.StringType(), False),
        st.StructField("sent_amount", st.FloatType(), False),
        st.StructField("sending_currency", st.StringType(), False),
        st.StructField("format", st.StringType(), False),
        st.StructField("is_laundering", st.IntegerType(), False),
    ]
)
columns = [x.name for x in schema]

In [None]:
with open(s.STAGED_PATTERNS_CSV_LOCATION, "r") as fl:
    patterns = fl.read()

cases = []
case_id = 0
for pattern in patterns.split("\n\n"):
    case_id += 1
    if not pattern.strip():
        continue
    pattern = pattern.split("\n")
    name = pattern.pop(0).split(" - ")[1]
    category, sub_category = name, name
    if ": " in name:
        category, sub_category = name.split(": ")
    pattern.pop()
    case = pd.DataFrame([x.split(",") for x in pattern], columns=columns)
    case.loc[:, "id"] = case_id
    case.loc[:, "type"] = category.strip().lower()
    case.loc[:, "sub_type"] = sub_category.strip().lower()
    cases.append(case)
cases = pd.concat(cases, ignore_index=True)
cases = spark.createDataFrame(cases)
cases = cases.withColumn("timestamp", sf.to_timestamp("timestamp", s.TIMESTAMP_FORMAT))
cases = cases.select("transaction_id", "id", "type", "sub_type")

In [None]:
CURRENCY_MAPPING = {
    "Australian Dollar": "aud",
    "Bitcoin": "btc",
    "Brazil Real": "brl",
    "Canadian Dollar": "cad",
    "Euro": "eur",
    "Mexican Peso": "mxn",
    "Ruble": "rub",
    "Rupee": "inr",
    "Saudi Riyal": "sar",
    "Shekel": "ils",
    "Swiss Franc": "chf",
    "UK Pound": "gbp",
    "US Dollar": "usd",
    "Yen": "jpy",
    "Yuan": "cny",
}

currency_code = sf.udf(lambda x: CURRENCY_MAPPING[x], st.StringType())

In [None]:
%%time

data = spark.read.csv(
    s.STAGED_DATA_CSV_LOCATION,
    header=False,
    schema=schema,
    timestampFormat=s.TIMESTAMP_FORMAT,
)
group_by = [
    "timestamp",
    "source_bank",
    "source",
    "target_bank",
    "target",
    "receiving_currency",
    "sending_currency",
    "format",
]
data = data.groupby(group_by).agg(
    sf.first("transaction_id").alias("transaction_id"),
    sf.collect_set("transaction_id").alias("transaction_ids"),
    sf.sum("received_amount").alias("received_amount"),
    sf.sum("sent_amount").alias("sent_amount"),
    sf.max("is_laundering").alias("is_laundering"),
)
data = data.withColumn(
    "source_currency", currency_code(sf.col("sending_currency"))
).withColumn(
    "target_currency",
    currency_code(sf.col("receiving_currency")),
)
data = data.join(cases, on="transaction_id", how="left").repartition(
    128, "transaction_id"
)
data = data.select(
    "transaction_id",
    "transaction_ids",
    "timestamp",
    sf.concat(sf.col("source"), sf.lit("-"), sf.col("source_currency")).alias("source"),
    sf.concat(sf.col("target"), sf.lit("-"), sf.col("target_currency")).alias("target"),
    "source_bank",
    "target_bank",
    "source_currency",
    "target_currency",
    sf.col("sent_amount").alias("source_amount"),
    sf.col("received_amount").alias("target_amount"),
    "format",
    "is_laundering",
).persist(StorageLevel.DISK_ONLY)
data.count()

In [None]:
data.select(sf.explode("transaction_ids")).count() - data.count()

In [None]:
cases_data = (
    cases.join(
        data.withColumnRenamed("transaction_id", "x")
        .drop(*cases.columns)
        .select(sf.explode("transaction_ids").alias("transaction_id"), "*"),
        on="transaction_id",
        how="left",
    )
    .drop("is_laundering", "transaction_id", "transaction_ids")
    .withColumnRenamed("x", "transaction_id")
)
cases_data.toPandas().to_parquet(s.STAGED_CASES_DATA_LOCATION)
cases_data = pd.read_parquet(s.STAGED_CASES_DATA_LOCATION)

In [None]:
data = data.drop("transaction_ids")
data.write.parquet(s.STAGED_DATA_LOCATION, mode="overwrite")
data = spark.read.parquet(s.STAGED_DATA_LOCATION)

In [None]:
assert data.count() == data.select("transaction_id").distinct().count()

In [None]:
data.count(), cases_data.shape[0], cases_data["transaction_id"].nunique(), cases_data[
    "id"
].nunique()
# (179504480, 137936, 137933, 16467)