## View Attack (Malicious) Time Windows

In [None]:
import json, sys, os
from typing import List, Dict, Any, Tuple, Optional
from datetime import datetime, timezone
import pandas as pd

# ---------- helpers ----------
def _unwrap(v):
    if isinstance(v, dict) and len(v) == 1:
        return next(iter(v.values()))
    return v

def _iso_from_ns(ns: Optional[int]) -> Optional[str]:
    try:
        return datetime.fromtimestamp(int(ns) / 1_000_000_000, tz=timezone.utc).isoformat()
    except Exception:
        return None

def _add_row(rows: List[Dict[str, Any]], name: str, start_ns: Any, end_ns: Any,
             start_iso: Optional[str], end_iso: Optional[str], malicious: Optional[int], meta: Dict[str, Any]):
    # Normalize ints
    try:
        s = int(_unwrap(start_ns)) if start_ns is not None else None
    except Exception:
        s = None
    try:
        e = int(_unwrap(end_ns)) if end_ns is not None else None
    except Exception:
        e = None

    # Derive ISO if missing
    if start_iso is None and s is not None:
        start_iso = _iso_from_ns(s)
    if end_iso is None and e is not None:
        end_iso = _iso_from_ns(e)

    # Label: default attacks => malicious=1 unless explicitly given 0/False
    lbl = malicious
    if lbl is None:
        lbl = 1  # every interval in this file is an attack window by definition

    rows.append({
        "name": name,
        "start_ns": s,
        "end_ns": e,
        "start_iso": start_iso,
        "end_iso": end_iso,
        "malicious": int(bool(lbl)),
        **meta
    })

# ---------- main loader ----------
def load_attack_windows(path: str) -> List[Dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    rows: List[Dict[str, Any]] = []

    # Schema A: {"attacks_by_host": {"HOST": [ {start_ns, end_ns, ...}, ... ], ...}}
    if isinstance(data, dict) and isinstance(data.get("attacks_by_host"), dict):
        for host, arr in data["attacks_by_host"].items():
            if isinstance(arr, list):
                for i, a in enumerate(arr):
                    if not isinstance(a, dict): 
                        continue
                    start_ns = a.get("start_ns") or (a.get("window") or {}).get("start_unix_ns")
                    end_ns   = a.get("end_ns")   or (a.get("window") or {}).get("end_unix_ns")
                    start_iso = a.get("start_iso") or (a.get("window") or {}).get("start_iso")
                    end_iso   = a.get("end_iso")   or (a.get("window") or {}).get("end_iso")
                    malicious = a.get("malicious")
                    name = a.get("name") or a.get("label") or host
                    _add_row(rows, name, start_ns, end_ns, start_iso, end_iso, malicious,
                             {"host": host, "idx": i})
        return rows

    # Schema B: {"attacks": [ {"window":{"start_unix_ns","end_unix_ns","start_iso","end_iso"}, "name":...,"malicious":...}, ... ]}
    if isinstance(data, dict) and isinstance(data.get("attacks"), list):
        for i, a in enumerate(data["attacks"]):
            if not isinstance(a, dict): 
                continue
            w = a.get("window", {})
            start_ns = w.get("start_unix_ns")
            end_ns   = w.get("end_unix_ns")
            start_iso = w.get("start_iso")
            end_iso   = w.get("end_iso")
            malicious = a.get("malicious")
            name = a.get("name") or a.get("label") or f"attack_{i}"
            _add_row(rows, name, start_ns, end_ns, start_iso, end_iso, malicious, {"idx": i})
        return rows

    # Schema C: {"intervals": [[start_ns, end_ns], ...]}
    if isinstance(data, dict) and isinstance(data.get("intervals"), list):
        for i, pair in enumerate(data["intervals"]):
            if isinstance(pair, (list, tuple)) and len(pair) == 2:
                _add_row(rows, f"attack_{i}", pair[0], pair[1], None, None, 1, {"idx": i})
        return rows

    # Schema D: list of dicts like [{"start_ns":..., "end_ns":..., "start_iso":..., "end_iso":..., "name":...}, ...]
    if isinstance(data, list):
        for i, a in enumerate(data):
            if not isinstance(a, dict):
                continue
            start_ns = a.get("start_ns") or (a.get("window") or {}).get("start_unix_ns")
            end_ns   = a.get("end_ns")   or (a.get("window") or {}).get("end_unix_ns")
            start_iso = a.get("start_iso") or (a.get("window") or {}).get("start_iso")
            end_iso   = a.get("end_iso")   or (a.get("window") or {}).get("end_iso")
            malicious = a.get("malicious")
            name = a.get("name") or a.get("label") or f"attack_{i}"
            _add_row(rows, name, start_ns, end_ns, start_iso, end_iso, malicious, {"idx": i})
        return rows

    # Fallback: unknown structure
    raise ValueError("Unsupported attack windows schema in file: " + path)

# ---------- main ----------
path = r"C:\Users\Ali\Desktop\ChatGPT Scripts\attack_windows_e5_iso_and_ns_cleaned.json"
if not os.path.isfile(path):
    print(f"[ERROR] File not found: {path}")
else:
    rows = load_attack_windows(path)

    # Sort by start_ns (if present)
    rows.sort(key=lambda r: (r["start_ns"] is None, r["start_ns"]))

    # Show summary
    total = len(rows)
    mal = sum(r.get("malicious", 0) == 1 for r in rows)
    print(f"\nLoaded attack windows: {total} | malicious=1 count: {mal}\n")

    # Pretty display
    cols = ["name", "start_iso", "end_iso", "start_ns", "end_ns", "malicious"]
    if pd:
        df = pd.DataFrame(rows)
        # If 'host' exists (in some schemas), include it
        if "host" in df.columns and "host" not in cols:
            cols.insert(1, "host")
        print(df[ [c for c in cols if c in df.columns] ].to_string(index=False))
    else:
        # fallback plain text table
        def fmt(x): return "" if x is None else str(x)
        header = " | ".join(cols)
        print(header)
        print("-" * len(header))
        for r in rows:
            print(" | ".join(fmt(r.get(c)) for c in cols))

## Check Malicious Events in Malicious Folder

In [None]:
import pandas as pd
from pathlib import Path

malicious_dir = Path(r"C:\Users\Ali\Desktop\ChatGPT Scripts\out_parquet_time_only\events\malicious")

total_malicious = 0
results = {}

for parquet_file in malicious_dir.glob("*.parquet"):
    df = pd.read_parquet(parquet_file, columns=["malicious"])
    
    # normalize column
    s = pd.to_numeric(df["malicious"], errors="coerce")
    
    cnt_malicious = (s == 1).sum()
    cnt_non = (s != 1).sum()
    
    results[parquet_file.name] = {"malicious": int(cnt_malicious), "non_malicious": int(cnt_non)}
    total_malicious += int(cnt_malicious)

# Print per-file summary
print("=== Malicious Folder File-by-File Check ===")
for fname, counts in results.items():
    print(f"{fname}: malicious={counts['malicious']}, non_malicious={counts['non_malicious']}")

print("\nTOTAL malicious events:", total_malicious)


## Check Non-Malicious Events in Non-Malicious Folder

In [None]:
import pandas as pd
from pathlib import Path

non_malicious_dir = Path(r"C:\Users\Ali\Desktop\ChatGPT Scripts\out_parquet_time_only\events\non_malicious")

total_non_malicious = 0
results = {}

for parquet_file in non_malicious_dir.glob("*.parquet"):
    df = pd.read_parquet(parquet_file, columns=["malicious"])
    
    # normalize column
    s = pd.to_numeric(df["malicious"], errors="coerce")
    
    cnt_non = (s == 0).sum()
    cnt_malicious = (s != 0).sum()
    
    results[parquet_file.name] = {"non_malicious": int(cnt_non), "malicious": int(cnt_malicious)}
    total_non_malicious += int(cnt_non)

# Print per-file summary
print("=== Non-Malicious Folder File-by-File Check ===")
for fname, counts in results.items():
    print(f"{fname}: non_malicious={counts['non_malicious']}, malicious={counts['malicious']}")

print("\nTOTAL non-malicious events:", total_non_malicious)


## Inspect Features for each node type

In [None]:
"""
inspect_features.py

Explore what columns (features) exist for each node type (subjects, fileobjects, etc.)
and for events in the preprocessed Parquet dataset. Also prints a few sample rows.

Usage:
  python inspect_features.py --input-dir ./out_parquet_time_only --samples 3
"""

import argparse
from pathlib import Path
import pyarrow.parquet as pq

def inspect_folder(parquet_dir: Path, n_samples: int = 3):
    """Inspect schema and sample rows from a parquet folder."""
    files = sorted(parquet_dir.glob("*.parquet"))
    if not files:
        return None, None
    # Open first file
    pf = pq.ParquetFile(files[0])
    cols = pf.schema.names
    # Read a few rows as dicts
    table = pf.read_row_groups([0], columns=cols)
    rows = table.to_pylist()[:n_samples]
    return cols, rows

#------ main -----------
input_dir = r"./out_parquet_time_only"
samples = 2
base = Path(input_dir)
if not base.exists():
    raise FileNotFoundError(f"{base} not found")

folders = [p for p in base.iterdir() if p.is_dir()]
print(f"Inspecting {len(folders)} parquet folders under {base}\n")

for folder in folders:
    cols, rows = inspect_folder(folder, samples)
    if cols is None:
        print(f"[{folder.name}] No parquet files found\n")
        continue

    print(folder.name.upper())
    print(f"Total Columns: {len(cols)}")
    print("Columns: " + ", ".join(cols))
    for i, row in enumerate(rows):
        print(f"Sample Row {i+1}: {row}")
    print("")


## Make 10-Mint Window with UUID to Type --> .gpickle Output

Quick test script to extract a single 10-minute provenance graph
from your events parquet folder. Saves to .gpickle and prints:
- number of nodes
- number of edges
- top-degree nodes
- sample edges
- time taken

Now includes UUID → type mapping so nodes get real types
(subjects, fileobjects, netflows, etc.) instead of 'unknown'.

In [None]:
import time
import logging
import pickle
import networkx as nx
import pyarrow.dataset as ds
import pyarrow.parquet as pq
from pathlib import Path
from datetime import datetime, timezone

# ----------- Config -----------
BASE_DIR = Path("./out_parquet_time_only")
INPUT_DIR = BASE_DIR / "events"
OUTPUT_FILE = Path("./test_window.gpickle")
WINDOW_MIN = 10
BATCH_SIZE = 100_000
LOG_LEVEL = logging.INFO
# ------------------------------

NS_PER_SEC = 1_000_000_000
NS_PER_MIN = 60 * NS_PER_SEC

logging.basicConfig(
    format="[%(levelname)s %(asctime)s] %(message)s",
    datefmt="%H:%M:%S",
    level=LOG_LEVEL
)
log = logging.getLogger("test-one-window")

def ns_to_iso(ns: int) -> str:
    return datetime.fromtimestamp(ns / 1e9, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")

def save_graph(G, path: Path):
    if hasattr(nx, "write_gpickle"):
        nx.write_gpickle(G, path)
    else:
        with open(path, "wb") as f:
            pickle.dump(G, f)

def build_uuid_type_map(base: Path) -> dict:
    """Scan non-event folders (subjects, fileobjects, etc.) and build UUID → type map."""
    uuid2type = {}
    node_folders = [p for p in base.iterdir() if p.is_dir() and p.name.lower() != "events"]
    for folder in node_folders:
        try:
            dataset = ds.dataset(folder, format="parquet")
        except Exception as e:
            log.warning(f"Skipping {folder}: {e}")
            continue
        for frag in dataset.get_fragments():
            try:
                tb = pq.read_table(frag.path, columns=["uuid"])
            except Exception as e:
                log.warning(f"Error reading {frag.path}: {e}")
                continue
            for row in tb.to_pylist():
                u = row.get("uuid")
                if u and u not in uuid2type:
                    uuid2type[u] = folder.name
    log.info(f"UUID→type map built with {len(uuid2type):,} entries")
    return uuid2type

# ----------- main -----------
win_ns = WINDOW_MIN * NS_PER_MIN
graphs = {}

t0 = time.time()
files = sorted(INPUT_DIR.glob("*.parquet"))
if not files:
    raise FileNotFoundError("No parquet files in events/ folder")

log.info(f"Found {len(files)} parquet files, using first to build one window")

# Build UUID→type map first
log.info(f"START: Building UUID→type map from node folders...")
uuid2type = build_uuid_type_map(BASE_DIR)
log.info(f"Completed: Building UUID→type map")

for file in files:
    log.info(f"Reading file {file.name}")
    pf = pq.ParquetFile(file)
    cols = [c for c in ["timestampNanos","event_type","subject_uuid","object1_uuid","object2_uuid","malicious"]
            if c in pf.schema_arrow.names]

    for batch in pf.iter_batches(batch_size=BATCH_SIZE, columns=cols):
        rows = batch.to_pylist()
        for r in rows:
            ts = r.get("timestampNanos")
            if ts is None: 
                continue
            ts = int(ts)
            ws = (ts // win_ns) * win_ns
            if ws not in graphs:
                graphs[ws] = nx.MultiDiGraph()
            u, v1, v2 = r.get("subject_uuid"), r.get("object1_uuid"), r.get("object2_uuid")
            for node in [u,v1,v2]:
                if node and not graphs[ws].has_node(node):
                    graphs[ws].add_node(node, node_type=uuid2type.get(node,"unknown"))
            et = r.get("event_type")
            mal = bool(r.get("malicious", False))
            if u and v1:
                graphs[ws].add_edge(u,v1,event_type=et,timestamp=ts,malicious=mal)
            if u and v2:
                graphs[ws].add_edge(u,v2,event_type=et,timestamp=ts,malicious=mal)

        # Once first window fills, flush and exit
        if graphs:
            ws0 = min(graphs.keys())
            G = graphs[ws0]
            save_graph(G, OUTPUT_FILE)
            elapsed = time.time() - t0
            log.info(f"Saved {OUTPUT_FILE}")
            log.info(f"Window start: {ns_to_iso(ws0)}")
            log.info(f"Nodes: {G.number_of_nodes()}, Edges: {G.number_of_edges()}")
            log.info(f"Time taken: {elapsed:.2f} sec")

            # ---- Visual inspection ----
            log.info("Top 10 nodes by degree:")
            degs = sorted(G.degree, key=lambda x: x[1], reverse=True)[:10]
            for node, d in degs:
                log.info(f"  {node} ({G.nodes[node]['node_type']}) -> degree {d}")

            log.info("Sample edges (3):")
            for i, (u,v,edata) in enumerate(G.edges(data=True)):
                if i >= 3: break
                log.info(f"  {u} ({G.nodes[u]['node_type']}) -> {v} ({G.nodes[v]['node_type']}), attrs={edata}")
            break # stop after first window
    break  # stop after first batch

## Visualize a small subgraph for inspection.

In [None]:
import pickle
import networkx as nx
import matplotlib.pyplot as plt
from pathlib import Path

def load_graph(path: Path):
    if hasattr(nx, "read_gpickle"):
        return nx.read_gpickle(path)
    else:
        with open(path, "rb") as f:
            return pickle.load(f)

# Arguments
gpickle_path = Path("test_window_withUUID_to_Type.gpickle")
# gpickle_path = Path("test_window_withoutUUID_to_Type.gpickle")
max_nodes = 50

G = load_graph(Path(gpickle_path))
print(f"Loaded graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

# Pick top-degree nodes to keep
nodes_sorted = sorted(G.degree, key=lambda x: x[1], reverse=True)
keep_nodes = [n for n, _ in nodes_sorted[:max_nodes]]
H = G.subgraph(keep_nodes).copy()

print(f"Visualizing subgraph with {H.number_of_nodes()} nodes and {H.number_of_edges()} edges")

# Node colors by type
node_types = nx.get_node_attributes(H, "node_type")
color_map = {"subjects": "skyblue", "fileobjects": "orange",
                "netflows": "green", "memory": "purple",
                "ipc": "pink", "registrykeys": "red", "unknown": "gray"}
node_colors = [color_map.get(node_types.get(n, "unknown"), "gray") for n in H.nodes]

pos = nx.spring_layout(H, seed=42, k=0.3)  # spring layout
plt.figure(figsize=(10, 8))
nx.draw_networkx_nodes(H, pos, node_size=300, node_color=node_colors)
nx.draw_networkx_edges(H, pos, alpha=0.4)
nx.draw_networkx_labels(H, pos, font_size=7)

plt.title(f"Subgraph of {gpickle_path} (top {max_nodes} nodes by degree)")
plt.axis("off")
plt.tight_layout()
plt.show()

## Check Edge Dimensions

In [None]:
# check_edge_dims.py
import torch
from pathlib import Path

data_dir = Path(r"C:\Users\Ali\Desktop\ChatGPT Scripts\win_tensors_balanced")
pt_files = list(data_dir.glob("*.pt")) # Check all files

for pt_file in pt_files:
    data = torch.load(pt_file, weights_only=False) # Adjust if your data is a Data object directly
    # If data is a dict-like object from torch.load
    if isinstance(data, dict):
        edge_attr = data.get('edge_attr')
    # If data is a PyG Data object directly (less common with torch.save dict)
    # else:
    #     edge_attr = getattr(data, 'edge_attr', None)

    if edge_attr is not None:
        print(f"{pt_file.name}: edge_attr.shape = {edge_attr.shape}")
        # Check if the number of features (columns) is consistent (e.g., 63)
        # assert edge_attr.shape[1] == 63, f"Inconsistent dim in {pt_file.name}"
    else:
        print(f"{pt_file.name}: edge_attr not found!")

## Check Graph Dimensions

In [None]:
# check_graph_dims.py
import torch
from pathlib import Path

# Point this to your directory containing the .pt files
data_dir = Path(r"C:\Users\Ali\Desktop\ChatGPT Scripts\win_tensors_balanced")

# Get a list of .pt files (checking first few for speed, you can remove the slice [:10] to check all)
pt_files = list(data_dir.glob("*.pt")) #[:10] # Remove [:10] to check all files

# Use a set to store unique numbers of node features found
unique_x_dims = set()

print(f"Checking dimensions for node features (x) in .pt files...")
print("-" * 20)

for pt_file in pt_files:
    try:
        # Load the .pt file. It likely contains a dictionary.
        data_dict = torch.load(pt_file, weights_only=False)

        # Extract the node feature tensor 'x'
        # Adjust the key name if your script saves it differently (e.g., 'node_features')
        x = data_dict.get('x')

        if x is not None:
            # Get the number of features (columns) for nodes
            num_node_features = x.shape[1] # Shape is typically [num_nodes, num_features]
            unique_x_dims.add(num_node_features)
            # Print the first few files' dimensions to see an example
            if len(unique_x_dims) <= 3: # Just print a few examples
                 print(f"{pt_file.name}: x.shape = {x.shape}")
        else:
            print(f"{pt_file.name}: 'x' tensor not found in the loaded data.")

    except Exception as e:
        print(f"Error loading {pt_file.name}: {e}")

print("-" * 20)
print(f"Unique node feature dimensions (x.shape[1]) found: {sorted(unique_x_dims)}")

if len(unique_x_dims) > 1:
    print("\n*** INCONSISTENCY DETECTED ***")
    print("Graphs have different numbers of node features.")
    print("This will cause a batching error in gine_train.py.")
else:
    print("\nNode feature dimensions appear consistent (based on checked files).")
    print("If gine_train.py still fails, the issue might be elsewhere or in unchecked files.")


## Analyze TCCDMDatum.json

In [None]:
import json
import argparse
from pathlib import Path

def walk_schema(obj, found):
    """Recursively walk JSON schema and collect record/enum names."""
    if isinstance(obj, dict):
        t = obj.get("type")
        name = obj.get("name")

        # If it's an enum, record its symbols
        if t == "enum" and name:
            found["enums"][name] = obj.get("symbols", [])

        # If it's a record, record its fields
        if t == "record" and name:
            found["records"].append(name)
            for f in obj.get("fields", []):
                walk_schema(f.get("type"), found)

        # Unions or nested types
        if isinstance(t, list):
            for u in t:
                walk_schema(u, found)

        # Recurse into dicts
        if isinstance(t, dict):
            walk_schema(t, found)

    elif isinstance(obj, list):
        for item in obj:
            walk_schema(item, found)

# ----------- main -----------
schema = r"C:\Users\Ali\Desktop\ChatGPT Scripts\TCCDMDatum.json"
path = Path(schema)
with open(path, "r", encoding="utf-8") as f:
    schema = json.load(f)

found = {"records": [], "enums": {}}
walk_schema(schema, found)

print("=== Record Types ===")
for r in sorted(set(found["records"])):
    print(" -", r)

print("\n=== Enums and Symbols ===")
for enum, symbols in found["enums"].items():
    print(f"{enum}: {symbols}")


## Analyze .pt Files

In [None]:
"""
inspect_pt_file.py

Inspect a .pt file from your DARPA E5 dataset to diagnose KeyError in gine_train.py.
Prints top-level keys and detailed contents (tensor shapes, dict structures).
Optionally scans multiple files for key consistency.

Usage:
  python inspect_pt_file.py --pt-file "path/to/sample.pt"
  # or scan all .pt files in a directory:
  python inspect_pt_file.py --data-dir "path/to/win_tensors_balanced"

Requires:
  - Python 3.10+
  - PyTorch >= 2.2
  - PyTorch Geometric >= 2.5
"""

import argparse
from pathlib import Path
from typing import Any, Dict, List
import torch
from torch_geometric.data import Data

def print_tensor_info(name: str, tensor: Any) -> None:
    """Print tensor shape and dtype, or value if scalar."""
    if isinstance(tensor, torch.Tensor):
        print(f"  {name}: shape={tuple(tensor.shape)}, dtype={tensor.dtype}")
    else:
        print(f"  {name}: {tensor} (type={type(tensor).__name__})")

def inspect_pt_file(pt_path: Path) -> None:
    """Inspect a single .pt file's contents."""
    print(f"\nInspecting: {pt_path.name}")
    try:
        obj = torch.load(pt_path, map_location="cpu")
        print("Top-level keys:", list(obj.keys()) if isinstance(obj, dict) else "PyG Data object")
        
        if isinstance(obj, Data):
            # Handle PyG Data object
            print("Attributes in Data object:")
            for key, value in obj.items():
                print_tensor_info(key, value)
        elif isinstance(obj, dict):
            # Handle raw dict
            print("Contents of dict:")
            for key, value in obj.items():
                if isinstance(value, dict):
                    print(f"  {key}: dict with keys {list(value.keys())}")
                else:
                    print_tensor_info(key, value)
        else:
            print(f"Unexpected type: {type(obj).__name__}")
    except Exception as e:
        print(f"Error loading {pt_path.name}: {str(e)}")

def scan_directory(data_dir: Path) -> None:
    """Scan all .pt files in a directory for key consistency."""
    pt_files = sorted([p for p in data_dir.glob("*.pt") if p.is_file()])
    if not pt_files:
        print(f"No .pt files found in {data_dir}")
        return
    
    key_sets: Dict[str, set] = {}
    for pt_file in pt_files:
        try:
            obj = torch.load(pt_file, map_location="cpu")
            keys = set(obj.keys()) if isinstance(obj, dict) else set(obj.keys)
            key_sets[pt_file.name] = keys
        except Exception as e:
            print(f"Error loading {pt_file.name}: {str(e)}")
    
    # Check for inconsistent keys
    print("\nKey consistency across files:")
    all_keys = set.union(*key_sets.values())
    for key in all_keys:
        files_missing = [name for name, keys in key_sets.items() if key not in keys]
        if files_missing:
            print(f"Key '{key}' missing in {len(files_missing)} files: {', '.join(files_missing[:3])}{'...' if len(files_missing) > 3 else ''}")
        else:
            print(f"Key '{key}' present in all {len(pt_files)} files")

pt_file = None
data_dir = r"C:\Users\Ali\Desktop\ChatGPT Scripts\win_tensors_balanced"
if pt_file:
    pt_path = Path(pt_file)
    if not pt_path.exists():
        print(f"File not found: {pt_path}")

    inspect_pt_file(pt_path)
elif data_dir:
    data_dir = Path(data_dir)
    if not data_dir.exists():
        print(f"Directory not found: {data_dir}")

    scan_directory(data_dir)
else:
    print("Please provide --pt-file or --data-dir")

In [None]:
from pathlib import Path
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import logging

def setup_logging():
    logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s] %(message)s")

def print_tensor_info(name: str, tensor: Any) -> None:
    if isinstance(tensor, torch.Tensor):
        logging.info(f"  {name}: shape={tuple(tensor.shape)}, dtype={tensor.dtype}")
    else:
        logging.info(f"  {name}: {tensor} (type={type(tensor).__name__})")

def inspect_meta(pt_path: Path) -> None:
    logging.info(f"Inspecting: {pt_path.name}")
    try:
        obj = torch.load(pt_path, map_location="cpu")
        logging.info(f"Top-level keys: {list(obj.keys())}")
        
        if isinstance(obj, dict):
            meta = obj.get("meta", {})
            if isinstance(meta, dict):
                logging.info("Contents of 'meta':")
                for key, value in meta.items():
                    if isinstance(value, dict):
                        logging.info(f"  {key}: dict with {len(value)} keys (e.g., {list(value.keys())[:3]}{'...' if len(value) > 3 else ''})")
                    else:
                        print_tensor_info(key, value)
            for key, value in obj.items():
                if key != "meta":
                    print_tensor_info(key, value)
        else:
            logging.info("Expected dict, got PyG Data or other type")
    except Exception as e:
        logging.error(f"Error loading {pt_path.name}: {str(e)}")

def test_batching(data_dir: Path, batch_size: int = 2) -> None:
    class TempDataset(torch.utils.data.Dataset):
        def __init__(self, data_dir: Path):
            self.files = sorted([p for p in data_dir.glob("*.pt") if p.is_file()])
            if not self.files:
                raise ValueError(f"No .pt files in {data_dir}")
        
        def __len__(self):
            return min(len(self.files), 10)  # Test first 10 files
        
        def __getitem__(self, i: int) -> Data:
            obj = torch.load(self.files[i], map_location="cpu")
            lab = obj.get("label", obj.get("y", None))
            if lab is None:
                raise KeyError(f"{self.files[i].name}: missing 'label'/'y'")
            d = Data(x=obj["x"], edge_index=obj["edge_index"], edge_attr=obj["edge_attr"], y=torch.tensor(int(lab)))
            d.x = d.x.to(torch.float32)
            d.edge_index = d.edge_index.to(torch.long)
            d.edge_attr = d.edge_attr.to(torch.float32)
            d.y = d.y.view(-1)[0].to(torch.long)
            return d

    logging.info(f"Testing batching on {data_dir}")
    try:
        dataset = TempDataset(data_dir)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        batch = next(iter(loader))
        logging.info("Batching successful! Batch details:")
        print_tensor_info("x", batch.x)
        print_tensor_info("edge_index", batch.edge_index)
        print_tensor_info("edge_attr", batch.edge_attr)
        print_tensor_info("y", batch.y)
        print_tensor_info("batch", batch.batch)
    except Exception as e:
        logging.error(f"Batching failed: {str(e)}")

setup_logging()

pt_file = None
data_dir = r"C:\Users\Ali\Desktop\ChatGPT Scripts\win_tensors_balanced"
if pt_file:
    pt_path = Path(pt_file)
    if not pt_path.exists():
        logging.error(f"File not found: {pt_path}")

    inspect_meta(pt_path)
if data_dir:
    data_dir = Path(data_dir)
    if not data_dir.exists():
        logging.error(f"Directory not found: {data_dir}")

    inspect_meta(data_dir / sorted(data_dir.glob("*.pt"))[0])  # Inspect first file
    test_batching(data_dir)

## Check single .pt file

In [None]:
import os
import torch
from torch_geometric.data import Data
import glob

def check_pt_file(file_path):
    """Inspect a single .pt file and print its contents."""
    print(f"\nChecking file: {file_path}")
    try:
        # Attempt to load the .pt file
        data = torch.load(file_path)
        print(f"  File loaded successfully!")
        print(f"  Type of loaded object: {type(data)}")

        # Check if it's a PyTorch Geometric Data object
        if isinstance(data, Data):
            print("  PyG Data object detected with attributes:")
            for key in data.keys:
                value = data[key]
                print(f"    {key}: {type(value)}, Shape: {getattr(value, 'shape', 'N/A')}")
            if hasattr(data, 'y'):
                print(f"    Label (y): {data.y}")
        # Check if it's a tensor
        elif isinstance(data, torch.Tensor):
            print(f"  Tensor detected with shape: {data.shape}")
            print(f"  Data type: {data.dtype}")
        # Check if it's a dictionary or other object
        elif isinstance(data, dict):
            print("  Dictionary detected with keys:")
            for key, value in data.items():
                print(f"    {key}: {type(value)}, Shape: {getattr(value, 'shape', 'N/A')}")
        else:
            print(f"  Unknown object type: {type(data)}")
    except Exception as e:
        print(f"  Error loading file: {str(e)}")

def main():
    # Path to the directory containing .pt files
    directory = "./win_tensors_balanced"  # Adjust if your files are elsewhere
    pt_files = glob.glob(os.path.join(directory, "*.pt"))

    if not pt_files:
        print(f"No .pt files found in {directory}")
        return

    print(f"Found {len(pt_files)} .pt files")
    for file_path in pt_files:
        check_pt_file(file_path)

if __name__ == "__main__":
    main()

## Checking GINE Compatibility

In [None]:
import torch
from torch_geometric.data import Data

def check_gine_compatibility(file_path):
    """Check if a .pt file is compatible with GINE model requirements."""
    print(f"\nChecking GINE compatibility for: {file_path}")
    try:
        data = torch.load(file_path)
        if not isinstance(data, Data):
            print("  Error: Not a PyTorch Geometric Data object")
            return

        # Check required attributes for GINE
        required_attrs = ['x', 'edge_index', 'edge_attr', 'y']
        missing_attrs = [attr for attr in required_attrs if not hasattr(data, attr)]
        if missing_attrs:
            print(f"  Missing attributes: {missing_attrs}")
        else:
            print("  All required attributes present:")
            print(f"    Node features (x): {data.x.shape}, Type: {data.x.dtype}")
            print(f"    Edge index (edge_index): {data.edge_index.shape}, Type: {data.edge_index.dtype}")
            print(f"    Edge attributes (edge_attr): {data.edge_attr.shape}, Type: {data.edge_attr.dtype}")
            print(f"    Label (y): {data.y}, Type: {data.y.dtype}")

            # Verify data integrity
            if data.x.shape[0] == 0:
                print("  Warning: No nodes (x.shape[0] == 0)")
            if data.edge_index.shape[1] == 0:
                print("  Warning: No edges (edge_index.shape[1] == 0)")
            if data.edge_index.max() >= data.x.shape[0]:
                print("  Error: Edge indices reference non-existent nodes")
    except Exception as e:
        print(f"  Error loading file: {str(e)}")

def main():
    # Path to a single .pt file for testing (replace with a problematic file)
    file_path = r"C:\Users\Ali\Desktop\ChatGPT Scripts\win_tensors_balanced\win_2019-05-07T18-40-00Z__2019-05-07T18-50-00Z__ben_non_malicious__ta1-marple-2-e5-official-1.bin__part-000003.pt"
    check_gine_compatibility(file_path)

if __name__ == "__main__":
    main()

In [None]:
data = torch.load(r"C:\Users\Ali\Desktop\ChatGPT Scripts\win_tensors_balanced\win_2019-05-07T18-40-00Z__2019-05-07T18-50-00Z__ben_non_malicious__ta1-marple-2-e5-official-1.bin__part-000003.pt")
data

## Checking Dataset Stats

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import torch
from torch_geometric.data import Data
import glob
import logging
import numpy as np

def setup_logging():
    logging.basicConfig(
        level=logging.INFO,
        format="[%(asctime)s][%(levelname)s] %(message)s",
        handlers=[logging.StreamHandler()]
    )
    return logging.getLogger(__name__)

logger = setup_logging()

def analyze_dataset(data_dir: str):
    pt_files = glob.glob(os.path.join(data_dir, "*.pt"))
    if not pt_files:
        logger.error(f"No .pt files found in {data_dir}")
        return

    labels = []
    node_counts = []
    edge_counts = []
    x_stats = []
    edge_attr_stats = []

    for file_path in pt_files:
        try:
            loaded_data = torch.load(file_path)
            if not isinstance(loaded_data, dict):
                logger.warning(f"Skipping {file_path}: Not a dict")
                continue

            x = loaded_data.get('x')
            edge_index = loaded_data.get('edge_index')
            edge_attr = loaded_data.get('edge_attr')
            label = loaded_data.get('label')

            if x is None or edge_index is None or edge_attr is None or label is None:
                logger.warning(f"Skipping {file_path}: Missing required keys")
                continue

            labels.append(int(label))
            node_counts.append(x.shape[0])
            edge_counts.append(edge_index.shape[1])
            x_stats.append([x.min().item(), x.max().item(), x.mean().item(), x.std().item()])
            edge_attr_stats.append([edge_attr.min().item(), edge_attr.max().item(), edge_attr.mean().item(), edge_attr.std().item()])

        except Exception as e:
            logger.warning(f"Failed to load {file_path}: {str(e)}")
            continue

    if not labels:
        logger.error("No valid data found")
        return

    labels = np.array(labels)
    class_counts = np.bincount(labels)
    logger.info(f"Class distribution: {class_counts}")
    logger.info(f"Class weights: {1.0 / (class_counts + 1e-6)}")

    node_counts = np.array(node_counts)
    edge_counts = np.array(edge_counts)
    x_stats = np.array(x_stats)
    edge_attr_stats = np.array(edge_attr_stats)

    logger.info(f"Node counts: min={node_counts.min()}, max={node_counts.max()}, mean={node_counts.mean():.2f}, std={node_counts.std():.2f}")
    logger.info(f"Edge counts: min={edge_counts.min()}, max={edge_counts.max()}, mean={edge_counts.mean():.2f}, std={edge_counts.std():.2f}")
    logger.info(f"Node features (x): min={x_stats[:,0].min():.4f}, max={x_stats[:,1].max():.4f}, mean={x_stats[:,2].mean():.4f}, std={x_stats[:,3].mean():.4f}")
    logger.info(f"Edge features (edge_attr): min={edge_attr_stats[:,0].min():.4f}, max={edge_attr_stats[:,1].max():.4f}, mean={edge_attr_stats[:,2].mean():.4f}, std={edge_attr_stats[:,3].mean():.4f}")

if __name__ == "__main__":
    data_dir = r"C:\Users\Ali\Desktop\ChatGPT Scripts\win_tensors_balanced"
    analyze_dataset(data_dir)