In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import pandas as pd
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.preprocessing import StandardScaler
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class DataSprawlGNN(nn.Module):
    def __init__(self, num_node_features, num_edge_features, hidden_dim=128):
        super(DataSprawlGNN, self).__init__()
        # Node feature processing
        self.node_encoder = nn.Linear(num_node_features, hidden_dim)
        
        # Edge feature processing
        self.edge_encoder = nn.Linear(num_edge_features, hidden_dim)
        
        # Graph convolutional layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Prediction heads
        self.sensitive_classifier = nn.Linear(hidden_dim, 2)  # Binary classification for sensitive data
        self.sprawl_predictor = nn.Linear(hidden_dim * 2, 1)  # Edge scoring for sprawl pathways
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        
        # Encode node features
        x = F.relu(self.node_encoder(x))
        x = self.dropout(x)
        
        # Encode edge features
        edge_embed = F.relu(self.edge_encoder(edge_attr))
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Sensitive data classification
        sensitive_pred = self.sensitive_classifier(x)
        
        # Sprawl pathway prediction (edge scoring)
        src, dst = edge_index
        edge_scores = self.sprawl_predictor(torch.cat([x[src], x[dst]], dim=1))
        
        return sensitive_pred, edge_scores.squeeze()

def load_and_preprocess_data(node_file, edge_file):
    # Load node data
    node_df = pd.read_csv(node_file)
    
    # Load edge data
    edge_df = pd.read_csv(edge_file)
    
    # Preprocess node features
    node_features = node_df.drop(['node_id', 'sensitive_label'], axis=1).values
    node_labels = node_df['sensitive_label'].values
    
    # Preprocess edge features
    edge_features = edge_df.drop(['source', 'target', 'sprawl_risk'], axis=1).values
    edge_labels = edge_df['sprawl_risk'].values
    
    # Normalize features
    node_scaler = StandardScaler()
    node_features = node_scaler.fit_transform(node_features)
    
    edge_scaler = StandardScaler()
    edge_features = edge_scaler.fit_transform(edge_features)
    
    # Create edge index
    edge_index = torch.tensor(edge_df[['source', 'target']].values.T, dtype=torch.long)
    
    # Create PyG Data object
    data = Data(
        x=torch.tensor(node_features, dtype=torch.float),
        edge_index=edge_index,
        edge_attr=torch.tensor(edge_features, dtype=torch.float),
        y=torch.tensor(node_labels, dtype=torch.long),
        edge_y=torch.tensor(edge_labels, dtype=torch.float)
    )
    
    return data

def train_model(data, epochs=100, batch_size=32):
    # Split data into train/val/test
    num_nodes = data.x.size(0)
    num_edges = data.edge_index.size(1)
    
    # Node splits
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    
    # Random splits (60/20/20)
    indices = torch.randperm(num_nodes)
    train_mask[indices[:int(0.6*num_nodes)]] = True
    val_mask[indices[int(0.6*num_nodes):int(0.8*num_nodes)]] = True
    test_mask[indices[int(0.8*num_nodes):]] = True
    
    # Edge splits
    edge_train_mask = torch.zeros(num_edges, dtype=torch.bool)
    edge_val_mask = torch.zeros(num_edges, dtype=torch.bool)
    edge_test_mask = torch.zeros(num_edges, dtype=torch.bool)
    
    edge_indices = torch.randperm(num_edges)
    edge_train_mask[edge_indices[:int(0.6*num_edges)]] = True
    edge_val_mask[edge_indices[int(0.6*num_edges):int(0.8*num_edges)]] = True
    edge_test_mask[edge_indices[int(0.8*num_edges):]] = True
    
    # Add masks to data
    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    data.edge_train_mask = edge_train_mask
    data.edge_val_mask = edge_val_mask
    data.edge_test_mask = edge_test_mask
    
    # Initialize model
    model = DataSprawlGNN(
        num_node_features=data.x.size(1),
        num_edge_features=data.edge_attr.size(1)
    ).to(device)
    
    # Loss functions and optimizer
    criterion_node = nn.CrossEntropyLoss()
    criterion_edge = nn.BCEWithLogitsLoss()
    optimizer = Adam(model.parameters(), lr=0.01)
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
    
    # Training loop
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        
        # Forward pass
        sensitive_pred, edge_scores = model(data)
        
        # Compute losses
        node_loss = criterion_node(sensitive_pred[data.train_mask], data.y[data.train_mask])
        edge_loss = criterion_edge(edge_scores[data.edge_train_mask], data.edge_y[data.edge_train_mask])
        total_loss = node_loss + edge_loss
        
        # Backward pass
        total_loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Validation
        model.eval()
        with torch.no_grad():
            val_sensitive_pred, val_edge_scores = model(data)
            val_node_loss = criterion_node(val_sensitive_pred[data.val_mask], data.y[data.val_mask])
            val_edge_loss = criterion_edge(val_edge_scores[data.edge_val_mask], data.edge_y[data.edge_val_mask])
            val_total_loss = val_node_loss + val_edge_loss
            
        # Save losses for plotting
        train_losses.append(total_loss.item())
        val_losses.append(val_total_loss.item())
        
        # Early stopping check
        if val_total_loss < best_val_loss:
            best_val_loss = val_total_loss
            torch.save(model.state_dict(), 'best_model.pt')
        
        print(f'Epoch {epoch+1}/{epochs} | Train Loss: {total_loss.item():.4f} | Val Loss: {val_total_loss.item():.4f}')
    
    # Plot training curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss Curves')
    plt.savefig('training_curves.png')
    plt.close()
    
    return model

def evaluate_model(model, data):
    model.eval()
    with torch.no_grad():
        sensitive_pred, edge_scores = model(data)
        
        # Node classification metrics
        node_pred = sensitive_pred.argmax(dim=1)
        node_acc = (node_pred[data.test_mask] == data.y[data.test_mask]).float().mean()
        node_precision = precision_score(data.y[data.test_mask].cpu(), node_pred[data.test_mask].cpu())
        node_recall = recall_score(data.y[data.test_mask].cpu(), node_pred[data.test_mask].cpu())
        node_f1 = f1_score(data.y[data.test_mask].cpu(), node_pred[data.test_mask].cpu())
        
        # Edge prediction metrics
        edge_pred = (torch.sigmoid(edge_scores) > 0.5).float()
        edge_acc = (edge_pred[data.edge_test_mask] == data.edge_y[data.edge_test_mask]).float().mean()
        edge_precision = precision_score(data.edge_y[data.edge_test_mask].cpu(), edge_pred[data.edge_test_mask].cpu())
        edge_recall = recall_score(data.edge_y[data.edge_test_mask].cpu(), edge_pred[data.edge_test_mask].cpu())
        edge_f1 = f1_score(data.edge_y[data.edge_test_mask].cpu(), edge_pred[data.edge_test_mask].cpu())
        
        print("\nNode Classification Results:")
        print(f"Accuracy: {node_acc:.4f} | Precision: {node_precision:.4f} | Recall: {node_recall:.4f} | F1: {node_f1:.4f}")
        
        print("\nEdge Prediction Results:")
        print(f"Accuracy: {edge_acc:.4f} | Precision: {edge_precision:.4f} | Recall: {edge_recall:.4f} | F1: {edge_f1:.4f}")
        
        return {
            'node_metrics': {
                'accuracy': node_acc.item(),
                'precision': node_precision,
                'recall': node_recall,
                'f1': node_f1
            },
            'edge_metrics': {
                'accuracy': edge_acc.item(),
                'precision': edge_precision,
                'recall': edge_recall,
                'f1': edge_f1
            }
        }

def generate_mitigation_recommendations(model, data, node_df, edge_df):
    model.eval()
    with torch.no_grad():
        sensitive_pred, edge_scores = model(data)
        
        # Get top sensitive nodes
        sensitive_probs = torch.softmax(sensitive_pred, dim=1)[:, 1]
        top_sensitive_nodes = torch.topk(sensitive_probs, k=10)
        
        # Get top risky edges
        edge_risks = torch.sigmoid(edge_scores)
        top_risky_edges = torch.topk(edge_risks, k=10)
        
        print("\nTop Sensitive Data Assets:")
        for i, (node_idx, score) in enumerate(zip(top_sensitive_nodes.indices, top_sensitive_nodes.values)):
            node_info = node_df.iloc[node_idx.item()]
            print(f"{i+1}. Node {node_idx.item()} (Score: {score:.4f})")
            print(f"   Type: {node_info['type']}, Size: {node_info['size']}, Last Accessed: {node_info['last_accessed']}")
        
        print("\nTop Risky Sprawl Pathways:")
        for i, (edge_idx, score) in enumerate(zip(top_risky_edges.indices, top_risky_edges.values)):
            edge_info = edge_df.iloc[edge_idx.item()]
            print(f"{i+1}. Edge {edge_idx.item()} (Score: {score:.4f})")
            print(f"   Source: {edge_info['source']} -> Target: {edge_info['target']}")
            print(f"   Access Frequency: {edge_info['access_frequency']}, Sharing Level: {edge_info['sharing_level']}")
        
        return {
            'sensitive_nodes': [(idx.item(), score.item()) for idx, score in zip(top_sensitive_nodes.indices, top_sensitive_nodes.values)],
            'risky_edges': [(idx.item(), score.item()) for idx, score in zip(top_risky_edges.indices, top_risky_edges.values)]
        }

# Main execution
if __name__ == "__main__":
    # Wait for dataset files (you'll provide these)
    print("Please provide the paths to your node and edge CSV files when ready.")
    print("Expected format:")
    print("Node CSV: node_id, [features...], sensitive_label")
    print("Edge CSV: source, target, [features...], sprawl_risk")
    
    # Example usage (uncomment when files are available):
    """
    node_file = "nodes.csv"
    edge_file = "edges.csv"
    
    # Load and preprocess data
    data = load_and_preprocess_data(node_file, edge_file)
    data = data.to(device)
    
    # Load the node and edge DataFrames for recommendations
    node_df = pd.read_csv(node_file)
    edge_df = pd.read_csv(edge_file)
    
    # Train model
    model = train_model(data)
    
    # Evaluate
    metrics = evaluate_model(model, data)
    
    # Generate recommendations
    recommendations = generate_mitigation_recommendations(model, data, node_df, edge_df)
    """