<a href="https://colab.research.google.com/github/caetano-dev/PixFraudDetection/blob/main/TCC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install pandas
!pip install pyarrow



## Data Loading and Preprocessing with DuckDB

1. **Mount Google Drive:** Mounts your Google Drive to access the data files.
2. **Define File Paths:** Sets up paths for the raw data files and the directory for processed data, handling both large and small datasets.
3. **Define Column Structures:** Specifies the standard column names and their expected data types.
4. **Parse Patterns File:** Reads and parses a text file containing money laundering patterns into a pandas DataFrame.
5. **Define Currencies:** Lists the specific currencies that will be processed.
6. **Initialize DuckDB:** Sets up an in-memory DuckDB database for efficient data processing.
7. **Define SQL Queries:** Creates SQL queries using DuckDB syntax to read the raw transaction data, parse timestamps, and cast columns to appropriate types.
8. **Filter and Save Data by Currency:** Iterates through the defined currencies, applies filters to the transaction data (normal and laundering) and account data using DuckDB, and saves the filtered data into separate Parquet files for each currency. This includes joining with the parsed patterns data to identify laundering transactions.

In [8]:
import os
import re
import duckdb
import pandas as pd
from google.colab import drive

drive.mount('/content/drive')
DRIVE_DIR = '/content/drive/MyDrive/AML'

#PROCESSED_DIR = os.path.join(DRIVE_DIR, 'processed/')
#TX_CSV = os.path.join(DRIVE_DIR, 'HI-Large_Trans.csv')
#PATTERNS_TXT = os.path.join(DRIVE_DIR, 'HI-Large_Patterns.txt')
#ACCOUNTS_CSV = os.path.join(DRIVE_DIR, 'HI-Large_Accounts.csv')

PROCESSED_DIR = os.path.join(DRIVE_DIR, 'processed/small')
TX_CSV = os.path.join(DRIVE_DIR, 'HI-Small_Trans.csv')
PATTERNS_TXT = os.path.join(DRIVE_DIR, 'HI-Small_Patterns.txt')
ACCOUNTS_CSV = os.path.join(DRIVE_DIR, 'HI-Small_Accounts.csv')

os.makedirs(PROCESSED_DIR, exist_ok=True)

if not os.path.exists(TX_CSV):
    raise FileNotFoundError(f"Transaction file not found: {TX_CSV}")
else:
    print(f"Found data folder: {DRIVE_DIR}")
    print("-" * 50)

standard_columns = [
    'timestamp', 'from_bank', 'from_account', 'to_bank', 'to_account',
    'amount_received', 'currency_received', 'amount_sent', 'currency_sent',
    'payment_type', 'is_laundering'
]

column_types = {
    'timestamp': 'VARCHAR',
    'from_bank': 'VARCHAR',
    'from_account': 'VARCHAR',
    'to_bank': 'VARCHAR',
    'to_account': 'VARCHAR',
    'amount_received': 'VARCHAR',
    'currency_received': 'VARCHAR',
    'amount_sent': 'VARCHAR',
    'currency_sent': 'VARCHAR',
    'payment_type': 'VARCHAR',
    'is_laundering': 'VARCHAR'
}

def parse_patterns_file(file_path):
    attempts = []
    current_attempt = None
    attempt_counter = 0

    with open(file_path, 'r') as f:
        for raw in f:
            line = raw.strip()
            if not line:
                continue
            if line.startswith('BEGIN LAUNDERING ATTEMPT'):
                attempt_counter += 1
                m = re.search(r'BEGIN LAUNDERING ATTEMPT\s*-\s*(.+)$', line)
                attempt_type = m.group(1).strip() if m else 'UNKNOWN'
                current_attempt = {
                    'attempt_id': attempt_counter,
                    'attempt_type': attempt_type,
                    'transactions': []
                }
            elif line.startswith('END LAUNDERING ATTEMPT'):
                if current_attempt:
                    attempts.append(current_attempt)
                current_attempt = None
            elif current_attempt:
                parts = [p.strip() for p in line.split(',')]
                if len(parts) >= 11:
                    tx = dict(zip(standard_columns, parts[:11]))
                    tx['attempt_id'] = current_attempt['attempt_id']
                    tx['attempt_type'] = current_attempt['attempt_type']
                    current_attempt['transactions'].append(tx)

    all_transactions = [tx for attempt in attempts for tx in attempt['transactions']]
    return pd.DataFrame(all_transactions, columns=standard_columns + ['attempt_id', 'attempt_type'])

CURRENCIES = [
    "US Dollar",
    "Euro",
    "Yuan",
    "Shekel",
    "Canadian Dollar",
    "UK Pound",
    "Ruble",
    "Australian Dollar",
    "Swiss Franc",
    "Yen",
    "Mexican Peso",
    "Rupee",
    "Brazil Real",
    "Saudi Riyal"
]

con = duckdb.connect(database=':memory:')
con.execute("PRAGMA threads=8")

read_tx_csv_sql = f"""
  SELECT * FROM read_csv_auto(
    '{TX_CSV}',
    delim=',',
    header=false,
    columns={column_types},
    all_varchar=true
  )
"""

ts_parse_sql = """
CASE
  WHEN length(trim(timestamp)) = 16 THEN strptime(trim(timestamp), '%Y/%m/%d %H:%M')
  WHEN length(trim(timestamp)) = 19 THEN strptime(trim(timestamp), '%Y/%m/%d %H:%M:%S')
  ELSE NULL
END
"""

typed_tx_sql = f"""
WITH raw AS ({read_tx_csv_sql})
SELECT
  {ts_parse_sql}::TIMESTAMP AS timestamp,
  trim(from_bank) AS from_bank,
  trim(from_account) AS from_account,
  trim(to_bank) AS to_bank,
  trim(to_account) AS to_account,
  try_cast(nullif(trim(amount_received), '') AS DOUBLE) AS amount_received,
  trim(currency_received) AS currency_received,
  try_cast(nullif(trim(amount_sent), '') AS DOUBLE) AS amount_sent,
  trim(currency_sent) AS currency_sent,
  trim(payment_type) AS payment_type,
  coalesce(try_cast(nullif(trim(is_laundering), '') AS INTEGER), 0) AS is_laundering
FROM raw
"""

def currency_filter_sql(currency_name):
    return f"""
    upper(trim(currency_sent)) = upper('{currency_name}') AND
    upper(trim(currency_received)) = upper('{currency_name}') AND
    upper(trim(payment_type)) = 'ACH'
    """

patterns_df = parse_patterns_file(PATTERNS_TXT)
if patterns_df.empty:
    patterns_df = pd.DataFrame(columns=standard_columns + ['attempt_id', 'attempt_type'])
con.register('patterns_df', patterns_df)

for currency in CURRENCIES:
    cur_dirname = currency.replace(' ', '_')
    OUT_DIR = os.path.join(PROCESSED_DIR, cur_dirname)
    os.makedirs(OUT_DIR, exist_ok=True)

    OUT_STEP1 = os.path.join(OUT_DIR, '1_filtered_normal_transactions.parquet')
    OUT_STEP2 = os.path.join(OUT_DIR, '2_filtered_laundering_transactions.parquet')
    OUT_STEP3 = os.path.join(OUT_DIR, '3_filtered_accounts.parquet')

    filt_sql = currency_filter_sql(currency)

    # Step 1: normal transactions for this currency
    con.execute(f"""
      COPY (
        WITH typed AS ({typed_tx_sql})
        SELECT
          timestamp, from_bank, from_account, to_bank, to_account,
          amount_received, currency_received, amount_sent, currency_sent,
          payment_type, is_laundering
        FROM typed
        WHERE timestamp IS NOT NULL
          AND {filt_sql}
          AND is_laundering = 0
      ) TO '{OUT_STEP1}' (FORMAT PARQUET, COMPRESSION ZSTD)
    """)

    step1_rows = con.execute(f"SELECT COUNT(*) FROM read_parquet('{OUT_STEP1}')").fetchone()[0]
    print(f"[{currency}] Step 1: Saved normal transactions to '{OUT_STEP1}' (rows={step1_rows:,})")

    # Step 2: laundering transactions (from patterns + missing from CSV) for this currency
    con.execute(f"""
      COPY (
        WITH
          pat_raw AS (
            SELECT
              timestamp, from_bank, from_account, to_bank, to_account,
              amount_received, currency_received, amount_sent, currency_sent,
              payment_type, is_laundering,
              attempt_id,
              attempt_type
            FROM patterns_df
          ),
          pat_typed AS (
            SELECT
              {ts_parse_sql}::TIMESTAMP AS timestamp,
              trim(from_bank) AS from_bank,
              trim(from_account) AS from_account,
              trim(to_bank) AS to_bank,
              trim(to_account) AS to_account,
              try_cast(nullif(trim(amount_received), '') AS DOUBLE) AS amount_received,
              trim(currency_received) AS currency_received,
              try_cast(nullif(trim(amount_sent), '') AS DOUBLE) AS amount_sent,
              trim(currency_sent) AS currency_sent,
              trim(payment_type) AS payment_type,
              coalesce(try_cast(nullif(trim(is_laundering), '') AS INTEGER), 0) AS is_laundering,
              try_cast(attempt_id AS BIGINT) AS attempt_id,
              trim(attempt_type) AS attempt_type
            FROM pat_raw
          ),
          pat_filt AS (
            SELECT
              timestamp, from_bank, from_account, to_bank, to_account,
              amount_received, currency_received, amount_sent, currency_sent,
              payment_type, is_laundering, attempt_id, attempt_type,
              CAST(round(amount_sent * 100) AS BIGINT) AS amount_sent_c,
              CAST(round(amount_received * 100) AS BIGINT) AS amount_received_c
            FROM pat_typed
            WHERE timestamp IS NOT NULL
              AND {filt_sql}
              AND is_laundering = 1
          ),
          raw_pos AS (
            WITH typed AS ({typed_tx_sql})
            SELECT
              timestamp, from_bank, from_account, to_bank, to_account,
              amount_received, currency_received, amount_sent, currency_sent,
              payment_type, is_laundering,
              CAST(round(amount_sent * 100) AS BIGINT) AS amount_sent_c,
              CAST(round(amount_received * 100) AS BIGINT) AS amount_received_c
            FROM typed
            WHERE timestamp IS NOT NULL
              AND {filt_sql}
              AND is_laundering = 1
          ),
          missing AS (
            SELECT raw_pos.*
            FROM raw_pos
            LEFT JOIN pat_filt
              ON raw_pos.timestamp = pat_filt.timestamp
              AND raw_pos.from_bank = pat_filt.from_bank
              AND raw_pos.from_account = pat_filt.from_account
              AND raw_pos.to_bank = pat_filt.to_bank
              AND raw_pos.to_account = pat_filt.to_account
              AND raw_pos.amount_received_c = pat_filt.amount_received_c
              AND raw_pos.amount_sent_c = pat_filt.amount_sent_c
            WHERE pat_filt.timestamp IS NULL
          ),
          unioned AS (
            SELECT
              timestamp, from_bank, from_account, to_bank, to_account,
              amount_received, currency_received, amount_sent, currency_sent,
              payment_type, is_laundering,
              attempt_id, attempt_type
            FROM pat_filt
            UNION ALL
            SELECT
              timestamp, from_bank, from_account, to_bank, to_account,
              amount_received, currency_received, amount_sent, currency_sent,
              payment_type, is_laundering,
              NULL::INTEGER AS attempt_id, 'UNLISTED' AS attempt_type
            FROM missing
          )
        SELECT * FROM unioned
      ) TO '{OUT_STEP2}' (FORMAT PARQUET, COMPRESSION ZSTD)
    """)

    base_count = con.execute("""
      WITH x as (SELECT attempt_type FROM read_parquet(?) WHERE attempt_type <> 'UNLISTED')
      SELECT COUNT(*) FROM x
    """, [OUT_STEP2]).fetchone()[0]
    added_count = con.execute("""
      WITH x as (SELECT attempt_type FROM read_parquet(?) WHERE attempt_type = 'UNLISTED')
      SELECT COUNT(*) FROM x
    """, [OUT_STEP2]).fetchone()[0]
    total_count = con.execute(f"SELECT COUNT(*) FROM read_parquet('{OUT_STEP2}')").fetchone()[0]
    print(f"[{currency}] Step 2: Saved laundering transactions to '{OUT_STEP2}' (patterns={base_count:,}, added_from_csv={added_count:,}, total={total_count:,})")

    # Step 3: Filter accounts involved in either step1 or step2 for this currency
    con.execute(f"""
      COPY (
        WITH all_tx AS (
          SELECT
            timestamp, from_bank, from_account, to_bank, to_account,
            amount_received, currency_received, amount_sent, currency_sent,
            payment_type, is_laundering,
            NULL::INTEGER AS attempt_id, NULL::VARCHAR AS attempt_type
          FROM read_parquet('{OUT_STEP1}')
          UNION ALL
          SELECT
            timestamp, from_bank, from_account, to_bank, to_account,
            amount_received, currency_received, amount_sent, currency_sent,
            payment_type, is_laundering,
            attempt_id, attempt_type
          FROM read_parquet('{OUT_STEP2}')
        ),
        involved AS (
          SELECT DISTINCT from_account AS account FROM all_tx WHERE from_account IS NOT NULL
          UNION
          SELECT DISTINCT to_account AS account FROM all_tx WHERE to_account IS NOT NULL
        ),
        accounts AS (
          SELECT * FROM read_csv_auto(
            '{ACCOUNTS_CSV}',
            delim=',',
            header=false,
            columns={{'bank_name': 'VARCHAR', 'bank_id': 'VARCHAR', 'account_id_hex': 'VARCHAR', 'entity_id': 'VARCHAR', 'entity_name': 'VARCHAR'}},
            all_varchar=true
          )
        )
        SELECT a.*
        FROM accounts a
        INNER JOIN involved i
          ON trim(a.account_id_hex) = trim(i.account)
      ) TO '{OUT_STEP3}' (FORMAT PARQUET, COMPRESSION ZSTD)
    """)

    step3_rows = con.execute(f"SELECT COUNT(*) FROM read_parquet('{OUT_STEP3}')").fetchone()[0]
    print(f"[{currency}] Step 3: Saved filtered account details to '{OUT_STEP3}' (rows={step3_rows:,})")

con.close()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Found data folder: /content/drive/MyDrive/AML
--------------------------------------------------


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[US Dollar] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/US_Dollar/1_filtered_normal_transactions.parquet' (rows=199,982)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[US Dollar] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/US_Dollar/2_filtered_laundering_transactions.parquet' (patterns=1,178, added_from_csv=485, total=1,663)
[US Dollar] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/US_Dollar/3_filtered_accounts.parquet' (rows=93,102)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Euro] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Euro/1_filtered_normal_transactions.parquet' (rows=125,228)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Euro] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Euro/2_filtered_laundering_transactions.parquet' (patterns=886, added_from_csv=320, total=1,206)
[Euro] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Euro/3_filtered_accounts.parquet' (rows=57,220)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Yuan] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Yuan/1_filtered_normal_transactions.parquet' (rows=21,877)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Yuan] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Yuan/2_filtered_laundering_transactions.parquet' (patterns=107, added_from_csv=55, total=162)
[Yuan] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Yuan/3_filtered_accounts.parquet' (rows=10,088)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Shekel] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Shekel/1_filtered_normal_transactions.parquet' (rows=20,461)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Shekel] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Shekel/2_filtered_laundering_transactions.parquet' (patterns=25, added_from_csv=56, total=81)
[Shekel] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Shekel/3_filtered_accounts.parquet' (rows=9,377)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Canadian Dollar] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Canadian_Dollar/1_filtered_normal_transactions.parquet' (rows=15,732)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Canadian Dollar] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Canadian_Dollar/2_filtered_laundering_transactions.parquet' (patterns=76, added_from_csv=37, total=113)
[Canadian Dollar] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Canadian_Dollar/3_filtered_accounts.parquet' (rows=6,939)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[UK Pound] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/UK_Pound/1_filtered_normal_transactions.parquet' (rows=19,186)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[UK Pound] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/UK_Pound/2_filtered_laundering_transactions.parquet' (patterns=71, added_from_csv=35, total=106)
[UK Pound] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/UK_Pound/3_filtered_accounts.parquet' (rows=8,557)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Ruble] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Ruble/1_filtered_normal_transactions.parquet' (rows=16,430)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Ruble] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Ruble/2_filtered_laundering_transactions.parquet' (patterns=72, added_from_csv=43, total=115)
[Ruble] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Ruble/3_filtered_accounts.parquet' (rows=7,410)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Australian Dollar] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Australian_Dollar/1_filtered_normal_transactions.parquet' (rows=14,522)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Australian Dollar] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Australian_Dollar/2_filtered_laundering_transactions.parquet' (patterns=69, added_from_csv=42, total=111)
[Australian Dollar] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Australian_Dollar/3_filtered_accounts.parquet' (rows=6,711)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Swiss Franc] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Swiss_Franc/1_filtered_normal_transactions.parquet' (rows=25,236)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Swiss Franc] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Swiss_Franc/2_filtered_laundering_transactions.parquet' (patterns=114, added_from_csv=50, total=164)
[Swiss Franc] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Swiss_Franc/3_filtered_accounts.parquet' (rows=11,538)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Yen] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Yen/1_filtered_normal_transactions.parquet' (rows=16,586)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Yen] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Yen/2_filtered_laundering_transactions.parquet' (patterns=89, added_from_csv=43, total=132)
[Yen] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Yen/3_filtered_accounts.parquet' (rows=7,696)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Mexican Peso] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Mexican_Peso/1_filtered_normal_transactions.parquet' (rows=11,552)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Mexican Peso] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Mexican_Peso/2_filtered_laundering_transactions.parquet' (patterns=53, added_from_csv=26, total=79)
[Mexican Peso] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Mexican_Peso/3_filtered_accounts.parquet' (rows=5,323)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Rupee] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Rupee/1_filtered_normal_transactions.parquet' (rows=20,858)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Rupee] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Rupee/2_filtered_laundering_transactions.parquet' (patterns=111, added_from_csv=34, total=145)
[Rupee] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Rupee/3_filtered_accounts.parquet' (rows=9,390)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Brazil Real] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Brazil_Real/1_filtered_normal_transactions.parquet' (rows=7,885)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Brazil Real] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Brazil_Real/2_filtered_laundering_transactions.parquet' (patterns=21, added_from_csv=24, total=45)
[Brazil Real] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Brazil_Real/3_filtered_accounts.parquet' (rows=3,412)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Saudi Riyal] Step 1: Saved normal transactions to '/content/drive/MyDrive/AML/processed/small/Saudi_Riyal/1_filtered_normal_transactions.parquet' (rows=9,197)


FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

[Saudi Riyal] Step 2: Saved laundering transactions to '/content/drive/MyDrive/AML/processed/small/Saudi_Riyal/2_filtered_laundering_transactions.parquet' (patterns=336, added_from_csv=25, total=361)
[Saudi Riyal] Step 3: Saved filtered account details to '/content/drive/MyDrive/AML/processed/small/Saudi_Riyal/3_filtered_accounts.parquet' (rows=4,095)


In [13]:
import os
from pathlib import Path
DRIVE_BASE = Path('/content/drive/MyDrive/AML/processed/small/US_Dollar')
print("Normal transactions")
df = pd.read_parquet(DRIVE_BASE / '1_filtered_normal_transactions.parquet')
df.info()
print(df.head())
print("Laundering transactions")
df = pd.read_parquet(DRIVE_BASE / '2_filtered_laundering_transactions.parquet')
df.info()
print(df.head())
print("Bank accounts")
df = pd.read_parquet(DRIVE_BASE / '3_filtered_accounts.parquet')
df.info()
print(df.head())

Normal transactions
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 199982 entries, 0 to 199981
Data columns (total 11 columns):
 #   Column             Non-Null Count   Dtype         
---  ------             --------------   -----         
 0   timestamp          199982 non-null  datetime64[us]
 1   from_bank          199982 non-null  object        
 2   from_account       199982 non-null  object        
 3   to_bank            199982 non-null  object        
 4   to_account         199982 non-null  object        
 5   amount_received    199982 non-null  float64       
 6   currency_received  199982 non-null  object        
 7   amount_sent        199982 non-null  float64       
 8   currency_sent      199982 non-null  object        
 9   payment_type       199982 non-null  object        
 10  is_laundering      199982 non-null  int32         
dtypes: datetime64[us](1), float64(2), int32(1), object(7)
memory usage: 16.0+ MB
            timestamp from_bank from_account to_bank to_acc

In [1]:
!pip install networkx
!pip install python-louvain
!pip install community



In [9]:
import pandas as pd
import networkx as nx
import numpy as np
from pathlib import Path
from datetime import timedelta
from networkx.algorithms.community import louvain_communities as nx_louvain_communities
from google.colab import drive

drive.mount('/content/drive', force_remount=False)

DRIVE_BASE = Path('/content/drive/MyDrive/AML/processed/small/US_Dollar')
#DRIVE_BASE = Path('/content/drive/MyDrive/AML/processed/US_Dollar')
proc = DRIVE_BASE
p_norm = proc / '1_filtered_normal_transactions.parquet'
p_pos  = proc / '2_filtered_laundering_transactions.parquet'
p_acct = proc / '3_filtered_accounts.parquet'

SAVE_GPICKLE = True
GPICKLE_PATH = proc / 'G_all_multi.gpickle'

WINDOW_DAYS_LIST = [3, 7]
WINDOW_STRIDE_DAYS = 1
MAX_WINDOWS_PER_SETTING = 5  # set None to process all windows

LOUVAIN_RESOLUTION = 1.0
LOUVAIN_SEED = 42

def parse_ts(s: pd.Series) -> pd.Series:
    s = s.astype(str).str.strip()
    dt = pd.to_datetime(s, format='%Y/%m/%d %H:%M', errors='coerce')
    mask = dt.isna()
    if mask.any():
        dt2 = pd.to_datetime(s[mask], format='%Y/%m/%d %H:%M:%S', errors='coerce')
        dt.loc[mask] = dt2
    return dt

def to_cents(s: pd.Series) -> pd.Series:
    return pd.to_numeric(s, errors='coerce').mul(100).round().astype('Int64')

def summarize_graph(G: nx.MultiDiGraph, name='Graph'):
    pos_e = sum(1 for _,_,_,d in G.edges(keys=True, data=True) if int(d.get('is_laundering', 0)) == 1)
    print(f"{name}: {G.number_of_nodes():,} nodes, {G.number_of_edges():,} edges, positive_edges={pos_e:,}")

def derive_node_labels(G: nx.Graph) -> int:
    nx.set_node_attributes(G, 0, 'is_laundering_involved')
    for u, v, d in G.edges(data=True):
        if int(d.get('is_laundering', 0)) == 1:
            G.nodes[u]['is_laundering_involved'] = 1
            G.nodes[v]['is_laundering_involved'] = 1
    n_pos = sum(d.get('is_laundering_involved', 0) for _, d in G.nodes(data=True))
    return n_pos

def graph_time_range(G: nx.MultiDiGraph):
    ts = [d['timestamp'] for _,_,d in G.edges(data=True) if d.get('timestamp') is not None]
    return (min(ts), max(ts)) if ts else (None, None)

def iter_windows(start, end, window_days=3, stride_days=1):
    cur = start
    while cur < end:
        yield (cur, cur + timedelta(days=window_days))
        cur += timedelta(days=stride_days)

def get_windowed_graph_fast(df_slice: pd.DataFrame, base_graph: nx.MultiDiGraph) -> nx.MultiDiGraph:
    if len(df_slice) == 0:
        return nx.MultiDiGraph(name="win_empty")
    ts_min = df_slice['timestamp'].min()
    ts_max = df_slice['timestamp'].max()
    H = nx.MultiDiGraph(name=f"win_{ts_min:%Y%m%d}_{ts_max:%Y%m%d}")
    nodes_needed = set(df_slice['from_account'].astype(str)) | set(df_slice['to_account'].astype(str))
    for node in nodes_needed:
        if node in base_graph.nodes:
            attrs = dict(base_graph.nodes[node])
            attrs.pop('is_laundering_involved', None)
            H.add_node(node, **attrs)
        else:
            H.add_node(node)
    for idx, r in df_slice.iterrows():
        u = str(r['from_account']); v = str(r['to_account'])
        ts = r['timestamp']; t_ns = int(ts.value) if pd.notna(ts) else None
        attempt_id = r.get('attempt_id', None)
        attempt_type = r.get('attempt_type', None)
        if pd.isna(attempt_id): attempt_id = None
        if pd.isna(attempt_type): attempt_type = 'UNLISTED'
        H.add_edge(
            u, v,
            key=int(idx),
            timestamp=ts,
            t_ns=t_ns,
            is_laundering=int(r['is_laundering']),
            attempt_id=attempt_id,
            attempt_type=attempt_type,
            amount_sent_c=r['amount_sent_c'],
            amount_received_c=r['amount_received_c'],
            amount_sent=r['amount_sent'],
            amount_received=r['amount_received'],
            from_bank=str(r['from_bank']),
            to_bank=str(r['to_bank']),
            same_bank=bool(r['same_bank']),
            payment_type=r.get('payment_type', 'ACH'),
            currency_sent=r.get('currency_sent', 'US Dollar'),
            currency_received=r.get('currency_received', 'US Dollar'),
        )
    return H

def get_windowed_graph(G: nx.MultiDiGraph, start, end) -> nx.MultiDiGraph:
    H = nx.MultiDiGraph(name=f"win_{start:%Y%m%d}_{end:%Y%m%d}")
    for u, v, k, d in G.edges(keys=True, data=True):
        ts = d.get('timestamp')
        if ts is not None and (ts >= start) and (ts < end):
            if u not in H: H.add_node(u, **G.nodes[u])
            if v not in H: H.add_node(v, **G.nodes[v])
            H.add_edge(u, v, key=k, **d)
    return H

def aggregate_graph(G: nx.MultiDiGraph, directed=False) -> nx.Graph:
    H = nx.DiGraph() if directed else nx.Graph()
    edge_data = {}
    for u, v, d in G.edges(data=True):
        a, b = (u, v) if directed else tuple(sorted((u, v)))
        if not H.has_node(a): H.add_node(a, **G.nodes[a])
        if not H.has_node(b): H.add_node(b, **G.nodes[b])
        key = (a, b)
        edge_data.setdefault(key, []).append(d)
    for (a, b), edge_list in edge_data.items():
        w_count = len(edge_list)
        w_amount = sum(int(d.get('amount_received_c', 0) or 0) for d in edge_list)
        w_amount_log = np.log1p(w_amount) if w_amount > 0 else 0.0
        timestamps = [d.get('timestamp') for d in edge_list if d.get('timestamp') is not None]
        first_ts = min(timestamps) if timestamps else None
        last_ts = max(timestamps) if timestamps else None
        span_seconds = (last_ts - first_ts).total_seconds() if (first_ts and last_ts) else 0
        H.add_edge(a, b, w_count=w_count, w_amount=w_amount, w_amount_log=w_amount_log,
                   first_ts=first_ts, last_ts=last_ts, span_seconds=span_seconds)
    if directed:
        for u, v in H.edges():
            H[u][v]['reciprocated'] = 1 if H.has_edge(v, u) else 0
    for node in H.nodes():
        if directed:
            in_edges = [(pred, node) for pred in H.predecessors(node)]
            out_edges = [(node, succ) for succ in H.successors(node)]
            in_amount_sum  = sum(H[u][v].get('w_amount', 0) for u, v in in_edges)
            out_amount_sum = sum(H[u][v].get('w_amount', 0) for u, v in out_edges)
            in_deg = len(in_edges); out_deg = len(out_edges)
            in_tx_count  = sum(H[u][v].get('w_count', 0) for u, v in in_edges)
            out_tx_count = sum(H[u][v].get('w_count', 0) for u, v in out_edges)
        else:
            adj_edges = [(node, nbr) if node <= nbr else (nbr, node) for nbr in H.neighbors(node)]
            total_amount = sum(H[u][v].get('w_amount', 0) for u, v in adj_edges)
            total_tx     = sum(H[u][v].get('w_count', 0) for u, v in adj_edges)
            deg = len(adj_edges)
            in_amount_sum = out_amount_sum = total_amount
            in_deg = out_deg = deg
            in_tx_count = out_tx_count = total_tx
        in_out_amount_ratio = (in_amount_sum + 1) / (out_amount_sum + 1)
        H.nodes[node].update({
            'in_amount_sum': in_amount_sum,
            'out_amount_sum': out_amount_sum,
            'in_deg': in_deg, 'out_deg': out_deg,
            'in_tx_count': in_tx_count, 'out_tx_count': out_tx_count,
            'in_out_amount_ratio': in_out_amount_ratio
        })
    return H

def run_louvain(H: nx.Graph, resolution=1.0, seed=42, weight='w_amount_log'):
    comms = nx_louvain_communities(H, weight=weight, resolution=resolution, seed=seed)
    partition = {n: cid for cid, c in enumerate(comms) for n in c}
    return partition, [set(c) for c in comms]

def score_communities_unsupervised(H, comms: list, min_size=3):
    scores = {}
    for cid, nodes in enumerate(comms):
        if not nodes or len(nodes) < min_size:
            scores[cid] = 0.0
            continue
        sub = H.subgraph(nodes)
        n = len(nodes)
        max_edges = n*(n-1)/2 if n > 1 else 0
        internal_density = (sub.number_of_edges()/max_edges) if max_edges else 0.0
        try:
            avg_clust = nx.average_clustering(sub, weight='w_amount_log')
        except Exception:
            avg_clust = 0.0
        total_amount = sum(d.get('w_amount', 0) for _,_,d in sub.edges(data=True))
        amount_score = min(1.0, np.log1p(total_amount)/20)
        size_boost = 1 - np.exp(-n/10)
        scores[cid] = (0.35*internal_density + 0.2*avg_clust + 0.45*amount_score) * size_boost
    return scores

# -----------------------
# Data loading and build
# -----------------------
def load_processed():
    df_n = pd.read_parquet(p_norm)
    df_p = pd.read_parquet(p_pos)
    df = pd.concat([df_n, df_p], ignore_index=True)
    df.sort_values('timestamp', inplace=True)
    df['is_laundering'] = pd.to_numeric(df['is_laundering'], errors='coerce').fillna(0).astype('int8')
    df['amount_sent_c'] = to_cents(df['amount_sent'])
    df['amount_received_c'] = to_cents(df['amount_received'])
    df['same_bank'] = (df['from_bank'].astype(str) == df['to_bank'].astype(str))
    acct = pd.read_parquet(
        p_acct
    ).drop_duplicates(subset=['account_id_hex'])
    acct.set_index('account_id_hex', inplace=True)
    return df, acct

def build_canonical_graph(df: pd.DataFrame, acct: pd.DataFrame) -> nx.MultiDiGraph:
    G = nx.MultiDiGraph(name='G_all_multi')
    for acc_id, row in acct.iterrows():
        G.add_node(
            str(acc_id),
            bank_id=str(row.get('bank_id', '')),
            entity_id=str(row.get('entity_id', '')),
            entity_name=str(row.get('entity_name', ''))
        )
    for col in ['from_account', 'to_account']:
        missing = set(df[col].astype(str)) - set(G.nodes)
        if missing:
            for acc_id in missing:
                G.add_node(str(acc_id))
    for idx, r in df.iterrows():
        u = str(r['from_account']); v = str(r['to_account'])
        ts = r['timestamp']; t_ns = int(ts.value) if pd.notna(ts) else None
        attempt_id = r.get('attempt_id', None)
        attempt_type = r.get('attempt_type', None)
        if pd.isna(attempt_id): attempt_id = None
        if pd.isna(attempt_type): attempt_type = 'UNLISTED'
        G.add_edge(
            u, v, key=int(idx),
            timestamp=ts, t_ns=t_ns,
            is_laundering=int(r['is_laundering']),
            attempt_id=attempt_id, attempt_type=attempt_type,
            amount_sent_c=r['amount_sent_c'], amount_received_c=r['amount_received_c'],
            amount_sent=r['amount_sent'], amount_received=r['amount_received'],
            from_bank=str(r['from_bank']), to_bank=str(r['to_bank']),
            same_bank=bool(r['same_bank']),
            payment_type=r.get('payment_type', 'ACH'),
            currency_sent=r.get('currency_sent', 'US Dollar'),
            currency_received=r.get('currency_received', 'US Dollar'),
        )
    return G

def build_all(save_gpickle: bool = True):
    df, acct = load_processed()
    G = build_canonical_graph(df, acct)
    summarize_graph(G, "G_all_multi")
    n_pos_nodes = derive_node_labels(G)
    print(f"Positive nodes (full period): {n_pos_nodes:,}")
    if save_gpickle and SAVE_GPICKLE:
        save_graph_gpickle(G, GPICKLE_PATH)
        print(f"Saved G_all_multi to {GPICKLE_PATH}")
    tmin, tmax = graph_time_range(G)
    print(f"Time range: {tmin} → {tmax}")
    return df, G, tmin, tmax

def save_graph_gpickle(G, path):
    path = Path(path); path.parent.mkdir(parents=True, exist_ok=True)
    try:
        from networkx.readwrite.gpickle import write_gpickle
        write_gpickle(G, str(path)); return
    except Exception:
        pass
    import pickle
    with open(path, 'wb') as f:
        pickle.dump(G, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_graph_gpickle(path):
    try:
        from networkx.readwrite.gpickle import read_gpickle
        return read_gpickle(str(path))
    except Exception:
        import pickle
        with open(path, 'rb') as f:
            return pickle.load(f)

if __name__ == "__main__":
    build_all(save_gpickle=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
G_all_multi: 93,101 nodes, 201,645 edges, positive_edges=1,663
Positive nodes (full period): 2,464
Saved G_all_multi to /content/drive/MyDrive/AML/processed/small/US_Dollar/G_all_multi.gpickle
Time range: 2022-09-01 00:00:00 → 2022-09-18 16:18:00


In [17]:
from sklearn.metrics import average_precision_score
SKLEARN_OK = True

"""
# Import graph utilities and config
from graph import (
    proc, p_norm, p_pos, p_acct,
    WINDOW_DAYS_LIST, WINDOW_STRIDE_DAYS,
    LOUVAIN_RESOLUTION, LOUVAIN_SEED,
    build_all, iter_windows, get_windowed_graph_fast,
    derive_node_labels, aggregate_graph, run_louvain, score_communities_unsupervised
)
"""

# -----------------------
# Metrics configs
# -----------------------
METRICS_DIR = proc / "metrics"
METRICS_DIR.mkdir(parents=True, exist_ok=True)
RESULTS_CSV = METRICS_DIR / "window_metrics.csv"

K_FRACS = (0.005, 0.01, 0.02)  # 0.5%, 1%, 2% for precision@k and attempt coverage
SEED_CUTOFF_FRAC = 0.2         # first 20% of timeline for global seeds

MAX_WINDOWS_PER_SETTING = None

# -----------------------
# Metrics helpers
# -----------------------
def precision_at_k(y_true, y_score, k_frac=0.01):
    y_true = np.asarray(y_true); y_score = np.asarray(y_score)
    n = max(1, int(len(y_true) * k_frac))
    idx = np.argsort(-y_score)[:n]
    return float(y_true[idx].mean())

def eval_scores(nodes, y_true_dict, score_dict, k_fracs=(0.005, 0.01, 0.02), exclude_nodes=None):
    if exclude_nodes is None: exclude_nodes = set()
    eval_nodes = [n for n in nodes if n not in exclude_nodes]
    y_true = np.array([y_true_dict.get(n, 0) for n in eval_nodes], dtype=int)
    res = {}
    for name, s in score_dict.items():
        scores = np.array([s.get(n, 0.0) for n in eval_nodes], dtype=float)
        ap = average_precision_score(y_true, scores) if SKLEARN_OK and len(set(y_true)) > 1 else None
        metrics = {'ap': ap}
        metrics['_eval_nodes'] = len(eval_nodes)
        metrics['_eval_pos'] = int(y_true.sum())
        for k in k_fracs:
            metrics[f"p_at_{int(k*1000)/10:.1f}pct"] = precision_at_k(y_true, scores, k)
        order = np.argsort(-scores)
        metrics['_ranked_nodes'] = [eval_nodes[i] for i in order]
        res[name] = metrics
    return res

def run_centrality_baselines(H_dir: nx.DiGraph):
    scores = {}
    try:
        scores['pagerank_wlog'] = nx.pagerank(H_dir, weight='w_amount_log', alpha=0.9, max_iter=100, tol=1e-6)
    except Exception:
        scores['pagerank_wlog'] = {}
    try:
        hubs, auth = nx.hits(H_dir, max_iter=500, tol=1e-8, normalized=True)
        scores['hits_hub'] = hubs; scores['hits_auth'] = auth
    except Exception:
        scores['hits_hub'] = {}; scores['hits_auth'] = {}
    scores['in_deg'] = {n: H_dir.nodes[n].get('in_deg', 0) for n in H_dir}
    scores['out_deg'] = {n: H_dir.nodes[n].get('out_deg', 0) for n in H_dir}
    scores['in_tx'] = {n: H_dir.nodes[n].get('in_tx_count', 0) for n in H_dir}
    scores['out_tx'] = {n: H_dir.nodes[n].get('out_tx_count', 0) for n in H_dir}
    scores['in_amt'] = {n: H_dir.nodes[n].get('in_amount_sum', 0) for n in H_dir}
    scores['out_amt'] = {n: H_dir.nodes[n].get('out_amount_sum', 0) for n in H_dir}
    scores['collector'] = {n: (H_dir.nodes[n].get('in_amount_sum',0)) / (H_dir.nodes[n].get('out_amount_sum',0)+1) for n in H_dir}
    scores['distributor'] = {n: (H_dir.nodes[n].get('out_amount_sum',0)) / (H_dir.nodes[n].get('in_amount_sum',0)+1) for n in H_dir}
    return scores

def get_attempt_nodes_map(H_win: nx.MultiDiGraph):
    att_nodes = {}
    for u, v, k, d in H_win.edges(keys=True, data=True):
        if int(d.get('is_laundering', 0)) != 1:
            continue
        att_id = d.get('attempt_id', None)
        if att_id is None or (isinstance(att_id, float) and np.isnan(att_id)):
            continue
        att_nodes.setdefault(att_id, set()).update([u, v])
    return att_nodes

def attempt_coverage(nodes_ranked, attempt_nodes_map: dict, k_frac=0.01):
    if not attempt_nodes_map:
        return None
    N = len(nodes_ranked); k = max(1, int(N * k_frac))
    top = set(nodes_ranked[:k])
    covered = sum(1 for nodes in attempt_nodes_map.values() if top & nodes)
    return covered / max(1, len(attempt_nodes_map))

def pretty_metrics(results: dict):
    def is_num(x):
        return isinstance(x, (int, float, np.integer, np.floating))
    out = {}
    for method, metr in results.items():
        out[method] = {}
        for k, v in metr.items():
            if str(k).startswith('_'):
                continue
            if v is None:
                out[method][k] = None
            elif is_num(v):
                out[method][k] = round(float(v), 4)
            else:
                out[method][k] = v
    return out

def get_seeded_pagerank_scores(H_agg_dir: nx.DiGraph, seed_nodes: set, weight='w_amount_log', alpha=0.9):
    if not seed_nodes:
        return {}
    personalization = {n: (1.0/len(seed_nodes) if n in seed_nodes else 0.0) for n in H_agg_dir}
    try:
        pr_fwd = nx.pagerank(H_agg_dir, personalization=personalization, weight=weight, alpha=alpha, max_iter=100, tol=1e-6)
        pr_rev = nx.pagerank(H_agg_dir.reverse(copy=False), personalization=personalization, weight=weight, alpha=alpha, max_iter=100, tol=1e-6)
        return {n: 0.5*(pr_fwd.get(n,0.0) + pr_rev.get(n,0.0)) for n in H_agg_dir}
    except Exception:
        return {}

# -----------------------
# Build graph and data
# -----------------------
df, G, tmin, tmax = build_all(save_gpickle=True)

# -----------------------
# Temporal windows quick summary (kept for parity)
# -----------------------
for window_days in WINDOW_DAYS_LIST:
    print(f"\n-- {window_days}-day windows, stride={WINDOW_STRIDE_DAYS}d --")
    for i, (ws, we) in enumerate(iter_windows(tmin, tmax, window_days=window_days, stride_days=WINDOW_STRIDE_DAYS)):
        df_slice = df[(df['timestamp'] >= ws) & (df['timestamp'] < we)]
        H_win = get_windowed_graph_fast(df_slice, G)
        if H_win.number_of_edges() == 0:
            continue
        pos_nodes_win = derive_node_labels(H_win)
        pos_e = sum(1 for *_ , d in H_win.edges(keys=True, data=True) if int(d.get('is_laundering',0))==1)
        print(f"[{i:03d}] {ws:%Y-%m-%d} → {we:%Y-%m-%d}: nodes={H_win.number_of_nodes():,}, edges={H_win.number_of_edges():,}, pos_edges={pos_e:,}, pos_nodes={pos_nodes_win:,}")
        if MAX_WINDOWS_PER_SETTING is not None and i + 1 >= MAX_WINDOWS_PER_SETTING:
            break

# -----------------------
# Full-period community baseline and seeded PR (parity with earlier prints)
# -----------------------
H_full = aggregate_graph(G, directed=False)
print(f"\nEnhanced community baseline on full period:")
print(f"Aggregated graph: {H_full.number_of_nodes():,} nodes, {H_full.number_of_edges():,} edges")
partition_full, comms_full = run_louvain(H_full, resolution=LOUVAIN_RESOLUTION, seed=LOUVAIN_SEED, weight='w_amount_log')
comm_scores = score_communities_unsupervised(H_full, comms_full)
print("Unsupervised community scoring (top 5 by heuristic score):")
sorted_comms = sorted(comm_scores.items(), key=lambda x: x[1], reverse=True)
for cid, score in sorted_comms[:5]:
    size = len(comms_full[cid]) if cid < len(comms_full) else 0
    print(f"  cid={cid:>4}  score={score:.3f}  size={size:>6}")

H_full_directed = aggregate_graph(G, directed=True)
pos_nodes_set = {n for n, d in H_full_directed.nodes(data=True) if int(d.get('is_laundering_involved', 0)) == 1}
if len(pos_nodes_set) > 0:
    seed_size = max(1, len(pos_nodes_set) // 5)
    seed_nodes = set(sorted(pos_nodes_set)[:seed_size])
    # Compute seeded PR-AUC on full period (non-leakage variant is below for windows)
    pr_scores_full = get_seeded_pagerank_scores(H_full_directed, seed_nodes, weight='w_amount_log', alpha=0.9)
    nodes_full = list(H_full_directed.nodes())
    y_true_full = [int(H_full_directed.nodes[n].get('is_laundering_involved', 0)) for n in nodes_full if n not in seed_nodes]
    y_score_full = [pr_scores_full.get(n, 0.0) for n in nodes_full if n not in seed_nodes]
    if SKLEARN_OK and len(set(y_true_full)) > 1:
        pr_auc_full = average_precision_score(y_true_full, y_score_full)
        print(f"PersonalizedPageRank baseline PR-AUC: {pr_auc_full:.4f}")
        print(f"(Used {len(seed_nodes)} seed nodes out of {len(pos_nodes_set)} total positive nodes)")

# -----------------------
# Build fixed time-based seeds (no same-window leakage)
# -----------------------
if tmin is None or tmax is None:
    raise RuntimeError("Time range unavailable; cannot build seeds.")

T = tmin + (tmax - tmin) * SEED_CUTOFF_FRAC
df_seed = df[(df['timestamp'] >= tmin) & (df['timestamp'] < T)]
H_seed = get_windowed_graph_fast(df_seed, G)
derive_node_labels(H_seed)
H_seed_dir = aggregate_graph(H_seed, directed=True)
seed_nodes_global = {n for n, d in H_seed_dir.nodes(data=True) if int(d.get('is_laundering_involved', 0)) == 1}
print(f"Global seeds cutoff T={T} | seed_nodes={len(seed_nodes_global)}")

# -----------------------
# Per-window enhanced analysis (prints)
# -----------------------
print("\nPer-window enhanced analysis (first few windows per setting):")
for window_days in WINDOW_DAYS_LIST:
    print(f"\n-- {window_days}-day windows --")
    count = 0
    for ws, we in iter_windows(tmin, tmax, window_days=window_days, stride_days=WINDOW_STRIDE_DAYS):
        df_slice = df[(df['timestamp'] >= ws) & (df['timestamp'] < we)]
        H_win = get_windowed_graph_fast(df_slice, G)
        if H_win.number_of_edges() == 0:
            continue
        derive_node_labels(H_win)
        H_agg = aggregate_graph(H_win, directed=False)
        H_agg_dir = aggregate_graph(H_win, directed=True)

        # Centralities
        nodes = list(H_agg_dir.nodes())
        y_true_dict = {n: int(H_agg_dir.nodes[n].get('is_laundering_involved', 0)) for n in nodes}
        score_dict = run_centrality_baselines(H_agg_dir)
        results = eval_scores(nodes, y_true_dict, score_dict, k_fracs=(0.005, 0.01, 0.02))
        print("  Centrality baselines:", pretty_metrics(results))

        # Communities
        _, comms_win = run_louvain(H_agg, resolution=LOUVAIN_RESOLUTION, seed=LOUVAIN_SEED, weight='w_amount_log')
        comm_scores_win = score_communities_unsupervised(H_agg, comms_win)
        avg_comm_score = np.mean(list(comm_scores_win.values())) if comm_scores_win else 0

        # Seeded PR (time-based seeds)
        pr_auc_win = None
        if ws >= T and seed_nodes_global:
            pr_scores = get_seeded_pagerank_scores(H_agg_dir, seed_nodes_global, weight='w_amount_log', alpha=0.9)
            eval_nodes = [n for n in nodes if n not in seed_nodes_global]
            y_true = [y_true_dict[n] for n in eval_nodes]
            y_score = [pr_scores.get(n, 0.0) for n in eval_nodes]
            if SKLEARN_OK and len(set(y_true)) > 1:
                pr_auc_win = average_precision_score(y_true, y_score)

        print(f"[{ws:%Y-%m-%d} → {we:%Y-%m-%d}] nodes={H_agg.number_of_nodes():,}, edges={H_agg.number_of_edges():,}")
        print(f"  Avg community score: {avg_comm_score:.4f}")
        if pr_auc_win is not None:
            print(f"  PersonalizedPageRank PR-AUC: {pr_auc_win:.4f}")

        count += 1
        if MAX_WINDOWS_PER_SETTING is not None and count >= MAX_WINDOWS_PER_SETTING:
            break

# -----------------------
# Full per-window metrics -> CSV
# -----------------------
rows = []
for window_days in WINDOW_DAYS_LIST:
    count = 0
    for ws, we in iter_windows(tmin, tmax, window_days=window_days, stride_days=WINDOW_STRIDE_DAYS):
        df_slice = df[(df['timestamp'] >= ws) & (df['timestamp'] < we)]
        H_win = get_windowed_graph_fast(df_slice, G)
        if H_win.number_of_edges() == 0:
            continue
        derive_node_labels(H_win)
        H_agg = aggregate_graph(H_win, directed=False)
        H_agg_dir = aggregate_graph(H_win, directed=True)
        nodes = list(H_agg_dir.nodes())
        y_true_dict = {n: int(H_agg_dir.nodes[n].get('is_laundering_involved', 0)) for n in nodes}
        att_nodes_map = get_attempt_nodes_map(H_win)

        score_dict = run_centrality_baselines(H_agg_dir)
        results = eval_scores(nodes, y_true_dict, score_dict, k_fracs=K_FRACS)

        if ws >= T and seed_nodes_global:
            seeded_scores = get_seeded_pagerank_scores(H_agg_dir, seed_nodes_global, weight='w_amount_log', alpha=0.9)
            seeded_res = eval_scores(nodes, y_true_dict, {'seeded_pr': seeded_scores}, k_fracs=K_FRACS, exclude_nodes=seed_nodes_global)
            results.update(seeded_res)

        _, comms_win = run_louvain(H_agg, resolution=LOUVAIN_RESOLUTION, seed=LOUVAIN_SEED, weight='w_amount_log')
        comm_scores_win = score_communities_unsupervised(H_agg, comms_win)
        comm_ranked_nodes_cache = {}
        if comm_scores_win:
            comm_order = sorted(comm_scores_win.items(), key=lambda x: x[1], reverse=True)
            total_nodes = len(H_agg)
            acc = set()
            for kf in K_FRACS:
                target = max(1, int(total_nodes * kf))
                acc.clear()
                for cid, _score in comm_order:
                    acc |= set(comms_win[cid])
                    if len(acc) >= target:
                        break
                comm_ranked_nodes_cache[kf] = list(acc)

        base = {
            'window_days': window_days, 'ws': ws, 'we': we,
            'nodes': H_agg_dir.number_of_nodes(), 'edges': H_agg_dir.number_of_edges(),
            'pos_nodes': int(sum(y_true_dict.values()))
        }
        for method, m in results.items():
            row = dict(base); row['method'] = method; row['ap'] = m.get('ap', None)

            # Track evaluation population (different from full population for seeded methods)
            eval_nodes_count = m.get('_eval_nodes', len(nodes))
            eval_pos_count = m.get('_eval_pos', int(sum(y_true_dict.values())))
            row['eval_nodes'] = eval_nodes_count
            row['eval_pos_nodes'] = eval_pos_count
            row['prevalence_eval'] = (eval_pos_count / eval_nodes_count) if eval_nodes_count > 0 else np.nan

            for kf in K_FRACS:
                key = f"p_at_{int(kf*1000)/10:.1f}pct"
                row[key] = m.get(key, None)
                ranked_nodes = m.get('_ranked_nodes', [])
                cov = attempt_coverage(ranked_nodes, att_nodes_map, k_frac=kf)
                row[f"attcov_at_{int(kf*100)}pct"] = cov
            rows.append(row)

        if comm_ranked_nodes_cache:
            row = dict(base); row['method'] = 'communities_unsup'; row['ap'] = None
            row['eval_nodes'] = base['nodes']  # Communities use full population
            row['eval_pos_nodes'] = base['pos_nodes']
            row['prevalence_eval'] = base['pos_nodes'] / base['nodes'] if base['nodes'] > 0 else np.nan
            for kf in K_FRACS:
                row[f"p_at_{int(kf*1000)/10:.1f}pct"] = None
                cov = attempt_coverage(comm_ranked_nodes_cache[kf], att_nodes_map, k_frac=1.0)
                name = f"{int(kf*1000)/10:.1f}pct"
                row[f"attcov_at_{name}"] = cov
            rows.append(row)

        count += 1
        if MAX_WINDOWS_PER_SETTING is not None and count >= MAX_WINDOWS_PER_SETTING:
            break

df_metrics = pd.DataFrame(rows)

def add_random_baseline(dfm: pd.DataFrame) -> pd.DataFrame:
    cols = list(dfm.columns)
    rows = []
    for _, r in dfm.groupby(['window_days', 'ws', 'we']).head(1).iterrows():
        base = {c: r.get(c, None) for c in cols}
        base['method'] = 'random'
        # prevalence_eval is already present for this row (full-pop for centralities)
        prev_eval = base.get('prevalence_eval')
        if prev_eval is None or pd.isna(prev_eval):
            prev_eval = (base.get('pos_nodes', 0) / base.get('nodes', 1)) if base.get('nodes', 0) else np.nan
            base['prevalence_eval'] = prev_eval
        base['ap'] = prev_eval  # AP of random ≈ prevalence

        for pcol in ['p_at_0.5pct', 'p_at_1.0pct', 'p_at_2.0pct']:
            base[pcol] = prev_eval
            pct = pcol.split('_at_')[1].replace('.0','')
            base[f'attcov_at_{pct}'] = None
        rows.append(base)
    rand_df = pd.DataFrame(rows, columns=cols)
    return pd.concat([dfm, rand_df], ignore_index=True)

df_metrics = add_random_baseline(df_metrics)

# Compute prevalence and lift metrics
df_metrics['prevalence'] = df_metrics['pos_nodes'] / df_metrics['nodes']
for col in ['p_at_0.5pct', 'p_at_1.0pct', 'p_at_2.0pct']:
    if col in df_metrics.columns:
        df_metrics[f'lift_{col}'] = df_metrics[col] / df_metrics['prevalence']
        df_metrics[f'lift_eval_{col}'] = df_metrics[col] / df_metrics['prevalence_eval']
# -----------------------
# Validation checks
# -----------------------
# Check that within each window, nodes and pos_nodes are identical across methods
chk = (df_metrics.groupby(['window_days','ws','we'])
       .agg(nodes_nunique=('nodes','nunique'),
            pos_nodes_nunique=('pos_nodes','nunique'))
       .reset_index())
bad = chk[(chk.nodes_nunique != 1) | (chk.pos_nodes_nunique != 1)]
if not bad.empty:
    print("WARNING: nodes/pos_nodes inconsistent across methods:")
    print(bad.to_string(index=False))

# Sanity check for random baseline: lift should be ≈ 1
random_rows = df_metrics[df_metrics.method == 'random']
if not random_rows.empty:
    random_lift_median = random_rows['lift_p_at_1.0pct'].median()
    if abs(random_lift_median - 1.0) > 0.05:
        print(f"WARNING: Random baseline lift_p_at_1.0pct median = {random_lift_median:.3f}, expected ≈ 1.0")

# Seeded PR sanity: prevalence_eval should be reasonable
seeded_rows = df_metrics[df_metrics.method == 'seeded_pr']
if not seeded_rows.empty:
    high_prev = seeded_rows[seeded_rows.prevalence_eval > 0.5]
    if not high_prev.empty:
        print(f"WARNING: {len(high_prev)} seeded_pr rows have prevalence_eval > 0.5 (potentially degenerate)")

df_metrics.to_csv(RESULTS_CSV, index=False)
print(f"\nSaved per-window metrics to {RESULTS_CSV}")

if not df_metrics.empty:
    summary = (df_metrics
               .groupby(['window_days', 'method'])
               .agg(ap_median=('ap','median'),
                        p01_median=('p_at_1.0pct','median'),
                        lift_p01_median=('lift_p_at_1.0pct','median'),
                        lift_eval_p01_median=('lift_eval_p_at_1.0pct','median'),
                        attcov01_median=('attcov_at_1.0pct','median'),
                        prevalence_median=('prevalence','median'),
                        windows=('ws','count'))
               .reset_index()
               .sort_values(['window_days', 'ap_median'], ascending=[True, False]))
    print("\nSummary (median across windows):")
    print(summary.to_string(index=False))

G_all_multi: 93,101 nodes, 201,645 edges, positive_edges=1,663
Positive nodes (full period): 2,464
Saved G_all_multi to /content/drive/MyDrive/AML/processed/small/US_Dollar/G_all_multi.gpickle
Time range: 2022-09-01 00:00:00 → 2022-09-18 16:18:00

-- 3-day windows, stride=1d --
[000] 2022-09-01 → 2022-09-04: nodes=72,767, edges=73,646, pos_edges=332, pos_nodes=581
[001] 2022-09-02 → 2022-09-05: nodes=59,716, edges=54,448, pos_edges=382, pos_nodes=646
[002] 2022-09-03 → 2022-09-06: nodes=38,434, edges=39,451, pos_edges=415, pos_nodes=671
[003] 2022-09-04 → 2022-09-07: nodes=38,067, edges=48,316, pos_edges=464, pos_nodes=743
[004] 2022-09-05 → 2022-09-08: nodes=38,102, edges=57,661, pos_edges=463, pos_nodes=748
[005] 2022-09-06 → 2022-09-09: nodes=38,146, edges=57,613, pos_edges=473, pos_nodes=780
[006] 2022-09-07 → 2022-09-10: nodes=55,428, edges=69,010, pos_edges=453, pos_nodes=738
[007] 2022-09-08 → 2022-09-11: nodes=55,983, edges=60,367, pos_edges=487, pos_nodes=800
[008] 2022-09-09 

KeyboardInterrupt: 