# Phase 3: Graph Construction and Preprocessing

This notebook implements the graph construction phase for the AML Multi-GNN project.

## Objectives:
1. Load processed data from Phase 2
2. Construct transaction graphs with multiple views
3. Create node and edge features
4. Implement graph preprocessing and validation
5. Save graph data for model training

## Graph Views:
- **Transaction View**: Direct transaction relationships
- **Account View**: Account-based connections
- **Temporal View**: Time-based transaction patterns
- **Amount View**: Value-based transaction clustering


In [None]:
# Phase 3: Graph Construction and Preprocessing
print("=" * 60)
print("AML Multi-GNN - Phase 3: Graph Construction")
print("=" * 60)

# Import required libraries
import pandas as pd
import numpy as np
import networkx as nx
import torch
import torch_geometric
from torch_geometric.data import Data, HeteroData
from torch_geometric.utils import to_networkx, from_networkx
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
import json
import os
import gc
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("✓ Libraries imported successfully")


In [None]:
# Load processed data from Phase 2
print("Loading processed data from Phase 2...")

try:
    # Load unified data
    unified_data_path = "/content/drive/MyDrive/LaunDetection/data/processed/unified_data.csv"
    if os.path.exists(unified_data_path):
        df = pd.read_csv(unified_data_path)
        print(f"✓ Loaded unified data: {df.shape}")
    else:
        print("✗ Unified data not found. Please run Phase 2 first.")
        df = None
    
    # Load data structure
    structure_path = "/content/drive/MyDrive/LaunDetection/data/processed/data_structure.json"
    if os.path.exists(structure_path):
        with open(structure_path, 'r') as f:
            data_structure = json.load(f)
        print(f"✓ Loaded data structure: {data_structure}")
    else:
        print("⚠️  Data structure not found, using default structure")
        data_structure = {
            'transactions': 'HI-Small_Trans',
            'accounts': 'HI-Small_accounts',
            'labels': None
        }
    
    # Load quality metrics
    quality_path = "/content/drive/MyDrive/LaunDetection/data/processed/quality_metrics.json"
    if os.path.exists(quality_path):
        with open(quality_path, 'r') as f:
            quality_metrics = json.load(f)
        print(f"✓ Loaded quality metrics")
    else:
        print("⚠️  Quality metrics not found")
        quality_metrics = None
        
except Exception as e:
    print(f"✗ Error loading processed data: {e}")
    df = None
    data_structure = None
    quality_metrics = None


In [None]:
# Data validation and preprocessing
if df is not None:
    print("Validating and preprocessing data...")
    
    # Display basic info
    print(f"Dataset shape: {df.shape}")
    print(f"Columns: {list(df.columns)}")
    print(f"Data types:\n{df.dtypes}")
    
    # Check for required columns
    required_cols = ['Timestamp', 'From Bank', 'Account', 'To Bank', 'Account.1', 'Amount Received', 'Is Laundering']
    missing_cols = [col for col in required_cols if col not in df.columns]
    
    if missing_cols:
        print(f"⚠️  Missing required columns: {missing_cols}")
    else:
        print("✓ All required columns present")
    
    # Handle missing values
    if df.isnull().sum().sum() > 0:
        print(f"Missing values found:\n{df.isnull().sum()}")
        # Fill missing values
        df = df.fillna(method='ffill').fillna(method='bfill')
        print("✓ Missing values filled")
    else:
        print("✓ No missing values")
    
    # Convert timestamp to datetime
    if 'Timestamp' in df.columns:
        df['Timestamp'] = pd.to_datetime(df['Timestamp'])
        print("✓ Timestamp converted to datetime")
    
    # Check for laundering cases
    if 'Is Laundering' in df.columns:
        laundering_count = df['Is Laundering'].sum()
        total_count = len(df)
        print(f"Laundering cases: {laundering_count}/{total_count} ({laundering_count/total_count*100:.1f}%)")
        
        if laundering_count == 0:
            print("⚠️  No laundering cases in current sample - this is expected for exploration")
    
    print("✓ Data validation complete")
else:
    print("✗ Cannot proceed without data")


In [None]:
# Graph Construction Functions

def create_transaction_graph(df):
    """
    Create a transaction-based graph where:
    - Nodes represent accounts
    - Edges represent transactions
    - Edge features include transaction details
    """
    print("Creating transaction graph...")
    
    # Create NetworkX graph
    G = nx.DiGraph()
    
    # Add nodes (accounts)
    from_accounts = df['Account'].unique()
    to_accounts = df['Account.1'].unique()
    all_accounts = list(set(from_accounts) | set(to_accounts))
    
    for account in all_accounts:
        G.add_node(account, node_type='account')
    
    print(f"Added {len(all_accounts)} account nodes")
    
    # Add edges (transactions)
    edge_features = []
    
    for idx, row in df.iterrows():
        from_acc = row['Account']
        to_acc = row['Account.1']
        
        # Create edge features
        edge_feat = {
            'amount_received': row['Amount Received'],
            'amount_paid': row['Amount Paid'],
            'receiving_currency': row['Receiving Currency'],
            'payment_currency': row['Payment Currency'],
            'payment_format': row['Payment Format'],
            'timestamp': row['Timestamp'],
            'is_laundering': row['Is Laundering'],
            'from_bank': row['From Bank'],
            'to_bank': row['To Bank']
        }
        
        # Add edge with features
        G.add_edge(from_acc, to_acc, **edge_feat)
        edge_features.append(edge_feat)
    
    print(f"Added {len(edge_features)} transaction edges")
    
    return G, edge_features

def create_account_graph(df):
    """
    Create an account-based graph where:
    - Nodes represent accounts with aggregated features
    - Edges represent account relationships
    """
    print("Creating account graph...")
    
    # Aggregate account features
    account_features = {}
    
    # From account features
    from_features = df.groupby('Account').agg({
        'Amount Received': ['sum', 'mean', 'count'],
        'From Bank': 'first',
        'Is Laundering': 'sum'
    }).round(2)
    
    # To account features
    to_features = df.groupby('Account.1').agg({
        'Amount Paid': ['sum', 'mean', 'count'],
        'To Bank': 'first',
        'Is Laundering': 'sum'
    }).round(2)
    
    # Combine features
    all_accounts = list(set(df['Account'].unique()) | set(df['Account.1'].unique()))
    
    for account in all_accounts:
        features = {
            'total_sent': 0,
            'total_received': 0,
            'avg_sent': 0,
            'avg_received': 0,
            'transaction_count': 0,
            'laundering_count': 0,
            'primary_bank': None
        }
        
        # From account features
        if account in from_features.index:
            features['total_sent'] = from_features.loc[account, ('Amount Received', 'sum')]
            features['avg_sent'] = from_features.loc[account, ('Amount Received', 'mean')]
            features['transaction_count'] += from_features.loc[account, ('Amount Received', 'count')]
            features['laundering_count'] += from_features.loc[account, ('Is Laundering', 'sum')]
            features['primary_bank'] = from_features.loc[account, ('From Bank', 'first')]
        
        # To account features
        if account in to_features.index:
            features['total_received'] = to_features.loc[account, ('Amount Paid', 'sum')]
            features['avg_received'] = to_features.loc[account, ('Amount Paid', 'mean')]
            features['transaction_count'] += to_features.loc[account, ('Amount Paid', 'count')]
            features['laundering_count'] += to_features.loc[account, ('Is Laundering', 'sum')]
            if features['primary_bank'] is None:
                features['primary_bank'] = to_features.loc[account, ('To Bank', 'first')]
        
        account_features[account] = features
    
    # Create NetworkX graph
    G = nx.Graph()
    
    # Add nodes with features
    for account, features in account_features.items():
        G.add_node(account, **features)
    
    # Add edges based on transaction relationships
    for idx, row in df.iterrows():
        from_acc = row['Account']
        to_acc = row['Account.1']
        
        if G.has_edge(from_acc, to_acc):
            # Update edge weight (transaction count)
            G[from_acc][to_acc]['weight'] += 1
        else:
            # Add new edge
            G.add_edge(from_acc, to_acc, weight=1)
    
    print(f"Added {len(account_features)} account nodes with features")
    print(f"Added {G.number_of_edges()} account relationships")
    
    return G, account_features

def create_temporal_graph(df):
    """
    Create a temporal graph based on transaction timing
    """
    print("Creating temporal graph...")
    
    # Sort by timestamp
    df_sorted = df.sort_values('Timestamp')
    
    # Create time-based connections
    G = nx.DiGraph()
    
    # Add all accounts as nodes
    all_accounts = list(set(df['Account'].unique()) | set(df['Account.1'].unique()))
    for account in all_accounts:
        G.add_node(account, node_type='account')
    
    # Create temporal edges based on transaction sequence
    for i in range(len(df_sorted) - 1):
        current_row = df_sorted.iloc[i]
        next_row = df_sorted.iloc[i + 1]
        
        # Connect accounts involved in consecutive transactions
        current_accounts = [current_row['Account'], current_row['Account.1']]
        next_accounts = [next_row['Account'], next_row['Account.1']]
        
        # Add temporal connections
        for curr_acc in current_accounts:
            for next_acc in next_accounts:
                if curr_acc != next_acc:
                    if G.has_edge(curr_acc, next_acc):
                        G[curr_acc][next_acc]['temporal_weight'] += 1
                    else:
                        G.add_edge(curr_acc, next_acc, temporal_weight=1)
    
    print(f"Added {G.number_of_edges()} temporal connections")
    
    return G

def create_amount_graph(df):
    """
    Create a graph based on transaction amounts
    """
    print("Creating amount-based graph...")
    
    # Create amount-based connections
    G = nx.Graph()
    
    # Add all accounts as nodes
    all_accounts = list(set(df['Account'].unique()) | set(df['Account.1'].unique()))
    for account in all_accounts:
        G.add_node(account, node_type='account')
    
    # Create amount-based edges
    for idx, row in df.iterrows():
        from_acc = row['Account']
        to_acc = row['Account.1']
        amount = row['Amount Received']
        
        # Add edge with amount as weight
        if G.has_edge(from_acc, to_acc):
            G[from_acc][to_acc]['amount_weight'] += amount
        else:
            G.add_edge(from_acc, to_acc, amount_weight=amount)
    
    print(f"Added {G.number_of_edges()} amount-based connections")
    
    return G

print("✓ Graph construction functions defined")


In [None]:
# Build all graph views
if df is not None:
    print("Building multiple graph views...")
    
    graphs = {}
    
    # 1. Transaction Graph
    print("\n1. Transaction Graph:")
    trans_graph, edge_features = create_transaction_graph(df)
    graphs['transaction'] = trans_graph
    
    # 2. Account Graph
    print("\n2. Account Graph:")
    account_graph, account_features = create_account_graph(df)
    graphs['account'] = account_graph
    
    # 3. Temporal Graph
    print("\n3. Temporal Graph:")
    temporal_graph = create_temporal_graph(df)
    graphs['temporal'] = temporal_graph
    
    # 4. Amount Graph
    print("\n4. Amount Graph:")
    amount_graph = create_amount_graph(df)
    graphs['amount'] = amount_graph
    
    print("\n✓ All graph views created successfully")
    
    # Display graph statistics
    print("\nGraph Statistics:")
    for name, graph in graphs.items():
        print(f"{name.capitalize()} Graph: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges")
        
else:
    print("✗ Cannot build graphs without data")
    graphs = None


In [None]:
# Convert to PyTorch Geometric format
if graphs is not None:
    print("Converting graphs to PyTorch Geometric format...")
    
    pytorch_graphs = {}
    
    for name, graph in graphs.items():
        print(f"\nConverting {name} graph...")
        
        try:
            # Convert to PyTorch Geometric Data object
            pyg_data = from_networkx(graph)
            
            # Add node features if available
            if name == 'account' and account_features:
                # Create node feature matrix
                node_features = []
                for node in graph.nodes():
                    if node in account_features:
                        features = account_features[node]
                        feature_vector = [
                            features['total_sent'],
                            features['total_received'],
                            features['avg_sent'],
                            features['avg_received'],
                            features['transaction_count'],
                            features['laundering_count'],
                            features['primary_bank'] or 0
                        ]
                        node_features.append(feature_vector)
                    else:
                        # Default features for nodes without account features
                        node_features.append([0, 0, 0, 0, 0, 0, 0])
                
                pyg_data.x = torch.tensor(node_features, dtype=torch.float)
                print(f"✓ Added node features: {pyg_data.x.shape}")
            
            # Add edge features if available
            if name == 'transaction' and edge_features:
                # Create edge feature matrix
                edge_features_tensor = []
                for edge in graph.edges():
                    edge_data = graph[edge[0]][edge[1]]
                    feature_vector = [
                        edge_data.get('amount_received', 0),
                        edge_data.get('amount_paid', 0),
                        edge_data.get('is_laundering', 0),
                        edge_data.get('from_bank', 0),
                        edge_data.get('to_bank', 0)
                    ]
                    edge_features_tensor.append(feature_vector)
                
                pyg_data.edge_attr = torch.tensor(edge_features_tensor, dtype=torch.float)
                print(f"✓ Added edge features: {pyg_data.edge_attr.shape}")
            
            # Add labels (laundering information)
            if 'Is Laundering' in df.columns:
                # Create node labels based on account involvement in laundering
                account_labels = {}
                for account in graph.nodes():
                    # Check if account is involved in any laundering transactions
                    laundering_involvement = df[
                        (df['Account'] == account) | (df['Account.1'] == account)
                    ]['Is Laundering'].sum()
                    
                    account_labels[account] = 1 if laundering_involvement > 0 else 0
                
                # Create label tensor
                labels = [account_labels.get(node, 0) for node in graph.nodes()]
                pyg_data.y = torch.tensor(labels, dtype=torch.long)
                print(f"✓ Added labels: {pyg_data.y.shape}")
            
            pytorch_graphs[name] = pyg_data
            print(f"✓ {name} graph converted successfully")
            
        except Exception as e:
            print(f"✗ Error converting {name} graph: {e}")
            pytorch_graphs[name] = None
    
    print("\n✓ All graphs converted to PyTorch Geometric format")
    
    # Display PyTorch Geometric graph info
    print("\nPyTorch Geometric Graph Information:")
    for name, pyg_data in pytorch_graphs.items():
        if pyg_data is not None:
            print(f"\n{name.capitalize()} Graph:")
            print(f"  Nodes: {pyg_data.num_nodes}")
            print(f"  Edges: {pyg_data.num_edges}")
            if hasattr(pyg_data, 'x') and pyg_data.x is not None:
                print(f"  Node features: {pyg_data.x.shape}")
            if hasattr(pyg_data, 'edge_attr') and pyg_data.edge_attr is not None:
                print(f"  Edge features: {pyg_data.edge_attr.shape}")
            if hasattr(pyg_data, 'y') and pyg_data.y is not None:
                print(f"  Labels: {pyg_data.y.shape}")
                print(f"  Label distribution: {torch.bincount(pyg_data.y)}")
        else:
            print(f"\n{name.capitalize()} Graph: Failed to convert")
            
else:
    print("✗ Cannot convert graphs without data")
    pytorch_graphs = None


In [None]:
# Graph visualization and analysis
if pytorch_graphs is not None:
    print("Analyzing graph properties...")
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Graph Views Analysis', fontsize=16)
    
    graph_names = ['transaction', 'account', 'temporal', 'amount']
    
    for i, name in enumerate(graph_names):
        if name in pytorch_graphs and pytorch_graphs[name] is not None:
            ax = axes[i//2, i%2]
            
            # Get NetworkX graph for visualization
            if name in graphs:
                G = graphs[name]
                
                # Create subgraph for visualization (limit nodes for clarity)
                if G.number_of_nodes() > 50:
                    # Sample nodes for visualization
                    nodes_to_keep = list(G.nodes())[:50]
                    G_viz = G.subgraph(nodes_to_keep)
                else:
                    G_viz = G
                
                # Plot graph
                pos = nx.spring_layout(G_viz, k=1, iterations=50)
                nx.draw(G_viz, pos, ax=ax, node_size=50, node_color='lightblue', 
                       edge_color='gray', arrows=True, arrowsize=10)
                
                ax.set_title(f'{name.capitalize()} Graph\n({G_viz.number_of_nodes()} nodes, {G_viz.number_of_edges()} edges)')
            else:
                ax.text(0.5, 0.5, f'{name.capitalize()}\nGraph not available', 
                       ha='center', va='center', transform=ax.transAxes)
                ax.set_title(f'{name.capitalize()} Graph')
        else:
            ax = axes[i//2, i%2]
            ax.text(0.5, 0.5, f'{name.capitalize()}\nGraph not available', 
                   ha='center', va='center', transform=ax.transAxes)
            ax.set_title(f'{name.capitalize()} Graph')
    
    plt.tight_layout()
    plt.show()
    
    # Graph statistics
    print("\nDetailed Graph Statistics:")
    for name, pyg_data in pytorch_graphs.items():
        if pyg_data is not None:
            print(f"\n{name.capitalize()} Graph:")
            print(f"  Nodes: {pyg_data.num_nodes}")
            print(f"  Edges: {pyg_data.num_edges}")
            print(f"  Density: {pyg_data.num_edges / (pyg_data.num_nodes * (pyg_data.num_nodes - 1)):.4f}")
            
            if hasattr(pyg_data, 'x') and pyg_data.x is not None:
                print(f"  Node features: {pyg_data.x.shape}")
                print(f"  Node feature range: [{pyg_data.x.min():.2f}, {pyg_data.x.max():.2f}]")
            
            if hasattr(pyg_data, 'edge_attr') and pyg_data.edge_attr is not None:
                print(f"  Edge features: {pyg_data.edge_attr.shape}")
                print(f"  Edge feature range: [{pyg_data.edge_attr.min():.2f}, {pyg_data.edge_attr.max():.2f}]")
            
            if hasattr(pyg_data, 'y') and pyg_data.y is not None:
                print(f"  Labels: {pyg_data.y.shape}")
                label_counts = torch.bincount(pyg_data.y)
                print(f"  Label distribution: {dict(zip(range(len(label_counts)), label_counts.tolist()))}")
                
                if len(label_counts) > 1:
                    print(f"  Class imbalance ratio: {label_counts[1].item() / label_counts[0].item():.4f}")
                else:
                    print(f"  Only one class present (all {label_counts[0].item()} samples)")
    
    print("\n✓ Graph analysis complete")
    
else:
    print("✗ Cannot analyze graphs without data")


In [None]:
# Save graph data for model training
if pytorch_graphs is not None:
    print("Saving graph data for model training...")
    
    try:
        # Create output directory
        output_dir = "/content/drive/MyDrive/LaunDetection/data/processed/graphs"
        os.makedirs(output_dir, exist_ok=True)
        
        # Save each graph view
        for name, pyg_data in pytorch_graphs.items():
            if pyg_data is not None:
                graph_path = os.path.join(output_dir, f"{name}_graph.pt")
                torch.save(pyg_data, graph_path)
                print(f"✓ Saved {name} graph to {graph_path}")
        
        # Save graph metadata
        metadata = {
            'num_graphs': len([g for g in pytorch_graphs.values() if g is not None]),
            'graph_names': list(pytorch_graphs.keys()),
            'creation_time': datetime.now().isoformat(),
            'data_shape': df.shape if df is not None else None,
            'data_structure': data_structure
        }
        
        metadata_path = os.path.join(output_dir, "graph_metadata.json")
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        print(f"✓ Saved graph metadata to {metadata_path}")
        
        # Save graph statistics
        stats = {}
        for name, pyg_data in pytorch_graphs.items():
            if pyg_data is not None:
                stats[name] = {
                    'num_nodes': pyg_data.num_nodes,
                    'num_edges': pyg_data.num_edges,
                    'has_node_features': hasattr(pyg_data, 'x') and pyg_data.x is not None,
                    'has_edge_features': hasattr(pyg_data, 'edge_attr') and pyg_data.edge_attr is not None,
                    'has_labels': hasattr(pyg_data, 'y') and pyg_data.y is not None
                }
                
                if hasattr(pyg_data, 'x') and pyg_data.x is not None:
                    stats[name]['node_feature_shape'] = list(pyg_data.x.shape)
                
                if hasattr(pyg_data, 'edge_attr') and pyg_data.edge_attr is not None:
                    stats[name]['edge_feature_shape'] = list(pyg_data.edge_attr.shape)
                
                if hasattr(pyg_data, 'y') and pyg_data.y is not None:
                    stats[name]['label_shape'] = list(pyg_data.y.shape)
                    stats[name]['label_distribution'] = torch.bincount(pyg_data.y).tolist()
        
        stats_path = os.path.join(output_dir, "graph_statistics.json")
        with open(stats_path, 'w') as f:
            json.dump(stats, f, indent=2)
        print(f"✓ Saved graph statistics to {stats_path}")
        
        print("\n✓ All graph data saved successfully")
        
    except Exception as e:
        print(f"✗ Error saving graph data: {e}")
        
else:
    print("✗ Cannot save graphs without data")


In [None]:
# Memory cleanup and final summary
print("Cleaning up memory...")
gc.collect()

print("\n" + "=" * 60)
print("Phase 3 - Graph Construction Completed!")
print("=" * 60)

if pytorch_graphs is not None:
    print("\nSummary:")
    print(f"✓ Created {len([g for g in pytorch_graphs.values() if g is not None])} graph views")
    print(f"✓ All graphs converted to PyTorch Geometric format")
    print(f"✓ Graph data saved to /content/drive/MyDrive/LaunDetection/data/processed/graphs/")
    
    print("\nNext steps:")
    print("1. Review the graph construction results above")
    print("2. Proceed to Phase 4: Multi-GNN Architecture")
    print("3. Run: %run notebooks/03_multi_gnn_architecture.ipynb")
    
    print("\nGraph Views Created:")
    for name, pyg_data in pytorch_graphs.items():
        if pyg_data is not None:
            print(f"  - {name.capitalize()}: {pyg_data.num_nodes} nodes, {pyg_data.num_edges} edges")
        else:
            print(f"  - {name.capitalize()}: Failed to create")
            
else:
    print("\n✗ Graph construction failed - please check data loading")
    print("\nTroubleshooting:")
    print("1. Ensure Phase 2 completed successfully")
    print("2. Check that unified_data.csv exists in processed folder")
    print("3. Verify data structure and quality metrics")

print("\n" + "=" * 60)
