In [None]:
# ==============================================================================
# Step 0: Install Dependencies
# ==============================================================================
# This notebook requires PyTorch Geometric and sentence-transformers.
!pip install ipywidgets pandas torch_geometric matplotlib sentence-transformers -q

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HGTConv, Linear
from torch_geometric.data import HeteroData
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
from sklearn.manifold import TSNE
import warnings
warnings.filterwarnings("ignore")

print(f"PyTorch Version: {torch.__version__}")

In [None]:
# ==============================================================================
# Step 1: Generate the Temporal Network Data
# ==============================================================================
# We will create a network of 20 devices with a failure that spreads.

NUM_NODES = 20
NUM_TIMESTEPS = 15
DEVICE_TYPES = ['router', 'switch', 'firewall']

# Create a random graph structure
G = nx.erdos_renyi_graph(NUM_NODES, p=0.2, seed=42)
while not nx.is_connected(G):
    G = nx.erdos_renyi_graph(NUM_NODES, p=0.2, seed=np.random.randint(1000))

# Assign device types and initial states
node_data = []
for i in range(NUM_NODES):
    node_data.append({
        'id': i,
        'type': np.random.choice(DEVICE_TYPES),
        'status': 'normal'
    })
df_nodes = pd.DataFrame(node_data)

# --- Simulate Failure Propagation ---
# At T=5, a switch will fail. The failure will spread to its neighbors.
FAILURE_START_TIME = 5
PATIENT_ZERO_ID = df_nodes[df_nodes['type'] == 'switch'].id.iloc[0]

temporal_snapshots = []

for t in range(NUM_TIMESTEPS):
    cpu_usage = []
    mem_usage = []
    status_text = []

    # Update node states based on failure propagation
    for i in range(NUM_NODES):
        node = df_nodes.iloc[i]
        current_status = node['status']

        if node.id == PATIENT_ZERO_ID and t >= FAILURE_START_TIME:
            df_nodes.at[i, 'status'] = 'failure'
        elif current_status == 'failure':
             # Keep it in failure state
            pass
        elif current_status == 'impacted':
             # Keep it in impacted state
            pass
        else: # Normal state
            neighbors = list(G.neighbors(i))
            for neighbor_id in neighbors:
                if df_nodes.iloc[neighbor_id]['status'] == 'failure' and t > FAILURE_START_TIME:
                    df_nodes.at[i, 'status'] = 'impacted'
                    break

    # Generate features for the current timestep
    for i in range(NUM_NODES):
        status = df_nodes.iloc[i]['status']
        if status == 'failure':
            cpu = np.random.uniform(0.9, 1.0)
            mem = np.random.uniform(0.7, 0.9)
            text = "CRITICAL: High CPU Load - Port Flapping"
        elif status == 'impacted':
            cpu = np.random.uniform(0.6, 0.8)
            mem = np.random.uniform(0.5, 0.7)
            text = "WARNING: High Latency Detected - Potential Congestion"
        else: # Normal
            cpu = np.random.uniform(0.1, 0.3)
            mem = np.random.uniform(0.2, 0.4)
            text = "INFO: Nominal Operation - BGP Session Active"

        cpu_usage.append(cpu)
        mem_usage.append(mem)
        status_text.append(text)

    snapshot = {
        'cpu': np.array(cpu_usage),
        'mem': np.array(mem_usage),
        'text': status_text,
        'status': df_nodes['status'].copy().values
    }
    temporal_snapshots.append(snapshot)

print(f"Generated {len(temporal_snapshots)} temporal snapshots for {NUM_NODES} devices.")
print(f"Failure introduced at T={FAILURE_START_TIME} on Node {PATIENT_ZERO_ID} (a {df_nodes.iloc[PATIENT_ZERO_ID]['type']}).")

In [None]:
# ==============================================================================
# Step 1.5: Visualize the Demo Network Topology
# ==============================================================================
# Create a visual representation of the network graph with device types

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Create network visualization
plt.figure(figsize=(14, 10))

# Set up the layout with better spacing
pos = nx.spring_layout(G, k=2, iterations=50, seed=42)

# Define device type colors and shapes
device_colors = {'router': '#FF6B6B', 'switch': '#4ECDC4', 'firewall': '#45B7D1'}
device_shapes = {'router': 's', 'switch': 'o', 'firewall': '^'}  # square, circle, triangle
device_sizes = {'router': 800, 'switch': 600, 'firewall': 700}

# Draw edges first (so they appear behind nodes)
nx.draw_networkx_edges(G, pos, edge_color='lightgray', width=2, alpha=0.6)

# Draw nodes by device type
for device_type in DEVICE_TYPES:
    # Get nodes of this device type
    nodes_of_type = df_nodes[df_nodes['type'] == device_type]['id'].tolist()
    
    if nodes_of_type:
        # Draw nodes of this type
        nx.draw_networkx_nodes(G, pos, 
                             nodelist=nodes_of_type,
                             node_color=device_colors[device_type],
                             node_shape=device_shapes[device_type],
                             node_size=device_sizes[device_type],
                             alpha=0.8,
                             edgecolors='black',
                             linewidths=2)

# Highlight patient zero with special styling
patient_zero_pos = pos[PATIENT_ZERO_ID]
plt.scatter(patient_zero_pos[0], patient_zero_pos[1], 
           s=1000, c='red', marker='*', 
           edgecolors='darkred', linewidths=3,
           alpha=0.9, zorder=10, label='Patient Zero')

# Add node labels
labels = {}
for node_id in G.nodes():
    device_type = df_nodes.iloc[node_id]['type']
    # Create label with node ID and device type abbreviation
    type_abbrev = {'router': 'R', 'switch': 'S', 'firewall': 'F'}[device_type]
    labels[node_id] = f"{node_id}\n({type_abbrev})"

nx.draw_networkx_labels(G, pos, labels, font_size=8, font_weight='bold')

# Create legend
legend_elements = []
for device_type in DEVICE_TYPES:
    legend_elements.append(mpatches.Patch(color=device_colors[device_type], 
                                        label=f'{device_type.capitalize()} ({device_shapes[device_type]})'))

# Add patient zero to legend
legend_elements.append(plt.Line2D([0], [0], marker='*', color='w', 
                                markerfacecolor='red', markersize=15,
                                label='Patient Zero (Failure Source)', 
                                markeredgecolor='darkred', markeredgewidth=2))

plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0.02, 0.98))

plt.title('Demo Network Topology\n20 Devices Connected via Erdős–Rényi Random Graph', 
         fontsize=16, fontweight='bold', pad=20)

# Add network statistics
stats_text = f"""Network Statistics:
• Total Nodes: {NUM_NODES}
• Total Edges: {G.number_of_edges()}
• Average Degree: {2 * G.number_of_edges() / NUM_NODES:.1f}
• Network Diameter: {nx.diameter(G) if nx.is_connected(G) else 'N/A'}
• Clustering Coefficient: {nx.average_clustering(G):.3f}

Device Distribution:"""

for device_type in DEVICE_TYPES:
    count = len(df_nodes[df_nodes['type'] == device_type])
    stats_text += f"\n• {device_type.capitalize()}: {count} devices"

patient_zero_type = df_nodes.iloc[PATIENT_ZERO_ID]['type']
stats_text += f"""

Failure Simulation:
• Patient Zero: Node {PATIENT_ZERO_ID} ({patient_zero_type})
• Failure Start Time: T={FAILURE_START_TIME}
• Total Simulation Time: {NUM_TIMESTEPS} timesteps"""

plt.text(1.02, 0.98, stats_text, transform=plt.gca().transAxes, 
         fontsize=10, verticalalignment='top', 
         bbox=dict(boxstyle='round,pad=0.5', facecolor='lightblue', alpha=0.7))

plt.axis('off')
plt.tight_layout()
plt.show()

# Print adjacency information for patient zero
print(f"\n{'='*60}")
print(f"PATIENT ZERO NETWORK ANALYSIS")
print(f"{'='*60}")
print(f"Patient Zero (Node {PATIENT_ZERO_ID}) is a {df_nodes.iloc[PATIENT_ZERO_ID]['type']}")
print(f"Direct neighbors: {list(G.neighbors(PATIENT_ZERO_ID))}")

neighbor_types = []
for neighbor in G.neighbors(PATIENT_ZERO_ID):
    neighbor_type = df_nodes.iloc[neighbor]['type']
    neighbor_types.append(neighbor_type)
    print(f"  - Node {neighbor}: {neighbor_type}")

print(f"\nNeighbor device types: {set(neighbor_types)}")
print(f"Patient zero degree (connections): {G.degree(PATIENT_ZERO_ID)}")

# Show how failure will spread
print(f"\nFailure Propagation Path:")
print(f"T=0-{FAILURE_START_TIME-1}: All nodes normal")
print(f"T={FAILURE_START_TIME}: Node {PATIENT_ZERO_ID} fails")
print(f"T={FAILURE_START_TIME+1}+: Neighbors become 'impacted':")
for neighbor in G.neighbors(PATIENT_ZERO_ID):
    print(f"  - Node {neighbor} ({df_nodes.iloc[neighbor]['type']}) will show degraded performance")

In [None]:
# ==============================================================================
# Step 2: Feature Engineering with Sentence-Transformers
# ==============================================================================
# Convert text status messages into meaningful embeddings.

@torch.no_grad()
def embed_text_features(snapshots, model_name='all-MiniLM-L6-v2'):
    model = SentenceTransformer(model_name)
    all_texts = [text for snapshot in snapshots for text in snapshot['text']]
    unique_texts = sorted(list(set(all_texts)))
    embeddings = model.encode(unique_texts, convert_to_tensor=True, device='cuda' if torch.cuda.is_available() else 'cpu')
    text_to_embedding = {text: emb for text, emb in zip(unique_texts, embeddings)}

    for snapshot in snapshots:
        snapshot['text_embeddings'] = torch.stack([text_to_embedding[text] for text in snapshot['text']])
    return snapshots

temporal_snapshots = embed_text_features(temporal_snapshots)
TEXT_EMBED_DIM = temporal_snapshots[0]['text_embeddings'].shape[1]
print(f"Text features embedded into vectors of size {TEXT_EMBED_DIM}.")



In [None]:
# ==============================================================================
# Step 3: Create PyTorch Geometric HeteroData Objects
# ==============================================================================
# We convert our list of snapshots into a list of HeteroData objects.

pyg_snapshots = []
node_type_map = {i: df_nodes.iloc[i]['type'] for i in range(NUM_NODES)}

for snapshot in temporal_snapshots:
    data = HeteroData()

    # Node features - separate by device type
    for node_type in DEVICE_TYPES:
        mask = (df_nodes['type'] == node_type).values
        if mask.any():
            numeric_feats = torch.tensor(np.vstack([snapshot['cpu'][mask], snapshot['mem'][mask]]).T, dtype=torch.float32)
            text_feats = snapshot['text_embeddings'][mask]
            data[node_type].x = torch.cat([numeric_feats, text_feats], dim=1)
            data[node_type].node_ids = torch.tensor(df_nodes[mask].id.values) # Keep track of original IDs

    # Edge index - we need to map node IDs to their positions within each node type
    # Create a global node mapping
    global_to_local = {}
    local_to_global = {}
    
    for node_type in DEVICE_TYPES:
        if node_type in data.node_types:
            for local_idx, global_id in enumerate(data[node_type].node_ids):
                global_to_local[int(global_id)] = (node_type, local_idx)
                local_to_global[(node_type, local_idx)] = int(global_id)

    # Create edges for each combination of node types
    for src_type in DEVICE_TYPES:
        for dst_type in DEVICE_TYPES:
            if src_type in data.node_types and dst_type in data.node_types:
                sources, targets = [], []
                
                for u, v in G.edges():
                    # Add edge u->v
                    if u in global_to_local and v in global_to_local:
                        u_type, u_local = global_to_local[u]
                        v_type, v_local = global_to_local[v]
                        
                        if u_type == src_type and v_type == dst_type:
                            sources.append(u_local)
                            targets.append(v_local)
                    
                    # Add edge v->u (bidirectional)
                    if v in global_to_local and u in global_to_local:
                        v_type, v_local = global_to_local[v]
                        u_type, u_local = global_to_local[u]
                        
                        if v_type == src_type and u_type == dst_type:
                            sources.append(v_local)
                            targets.append(u_local)
                
                if sources and targets:
                    edge_index = torch.tensor([sources, targets], dtype=torch.int64)
                    data[src_type, 'connects_to', dst_type].edge_index = edge_index

    # Store labels for our training task
    data.y = torch.tensor(snapshot['cpu'], dtype=torch.float32)
    data.status = snapshot['status'] # For visualization later

    pyg_snapshots.append(data)

print(f"Created {len(pyg_snapshots)} PyG HeteroData snapshots.")
print("\nExample snapshot at T=0:")
print(pyg_snapshots[0])
print(f"Node types in first snapshot: {pyg_snapshots[0].node_types}")
print(f"Edge types in first snapshot: {pyg_snapshots[0].edge_types}")

# Print device type distribution
for node_type in DEVICE_TYPES:
    if node_type in pyg_snapshots[0].node_types:
        print(f"{node_type}: {pyg_snapshots[0][node_type].x.shape[0]} nodes")
    else:
        print(f"{node_type}: 0 nodes")

In [None]:
# ==============================================================================
# Step 4: Define the Simplified THGAT Model
# ==============================================================================

class TemporalHGT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, node_types, metadata):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.node_types = node_types

        # Spatial Layer: HGTConv to handle heterogeneity
        self.hgt_conv = HGTConv(in_channels, hidden_channels, metadata, heads=2)
        
        # Temporal Layer: A GRU for each node type to learn temporal patterns
        self.gru_dict = nn.ModuleDict({
            node_type: nn.GRU(hidden_channels, hidden_channels, batch_first=True)
            for node_type in node_types
        })

        # Output Layer: Predicts the next CPU value
        self.out = Linear(hidden_channels, out_channels)

    def forward(self, snapshot_sequence):
        # Store embeddings over time for each node type
        embeddings_over_time = {node_type: [] for node_type in self.node_types}
        
        for snapshot in snapshot_sequence:
            # Get node features for each type
            x_dict = {}
            for node_type in self.node_types:
                if node_type in snapshot.node_types:
                    x_dict[node_type] = snapshot[node_type].x

            # Get edge index dictionary
            edge_index_dict = {}
            for edge_type in snapshot.edge_types:
                edge_index_dict[edge_type] = snapshot[edge_type].edge_index

            # Spatial message passing
            x_dict_out = self.hgt_conv(x_dict, edge_index_dict)
            
            # Store embeddings for each node type
            for node_type, embeddings in x_dict_out.items():
                embeddings_over_time[node_type].append(embeddings)
        
        # Apply temporal modeling for each node type
        final_embeddings_by_type = {}
        
        for node_type in self.node_types:
            if embeddings_over_time[node_type]:
                # Stack embeddings over time: (num_nodes, seq_len, hidden_dim)
                temporal_input = torch.stack(embeddings_over_time[node_type], dim=1)
                num_nodes, seq_len, hidden_dim = temporal_input.shape
                
                # Apply GRU for this node type
                # Reshape to (num_nodes, seq_len, hidden_dim) - already in correct format
                gru_output, _ = self.gru_dict[node_type](temporal_input)
                
                # Use the final timestep output
                final_embeddings_by_type[node_type] = gru_output[:, -1, :]  # (num_nodes, hidden_dim)
        
        # Combine all node type embeddings in original order
        # We need to reconstruct the original node ordering
        final_embeddings = []
        node_type_positions = {}
        
        # Track position of each node type in the combined embedding
        current_pos = 0
        for node_type in self.node_types:
            if node_type in final_embeddings_by_type:
                node_type_positions[node_type] = (current_pos, current_pos + final_embeddings_by_type[node_type].shape[0])
                current_pos += final_embeddings_by_type[node_type].shape[0]
        
        # Create a tensor to hold all embeddings in original node order
        first_snapshot = snapshot_sequence[0]
        total_nodes = sum(first_snapshot[nt].x.shape[0] for nt in first_snapshot.node_types)
        combined_embeddings = torch.zeros(total_nodes, self.hidden_channels)
        
        # Fill in embeddings by reconstructing original order
        global_idx = 0
        for node_type in self.node_types:
            if node_type in final_embeddings_by_type and node_type in first_snapshot.node_types:
                num_nodes_of_type = final_embeddings_by_type[node_type].shape[0]
                combined_embeddings[global_idx:global_idx + num_nodes_of_type] = final_embeddings_by_type[node_type]
                global_idx += num_nodes_of_type
        
        # Prediction
        pred = self.out(combined_embeddings).squeeze(-1)
        return pred, combined_embeddings

In [None]:
# ==============================================================================
# Step 5: Train the Model
# ==============================================================================

# Get the actual input feature dimensions from the data
sample_snapshot = pyg_snapshots[0]
actual_in_channels = None

for node_type in sample_snapshot.node_types:
    actual_in_channels = sample_snapshot[node_type].x.shape[1]
    break

print(f"Actual input feature dimensions: {actual_in_channels}")
print(f"This includes: 2 numeric features (CPU, memory) + {TEXT_EMBED_DIM} text embedding features")

# Model hyperparameters
in_channels = actual_in_channels  # Use actual feature dimensions
hidden_channels = 32
out_channels = 1  # Predicting CPU usage
num_epochs = 10
learning_rate = 0.01
T = 3  # Temporal window size

# Get metadata from the first snapshot
metadata = pyg_snapshots[0].metadata()
node_types = list(metadata[0])  # Extract node types from metadata tuple

print(f"Node types in metadata: {node_types}")
print(f"Edge types in metadata: {list(metadata[1])}")

# Initialize model
model = TemporalHGT(in_channels, hidden_channels, out_channels, node_types, metadata)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# Training loop
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    num_sequences = 0
    
    # Each sequence of T snapshots
    for i in range(len(pyg_snapshots) - T):
        sequence = pyg_snapshots[i:i+T]
        
        # Forward pass
        pred, embeddings = model(sequence)
        
        # Create target by getting next CPU values
        # We'll predict the CPU usage at the next time step
        next_snapshot_idx = i + T
        next_snapshot = pyg_snapshots[next_snapshot_idx]
        
        # Collect CPU targets from all node types in original order
        cpu_targets = []
        for node_type in node_types:
            if node_type in next_snapshot.node_types:
                cpu_targets.append(next_snapshot[node_type].x[:, 0])  # CPU is first feature
        
        if cpu_targets:
            target = torch.cat(cpu_targets, dim=0)
            
            # Ensure target and prediction have same size
            min_size = min(pred.shape[0], target.shape[0])
            pred_trimmed = pred[:min_size]
            target_trimmed = target[:min_size]
            
            # Compute loss
            loss = criterion(pred_trimmed, target_trimmed)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_sequences += 1
    
    avg_loss = total_loss / max(1, num_sequences)
    print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}')

print("Training completed!")

# Test the model on the last sequence
model.eval()
with torch.no_grad():
    test_sequence = pyg_snapshots[-T:]
    test_pred, test_embeddings = model(test_sequence)
    print(f"Test prediction shape: {test_pred.shape}")
    print(f"Test embeddings shape: {test_embeddings.shape}")
    print(f"Sample predictions: {test_pred[:5].cpu().numpy()}")

In [None]:
# ==============================================================================
# Step 6: Visualize Results with Device Type Preservation
# ==============================================================================

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import numpy as np

# Get final embeddings from the trained model
model.eval()
with torch.no_grad():
    final_sequence = pyg_snapshots[-T:]
    _, final_embeddings = model(final_sequence)

# Convert to numpy for visualization
embeddings_np = final_embeddings.cpu().numpy()

# Create device type labels for coloring
device_labels = []
device_colors = {'router': 'red', 'switch': 'blue', 'firewall': 'green'}
color_list = []

# Reconstruct device type mapping from the final snapshot
final_snapshot = pyg_snapshots[-1]
current_idx = 0

for node_type in node_types:
    if node_type in final_snapshot.node_types:
        num_nodes_of_type = final_snapshot[node_type].x.shape[0]
        device_labels.extend([node_type] * num_nodes_of_type)
        color_list.extend([device_colors[node_type]] * num_nodes_of_type)
        current_idx += num_nodes_of_type

# Apply t-SNE for dimensionality reduction
print("Applying t-SNE for visualization...")
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings_np)-1))
embeddings_2d = tsne.fit_transform(embeddings_np)

# Create visualization
plt.figure(figsize=(12, 8))

# Plot each device type separately for better legend
for device_type, color in device_colors.items():
    mask = np.array(device_labels) == device_type
    if mask.any():
        plt.scatter(embeddings_2d[mask, 0], embeddings_2d[mask, 1], 
                   c=color, label=f'{device_type.capitalize()} Devices', 
                   alpha=0.7, s=60)

plt.title('Network Device Embeddings by Type\n(Temporal Heterogeneous Graph Attention)', 
          fontsize=14, fontweight='bold')
plt.xlabel('t-SNE Dimension 1')
plt.ylabel('t-SNE Dimension 2')
plt.legend()
plt.grid(True, alpha=0.3)

# Add text annotations for interesting points
for i, (x, y) in enumerate(embeddings_2d):
    if i % 10 == 0:  # Annotate every 10th point to avoid clutter
        plt.annotate(f'Node {i}', (x, y), xytext=(5, 5), 
                    textcoords='offset points', fontsize=8, alpha=0.7)

plt.tight_layout()
plt.show()

# Print summary statistics
print("\n" + "="*60)
print("NETWORK FAILURE ANALYSIS SUMMARY")
print("="*60)
print(f"Total nodes analyzed: {len(embeddings_np)}")
print(f"Device type distribution:")
for device_type in device_colors.keys():
    count = device_labels.count(device_type)
    print(f"  - {device_type.capitalize()}: {count} devices")

print(f"\nModel architecture:")
print(f"  - Input features: {in_channels}")
print(f"  - Hidden dimensions: {hidden_channels}")
print(f"  - Temporal window: {T} timesteps")
print(f"  - Node types preserved: {len(node_types)}")

print(f"\nTraining completed with {num_epochs} epochs")
print(f"Final embedding dimension: {embeddings_np.shape}")

# Analyze failure propagation patterns
final_snapshot = pyg_snapshots[-1]
print(f"\nFailure propagation analysis:")
print(f"  - Patient zero was node {PATIENT_ZERO_ID}")
print(f"  - Total simulation timesteps: {NUM_TIMESTEPS}")
print(f"  - Failure started at timestep: {FAILURE_START_TIME}")

# Show CPU usage evolution for different device types
print(f"\nDevice type performance patterns:")
current_idx = 0
for node_type in node_types:
    if node_type in final_snapshot.node_types:
        num_nodes = final_snapshot[node_type].x.shape[0]
        type_embeddings = embeddings_np[current_idx:current_idx + num_nodes]
        avg_norm = np.mean(np.linalg.norm(type_embeddings, axis=1))
        print(f"  - {node_type.capitalize()}: Average embedding norm = {avg_norm:.3f}")
        current_idx += num_nodes

In [None]:
# ==============================================================================
# Step 6.5: Enhanced Clustering Analysis and Alternative Visualizations
# ==============================================================================
# Let's investigate why device type clustering isn't clear and try different approaches

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np

print("="*80)
print("ENHANCED CLUSTERING ANALYSIS")
print("="*80)

# First, let's analyze the embedding space directly
print("1. Raw Embedding Space Analysis:")
print(f"   Embedding shape: {embeddings_np.shape}")
print(f"   Embedding range: [{np.min(embeddings_np):.4f}, {np.max(embeddings_np):.4f}]")

# Calculate inter-class vs intra-class distances
device_type_means = {}
device_type_stds = {}

for device_type in ['router', 'switch', 'firewall']:
    mask = np.array(device_labels) == device_type
    if mask.any():
        type_embeddings = embeddings_np[mask]
        device_type_means[device_type] = np.mean(type_embeddings, axis=0)
        device_type_stds[device_type] = np.std(type_embeddings, axis=0)
        
        # Intra-class distances
        intra_distances = []
        for i in range(len(type_embeddings)):
            for j in range(i+1, len(type_embeddings)):
                intra_distances.append(np.linalg.norm(type_embeddings[i] - type_embeddings[j]))
        
        print(f"   {device_type.capitalize()}:")
        print(f"     - Count: {mask.sum()}")
        print(f"     - Mean embedding norm: {np.linalg.norm(device_type_means[device_type]):.4f}")
        print(f"     - Average intra-class distance: {np.mean(intra_distances):.4f}")

# Inter-class distances
print("\n2. Inter-class Distances:")
device_types = ['router', 'switch', 'firewall']
for i, type1 in enumerate(device_types):
    for j, type2 in enumerate(device_types):
        if i < j and type1 in device_type_means and type2 in device_type_means:
            dist = np.linalg.norm(device_type_means[type1] - device_type_means[type2])
            print(f"   {type1} <-> {type2}: {dist:.4f}")

# Try different dimensionality reduction techniques
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# 1. PCA (preserves global structure)
print("\n3. Trying PCA for better global structure preservation...")
pca = PCA(n_components=2, random_state=42)
embeddings_pca = pca.fit_transform(embeddings_np)

axes[0, 0].set_title('PCA Visualization\n(Preserves Global Variance)')
for device_type, color in device_colors.items():
    mask = np.array(device_labels) == device_type
    if mask.any():
        axes[0, 0].scatter(embeddings_pca[mask, 0], embeddings_pca[mask, 1], 
                          c=color, label=f'{device_type.capitalize()}', alpha=0.7, s=60)
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
print(f"   PCA explained variance ratio: {pca.explained_variance_ratio_}")

# 2. t-SNE with different perplexity
print("\n4. Trying different t-SNE parameters...")
perplexities = [5, 15, 50]
for idx, perplexity in enumerate(perplexities):
    if perplexity < len(embeddings_np):
        tsne_alt = TSNE(n_components=2, random_state=42, perplexity=perplexity, 
                       learning_rate=200, max_iter=1000)
        embeddings_tsne_alt = tsne_alt.fit_transform(embeddings_np)
        
        col_idx = idx + 1
        axes[0, col_idx].set_title(f't-SNE (perplexity={perplexity})')
        for device_type, color in device_colors.items():
            mask = np.array(device_labels) == device_type
            if mask.any():
                axes[0, col_idx].scatter(embeddings_tsne_alt[mask, 0], embeddings_tsne_alt[mask, 1], 
                                        c=color, label=f'{device_type.capitalize()}', alpha=0.7, s=60)
        axes[0, col_idx].legend()
        axes[0, col_idx].grid(True, alpha=0.3)

# 3. Analyze clustering quality with K-means
print("\n5. Clustering Quality Analysis:")
for n_clusters in [2, 3, 4]:
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    cluster_labels = kmeans.fit_predict(embeddings_np)
    silhouette_avg = silhouette_score(embeddings_np, cluster_labels)
    print(f"   K-means with {n_clusters} clusters - Silhouette Score: {silhouette_avg:.4f}")

# 4. Visualize K-means clustering
kmeans_3 = KMeans(n_clusters=3, random_state=42)
cluster_labels = kmeans_3.fit_predict(embeddings_np)

axes[1, 0].set_title('K-means Clustering (k=3)')
scatter = axes[1, 0].scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                            c=cluster_labels, cmap='viridis', alpha=0.7, s=60)
axes[1, 0].grid(True, alpha=0.3)

# 5. Create a confusion matrix between device types and clusters
print("\n6. Device Type vs Cluster Analysis:")
device_type_to_num = {'router': 0, 'switch': 1, 'firewall': 2}
true_labels = [device_type_to_num[label] for label in device_labels]

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(true_labels, cluster_labels)
print("   Confusion Matrix (rows=device_types, cols=clusters):")
print("   Device Types: router(0), switch(1), firewall(2)")
print(f"   {cm}")

# 6. Feature importance analysis
print("\n7. Feature Importance Analysis:")
# Calculate variance across device types for each embedding dimension
feature_variances = []
for dim in range(embeddings_np.shape[1]):
    type_means = []
    for device_type in device_types:
        mask = np.array(device_labels) == device_type
        if mask.any():
            type_means.append(np.mean(embeddings_np[mask, dim]))
    if len(type_means) > 1:
        feature_variances.append(np.var(type_means))
    else:
        feature_variances.append(0)

top_discriminative_dims = np.argsort(feature_variances)[-10:][::-1]
print(f"   Top 10 most discriminative dimensions: {top_discriminative_dims}")
variances_list = [feature_variances[dim] for dim in top_discriminative_dims]
print(f"   Their variances: {[f'{var:.6f}' for var in variances_list]}")

# 7. Visualize using only top discriminative features
if len(top_discriminative_dims) >= 2:
    axes[1, 1].set_title('Top 2 Discriminative Dimensions')
    for device_type, color in device_colors.items():
        mask = np.array(device_labels) == device_type
        if mask.any():
            axes[1, 1].scatter(embeddings_np[mask, top_discriminative_dims[0]], 
                              embeddings_np[mask, top_discriminative_dims[1]], 
                              c=color, label=f'{device_type.capitalize()}', alpha=0.7, s=60)
    axes[1, 1].set_xlabel(f'Dimension {top_discriminative_dims[0]}')
    axes[1, 1].set_ylabel(f'Dimension {top_discriminative_dims[1]}')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)

# 8. Try UMAP if available
try:
    import umap
    print("\n8. Trying UMAP for better local structure preservation...")
    reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=15, min_dist=0.1)
    embeddings_umap = reducer.fit_transform(embeddings_np)
    
    axes[1, 2].set_title('UMAP Visualization')
    for device_type, color in device_colors.items():
        mask = np.array(device_labels) == device_type
        if mask.any():
            axes[1, 2].scatter(embeddings_umap[mask, 0], embeddings_umap[mask, 1], 
                              c=color, label=f'{device_type.capitalize()}', alpha=0.7, s=60)
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
except ImportError:
    print("\n8. UMAP not available (pip install umap-learn to try)")
    axes[1, 2].text(0.5, 0.5, 'UMAP not available\npip install umap-learn', 
                    ha='center', va='center', transform=axes[1, 2].transAxes)
    axes[1, 2].set_title('UMAP (Not Available)')

plt.tight_layout()
plt.show()

# Summary and recommendations
print("\n" + "="*80)
print("CLUSTERING ANALYSIS SUMMARY")
print("="*80)
print("Possible reasons for poor device type clustering:")
print("1. Model may be learning temporal patterns rather than device-type patterns")
print("2. All devices see similar failure propagation, reducing type-specific signatures")
print("3. Text embeddings may dominate over device-type specific features")
print("4. Network topology may be more important than device type for this task")
print("\nRecommendations:")
print("- Try training with device-type specific objectives")
print("- Add device-type classification as an auxiliary task")
print("- Experiment with different loss functions that encourage type separation")
print("- Consider using device type as an explicit feature rather than just for heterogeneous processing")

In [None]:
# ==============================================================================
# Step 6.6: Retrain Model with Device Type Separation Objective
# ==============================================================================
# Add a contrastive loss to encourage device type clustering

import torch.nn.functional as F

class ImprovedTemporalHGT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, node_types, metadata):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.node_types = node_types

        # Spatial Layer: HGTConv to handle heterogeneity
        self.hgt_conv = HGTConv(in_channels, hidden_channels, metadata, heads=2)
        
        # Temporal Layer: A GRU for each node type to learn temporal patterns
        self.gru_dict = nn.ModuleDict({
            node_type: nn.GRU(hidden_channels, hidden_channels, batch_first=True)
            for node_type in node_types
        })

        # Output Layers
        self.out = Linear(hidden_channels, out_channels)  # CPU prediction
        self.device_classifier = Linear(hidden_channels, len(node_types))  # Device type classification

    def forward(self, snapshot_sequence):
        # Same forward pass as before
        embeddings_over_time = {node_type: [] for node_type in self.node_types}
        
        for snapshot in snapshot_sequence:
            x_dict = {}
            for node_type in self.node_types:
                if node_type in snapshot.node_types:
                    x_dict[node_type] = snapshot[node_type].x

            edge_index_dict = {}
            for edge_type in snapshot.edge_types:
                edge_index_dict[edge_type] = snapshot[edge_type].edge_index

            x_dict_out = self.hgt_conv(x_dict, edge_index_dict)
            
            for node_type, embeddings in x_dict_out.items():
                embeddings_over_time[node_type].append(embeddings)
        
        final_embeddings_by_type = {}
        
        for node_type in self.node_types:
            if embeddings_over_time[node_type]:
                temporal_input = torch.stack(embeddings_over_time[node_type], dim=1)
                gru_output, _ = self.gru_dict[node_type](temporal_input)
                final_embeddings_by_type[node_type] = gru_output[:, -1, :]
        
        # Combine embeddings in original order
        first_snapshot = snapshot_sequence[0]
        total_nodes = sum(first_snapshot[nt].x.shape[0] for nt in first_snapshot.node_types)
        combined_embeddings = torch.zeros(total_nodes, self.hidden_channels)
        
        global_idx = 0
        for node_type in self.node_types:
            if node_type in final_embeddings_by_type and node_type in first_snapshot.node_types:
                num_nodes_of_type = final_embeddings_by_type[node_type].shape[0]
                combined_embeddings[global_idx:global_idx + num_nodes_of_type] = final_embeddings_by_type[node_type]
                global_idx += num_nodes_of_type
        
        # Predictions
        cpu_pred = self.out(combined_embeddings).squeeze(-1)
        device_type_logits = self.device_classifier(combined_embeddings)
        
        return cpu_pred, combined_embeddings, device_type_logits

print("="*80)
print("RETRAINING WITH DEVICE TYPE SEPARATION")
print("="*80)

# Create device type labels
device_type_to_idx = {'router': 0, 'switch': 1, 'firewall': 2}
device_type_labels_tensor = torch.tensor([device_type_to_idx[label] for label in device_labels])

# Initialize improved model
improved_model = ImprovedTemporalHGT(in_channels, hidden_channels, out_channels, node_types, metadata)
improved_optimizer = torch.optim.Adam(improved_model.parameters(), lr=learning_rate)
cpu_criterion = nn.MSELoss()
device_criterion = nn.CrossEntropyLoss()

print("Training improved model with multi-task objectives...")

# Training loop with multi-task loss
improved_model.train()
for epoch in range(15):  # More epochs for better convergence
    total_cpu_loss = 0
    total_device_loss = 0
    total_combined_loss = 0
    num_sequences = 0
    
    for i in range(len(pyg_snapshots) - T):
        sequence = pyg_snapshots[i:i+T]
        
        # Forward pass
        cpu_pred, embeddings, device_logits = improved_model(sequence)
        
        # CPU prediction loss
        next_snapshot_idx = i + T
        next_snapshot = pyg_snapshots[next_snapshot_idx]
        
        cpu_targets = []
        for node_type in node_types:
            if node_type in next_snapshot.node_types:
                cpu_targets.append(next_snapshot[node_type].x[:, 0])
        
        if cpu_targets:
            cpu_target = torch.cat(cpu_targets, dim=0)
            min_size = min(cpu_pred.shape[0], cpu_target.shape[0])
            
            cpu_loss = cpu_criterion(cpu_pred[:min_size], cpu_target[:min_size])
            device_loss = device_criterion(device_logits[:min_size], device_type_labels_tensor[:min_size])
            
            # Combined loss with weighting
            combined_loss = cpu_loss + 0.5 * device_loss  # Weight device classification
            
            improved_optimizer.zero_grad()
            combined_loss.backward()
            improved_optimizer.step()
            
            total_cpu_loss += cpu_loss.item()
            total_device_loss += device_loss.item()
            total_combined_loss += combined_loss.item()
            num_sequences += 1
    
    if num_sequences > 0:
        avg_cpu_loss = total_cpu_loss / num_sequences
        avg_device_loss = total_device_loss / num_sequences
        avg_combined_loss = total_combined_loss / num_sequences
        
        print(f'Epoch {epoch+1}/15: CPU Loss: {avg_cpu_loss:.4f}, '
              f'Device Loss: {avg_device_loss:.4f}, Combined: {avg_combined_loss:.4f}')

print("\nImproved training completed!")

# Get new embeddings
improved_model.eval()
with torch.no_grad():
    final_sequence = pyg_snapshots[-T:]
    _, improved_embeddings, improved_device_logits = improved_model(final_sequence)

improved_embeddings_np = improved_embeddings.cpu().numpy()

# Evaluate device type classification accuracy
device_predictions = torch.argmax(improved_device_logits, dim=1)
device_accuracy = (device_predictions == device_type_labels_tensor).float().mean()
print(f"\nDevice type classification accuracy: {device_accuracy:.3f}")

# Visualize improved embeddings
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original embeddings
tsne_orig = TSNE(n_components=2, random_state=42, perplexity=min(15, len(embeddings_np)-1))
embeddings_2d_orig = tsne_orig.fit_transform(embeddings_np)

axes[0].set_title('Original Model Embeddings')
for device_type, color in device_colors.items():
    mask = np.array(device_labels) == device_type
    if mask.any():
        axes[0].scatter(embeddings_2d_orig[mask, 0], embeddings_2d_orig[mask, 1], 
                       c=color, label=f'{device_type.capitalize()}', alpha=0.7, s=60)
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Improved embeddings
tsne_improved = TSNE(n_components=2, random_state=42, perplexity=min(15, len(improved_embeddings_np)-1))
embeddings_2d_improved = tsne_improved.fit_transform(improved_embeddings_np)

axes[1].set_title('Improved Model Embeddings\n(with Device Type Objective)')
for device_type, color in device_colors.items():
    mask = np.array(device_labels) == device_type
    if mask.any():
        axes[1].scatter(embeddings_2d_improved[mask, 0], embeddings_2d_improved[mask, 1], 
                       c=color, label=f'{device_type.capitalize()}', alpha=0.7, s=60)
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# PCA comparison
pca_improved = PCA(n_components=2, random_state=42)
embeddings_pca_improved = pca_improved.fit_transform(improved_embeddings_np)

axes[2].set_title('Improved Model PCA')
for device_type, color in device_colors.items():
    mask = np.array(device_labels) == device_type
    if mask.any():
        axes[2].scatter(embeddings_pca_improved[mask, 0], embeddings_pca_improved[mask, 1], 
                       c=color, label=f'{device_type.capitalize()}', alpha=0.7, s=60)
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Compare clustering quality
from sklearn.metrics import adjusted_rand_score
kmeans_improved = KMeans(n_clusters=3, random_state=42)
improved_clusters = kmeans_improved.fit_predict(improved_embeddings_np)
improved_silhouette = silhouette_score(improved_embeddings_np, improved_clusters)

print(f"\nClustering Quality Comparison:")
print(f"Original model silhouette score: {silhouette_score(embeddings_np, KMeans(n_clusters=3, random_state=42).fit_predict(embeddings_np)):.4f}")
print(f"Improved model silhouette score: {improved_silhouette:.4f}")

# Device type vs cluster alignment
improved_ari = adjusted_rand_score(true_labels, improved_clusters)
original_ari = adjusted_rand_score(true_labels, KMeans(n_clusters=3, random_state=42).fit_predict(embeddings_np))

print(f"\nDevice Type - Cluster Alignment (Adjusted Rand Index):")
print(f"Original model ARI: {original_ari:.4f}")
print(f"Improved model ARI: {improved_ari:.4f}")
print("(Higher ARI = better alignment between device types and clusters)")

print("\n" + "="*80)

In [None]:
# ==============================================================================
# Step 7: Analyze Failed Node Deviation from Average Embedding
# ==============================================================================

import numpy as np
from scipy.spatial.distance import euclidean, cosine

print("="*80)
print("FAILED NODE EMBEDDING ANALYSIS")
print("="*80)

# Find the patient zero node in the final embeddings
# We need to map from global node ID to the position in the final embeddings
patient_zero_embedding_idx = None
current_idx = 0

# Reconstruct the node order to find patient zero's position
for node_type in node_types:
    if node_type in final_snapshot.node_types:
        # Get the node IDs for this device type
        node_ids = final_snapshot[node_type].node_ids
        num_nodes_of_type = len(node_ids)
        
        # Check if patient zero is in this device type
        for local_idx, global_id in enumerate(node_ids):
            if global_id.item() == PATIENT_ZERO_ID:
                patient_zero_embedding_idx = current_idx + local_idx
                patient_zero_type = node_type
                break
        
        current_idx += num_nodes_of_type
        if patient_zero_embedding_idx is not None:
            break

print(f"Patient Zero Analysis:")
print(f"  - Node ID: {PATIENT_ZERO_ID}")
print(f"  - Device Type: {patient_zero_type}")
print(f"  - Position in embeddings: {patient_zero_embedding_idx}")

if patient_zero_embedding_idx is not None:
    # Get the patient zero embedding
    patient_zero_emb = embeddings_np[patient_zero_embedding_idx]
    
    # Calculate statistics for all nodes
    all_embeddings_mean = np.mean(embeddings_np, axis=0)
    all_embeddings_std = np.std(embeddings_np, axis=0)
    
    # Calculate statistics for same device type
    same_type_mask = np.array(device_labels) == patient_zero_type
    same_type_embeddings = embeddings_np[same_type_mask]
    same_type_mean = np.mean(same_type_embeddings, axis=0)
    same_type_std = np.std(same_type_embeddings, axis=0)
    
    # Calculate deviations
    deviation_from_all = patient_zero_emb - all_embeddings_mean
    deviation_from_same_type = patient_zero_emb - same_type_mean
    
    # Calculate distances
    euclidean_dist_all = euclidean(patient_zero_emb, all_embeddings_mean)
    euclidean_dist_same_type = euclidean(patient_zero_emb, same_type_mean)
    cosine_dist_all = cosine(patient_zero_emb, all_embeddings_mean)
    cosine_dist_same_type = cosine(patient_zero_emb, same_type_mean)
    
    # Calculate z-scores (how many standard deviations away)
    z_scores_all = np.abs(deviation_from_all) / (all_embeddings_std + 1e-8)
    z_scores_same_type = np.abs(deviation_from_same_type) / (same_type_std + 1e-8)
    
    print(f"\nDistance Analysis:")
    print(f"  Euclidean distance from all nodes average: {euclidean_dist_all:.4f}")
    print(f"  Euclidean distance from same type average: {euclidean_dist_same_type:.4f}")
    print(f"  Cosine distance from all nodes average: {cosine_dist_all:.4f}")
    print(f"  Cosine distance from same type average: {cosine_dist_same_type:.4f}")
    
    print(f"\nDeviation Analysis:")
    print(f"  Max absolute deviation from all nodes: {np.max(np.abs(deviation_from_all)):.4f}")
    print(f"  Max absolute deviation from same type: {np.max(np.abs(deviation_from_same_type)):.4f}")
    print(f"  Average z-score vs all nodes: {np.mean(z_scores_all):.4f}")
    print(f"  Average z-score vs same type: {np.mean(z_scores_same_type):.4f}")
    print(f"  Max z-score vs all nodes: {np.max(z_scores_all):.4f}")
    print(f"  Max z-score vs same type: {np.max(z_scores_same_type):.4f}")
    
    # Find dimensions with highest deviations
    top_deviating_dims_all = np.argsort(np.abs(deviation_from_all))[-5:][::-1]
    top_deviating_dims_same = np.argsort(np.abs(deviation_from_same_type))[-5:][::-1]
    
    print(f"\nTop 5 dimensions with highest deviation from all nodes:")
    for i, dim in enumerate(top_deviating_dims_all):
        print(f"  {i+1}. Dimension {dim}: deviation = {deviation_from_all[dim]:.4f}, z-score = {z_scores_all[dim]:.4f}")
    
    print(f"\nTop 5 dimensions with highest deviation from same device type:")
    for i, dim in enumerate(top_deviating_dims_same):
        print(f"  {i+1}. Dimension {dim}: deviation = {deviation_from_same_type[dim]:.4f}, z-score = {z_scores_same_type[dim]:.4f}")
    
    # Visualization of deviations
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot 1: Embedding magnitude comparison
    axes[0, 0].bar(['Patient Zero', 'All Avg', 'Same Type Avg'], 
                   [np.linalg.norm(patient_zero_emb), 
                    np.linalg.norm(all_embeddings_mean), 
                    np.linalg.norm(same_type_mean)],
                   color=['red', 'blue', 'green'])
    axes[0, 0].set_title('Embedding Vector Magnitudes')
    axes[0, 0].set_ylabel('L2 Norm')
    
    # Plot 2: Distribution of z-scores vs all nodes
    axes[0, 1].hist(z_scores_all, bins=20, alpha=0.7, color='blue', edgecolor='black')
    axes[0, 1].axvline(np.mean(z_scores_all), color='red', linestyle='--', 
                       label=f'Mean: {np.mean(z_scores_all):.2f}')
    axes[0, 1].set_title('Z-Scores vs All Nodes Average')
    axes[0, 1].set_xlabel('Z-Score')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].legend()
    
    # Plot 3: Distribution of z-scores vs same type
    axes[1, 0].hist(z_scores_same_type, bins=20, alpha=0.7, color='green', edgecolor='black')
    axes[1, 0].axvline(np.mean(z_scores_same_type), color='red', linestyle='--', 
                       label=f'Mean: {np.mean(z_scores_same_type):.2f}')
    axes[1, 0].set_title('Z-Scores vs Same Device Type Average')
    axes[1, 0].set_xlabel('Z-Score')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].legend()
    
    # Plot 4: Top deviating dimensions
    top_dims = top_deviating_dims_all[:10]  # Top 10 for visibility
    axes[1, 1].bar(range(len(top_dims)), [np.abs(deviation_from_all[dim]) for dim in top_dims])
    axes[1, 1].set_title('Top 10 Deviating Dimensions (vs All Nodes)')
    axes[1, 1].set_xlabel('Dimension Rank')
    axes[1, 1].set_ylabel('Absolute Deviation')
    axes[1, 1].set_xticks(range(len(top_dims)))
    axes[1, 1].set_xticklabels([f'D{dim}' for dim in top_dims], rotation=45)
    
    plt.tight_layout()
    plt.show()
    
    # Comparison with other nodes of same type
    print(f"\nComparison with other {patient_zero_type} devices:")
    same_type_indices = np.where(same_type_mask)[0]
    same_type_distances = []
    
    for idx in same_type_indices:
        if idx != patient_zero_embedding_idx:
            dist = euclidean(patient_zero_emb, embeddings_np[idx])
            same_type_distances.append(dist)
    
    if same_type_distances:
        print(f"  Average distance to other {patient_zero_type}s: {np.mean(same_type_distances):.4f}")
        print(f"  Min distance to other {patient_zero_type}s: {np.min(same_type_distances):.4f}")
        print(f"  Max distance to other {patient_zero_type}s: {np.max(same_type_distances):.4f}")
        
        # Rank among same type devices
        all_same_type_dists = [euclidean(same_type_mean, embeddings_np[idx]) for idx in same_type_indices]
        patient_zero_dist_rank = sorted(all_same_type_dists, reverse=True).index(
            euclidean(same_type_mean, patient_zero_emb)) + 1
        print(f"  Patient zero ranks #{patient_zero_dist_rank} out of {len(same_type_indices)} {patient_zero_type} devices")
        print(f"  (1 = most deviant from type average)")

else:
    print("Error: Could not find patient zero in the embeddings!")

print("\n" + "="*80)

In [None]:
# ==============================================================================
# Patient Zero Dimensional Comparison Visualization
# ==============================================================================

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

print("="*80)
print("PATIENT ZERO DIMENSIONAL COMPARISON")
print("="*80)

# Get patient zero information
patient_zero_emb = embeddings_np[patient_zero_embedding_idx]
same_type_embeddings = embeddings_np[same_type_mask]
all_embeddings = embeddings_np

# Create comprehensive visualization
fig = plt.figure(figsize=(20, 16))

# 1. Radar Chart - Top Discriminative Dimensions
ax1 = plt.subplot(3, 3, 1, projection='polar')
top_dims_subset = top_discriminative_dims[:8]  # Use top 8 for cleaner visualization
angles = np.linspace(0, 2*np.pi, len(top_dims_subset), endpoint=False)

# Normalize values to [0, 1] for radar chart
patient_zero_values = patient_zero_emb[top_dims_subset]
same_type_mean_values = same_type_mean[top_dims_subset]
all_mean_values = all_embeddings_mean[top_dims_subset]

# Normalize to 0-1 range
all_values = np.concatenate([patient_zero_values, same_type_mean_values, all_mean_values])
min_val, max_val = np.min(all_values), np.max(all_values)
if max_val != min_val:
    patient_zero_norm = (patient_zero_values - min_val) / (max_val - min_val)
    same_type_norm = (same_type_mean_values - min_val) / (max_val - min_val)
    all_mean_norm = (all_mean_values - min_val) / (max_val - min_val)
else:
    patient_zero_norm = np.ones_like(patient_zero_values) * 0.5
    same_type_norm = np.ones_like(same_type_mean_values) * 0.5
    all_mean_norm = np.ones_like(all_mean_values) * 0.5

ax1.plot(angles, patient_zero_norm, 'r-', linewidth=3, label='Patient Zero', marker='o')
ax1.plot(angles, same_type_norm, 'g--', linewidth=2, label='Switch Average', marker='s')
ax1.plot(angles, all_mean_norm, 'b:', linewidth=2, label='All Nodes Average', marker='^')
ax1.fill(angles, patient_zero_norm, 'red', alpha=0.2)
ax1.set_ylim(0, 1)
ax1.set_title('Top Discriminative Dimensions\n(Radar Chart)', pad=20)
ax1.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))

# Set dimension labels
dim_labels = [f'Dim {dim}' for dim in top_dims_subset]
ax1.set_thetagrids(angles * 180/np.pi, dim_labels)

# 2. Dimensional Deviation Bar Chart
ax2 = plt.subplot(3, 3, 2)
top_10_dims = top_deviating_dims_all[:10]
deviations = np.abs(patient_zero_emb[top_10_dims] - all_embeddings_mean[top_10_dims])
colors = ['red' if dev > 2*all_embeddings_std[dim] else 'orange' if dev > all_embeddings_std[dim] else 'yellow' 
          for dev, dim in zip(deviations, top_10_dims)]

bars = ax2.bar(range(len(top_10_dims)), deviations, color=colors, alpha=0.7)
ax2.set_xlabel('Top Deviating Dimensions')
ax2.set_ylabel('Absolute Deviation')
ax2.set_title('Patient Zero Deviations\n(vs All Nodes Mean)')
ax2.set_xticks(range(len(top_10_dims)))
ax2.set_xticklabels([f'D{dim}' for dim in top_10_dims], rotation=45)

# Add value labels on bars
for bar, dev in zip(bars, deviations):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
             f'{dev:.3f}', ha='center', va='bottom', fontsize=8)

# 3. 3D Scatter Plot - Top 3 Discriminative Dimensions
ax3 = plt.subplot(3, 3, 3, projection='3d')
top_3_dims = top_discriminative_dims[:3]

# All nodes of same type (switches)
switch_embeddings = embeddings_np[same_type_mask]
ax3.scatter(switch_embeddings[:, top_3_dims[0]], 
           switch_embeddings[:, top_3_dims[1]], 
           switch_embeddings[:, top_3_dims[2]], 
           c='lightblue', alpha=0.6, s=60, label='Other Switches')

# Patient zero
ax3.scatter(patient_zero_emb[top_3_dims[0]], 
           patient_zero_emb[top_3_dims[1]], 
           patient_zero_emb[top_3_dims[2]], 
           c='red', s=200, marker='*', label='Patient Zero', edgecolors='black')

# Switch type mean
ax3.scatter(same_type_mean[top_3_dims[0]], 
           same_type_mean[top_3_dims[1]], 
           same_type_mean[top_3_dims[2]], 
           c='green', s=150, marker='s', label='Switch Mean', edgecolors='black')

ax3.set_xlabel(f'Dimension {top_3_dims[0]}')
ax3.set_ylabel(f'Dimension {top_3_dims[1]}')
ax3.set_zlabel(f'Dimension {top_3_dims[2]}')
ax3.set_title('3D View: Top 3 Discriminative Dims')
ax3.legend()

# 4. Distance Heatmap
ax4 = plt.subplot(3, 3, 4)
switch_indices = np.where(same_type_mask)[0]
n_switches = len(switch_indices)

# Calculate pairwise distances between all switches
distance_matrix = np.zeros((n_switches, n_switches))
for i in range(n_switches):
    for j in range(n_switches):
        distance_matrix[i, j] = np.linalg.norm(
            embeddings_np[switch_indices[i]] - embeddings_np[switch_indices[j]]
        )

# Find patient zero position in switch indices
patient_zero_switch_pos = np.where(switch_indices == patient_zero_embedding_idx)[0][0]

im = ax4.imshow(distance_matrix, cmap='viridis', interpolation='nearest')
ax4.set_title('Switch-to-Switch Distance Matrix')
ax4.set_xlabel('Switch Index')
ax4.set_ylabel('Switch Index')

# Highlight patient zero row and column
ax4.axhline(y=patient_zero_switch_pos, color='red', linewidth=2, alpha=0.7)
ax4.axvline(x=patient_zero_switch_pos, color='red', linewidth=2, alpha=0.7)
ax4.text(0.02, 0.98, f'Patient Zero: Row/Col {patient_zero_switch_pos}', 
         transform=ax4.transAxes, color='red', fontweight='bold', 
         bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.colorbar(im, ax=ax4, label='Euclidean Distance')

# 5. Dimensional Profile Comparison
ax5 = plt.subplot(3, 3, 5)
dim_indices = np.arange(embeddings_np.shape[1])
ax5.plot(dim_indices, patient_zero_emb, 'r-', linewidth=2, marker='o', 
         markersize=3, label='Patient Zero', alpha=0.8)
ax5.plot(dim_indices, same_type_mean, 'g--', linewidth=2, marker='s', 
         markersize=3, label='Switch Average', alpha=0.8)
ax5.fill_between(dim_indices, 
                 same_type_mean - same_type_std, 
                 same_type_mean + same_type_std, 
                 alpha=0.3, color='green', label='±1 STD (Switches)')
ax5.set_xlabel('Embedding Dimension')
ax5.set_ylabel('Activation Value')
ax5.set_title('Full Dimensional Profile Comparison')
ax5.legend()
ax5.grid(True, alpha=0.3)

# 6. Z-Score Distribution
ax6 = plt.subplot(3, 3, 6)
z_scores = (patient_zero_emb - same_type_mean) / same_type_std
colors = ['red' if abs(z) > 2 else 'orange' if abs(z) > 1 else 'green' for z in z_scores]
bars = ax6.bar(dim_indices, z_scores, color=colors, alpha=0.7)
ax6.axhline(y=2, color='red', linestyle='--', alpha=0.7, label='±2σ threshold')
ax6.axhline(y=-2, color='red', linestyle='--', alpha=0.7)
ax6.axhline(y=1, color='orange', linestyle='--', alpha=0.7, label='±1σ threshold')
ax6.axhline(y=-1, color='orange', linestyle='--', alpha=0.7)
ax6.set_xlabel('Embedding Dimension')
ax6.set_ylabel('Z-Score vs Switch Mean')
ax6.set_title('Patient Zero Z-Score Distribution')
ax6.legend()
ax6.grid(True, alpha=0.3)

# 7. Box Plot Comparison for Top Deviating Dimensions
ax7 = plt.subplot(3, 3, 7)
top_5_dims = top_deviating_dims_all[:5]
box_data = []
patient_zero_values_top5 = []

for dim in top_5_dims:
    box_data.append(switch_embeddings[:, dim])
    patient_zero_values_top5.append(patient_zero_emb[dim])

bp = ax7.boxplot(box_data, labels=[f'D{dim}' for dim in top_5_dims], patch_artist=True)
for patch in bp['boxes']:
    patch.set_facecolor('lightblue')
    patch.set_alpha(0.7)

# Overlay patient zero values
for i, (dim, val) in enumerate(zip(top_5_dims, patient_zero_values_top5)):
    ax7.scatter(i+1, val, color='red', s=100, marker='*', zorder=10, 
               edgecolors='black', linewidth=1)

ax7.set_title('Top 5 Deviating Dimensions\nBox Plot Comparison')
ax7.set_ylabel('Activation Value')
ax7.grid(True, alpha=0.3)

# Add legend
red_star = plt.Line2D([0], [0], marker='*', color='w', markerfacecolor='red', 
                      markersize=10, markeredgecolor='black', label='Patient Zero')
ax7.legend(handles=[red_star], loc='upper right')

# 8. Cumulative Distance Distribution
ax8 = plt.subplot(3, 3, 8)
# Calculate distances from patient zero to all other switches
distances_to_patient_zero = []
for i, switch_idx in enumerate(switch_indices):
    if switch_idx != patient_zero_embedding_idx:
        dist = np.linalg.norm(embeddings_np[switch_idx] - patient_zero_emb)
        distances_to_patient_zero.append(dist)

# Calculate distances between all other switch pairs (excluding patient zero)
other_switch_distances = []
other_switch_indices = [idx for idx in switch_indices if idx != patient_zero_embedding_idx]
for i in range(len(other_switch_indices)):
    for j in range(i+1, len(other_switch_indices)):
        dist = np.linalg.norm(embeddings_np[other_switch_indices[i]] - 
                             embeddings_np[other_switch_indices[j]])
        other_switch_distances.append(dist)

# Plot histograms
ax8.hist(distances_to_patient_zero, bins=10, alpha=0.7, color='red', 
         label=f'Distances to Patient Zero\n(mean: {np.mean(distances_to_patient_zero):.3f})', 
         density=True)
ax8.hist(other_switch_distances, bins=10, alpha=0.7, color='blue', 
         label=f'Inter-Switch Distances\n(mean: {np.mean(other_switch_distances):.3f})', 
         density=True)
ax8.set_xlabel('Euclidean Distance')
ax8.set_ylabel('Density')
ax8.set_title('Distance Distribution Comparison')
ax8.legend()
ax8.grid(True, alpha=0.3)

# 9. Summary Statistics Table
ax9 = plt.subplot(3, 3, 9)
ax9.axis('off')

# Create summary statistics
stats_data = [
    ['Metric', 'Patient Zero', 'Switch Average', 'Deviation'],
    ['Embedding Norm', f'{np.linalg.norm(patient_zero_emb):.3f}', 
     f'{np.linalg.norm(same_type_mean):.3f}', 
     f'{np.linalg.norm(patient_zero_emb) - np.linalg.norm(same_type_mean):.3f}'],
    ['Mean Activation', f'{np.mean(patient_zero_emb):.3f}', 
     f'{np.mean(same_type_mean):.3f}', 
     f'{np.mean(patient_zero_emb) - np.mean(same_type_mean):.3f}'],
    ['Max Activation', f'{np.max(patient_zero_emb):.3f}', 
     f'{np.max(same_type_mean):.3f}', 
     f'{np.max(patient_zero_emb) - np.max(same_type_mean):.3f}'],
    ['Min Activation', f'{np.min(patient_zero_emb):.3f}', 
     f'{np.min(same_type_mean):.3f}', 
     f'{np.min(patient_zero_emb) - np.min(same_type_mean):.3f}'],
    ['Std Deviation', f'{np.std(patient_zero_emb):.3f}', 
     f'{np.mean(same_type_std):.3f}', 
     f'{np.std(patient_zero_emb) - np.mean(same_type_std):.3f}'],
    ['Avg |Z-Score|', f'{np.mean(np.abs(z_scores)):.3f}', '-', '-'],
    ['Max |Z-Score|', f'{np.max(np.abs(z_scores)):.3f}', '-', '-'],
    ['Dims > 2σ', f'{np.sum(np.abs(z_scores) > 2)}', '-', '-'],
    ['Distance Rank', f'{patient_zero_dist_rank}/{len(switch_indices)}', '-', '-']
]

# Create table
table = ax9.table(cellText=stats_data[1:], colLabels=stats_data[0], 
                  cellLoc='center', loc='center')
table.auto_set_font_size(False)
table.set_fontsize(8)
table.scale(1, 2)

# Color code the deviation column
for i in range(1, len(stats_data)):
    if i <= 6:  # Only for numeric comparisons
        try:
            dev_val = float(stats_data[i][3])
            if abs(dev_val) > 0.1:
                table[(i, 3)].set_facecolor('#ffcccc')  # Light red for high deviation
            elif abs(dev_val) > 0.05:
                table[(i, 3)].set_facecolor('#ffffcc')  # Light yellow for medium deviation
        except ValueError:
            pass

ax9.set_title('Patient Zero Summary Statistics', pad=20)

plt.tight_layout()
plt.show()

# Print detailed analysis
print(f"\nDetailed Dimensional Analysis:")
print(f"Patient Zero Embedding Index: {patient_zero_embedding_idx}")
print(f"Patient Zero Device Type: {patient_zero_type}")
print(f"Total Switch Devices: {len(switch_indices)}")
print(f"\nTop 10 Most Deviating Dimensions (vs Switch Average):")
for i, dim in enumerate(top_deviating_dims_same[:10]):
    z_score = z_scores[dim]
    deviation = patient_zero_emb[dim] - same_type_mean[dim]
    print(f"  {i+1:2d}. Dimension {dim:2d}: "
          f"PZ={patient_zero_emb[dim]:7.3f}, "
          f"Avg={same_type_mean[dim]:7.3f}, "
          f"Dev={deviation:7.3f}, "
          f"Z={z_score:6.2f}")

print(f"\nDistance Comparison:")
print(f"  Average distance to other switches: {np.mean(distances_to_patient_zero):.3f}")
print(f"  Average distance between other switches: {np.mean(other_switch_distances):.3f}")
print(f"  Patient zero is {np.mean(distances_to_patient_zero)/np.mean(other_switch_distances):.2f}x more distant")

print("\n" + "="*80)

# Heterogeneous Graph Attention Network (HGAT) Architecture Explanation

## Overview
The **Temporal Heterogeneous Graph Attention Network (T-HGAT)** in this notebook combines spatial heterogeneous graph processing with temporal sequence modeling for network failure analysis. It's designed to handle multiple device types (routers, switches, firewalls) and learn both their spatial relationships and temporal patterns.

## Architecture Components

### 1. **Input Layer Structure**
```
Input Features per Node:
├── Numeric Features (2D): [CPU Usage, Memory Usage]
├── Text Embeddings (384D): Sentence-BERT encoded status messages
└── Total Input Dimension: 386D per node
```

### 2. **Heterogeneous Node Types**
The model treats different device types as separate node types:
- **Router Nodes**: Network routing devices
- **Switch Nodes**: Network switching devices  
- **Firewall Nodes**: Security filtering devices

Each node type has its own feature space and learning parameters.

### 3. **Spatial Processing Layer: HGTConv**
The **Heterogeneous Graph Transformer (HGT)** layer handles different node and edge types:

#### Key Features:
- **Multi-head Attention**: 2 attention heads for learning different relationship patterns
- **Type-specific Transformations**: Different weight matrices for each node type
- **Heterogeneous Edges**: Handles connections between different device types
- **Message Passing**: Aggregates information from neighboring nodes of potentially different types

#### Mathematical Formulation:
```
For each node type τ and relation type φ:
H^(l+1)[τ] = Aggregate({
    Attention(Q^(l)[τ], K^(l)[φ], V^(l)[φ]) 
    for all φ connected to τ
})
```

### 4. **Temporal Processing Layer: Type-specific GRU**
Each device type has its own **Gated Recurrent Unit (GRU)**:

#### Purpose:
- **Temporal Sequence Learning**: Captures how device states evolve over time
- **Type-specific Patterns**: Different device types may have different temporal behaviors
- **Memory Retention**: GRU maintains information about past states

#### Architecture:
```
For each device type τ:
GRU_τ: (hidden_size=32, batch_first=True)
Input: [batch_size, sequence_length, hidden_dim]
Output: [batch_size, hidden_dim] (final timestep)
```

### 5. **Output Layers**

#### Original Model:
- **CPU Prediction**: Linear(32 → 1) for predicting next CPU usage

#### Improved Multi-task Model:
- **CPU Prediction**: Linear(32 → 1) for failure impact prediction
- **Device Classification**: Linear(32 → 3) for device type classification
- **Combined Loss**: CPU_loss + 0.5 × Device_loss

In [None]:
# ==============================================================================
# HGAT Architecture Visualization and Flow Diagram
# ==============================================================================

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import FancyBboxPatch, Rectangle
import numpy as np

print("="*80)
print("HETEROGENEOUS GRAPH ATTENTION NETWORK (HGAT) ARCHITECTURE")
print("="*80)

# Create comprehensive architecture diagram
fig = plt.figure(figsize=(18, 14))

# Define colors for different components
colors = {
    'input': '#E8F4FD',
    'spatial': '#B3E5FC', 
    'temporal': '#81C784',
    'output': '#FFAB91',
    'router': '#FF6B6B',
    'switch': '#4ECDC4', 
    'firewall': '#45B7D1',
    'attention': '#FFE082'
}

# Main architecture flow
ax1 = plt.subplot(2, 2, (1, 2))
ax1.set_xlim(0, 10)
ax1.set_ylim(0, 12)
ax1.axis('off')
ax1.set_title('Temporal Heterogeneous Graph Attention Network Architecture', 
             fontsize=16, fontweight='bold', pad=20)

# Draw the main components
components = [
    # Input Layer
    {'name': 'Input Features\n(Per Node)', 'pos': (1, 10), 'size': (1.5, 1.5), 'color': colors['input']},
    {'name': 'CPU: 0.2\nMem: 0.3\nText: [384D]', 'pos': (0.5, 8.5), 'size': (1, 1), 'color': colors['input']},
    {'name': 'Router\nNodes', 'pos': (3, 10.5), 'size': (1, 0.8), 'color': colors['router']},
    {'name': 'Switch\nNodes', 'pos': (3, 9.5), 'size': (1, 0.8), 'color': colors['switch']},
    {'name': 'Firewall\nNodes', 'pos': (3, 8.5), 'size': (1, 0.8), 'color': colors['firewall']},
    
    # Spatial Processing
    {'name': 'HGTConv Layer\n(Multi-head Attention)', 'pos': (5.5, 9.5), 'size': (2, 1.5), 'color': colors['spatial']},
    {'name': '2 Attention Heads\nType-specific Weights', 'pos': (5.5, 8), 'size': (2, 0.8), 'color': colors['attention']},
    
    # Temporal Processing  
    {'name': 'GRU_router', 'pos': (5, 6), 'size': (1.2, 0.8), 'color': colors['temporal']},
    {'name': 'GRU_switch', 'pos': (6.5, 6), 'size': (1.2, 0.8), 'color': colors['temporal']},
    {'name': 'GRU_firewall', 'pos': (8, 6), 'size': (1.2, 0.8), 'color': colors['temporal']},
    
    # Output
    {'name': 'CPU Prediction\nLinear(32→1)', 'pos': (5.5, 3.5), 'size': (1.8, 1), 'color': colors['output']},
    {'name': 'Device Classification\nLinear(32→3)', 'pos': (7.5, 3.5), 'size': (1.8, 1), 'color': colors['output']},
    
    # Final embedding
    {'name': 'Node Embeddings\n[N × 32D]', 'pos': (6.5, 1.5), 'size': (2, 1), 'color': colors['output']}
]

# Draw components
for comp in components:
    bbox = FancyBboxPatch(
        comp['pos'], comp['size'][0], comp['size'][1],
        boxstyle="round,pad=0.1", 
        facecolor=comp['color'],
        edgecolor='black',
        linewidth=1.5
    )
    ax1.add_patch(bbox)
    ax1.text(comp['pos'][0] + comp['size'][0]/2, comp['pos'][1] + comp['size'][1]/2, 
             comp['name'], ha='center', va='center', fontsize=9, fontweight='bold')

# Draw arrows showing data flow
arrows = [
    # Input to node types
    ((2.5, 10.2), (3, 10.2)),
    ((2.5, 9.5), (3, 9.5)),
    ((2.5, 8.8), (3, 8.8)),
    
    # Node types to HGT
    ((4, 10.5), (5.5, 10.2)),
    ((4, 9.5), (5.5, 9.5)),
    ((4, 8.5), (5.5, 8.8)),
    
    # HGT to GRUs
    ((6.5, 8), (5.5, 6.8)),
    ((6.5, 8), (7, 6.8)),
    ((6.5, 8), (8.5, 6.8)),
    
    # GRUs to outputs
    ((6, 5.2), (6.2, 4.5)),
    ((7.5, 5.2), (8.2, 4.5)),
    
    # To final embedding
    ((6.5, 3.5), (7, 2.5)),
    ((8.5, 3.5), (8, 2.5))
]

for start, end in arrows:
    ax1.annotate('', xy=end, xytext=start,
                arrowprops=dict(arrowstyle='->', lw=2, color='darkblue'))

# Add temporal dimension indicator
ax1.text(1, 6, 'Temporal\nDimension\n(T=3 timesteps)', ha='center', va='center',
         fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor='lightyellow', alpha=0.8))

# Detailed HGT Attention Mechanism
ax2 = plt.subplot(2, 2, 3)
ax2.set_xlim(0, 8)
ax2.set_ylim(0, 6)
ax2.axis('off')
ax2.set_title('HGT Attention Mechanism Detail', fontsize=14, fontweight='bold')

# Draw attention computation
attention_components = [
    {'name': 'Query (Q)\nRouter→Switch', 'pos': (0.5, 4.5), 'size': (1.5, 1), 'color': colors['attention']},
    {'name': 'Key (K)\nSwitch Features', 'pos': (2.5, 4.5), 'size': (1.5, 1), 'color': colors['attention']},
    {'name': 'Value (V)\nSwitch Features', 'pos': (4.5, 4.5), 'size': (1.5, 1), 'color': colors['attention']},
    {'name': 'Attention\nWeights', 'pos': (2.5, 2.5), 'size': (1.5, 1), 'color': colors['spatial']},
    {'name': 'Weighted\nMessage', 'pos': (5.5, 2.5), 'size': (1.5, 1), 'color': colors['spatial']},
    {'name': 'Aggregated\nEmbedding', 'pos': (3.5, 0.5), 'size': (1.5, 1), 'color': colors['output']}
]

for comp in attention_components:
    bbox = FancyBboxPatch(
        comp['pos'], comp['size'][0], comp['size'][1],
        boxstyle="round,pad=0.1", 
        facecolor=comp['color'],
        edgecolor='black',
        linewidth=1
    )
    ax2.add_patch(bbox)
    ax2.text(comp['pos'][0] + comp['size'][0]/2, comp['pos'][1] + comp['size'][1]/2, 
             comp['name'], ha='center', va='center', fontsize=8, fontweight='bold')

# Attention flow arrows
attention_arrows = [
    ((2, 4.5), (2.5, 3.5)),  # Q to attention
    ((3.25, 4.5), (3.25, 3.5)),  # K to attention
    ((5.25, 4.5), (5.25, 3.5)),  # V to weighted message
    ((4, 2.5), (4.5, 1.5)),  # Attention + Message to final
    ((5.5, 2.5), (4.5, 1.5))   # Weighted message to final
]

for start, end in attention_arrows:
    ax2.annotate('', xy=end, xytext=start,
                arrowprops=dict(arrowstyle='->', lw=1.5, color='darkred'))

# Multi-task Learning Diagram
ax3 = plt.subplot(2, 2, 4)
ax3.set_xlim(0, 8)
ax3.set_ylim(0, 6)
ax3.axis('off')
ax3.set_title('Multi-task Learning Objectives', fontsize=14, fontweight='bold')

# Draw multi-task components
multitask_components = [
    {'name': 'Node Embedding\n[32D]', 'pos': (3.5, 4.5), 'size': (1.5, 1), 'color': colors['temporal']},
    {'name': 'CPU Prediction\nTask', 'pos': (1.5, 2.5), 'size': (1.5, 1), 'color': colors['output']},
    {'name': 'Device Type\nClassification', 'pos': (5, 2.5), 'size': (1.5, 1), 'color': colors['output']},
    {'name': 'MSE Loss\n(Failure Impact)', 'pos': (1.5, 0.5), 'size': (1.5, 1), 'color': '#FFCDD2'},
    {'name': 'CrossEntropy Loss\n(Device Type)', 'pos': (5, 0.5), 'size': (1.5, 1), 'color': '#FFCDD2'},
    {'name': 'Combined Loss\nL = L_cpu + 0.5×L_device', 'pos': (3, 0.5), 'size': (2, 0.8), 'color': '#F8BBD9'}
]

for comp in multitask_components:
    bbox = FancyBboxPatch(
        comp['pos'], comp['size'][0], comp['size'][1],
        boxstyle="round,pad=0.1", 
        facecolor=comp['color'],
        edgecolor='black',
        linewidth=1
    )
    ax3.add_patch(bbox)
    ax3.text(comp['pos'][0] + comp['size'][0]/2, comp['pos'][1] + comp['size'][1]/2, 
             comp['name'], ha='center', va='center', fontsize=8, fontweight='bold')

# Multi-task arrows
multitask_arrows = [
    ((3.5, 4.5), (2.5, 3.5)),   # Embedding to CPU task
    ((4.5, 4.5), (5.5, 3.5)),   # Embedding to device task
    ((2.25, 2.5), (2.25, 1.5)), # CPU task to loss
    ((5.75, 2.5), (5.75, 1.5)), # Device task to loss
    ((2.25, 0.5), (3, 0.5)),    # Losses to combined
    ((5.75, 0.5), (5, 0.5))
]

for start, end in multitask_arrows:
    ax3.annotate('', xy=end, xytext=start,
                arrowprops=dict(arrowstyle='->', lw=1.5, color='darkgreen'))

plt.tight_layout()
plt.show()

# Print detailed explanation
print("\n" + "="*60)
print("ARCHITECTURE COMPONENTS BREAKDOWN")
print("="*60)

print("\n1. INPUT PROCESSING:")
print(f"   • Numeric Features: 2D (CPU usage, Memory usage)")
print(f"   • Text Embeddings: 384D (Sentence-BERT encoded status messages)")
print(f"   • Total Input Dimension: 386D per node")
print(f"   • Node Types: 3 (router, switch, firewall)")

print("\n2. SPATIAL PROCESSING (HGTConv):")
print(f"   • Input Channels: 386D")
print(f"   • Hidden Channels: 32D") 
print(f"   • Attention Heads: 2")
print(f"   • Heterogeneous Edge Types: Multiple (router-router, router-switch, etc.)")
print("   • Purpose: Learn spatial relationships between different device types")

print("\n3. TEMPORAL PROCESSING (GRU):")
print(f"   • Hidden Size: 32D")
print(f"   • Sequence Length: 3 timesteps")
print(f"   • Separate GRU per device type: 3 GRUs")
print("   • Purpose: Capture temporal evolution of device states")

print("\n4. OUTPUT LAYERS:")
print("   Original Model:")
print(f"     - CPU Prediction: Linear(32 → 1)")
print("   Improved Multi-task Model:")
print(f"     - CPU Prediction: Linear(32 → 1)")
print(f"     - Device Classification: Linear(32 → 3)")

print("\n5. TRAINING OBJECTIVES:")
print("   Original Model:")
print("     - MSE Loss for CPU prediction")
print("   Improved Model:")
print("     - Combined Loss = MSE(CPU) + 0.5 × CrossEntropy(Device)")
print("     - Encourages device type separation in embeddings")

print("\n6. KEY INNOVATIONS:")
print("   • Heterogeneous graph processing with different device types")
print("   • Temporal modeling with type-specific GRUs") 
print("   • Multi-task learning for better representation learning")
print("   • Attention mechanism for learning device interactions")

print("\n7. PERFORMANCE METRICS:")
print(f"   • Original Model - Silhouette Score: 0.73")
print(f"   • Improved Model - Silhouette Score: 0.85 (+16.5%)")
print(f"   • Device Classification Accuracy: 100%")
print(f"   • Adjusted Rand Index: 0.01 → 1.00 (perfect clustering)")

print("\n" + "="*60)

## Mathematical Formulation

### 1. **Heterogeneous Graph Transformer (HGT) Layer**

The HGT layer processes heterogeneous graphs with multiple node and edge types. For each node type τ and edge type φ:

#### **Attention Computation:**
```
For source node type τ_src and target node type τ_tgt:

Q^τ_src = H^(l)[τ_src] × W^Q_τ_src    (Query matrix)
K^τ_tgt = H^(l)[τ_tgt] × W^K_τ_tgt    (Key matrix)  
V^τ_tgt = H^(l)[τ_tgt] × W^V_τ_tgt    (Value matrix)

Attention(τ_src, τ_tgt) = softmax(Q^τ_src × K^τ_tgt^T / √d_k)
```

#### **Message Passing:**
```
Message_i = Σ_j∈N(i) Attention(τ_i, τ_j) × V^τ_j

H^(l+1)[τ] = LayerNorm(H^(l)[τ] + Message_τ)
```

Where:
- `H^(l)[τ]` = embeddings for node type τ at layer l
- `W^Q_τ`, `W^K_τ`, `W^V_τ` = learnable type-specific weight matrices
- `N(i)` = neighborhood of node i
- `d_k` = dimension of key vectors

### 2. **Temporal GRU Processing**

For each device type τ, the GRU processes temporal sequences:

#### **GRU Equations:**
```
r_t = σ(W_r × [h_{t-1}, x_t])         (Reset gate)
z_t = σ(W_z × [h_{t-1}, x_t])         (Update gate)
h̃_t = tanh(W_h × [r_t ⊙ h_{t-1}, x_t]) (Candidate state)
h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h̃_t (Hidden state)
```

Where:
- `x_t` = spatial embedding at timestep t from HGT layer
- `h_t` = hidden state at timestep t
- `σ` = sigmoid function
- `⊙` = element-wise multiplication

### 3. **Multi-task Loss Function**

#### **Combined Objective:**
```
L_total = L_CPU + λ × L_device

L_CPU = MSE(ŷ_cpu, y_cpu) = (1/N) Σ(ŷ_cpu - y_cpu)²

L_device = CrossEntropy(ŷ_device, y_device) 
         = -(1/N) Σ Σ y_device[i,c] × log(ŷ_device[i,c])

Where λ = 0.5 (device task weight)
```

### 4. **Model Complexity Analysis**

#### **Parameter Count:**
- **HGT Layer**: O(|V| × d_h × h) where |V| = node types, d_h = hidden dim, h = heads
- **GRU Layers**: O(|V| × 3 × d_h²) for 3 gates per device type  
- **Output Layers**: O(d_h × (1 + |V|)) for CPU + device classification

#### **Computational Complexity:**
- **Spatial**: O(|E| × d_h) per timestep where |E| = number of edges
- **Temporal**: O(T × |N| × d_h²) where T = sequence length, |N| = total nodes
- **Total**: O(T × (|E| × d_h + |N| × d_h²))

## Key Advantages of This Architecture

### 1. **Heterogeneous Processing**
- **Type-specific Parameters**: Different learnable weights for each device type
- **Cross-type Interactions**: Can model router→switch, switch→firewall relationships
- **Scalable**: Easy to add new device types without architectural changes

### 2. **Temporal Modeling**
- **Memory Retention**: GRU maintains information about past device states
- **Type-specific Patterns**: Each device type learns its own temporal dynamics
- **Failure Propagation**: Can track how failures spread through the network over time

### 3. **Multi-task Learning Benefits**
- **Better Representations**: Device classification forces embeddings to be type-discriminative
- **Regularization**: Additional objective prevents overfitting to CPU prediction task
- **Interpretability**: Clear separation between device types in embedding space

### 4. **Attention Mechanism**
- **Adaptive Weights**: Learns which connections are most important for each prediction
- **Interpretable**: Attention weights show which devices influence each other most
- **Dynamic**: Attention can change based on network state and failure conditions