In [4]:
import time
import logging
from util import create_parser, set_seed, logger_setup
from train_util import extract_param
from data_loading import get_data
from training import get_model
from inference import AddEgoIds, add_arange_ids, get_loaders
import json
import wandb 
from torch_geometric.nn import to_hetero
import onnxruntime
import torch

In [5]:
import sys
sys.argv = [
    "notebook",     # fake script name
    "--data", "data",
    "--model", "gin",
    "--emlps",
    "--reverse_mp",
    "--ports",
    "--batch_size", "8",
]


parser = create_parser()
args = parser.parse_args()


In [6]:


with open('data_config.json', 'r') as config_file:
    data_config = json.load(config_file)
logging.info(f"Unique name: {args.unique_name}")
print(f"Unique name: {args.unique_name}")

# Setup logging
# logger_setup()

#set seed
set_seed(args.seed)

#get data
print("Retrieving data")
# logging.info("Retrieving data")
t1 = time.perf_counter()

tr_data, val_data, te_data, tr_inds, val_inds, te_inds = get_data(args, data_config) # ports added here

Unique name: None
Retrieving data


In [12]:
tr_loader, val_loader, te_loader = get_loaders(tr_data, val_data, te_data, tr_inds, val_inds, te_inds, None, args)
sample_batch = next(iter(tr_loader))


In [17]:

wandb.init(
    mode="disabled" if args.testing else "online",
    project="your_proj_name",

    config={
        "epochs": args.n_epochs,
        "batch_size": 2, #args.batch_size,
        "model": args.model,
        "data": args.data,
        "num_neighbors": args.num_neighs,
        "lr": extract_param("lr", args),
        "n_hidden": extract_param("n_hidden", args),
        "n_gnn_layers": extract_param("n_gnn_layers", args),
        "loss": "ce",
        "w_ce1": extract_param("w_ce1", args),
        "w_ce2": extract_param("w_ce2", args),
        "dropout": extract_param("dropout", args),
        "final_dropout": extract_param("final_dropout", args),
        "n_heads": extract_param("n_heads", args) if args.model == 'gat' else None
    })
config = wandb.config

In [18]:
import torch.nn as nn
class WrapperModel(nn.Module):
    def __init__(self, model, node_keys, edge_keys):
        super().__init__()
        self.model = model
        self.node_keys = node_keys
        self.edge_keys = edge_keys
        self.target_edge_type = ('node', 'to', 'node')  # you index this in evaluate_hetero

    def forward(self, x_list, ei_list, ea_list):
        x  = {k: v for k, v in zip(self.node_keys, x_list)}
        ei = {k: v for k, v in zip(self.edge_keys, ei_list)}
        ea = {k: v for k, v in zip(self.edge_keys, ea_list)}
        out = self.model(x, ei, ea)
        return out[self.target_edge_type]

In [19]:
raw_model = get_model(sample_batch,config, args)
raw_model  = to_hetero(raw_model, te_data.metadata(), aggr='mean')
# AFTER GET MODEL
sample_batch['node','to','node'].edge_attr = sample_batch['node','to','node'].edge_attr[:, 1:]
sample_batch['node','rev_to','node'].edge_attr = sample_batch['node','rev_to','node'].edge_attr[:, 1:]

node_keys = list(sample_batch.x_dict.keys())
edge_keys = list(sample_batch.edge_index_dict.keys())

x_list  = [sample_batch.x_dict[k] for k in node_keys]
ei_list = [sample_batch.edge_index_dict[k] for k in edge_keys]
ea_list = []
for k, ei in zip(edge_keys, ei_list):
    ea = sample_batch.edge_attr_dict.get(k, None)
    if isinstance(ea, torch.Tensor):
        ea_list.append(ea)
    else:
        # no edge features: give a (num_edges x 0) float tensor
        ea_list.append(torch.empty((ei.size(1), 0), device='cpu'))

In [21]:
# for et in [('node','to','node'), ('node','rev_to','node')]:
#     ea = sample_batch[et].edge_attr
#     print(et, ea[:, 1:].shape[1])

In [None]:
model = WrapperModel(  raw_model, node_keys, edge_keys )
torch.onnx.export(
    model.to('cpu'),
    (x_list, ei_list, ea_list),
    'GIN.onnx',
    input_names = ['node_features', 'edge_index', 'edge_attr'],

    verbose = True,
    keep_initializers_as_inputs=True, opset_version=9
)

In [12]:
import pandas as pd

df = pd.read_csv('/Users/leo/Desktop/AML/data/formatted_transactions.csv')
df.head(30).to_csv('/Users/leo/Desktop/AML/data/formatted_transactions_sample30.csv',index=False)

In [9]:
import json
import pandas as pd

# ensure all ID columns are strings
df = df.copy()
id_cols = ["EdgeID", "from_id", "to_id", "Sent Currency", "Received Currency"]  # adjust as needed
for col in id_cols:
    if col in df.columns:
        df[col] = df[col].astype(str)

data_json = {
    "input_data": [
        {
            "fields": list(df.head(1).columns),
            "values": df.head(1).values.tolist()
        }
    ]
}

print(json.dumps(data_json, indent=4))

{
    "input_data": [
        {
            "fields": [
                "EdgeID",
                "from_id",
                "to_id",
                "Timestamp",
                "Amount Sent",
                "Sent Currency",
                "Amount Received",
                "Received Currency",
                "Payment Format",
                "Is Laundering"
            ],
            "values": [
                [
                    "2",
                    "3",
                    "3",
                    10,
                    14675.57,
                    "0",
                    14675.57,
                    "0",
                    0,
                    0
                ]
            ]
        }
    ]
}
