In [None]:
#%pip install torch torchvision torchaudio

Collecting torch
  Downloading torch-2.7.0-cp312-cp312-win_amd64.whl.metadata (29 kB)
Collecting torchvision
  Downloading torchvision-0.22.0-cp312-cp312-win_amd64.whl.metadata (6.3 kB)
Collecting torchaudio
  Downloading torchaudio-2.7.0-cp312-cp312-win_amd64.whl.metadata (6.7 kB)
Collecting filelock (from torch)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting typing-extensions>=4.10.0 (from torch)
  Downloading typing_extensions-4.13.2-py3-none-any.whl.metadata (3.0 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch)
  Downloading networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Downloading jinja2-3.1.6-py3-none-any.whl.metadata (2.9 kB)
Collecting fsspec (from torch)
  Downloading fsspec-2025.3.2-py3-none-any.whl.metadata (11 kB)
Collecting pillow!=8.3.*,>=5.3.0 (from torchvision)
  Downloading pillow-11.2.1-cp312-cp312-win_amd64.whl.metad

In [1]:
import torch
print(f"PyTorch: {torch.__version__}")

PyTorch: 2.7.0+cpu


In [2]:
import torch
from torch_geometric.data import Data   
import pandas as pd
import random
from itertools import combinations, islice

def create_edges(feature_column, df, max_edges_per_group=100):
    edge_list = []
    groups = df.groupby(feature_column).indices
    for _, indices in groups.items():
        n = len(indices)   
        if n < 2:
            continue
        
        pair_generator = combinations(indices, 2)
        limited_pairs = list(islice(pair_generator, max_edges_per_group))
        edge_list.extend(limited_pairs)

    if not edge_list:
        return torch.empty((2, 0), dtype=torch.int32)
    return torch.tensor(edge_list, dtype=torch.int32).t().contiguous()


X = pd.read_csv("reduced_features.csv") 
y = pd.read_csv("balanced_labels.csv").values

#create edges for relational features
edge_index = torch.empty((2, 0), dtype=torch.int32)
edge_features = ['card1', 'addr1', 'addr2', 'P_emaildomain', 'DeviceType', 'id_17', 'id_28']
for feature in edge_features:
    edges = create_edges(feature, X, max_edges_per_group=100)
    edge_index = torch.cat([edge_index, edges], dim=1)

x_node = X.drop(columns= edge_features)

data = Data(
    x = torch.tensor(x_node.values,
                    dtype=torch.float32),
                    edge_index = edge_index,
                    y = torch.tensor(y, dtype=torch.float32)
    )


In [3]:
from torch_geometric.nn import GCNConv
#from torch_geometric.transforms import RandomNodeSplit

#transform = RandomNodeSplit(split="train_rest", num_val=0.15, num_test=0.15)
#data = transform(data)

class FraudGNN(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.conv1 = GCNConv(input_dim, 64)
        self.conv2 = GCNConv(64, 32)
        self.classifier = torch.nn.Linear(32, 1)
            
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return torch.sigmoid(self.classifier(x))
        

In [None]:
from torch_geometric.nn import SAGEConv


class FraudGraphSAGE(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.conv1 = SAGEConv(input_dim, 64)
        self.conv2 = SAGEConv(64, 32)
        self.classifier = torch.nn.Linear(32, 1)
            
    def forward(self, data):
        x = data.x
        edge_index = data.edge_index.to(torch.long)
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return torch.sigmoid(self.classifier(x))

In [None]:
from torch_geometric.nn import GATConv

class FraudGraphGAT(torch.nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.conv1 = GATConv(input_dim, 64, heads=2, concat=True)
        self.conv2 = GCNConv(64*2, 32, heads=2, concat=True)
        self.classifier = torch.nn.Linear(32*2, 1)
            
    def forward(self, data):
        x = data.x
        edge_index = data.edge_index.to(torch.long)
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).relu()
        return torch.sigmoid(self.classifier(x))

In [8]:
from sklearn.model_selection import KFold
import pickle
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score

def cross_validate(model_class, data, save_path, num_folds=5, num_epochs=200, lr=0.001, 
                    device=None, verbose=True):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    best_model_state = None
    best_model = None
    best_metric = float('-inf')


    node_indices = torch.arange(data.num_nodes)
    kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)

    results = []

    for fold, (train_idx, val_idx) in enumerate(kf.split(node_indices)):
        if verbose:
            print(f"Fold {fold + 1}/{num_folds}")

        train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

        train_mask[train_idx] = True
        val_mask[val_idx] = True

   

        model = model_class(input_dim=data.num_node_features).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        loss_fn = torch.nn.BCELoss()

        for epoch in range(num_epochs):
            model.train()
            optimizer.zero_grad()
            out = model(data.to(device))
            loss = loss_fn(out[train_mask], data.y[train_mask].to(device))
            loss.backward()
            optimizer.step()

        model.eval()
        with torch.no_grad():
            
            logits = model(data).squeeze()
            probs = logits[val_mask].cpu().numpy()
            preds = (probs > 0.5).astype(int)
            labels = data.y[val_mask].cpu().numpy().flatten()

            accuracy = (preds == labels).sum() / len(preds)
            precision = precision_score(labels, preds, zero_division=0)
            recall = recall_score(labels, preds, zero_division=0)
            f1 = f1_score(labels, preds, zero_division=0)
            roc_auc = roc_auc_score(labels, probs)

        if verbose:
            print(f"Accuracy: {accuracy}")
            print(f"Precision: {precision}")
            print(f"Recall: {recall}")
            print(f"f1: {f1}")
            print((f"ROC-AUC Score: {roc_auc}"))

        results.append({
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'roc_auc': roc_auc
        })

        if accuracy > best_metric:
            best_metric = accuracy
            best_model_state = model.state_dict()
            best_model = model

        if best_model_state is not None:
            torch.save(best_model_state, f"{save_path}.pt" )
            with open(f"{save_path}.pkl", "wb") as f:
                pickle.dump(best_model, f)


        metrics = {}
        for key in results[0].keys():
            values = [fold[key] for fold in results]
            metrics[f"mean_{key}"] = sum(values) / num_folds
            metrics[f"std_{key}"] = (sum((x - metrics[f"mean_{key}"])**2 for x in values) / num_folds) ** 0.5
   

   

    return metrics, results
        


In [None]:
print("GNC Graph:")
metrics, results = cross_validate(FraudGNN, data, "fraudgnn")
print("GNC Metrics :", metrics)


GNC Graph:
Fold 1/5
Accuracy: 0.9823076011949936
Precision: 0.9992646124235106
Recall: 0.9653385021531877
f1: 0.9820086276504151
ROC-AUC Score: 0.986480117980127
Fold 2/5
Accuracy: 0.9828559646590714
Precision: 0.9992461604134311
Recall: 0.9664016865035794
f1: 0.9825495204243843
ROC-AUC Score: 0.9859743962401298
Fold 3/5
Accuracy: 0.9827594526893938
Precision: 0.9992193315299287
Recall: 0.9662570224719101
f1: 0.9824617773850643
ROC-AUC Score: 0.9866839618411103
Fold 4/5
Accuracy: 0.9823777917183956
Precision: 0.9987860449525742
Recall: 0.9659779903970841
f1: 0.9821080996815358
ROC-AUC Score: 0.9875337678350189
Fold 5/5
Accuracy: 0.9831761351173502
Precision: 0.9988765673697181
Recall: 0.9674367996630485
f1: 0.9829053352292736
ROC-AUC Score: 0.9877217140024369
GNC Metrics : {'mean_accuracy': np.float64(0.9826953890758409), 'std_accuracy': np.float64(0.0003200790349043443), 'mean_precision': 0.9990785433378326, 'std_precision': 0.0002043957439799105, 'mean_recall': 0.966282400237762, 'st

In [9]:
print("GraphSAGE Graph:")
gs_metrics, gs_results = cross_validate(FraudGraphSAGE, data, "graphSage")
print("GraphSAGE Metrics :", gs_metrics)

GraphSAGE Graph:
Fold 1/5


RuntimeError: scatter(): Expected dtype int64 for index

In [None]:
print("GAT Graph:")
gat_metrics, gat_results = cross_validate(FraudGraphGAT, data, "GAT")
print("GAT Metrics :", gat_metrics)