In [18]:
import torch
import pandas as pd
import numpy as npx
import matplotlib.pyplot as plt
from hybrid_model import HybridDQN
from vanet_dataset import VanetGraphDataset
from feature_extraction import load_ground_truth, process_logs_to_dataframe, calculate_kinematic_features

# 1. Setup paths
LOG_FOLDER = "./logs/" 
GNN_WEIGHTS = "pretrained_sybil_gnn.pth"
DQN_WEIGHTS = "sybil_trained_dqn.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Preprocess Logs
print("--- STAGE 1: RAW LOG PREPROCESSING ---")
gt_map = load_ground_truth(LOG_FOLDER)
# Extract features and apply log-level labeling (is_attack=1 if content is falsified)
raw_df = process_logs_to_dataframe(LOG_FOLDER, gt_map, set(), set(), use_vehicle_level_labeling=False)
df_featured = calculate_kinematic_features(raw_df)

print(f"\nExtracted {len(df_featured)} messages.")
print("Sample of Feature-Engineered Logs (First 5):")
display(df_featured[['messageID', 'rcvTime', 'rel_pos_x', 'rel_pos_y', 'distance_diff', 'is_attack']].head())

# Save temporary CSV for the PyG Dataset loader
df_featured.to_csv("demo_inference.csv", index=False)

--- STAGE 1: RAW LOG PREPROCESSING ---
Loaded 244099 ground truth records.

Extracted 1471 messages.
Sample of Feature-Engineered Logs (First 5):


Unnamed: 0,messageID,rcvTime,rel_pos_x,rel_pos_y,distance_diff,is_attack
0,17038,25208.812286,-1.181646,10.612802,0.0,0
1,17243,25208.925714,-5.258251,-223.833589,0.0,0
2,17460,25209.812276,-1.126122,12.15198,0.83527,0
3,17585,25209.925703,-5.517845,-220.562794,-0.045763,0
4,22042,25210.812278,2.005638,14.301253,0.696559,0


In [19]:
df_featured[df_featured["is_attack"] == 1]

Unnamed: 0,rcvTime,sendTime,receiver,sender,messageID,rel_pos_x,rel_pos_y,rel_spd_x,rel_spd_y,pos_x,...,acl_y,hed_x,hed_y,is_attack,speed,acceleration,distance_diff,beacon_rate,avg_speed_1s,stddev_speed_1s
49,25208.891879,25208.891879,21,15,17141,0.000000,0.000000,0.000000,0.000000,135.115420,...,-2.552403,0.093309,-0.995637,1,0.508598,2.563606,-9.692420,6.000071,1.356618,1.199281
51,25209.391888,25209.391888,21,15,17346,-4.437479,-232.716574,-0.246637,2.630646,131.038815,...,0.000041,0.998371,0.057050,1,0.000000,0.000058,-234.194446,4.999987,0.169533,0.293639
52,25209.891882,25209.891882,21,15,17483,0.000000,0.000000,0.000000,0.000000,135.476294,...,-1.537555,0.092600,-0.995703,1,2.642183,1.544310,-232.001820,4.666677,0.880728,1.525465
54,25210.391890,25210.391890,21,15,21912,-4.686821,-229.059396,-0.422275,4.504717,131.084571,...,0.000041,0.998452,0.055626,1,0.000000,0.000058,-231.999128,4.499989,0.880728,1.525465
55,25210.891884,25210.891884,21,15,22066,-0.295098,3.655378,-0.175638,1.874071,135.476294,...,-1.537555,0.092600,-0.995703,1,2.642183,1.544310,-231.999153,4.400001,0.880728,1.525465
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1395,25230.092349,25230.092349,9,75,64562,0.750399,-7.724254,-1.136189,2.967308,220.437754,...,-4.450997,-0.029735,0.999558,1,10.050640,4.499913,0.000006,7.999938,10.620255,0.989936
1396,25230.592345,25230.592345,9,75,64948,56.111082,-164.542842,13.826465,-3.429919,275.798438,...,0.631513,0.974112,0.226068,1,13.943608,2.483978,0.000008,6.000006,11.451093,1.847833
1398,25231.092348,25231.092348,9,75,74031,52.558941,-171.709126,8.204818,-3.938943,275.798438,...,0.631513,0.974112,0.226068,1,13.943608,2.483978,0.000008,5.333325,11.997574,2.247086
1399,25231.592343,25231.592343,9,75,74545,-3.552142,-7.166285,-5.621647,-0.509024,219.687355,...,2.596957,0.068848,0.997627,1,6.982533,2.600063,0.000004,5.000006,11.230547,3.374165


In [28]:
print("--- STAGE 2: GRAPH SNAPSHOT CONSTRUCTION ---")
dataset = VanetGraphDataset("demo_inference.csv", time_window=2.0)

# Pick one snapshot to inspect
sample_graph = dataset[0]
(receiver_id, time_bin), original_group = dataset.grouped[0]

print(f"Snapshot for Receiver: {receiver_id} at Time Bin: {time_bin}")
print(f"Nodes in Graph: {sample_graph.num_nodes} (1 Ego + {sample_graph.num_nodes - 1} Senders)")
print(f"Edge Index (Star Graph Connections):\n{sample_graph.edge_index}")

# Visualize the node features (X) for this snapshot
print("\nNode Feature Tensor (X) Shape:", sample_graph.x.shape)
print("First Sender's Features (Node 1):", sample_graph.x[1].numpy())
sample_graph

--- STAGE 2: GRAPH SNAPSHOT CONSTRUCTION ---
Loading and grouping data...
Created 225 graph snapshots.
Snapshot for Receiver: 9 at Time Bin: 25212.0
Nodes in Graph: 3 (1 Ego + 2 Senders)
Edge Index (Star Graph Connections):
tensor([[1, 2, 0, 0],
        [0, 0, 1, 2]])

Node Feature Tensor (X) Shape: torch.Size([3, 7])
First Sender's Features (Node 1): [  7.894916  -95.68901     2.1848454 -12.137279    2.0615287   0.
   0.       ]


Data(x=[3, 7], edge_index=[2, 4], y=[3], baseline_speed_mean=2.190430958616985, baseline_speed_std=0.0, receiver_id=9)

In [29]:
print("--- STAGE 3: GNN EMBEDDING EXTRACTION ---")

# Load Model
model = HybridDQN(gnn_weights_path=GNN_WEIGHTS).to(DEVICE)
model.load_state_dict(torch.load(DQN_WEIGHTS, map_location=DEVICE))
model.eval()

# Pass the sample graph through the GNN backbone
with torch.no_grad():
    x = sample_graph.x.to(DEVICE)
    edge_index = sample_graph.edge_index.to(DEVICE)
    batch = torch.zeros(x.shape[0], dtype=torch.long, device=DEVICE)
    
    # Access the GNN specifically
    context_vector, node_logits, hidden_h = model.gnn(x, edge_index, batch)

print(f"Graph Context Vector (Summary of neighborhood) - Shape: {context_vector.shape}")
print(f"Hidden Node Embeddings (Processed features) - Shape: {hidden_h.shape}")
print("\nSample Context Vector (First 10 dims):", context_vector[0, :10].cpu().numpy())

--- STAGE 3: GNN EMBEDDING EXTRACTION ---
Graph Context Vector (Summary of neighborhood) - Shape: torch.Size([1, 32])
Hidden Node Embeddings (Processed features) - Shape: torch.Size([3, 64])

Sample Context Vector (First 10 dims): [ 1.0539933  -1.6607968   2.9181864   2.1430743  -3.5749018  -4.4474335
 -0.85841227 -4.2523427   3.6220648   1.5867747 ]


In [31]:
import torch
import pandas as pd
from tqdm import tqdm # Useful for tracking progress over many snapshots

print("--- STAGE 5: FULL DATASET INFERENCE ---")

all_results = []
correct_count = 0
total_targets = 0

# Load everything from your previously defined dataset and model
model.eval()

print(f"{'Message ID':<15} | {'DQN Action':<12} | {'Actual Label':<12} | {'Correct?'}")
print("-" * 65)

with torch.no_grad():
    # Iterate through every graph snapshot in the dataset
    for i in range(len(dataset)):
        graph = dataset[i]
        (receiver_id, time_bin), original_group = dataset.grouped[i]
        
        # Prepare data for model
        x = graph.x.to(DEVICE)
        edge_index = graph.edge_index.to(DEVICE)
        batch = torch.zeros(x.shape[0], dtype=torch.long, device=DEVICE)
        
        num_senders = x.shape[0] - 1
        
        # Evaluate every sender in this specific snapshot
        for sender_idx in range(1, num_senders + 1):
            # Forward pass through HybridDQN
            q_values = model(x, edge_index, batch, sender_idx)
            
            # Action 0 = Accept, Action 1 = Reject
            action = q_values.argmax(dim=1).item()
            actual = int(graph.y[sender_idx].item())
            
            # Get Message ID from the original grouped dataframe
            # The sender nodes (1..N) correspond to the rows in 'original_group'
            msg_id = original_group.iloc[sender_idx - 1]['messageID']
            
            is_correct = (action == actual)
            if is_correct:
                correct_count += 1
            total_targets += 1
            
            pred_str = "REJECT (1)" if action == 1 else "ACCEPT (0)"
            act_str = "ATTACK (1)" if actual == 1 else "NORMAL (0)"
            status = "✓" if is_correct else "✗"
            
            # Print results (you may want to comment this out if the dataset is very large)
            print(f"{msg_id:<15} | {pred_str:<12} | {act_str:<12} | {status}")
            
            # Store for final metrics
            all_results.append({
                "messageID": msg_id,
                "prediction": action,
                "actual": actual,
                "q_accept": q_values[0,0].item(),
                "q_reject": q_values[0,1].item()
            })

print("\n--- INFERENCE COMPLETE ---")
print(f"Total Messages Evaluated: {total_targets}")
print(f"Overall Accuracy: {(correct_count/total_targets)*100:.2f}%")

--- STAGE 5: FULL DATASET INFERENCE ---
Message ID      | DQN Action   | Actual Label | Correct?
-----------------------------------------------------------------
31028.0         | REJECT (1)   | NORMAL (0)   | ✗
31552.0         | REJECT (1)   | NORMAL (0)   | ✗
31901.0         | REJECT (1)   | NORMAL (0)   | ✗
32250.0         | REJECT (1)   | NORMAL (0)   | ✗
32599.0         | REJECT (1)   | NORMAL (0)   | ✗
37270.0         | REJECT (1)   | NORMAL (0)   | ✗
37662.0         | REJECT (1)   | NORMAL (0)   | ✗
38054.0         | REJECT (1)   | NORMAL (0)   | ✗
38446.0         | REJECT (1)   | NORMAL (0)   | ✗
38838.0         | REJECT (1)   | NORMAL (0)   | ✗
39230.0         | REJECT (1)   | NORMAL (0)   | ✗
47990.0         | REJECT (1)   | NORMAL (0)   | ✗
48291.0         | REJECT (1)   | NORMAL (0)   | ✗
48620.0         | REJECT (1)   | NORMAL (0)   | ✗
48841.0         | REJECT (1)   | NORMAL (0)   | ✗
53274.0         | REJECT (1)   | NORMAL (0)   | ✗
53520.0         | REJECT (1)   | NORM