In [1]:
!pip install pandas networkx




[notice] A new release of pip is available: 23.2.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


# Graph Type and Purpose

You are constructing a **heterogeneous directed multigraph** using `NetworkX`’s `MultiDiGraph()` to model complex cyber network interactions. This design is particularly effective for advanced cybersecurity applications such as:

- **Graph-based threat detection**
- **Anomaly identification in multi-modal behaviors**
- **Learning embeddings for heterogeneous entities**

### Key Characteristics

- **Heterogeneous nodes**  
  Represents diverse entities: IP addresses, domain names, HTTP URIs, SSL certificate subjects/issuers, protocol violation types, etc.

- **Multi-view relationships**  
  Multiple directed edge types between the same pair of nodes allow different interaction views (e.g., flows, DNS queries, HTTP requests).

- **Directed edges**  
  Encode **temporal or causal flow** (e.g., `src_ip ➝ dst_ip`, `IP ➝ domain`), reflecting who initiated what.

# Node Types (Entities)

Each node represents a real-world entity, extracted from one or more dataset columns:

| Node Type         | Source Column(s)    | Description                                                                 |
|-------------------|---------------------|-----------------------------------------------------------------------------|
| **IP Address**     | `src_ip`, `dst_ip`  | Devices or interfaces on the network (e.g., `192.168.1.37`).                |
| **Domain Name**    | `dns_query`         | Fully qualified domain names queried by IPs (e.g., `www.example.com`).      |
| **HTTP URI**       | `http_uri`          | HTTP resource paths (e.g., `/login`, `/index.html`).                        |
| **SSL Subject**    | `ssl_subject`       | Distinguished Name of the certificate subject (e.g., `/C=US/O=Let's Encrypt`). |
| **SSL Issuer**     | `ssl_issuer`        | Distinguished Name of the certificate issuer (e.g., `/C=US/O=Google Trust Services`). |
| **Protocol Violation** | `weird_name`     | Descriptive label of detected anomalies (e.g., `bad_TCP_checksum`).         |

---

# Edge Types (Views)

Each directed edge represents an interaction or behavioral relationship, often enriched with protocol metadata:

## 1. `flow` — (IP ➝ IP)

Represents a network flow between two IP addresses.

- **Source:** `src_ip`  
- **Target:** `dst_ip`  
- **Attributes:**
  - `proto`, `service`, `duration`, `conn_state`
  - `src_bytes`, `dst_bytes`
  - `label`, `attack_type`

**Usefulness:**  
Defines the **structural backbone** of the graph, enabling analysis of traffic patterns and attack topologies.

## 2. `dns_query` — (IP ➝ Domain Name)

Represents a DNS lookup initiated by a host.

- **Source:** `src_ip`  
- **Target:** `dns_query`  
- **Attributes:**
  - `qclass`, `qtype`, `rcode`
  - `dns_AA`, `dns_RD`, `dns_RA`, `dns_rejected`

**Usefulness:**  
Reveals **host intent** and can indicate access to suspicious or malicious domains.

## 3. `http_request` — (IP ➝ HTTP URI)

Captures web resource requests made by a host.

- **Source:** `src_ip`  
- **Target:** `http_uri`  
- **Attributes:**
  - `method`, `version`, `status_code`
  - `trans_depth`, `req_body_len`, `resp_body_len`
  - `user_agent`, `orig_mime`, `resp_mime`

**Usefulness:**  
Reflects **web behavior**; useful for detecting scanning, reconnaissance, and probing activity.

## 4. `protocol_violation` — (IP ➝ Violation Label)

Links an IP to a protocol anomaly observed during communication.

- **Source:** `src_ip`  
- **Target:** `weird_name`  
- **Attributes:**
  - `weird_addl`, `weird_notice`

**Usefulness:**  
Highlights **anomalous or misconfigured hosts**. Many such events are early indicators of compromise or malicious activity.

# Semantic Graph Properties

- **IP nodes are central:**  
  Most interaction types originate from or are directed to IP addresses, making them critical in graph topology.

- **Multi-modal behavioral modeling:**  
  Combines HTTP, DNS, SSL, and flow-level information into one unified representation.

- **Multi-view learning ready:**  
  The graph supports training models on **protocol-specific subgraphs or jointly across views**.

- **Temporal/causal interpretation:**  
  Directed edges preserve **who initiated the interaction**, enabling traceability and behavioral profiling.


### 1. Community Detection
Apply clustering or community detection algorithms on specific views:
- flow → group IPs that communicate frequently
- dns_query → group IPs that query similar domains (suspicious beaconing behavior?)
- http_request → group clients based on similar URLs

## 2. Node Centrality Analysis
Compute betweenness centrality, eigenvector centrality, or PageRank on:
- Flow view → who routes/relays the most traffic?
- DNS view → which domains are queried the most?

Flow view → who routes/relays the most traffic?

DNS view → which domains are queried the most?

## TODO: DIMKA
Write the code for 
### 1. Community Detection
Apply clustering or community detection algorithms on specific views:
- flow → group IPs that communicate frequently
- dns_query → group IPs that query similar domains (suspicious beaconing behavior?)
- http_request → group clients based on similar URLs

### 2. Node Centrality Analysis
Compute betweenness centrality, eigenvector centrality, or PageRank on:
- Flow view → who routes/relays the most traffic?
- DNS view → which domains are queried the most?

In [3]:
!pip install torch torch_geometric




[notice] A new release of pip is available: 23.2.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [14]:
import torch
import pandas as pd
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder, StandardScaler
import networkx as nx

df = pd.read_csv("../datasets/train_test_network.csv")

G = nx.MultiDiGraph()

for _, row in df.iterrows():
    src_ip = row['src_ip']
    dst_ip = row['dst_ip']

    G.add_edge(
        src_ip, dst_ip,
        key="flow",
        proto=row.get("proto"),
        service=row.get("service"),
        duration=row.get("duration"),
        src_bytes=row.get("src_bytes"),
        dst_bytes=row.get("dst_bytes"),
        conn_state=row.get("conn_state"),
        label=row.get("label"),
        attack_type=row.get("type")
    )

    if pd.notna(row.get("dns_query")):
        dns_domain = row["dns_query"]
        G.add_edge(
            src_ip, dns_domain,
            key="dns_query",
            qclass=row.get("dns_qclass"),
            qtype=row.get("dns_qtype"),
            rcode=row.get("dns_rcode"),
            dns_AA=row.get("dns_AA"),
            dns_RD=row.get("dns_RD"),
            dns_RA=row.get("dns_RA"),
            dns_rejected=row.get("dns_rejected")
        )

    if pd.notna(row.get("http_uri")):
        http_target = row["http_uri"]
        G.add_edge(
            src_ip, http_target,
            key="http_request",
            method=row.get("http_method"),
            version=row.get("http_version"),
            status_code=row.get("http_status_code"),
            trans_depth=row.get("http_trans_depth"),
            req_body_len=row.get("http_request_body_len"),
            resp_body_len=row.get("http_response_body_len"),
            user_agent=row.get("http_user_agent"),
            orig_mime=row.get("http_orig_mime_types"),
            resp_mime=row.get("http_resp_mime_types")
        )

    if pd.notna(row.get("ssl_subject")):
        G.add_edge(
            src_ip, row["ssl_subject"],
            key="ssl_subject",
            ssl_version=row.get("ssl_version"),
            ssl_cipher=row.get("ssl_cipher"),
            ssl_resumed=row.get("ssl_resumed"),
            ssl_established=row.get("ssl_established")
        )

    if pd.notna(row.get("ssl_issuer")):
        G.add_edge(
            src_ip, row["ssl_issuer"],
            key="ssl_issuer",
            ssl_version=row.get("ssl_version"),
            ssl_cipher=row.get("ssl_cipher"),
            ssl_resumed=row.get("ssl_resumed"),
            ssl_established=row.get("ssl_established")
        )

    if pd.notna(row.get("weird_name")):
        G.add_edge(
            src_ip, row["weird_name"],
            key="protocol_violation",
            weird_addl=row.get("weird_addl"),
            weird_notice=row.get("weird_notice")
        )

print(f"Graph built with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")
print("Edge types (views) include:", set(k for _, _, k in G.edges(keys=True)))

Graph built with 1605 nodes and 2554 edges.
Edge types (views) include: {'dns_query', 'protocol_violation', 'http_request', 'ssl_subject', 'ssl_issuer', 'flow'}


In [17]:
label_counts = df['label'].value_counts()
print(f"Label 0 (Normal): {label_counts.get(0, 0)}")
print(f"Label 1 (Attack): {label_counts.get(1, 0)}")

Label 0 (Normal): 50000
Label 1 (Attack): 161043


In [8]:
from torch_geometric.nn import SAGEConv
import torch
import torch.nn.functional as F

# Convert flow_G (DiGraph) to PyG format
ip_nodes = [n for n in G.nodes if isinstance(n, str) and '.' in n]

node_to_idx = {node: i for i, node in enumerate(ip_nodes)}
edge_index = []

features = []
labels = []

for node in ip_nodes:
    out_deg = len([1 for _, _, k in G.out_edges(node, keys=True) if k == "flow"])
    in_deg = len([1 for _, _, k in G.in_edges(node, keys=True) if k == "flow"])

    features.append([in_deg, out_deg])

    label = "Normal"
    for _, _, k, d in G.out_edges(node, keys=True, data=True):
        if k == "flow" and d.get("label"):
            label = "Attack" if str(d["label"]).lower() != "normal" else "Normal"
            break
    labels.append(label)

# Encode features and labels
X = StandardScaler().fit_transform(features)
y = LabelEncoder().fit_transform(labels)
X = torch.tensor(X, dtype=torch.float)
y = torch.tensor(y, dtype=torch.long)

# Build edge index
for u, v in G.edges():
    if u in node_to_idx and v in node_to_idx:
        edge_index.append([node_to_idx[u], node_to_idx[v]])

edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

# Define PyG Data
data = Data(x=X, edge_index=edge_index, y=y)

# Split train/test
torch.manual_seed(42)
num_nodes = data.num_nodes
perm = torch.randperm(num_nodes)
train_idx = perm[:int(0.8 * num_nodes)]
test_idx = perm[int(0.8 * num_nodes):]

data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.train_mask[train_idx] = True
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask[test_idx] = True


# GNN Architectures
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x


class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=4):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=4)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x


class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x


# Training and Evaluation
def train(model, data, epochs=100, lr=0.01):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in range(epochs):
        model.train()
        out = model(data.x, data.edge_index)
        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        logits = model(data.x, data.edge_index)
        pred = logits[data.test_mask].argmax(dim=1)
        true = data.y[data.test_mask]
        report = classification_report(true.cpu(), pred.cpu(), target_names=["Normal", "Attack"], output_dict=True,
                                       zero_division=0)
        return report


# Experiment configurations
configs = [
    ("GCN", GCN(2, 16, 2)),
    ("GCN_deep", GCN(2, 64, 2)),
    ("GAT", GAT(2, 8, 2, heads=4)),
    ("GraphSAGE", GraphSAGE(2, 16, 2)),
]

for name, model in configs:
    report = train(model, data, epochs=100)
    acc = report['accuracy']
    prec = report['weighted avg']['precision']
    rec = report['weighted avg']['recall']
    f1 = report['weighted avg']['f1-score']
    print(f"{name}, {acc:.4f}, {prec:.4f}, {rec:.4f}, {f1:.4f}")

GCN, 0.9849, 0.9700, 0.9849, 0.9774
GCN_deep, 0.9849, 0.9700, 0.9849, 0.9774
GAT, 0.9849, 0.9700, 0.9849, 0.9774
GraphSAGE, 0.9887, 0.9888, 0.9887, 0.9853


## Code for Holdouts & Cross-validation

In [9]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x


def train(model, data, train_mask, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[train_mask], data.y[train_mask])
    loss.backward()
    optimizer.step()


from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score


def test(model, data, test_mask):
    model.eval()
    with torch.no_grad():
        logits = model(data.x, data.edge_index)
        preds = logits[test_mask].argmax(dim=1).cpu()
        labels = data.y[test_mask].cpu()

    return (
        accuracy_score(labels, preds),
        precision_score(labels, preds, zero_division=0),
        recall_score(labels, preds, zero_division=0),
        f1_score(labels, preds, zero_division=0),
    )


from sklearn.model_selection import train_test_split, StratifiedKFold


def run_holdout(data, test_sizes=None):
    if test_sizes is None:
        test_sizes = [0.1, 0.3, 0.5]
    results = []
    X = data.x.cpu().numpy()
    y = data.y.cpu().numpy()

    for test_size in test_sizes:
        train_idx, test_idx = train_test_split(
            range(len(y)), test_size=test_size, stratify=y, random_state=42
        )
        train_mask = torch.zeros(len(y), dtype=torch.bool)
        test_mask = torch.zeros(len(y), dtype=torch.bool)
        train_mask[train_idx] = True
        test_mask[test_idx] = True

        model = GraphSAGE(data.num_node_features, 32, int(data.y.max().item()) + 1).to(data.x.device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
        criterion = torch.nn.CrossEntropyLoss()

        for epoch in range(100):
            train(model, data, train_mask, optimizer, criterion)

        acc, prec, rec, f1 = test(model, data, test_mask)
        label = f"{int((1 - test_size) * 100)}/{int(test_size * 100)}"
        results.append((label, acc, prec, rec, f1))
    return results


def run_cv(data, splits=None):
    if splits is None:
        splits = [5, 10]
    results = []
    X = data.x.cpu().numpy()
    y = data.y.cpu().numpy()

    for k in splits:
        skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)
        accs, precs, recs, f1s = [], [], [], []

        for train_idx, test_idx in skf.split(X, y):
            train_mask = torch.zeros(len(y), dtype=torch.bool)
            test_mask = torch.zeros(len(y), dtype=torch.bool)
            train_mask[train_idx] = True
            test_mask[test_idx] = True

            model = GraphSAGE(data.num_node_features, 32, int(data.y.max().item()) + 1).to(data.x.device)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
            criterion = torch.nn.CrossEntropyLoss()

            for epoch in range(100):
                train(model, data, train_mask, optimizer, criterion)

            acc, prec, rec, f1 = test(model, data, test_mask)
            accs.append(acc)
            precs.append(prec)
            recs.append(rec)
            f1s.append(f1)

        results.append((str(k), sum(accs) / k, sum(precs) / k, sum(recs) / k, sum(f1s) / k))
    return results


# Run evaluations
holdout_results = run_holdout(data)
cv_results = run_cv(data)

print("Split/CV,Accuracy,precision,recal,f1-score")
for r in holdout_results + cv_results:
    print(f"{r[0]},{r[1]:.4f},{r[2]:.4f},{r[3]:.4f},{r[4]:.4f}")

Split/CV,Accuracy,precision,recal,f1-score
90/10,0.9925,0.9924,1.0000,0.9962
70/30,0.9950,0.9949,1.0000,0.9974
50/50,0.9909,0.9924,0.9985,0.9954
5,0.9909,0.9909,1.0000,0.9954
10,0.9886,0.9909,0.9977,0.9943
