# üõ°Ô∏è Malware Classification Using Graph Neural Networks on Control Flow Graphs

This notebook implements a complete pipeline for binary malware classification using:
- **Control Flow Graph (CFG) extraction** with angr
- **Graph Neural Networks (GNN)** for classification
- **PyTorch Geometric** for graph deep learning

## üìã Table of Contents
1. [Setup & Installation](#setup)
2. [Data Upload & Preparation](#data)
3. [CFG Extraction](#extraction)
4. [Feature Engineering](#features)
5. [Model Definition](#model)
6. [Training](#training)
7. [Evaluation & Visualization](#evaluation)

---

## ‚öôÔ∏è Runtime Configuration

**Important:** Enable GPU for faster training!
- Go to: **Runtime ‚Üí Change runtime type ‚Üí GPU**

## 1. Setup & Installation

Install all required dependencies.

In [None]:
%%capture
# Install PyTorch and PyTorch Geometric
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install torch-geometric
!pip install pyg-lib torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu118.html

# Install other dependencies
!pip install angr networkx scikit-learn pandas matplotlib seaborn pyyaml tqdm

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset, DataLoader
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_max_pool
from torch_geometric.utils import from_networkx

import angr
import networkx as nx
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    roc_curve,
    auc,
    precision_recall_curve,
    average_precision_score
)

import os
import json
import hashlib
import math
import warnings
from pathlib import Path
from tqdm.auto import tqdm
from datetime import datetime
from google.colab import files

warnings.filterwarnings('ignore')

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Using device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

### Configuration

In [None]:
# Configuration
CONFIG = {
    'seed': 42,
    'model_type': 'gcn',  # Options: gcn, gcn_deep, gat, graphsage
    'num_features': 10,
    'hidden_channels': 64,
    'num_classes': 2,
    'dropout': 0.5,
    'pooling': 'mean',
    'epochs': 100,
    'batch_size': 32,
    'learning_rate': 0.001,
    'weight_decay': 0.0005,
    'early_stopping_patience': 15,
    'train_ratio': 0.7,
    'val_ratio': 0.15,
    'test_ratio': 0.15,
    'use_class_weights': True
}

# Set random seeds
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

print("‚úÖ Configuration loaded")
print(json.dumps(CONFIG, indent=2))

## 2. Data Upload & Preparation

### Option A: Upload Binary Files

Upload your executable files (benign and malware samples).

In [None]:
# Create directory structure
!mkdir -p data/raw/benign data/raw/malware data/processed

print("üìÅ Directory structure created:")
print("   - data/raw/benign/    (place benign executables here)")
print("   - data/raw/malware/   (place malware executables here)")
print("   - data/processed/     (processed CFG files)")

In [None]:
# Upload binary files
print("üì§ Upload your binary files")
print("   1. First, upload BENIGN executables")
print("   2. Then, upload MALWARE executables")
print("\n‚ö†Ô∏è  Make sure to organize them correctly!\n")

# Uncomment the section you want to upload

# Upload benign files
# print("Uploading BENIGN files...")
# uploaded = files.upload()
# for filename in uploaded.keys():
#     !mv "{filename}" data/raw/benign/
# print(f"‚úÖ Uploaded {len(uploaded)} benign files\n")

# Upload malware files
# print("Uploading MALWARE files...")
# uploaded = files.upload()
# for filename in uploaded.keys():
#     !mv "{filename}" data/raw/malware/
# print(f"‚úÖ Uploaded {len(uploaded)} malware files")

print("\nüí° TIP: You can also mount Google Drive and use files from there!")

### Option B: Mount Google Drive

If you have files in Google Drive, mount it here.

In [None]:
# Mount Google Drive (optional)
# from google.colab import drive
# drive.mount('/content/drive')

# # Copy files from Drive
# !cp /content/drive/MyDrive/your_benign_folder/* data/raw/benign/
# !cp /content/drive/MyDrive/your_malware_folder/* data/raw/malware/

In [None]:
# Check uploaded files
benign_files = !ls data/raw/benign/ 2>/dev/null | wc -l
malware_files = !ls data/raw/malware/ 2>/dev/null | wc -l

benign_count = int(benign_files[0]) if benign_files else 0
malware_count = int(malware_files[0]) if malware_files else 0

print(f"üìä Dataset Summary:")
print(f"   Benign samples:  {benign_count}")
print(f"   Malware samples: {malware_count}")
print(f"   Total:           {benign_count + malware_count}")

if benign_count == 0 and malware_count == 0:
    print("\n‚ö†Ô∏è  No files found! Please upload binary files first.")

## 3. CFG Extraction

Extract Control Flow Graphs from binary executables using angr.

In [None]:
def get_file_hash(file_path):
    """Calculate SHA256 hash of a file"""
    sha256_hash = hashlib.sha256()
    with open(file_path, "rb") as f:
        for byte_block in iter(lambda: f.read(4096), b""):
            sha256_hash.update(byte_block)
    return sha256_hash.hexdigest()


def strip_none_attributes(G):
    """Remove None attributes from graph (required for GraphML export)"""
    for node, attrs in list(G.nodes(data=True)):
        for k, v in list(attrs.items()):
            if v is None:
                del attrs[k]
    
    for u, v, attrs in list(G.edges(data=True)):
        for k, val in list(attrs.items()):
            if val is None:
                del attrs[k]


def extract_cfg_from_binary(binary_path, label):
    """
    Extract CFG from a single binary
    
    Args:
        binary_path: Path to the binary file
        label: 0 for benign, 1 for malware
    
    Returns:
        tuple: (networkx graph, metadata dict)
    """
    file_hash = get_file_hash(binary_path)
    
    metadata = {
        'filename': os.path.basename(binary_path),
        'file_hash': file_hash,
        'label': label,
        'status': 'failed',
        'num_nodes': 0,
        'num_edges': 0
    }
    
    try:
        # Load binary with angr
        proj = angr.Project(
            binary_path,
            load_options={'auto_load_libs': False}
        )
        
        # Generate CFG
        cfg = proj.analyses.CFGFast(normalize=True)
        G = cfg.graph
        
        # Strip None attributes
        strip_none_attributes(G)
        
        metadata['status'] = 'success'
        metadata['num_nodes'] = G.number_of_nodes()
        metadata['num_edges'] = G.number_of_edges()
        
        return G, metadata
    
    except Exception as e:
        metadata['error'] = str(e)
        return None, metadata


print("‚úÖ CFG extraction functions defined")

In [None]:
# Extract CFGs from all binaries
print("üîç Extracting CFGs from binaries...\n")

all_graphs = []
all_metadata = []

# Process benign files
benign_dir = 'data/raw/benign'
if os.path.exists(benign_dir):
    benign_files = [f for f in os.listdir(benign_dir) if os.path.isfile(os.path.join(benign_dir, f))]
    print(f"Processing {len(benign_files)} benign files...")
    
    for filename in tqdm(benign_files, desc="Benign"):
        file_path = os.path.join(benign_dir, filename)
        graph, metadata = extract_cfg_from_binary(file_path, label=0)
        
        if graph is not None:
            all_graphs.append((graph, 0))
        all_metadata.append(metadata)

# Process malware files
malware_dir = 'data/raw/malware'
if os.path.exists(malware_dir):
    malware_files = [f for f in os.listdir(malware_dir) if os.path.isfile(os.path.join(malware_dir, f))]
    print(f"\nProcessing {len(malware_files)} malware files...")
    
    for filename in tqdm(malware_files, desc="Malware"):
        file_path = os.path.join(malware_dir, filename)
        graph, metadata = extract_cfg_from_binary(file_path, label=1)
        
        if graph is not None:
            all_graphs.append((graph, 1))
        all_metadata.append(metadata)

# Statistics
successful = sum(1 for m in all_metadata if m['status'] == 'success')
failed = len(all_metadata) - successful

print(f"\n‚úÖ CFG Extraction Complete:")
print(f"   Successful: {successful}/{len(all_metadata)}")
print(f"   Failed:     {failed}/{len(all_metadata)}")
print(f"   Total graphs: {len(all_graphs)}")

## 4. Feature Engineering

Extract node features from CFGs and convert to PyTorch Geometric format.

In [None]:
def extract_node_features(node_id, graph, node_attrs):
    """
    Extract features for a single node (basic block)
    
    Returns:
        list: Feature vector [10 features]
    """
    features = []
    
    # Feature 1: Node size (instruction count)
    size = 1
    if isinstance(node_id, str) and '[' in node_id and ']' in node_id:
        try:
            size = int(node_id.split('[')[-1].split(']')[0])
        except:
            size = 1
    features.append(float(size))
    
    # Feature 2: In-degree
    in_degree = graph.in_degree(node_id)
    features.append(float(in_degree))
    
    # Feature 3: Out-degree
    out_degree = graph.out_degree(node_id)
    features.append(float(out_degree))
    
    # Feature 4: Is entry node
    features.append(1.0 if in_degree == 0 else 0.0)
    
    # Feature 5: Is exit node
    features.append(1.0 if out_degree == 0 else 0.0)
    
    # Feature 6: Is hub node
    features.append(1.0 if (in_degree > 2 and out_degree > 2) else 0.0)
    
    # Feature 7: Degree ratio
    degree_ratio = float(out_degree) / (float(in_degree) + 1.0)
    features.append(degree_ratio)
    
    # Feature 8: Is branching node
    features.append(1.0 if out_degree > 1 else 0.0)
    
    # Feature 9: Is merge node
    features.append(1.0 if in_degree > 1 else 0.0)
    
    # Feature 10: Log of size
    log_size = math.log(size + 1)
    features.append(log_size)
    
    return features


def cfg_to_pyg_data(graph, label):
    """
    Convert NetworkX CFG to PyTorch Geometric Data object
    
    Args:
        graph: NetworkX graph
        label: 0 for benign, 1 for malware
    
    Returns:
        Data: PyTorch Geometric Data object
    """
    try:
        if graph.number_of_nodes() == 0:
            return None
        
        # Create integer node mapping
        node_list = list(graph.nodes())
        node_to_idx = {node: idx for idx, node in enumerate(node_list)}
        
        # Relabel nodes
        G = nx.relabel_nodes(graph, node_to_idx)
        
        # Extract node features
        node_features = []
        for node_id in range(len(node_list)):
            original_node = node_list[node_id]
            attrs = G.nodes[node_id]
            features = extract_node_features(original_node, G, attrs)
            node_features.append(features)
        
        # Convert to tensors
        x = torch.tensor(node_features, dtype=torch.float)
        
        # Create edge index
        edge_list = list(G.edges())
        if len(edge_list) == 0:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        else:
            edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        
        # Create label
        y = torch.tensor([label], dtype=torch.long)
        
        # Create PyG Data object
        data = Data(x=x, edge_index=edge_index, y=y)
        data.num_nodes = len(node_list)
        
        return data
    
    except Exception as e:
        print(f"Error processing graph: {e}")
        return None


print("‚úÖ Feature extraction functions defined")

In [None]:
# Convert all graphs to PyG Data objects
print("üîß Converting CFGs to PyTorch Geometric format...\n")

dataset = []
labels = []

for graph, label in tqdm(all_graphs, desc="Converting"):
    data = cfg_to_pyg_data(graph, label)
    if data is not None:
        dataset.append(data)
        labels.append(label)

print(f"\n‚úÖ Conversion Complete:")
print(f"   Total samples: {len(dataset)}")
print(f"   Benign:  {labels.count(0)}")
print(f"   Malware: {labels.count(1)}")

if len(dataset) > 0:
    print(f"\nüìä Sample statistics:")
    print(f"   Features per node: {dataset[0].x.shape[1]}")
    print(f"   Average nodes per graph: {np.mean([d.num_nodes for d in dataset]):.1f}")
    print(f"   Min nodes: {min([d.num_nodes for d in dataset])}")
    print(f"   Max nodes: {max([d.num_nodes for d in dataset])}")

## 5. Model Definition

Define Graph Neural Network architectures for malware classification.

In [None]:
class MalwareGCN(nn.Module):
    """
    Graph Convolutional Network for Malware Classification
    """
    def __init__(self, num_node_features, hidden_channels=64, num_classes=2,
                 dropout=0.5, pooling='mean'):
        super(MalwareGCN, self).__init__()
        
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.fc = nn.Linear(hidden_channels, num_classes)
        
        self.dropout = dropout
        self.pooling = pooling
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # First GCN layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Second GCN layer
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Global pooling
        if self.pooling == 'mean':
            x = global_mean_pool(x, batch)
        else:
            x = global_max_pool(x, batch)
        
        # Classification
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


class MalwareGAT(nn.Module):
    """
    Graph Attention Network for Malware Classification
    """
    def __init__(self, num_node_features, hidden_channels=64, num_classes=2,
                 heads=4, dropout=0.5, pooling='mean'):
        super(MalwareGAT, self).__init__()
        
        self.conv1 = GATConv(num_node_features, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, hidden_channels, heads=1)
        self.fc = nn.Linear(hidden_channels, num_classes)
        
        self.dropout = dropout
        self.pooling = pooling
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # First GAT layer
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Second GAT layer
        x = self.conv2(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # Global pooling
        if self.pooling == 'mean':
            x = global_mean_pool(x, batch)
        else:
            x = global_max_pool(x, batch)
        
        # Classification
        x = self.fc(x)
        return F.log_softmax(x, dim=1)


def create_model(config):
    """Factory function to create model"""
    if config['model_type'] == 'gcn':
        return MalwareGCN(
            num_node_features=config['num_features'],
            hidden_channels=config['hidden_channels'],
            num_classes=config['num_classes'],
            dropout=config['dropout'],
            pooling=config['pooling']
        )
    elif config['model_type'] == 'gat':
        return MalwareGAT(
            num_node_features=config['num_features'],
            hidden_channels=config['hidden_channels'],
            num_classes=config['num_classes'],
            dropout=config['dropout'],
            pooling=config['pooling']
        )
    else:
        raise ValueError(f"Unknown model type: {config['model_type']}")


print("‚úÖ Model architectures defined")

### Data Splitting

Split dataset into train, validation, and test sets.

In [None]:
# Split dataset
if len(dataset) > 0:
    indices = list(range(len(dataset)))
    
    # First split: train vs (val + test)
    train_indices, temp_indices = train_test_split(
        indices,
        test_size=(CONFIG['val_ratio'] + CONFIG['test_ratio']),
        stratify=labels,
        random_state=CONFIG['seed']
    )
    
    # Second split: val vs test
    temp_labels = [labels[i] for i in temp_indices]
    val_size = CONFIG['val_ratio'] / (CONFIG['val_ratio'] + CONFIG['test_ratio'])
    
    val_indices, test_indices = train_test_split(
        temp_indices,
        test_size=(1 - val_size),
        stratify=temp_labels,
        random_state=CONFIG['seed']
    )
    
    # Create subsets
    train_dataset = [dataset[i] for i in train_indices]
    val_dataset = [dataset[i] for i in val_indices]
    test_dataset = [dataset[i] for i in test_indices]
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False)
    
    print(f"‚úÖ Data split complete:")
    print(f"   Train: {len(train_dataset)} samples")
    print(f"   Val:   {len(val_dataset)} samples")
    print(f"   Test:  {len(test_dataset)} samples")
    
    # Calculate class weights
    if CONFIG['use_class_weights']:
        num_benign = labels.count(0)
        num_malware = labels.count(1)
        total = len(labels)
        
        weight_benign = total / (2 * num_benign) if num_benign > 0 else 1.0
        weight_malware = total / (2 * num_malware) if num_malware > 0 else 1.0
        
        class_weights = torch.tensor([weight_benign, weight_malware]).to(device)
        print(f"\n   Class weights: [{weight_benign:.3f}, {weight_malware:.3f}]")
    else:
        class_weights = None
else:
    print("‚ùå No data available for splitting!")

## 6. Training

Train the GNN model.

In [None]:
def train_epoch(model, train_loader, optimizer, device, class_weights=None):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        
        out = model(data)
        
        if class_weights is not None:
            loss = F.nll_loss(out, data.y, weight=class_weights)
        else:
            loss = F.nll_loss(out, data.y)
        
        loss.backward()
        optimizer.step()
        
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
        total += data.y.size(0)
        total_loss += loss.item()
    
    return total_loss / len(train_loader), correct / total


@torch.no_grad()
def evaluate(model, loader, device, class_weights=None):
    """Evaluate the model"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    all_preds = []
    all_labels = []
    all_probs = []
    
    for data in loader:
        data = data.to(device)
        out = model(data)
        
        if class_weights is not None:
            loss = F.nll_loss(out, data.y, weight=class_weights)
        else:
            loss = F.nll_loss(out, data.y)
        
        pred = out.argmax(dim=1)
        probs = torch.exp(out)
        
        all_preds.extend(pred.cpu().numpy())
        all_labels.extend(data.y.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())
        
        correct += (pred == data.y).sum().item()
        total += data.y.size(0)
        total_loss += loss.item()
    
    return (total_loss / len(loader), correct / total, 
            np.array(all_preds), np.array(all_labels), np.array(all_probs))


class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
    
    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0


print("‚úÖ Training functions defined")

In [None]:
# Train the model
if len(dataset) > 0:
    print("üöÄ Starting training...\n")
    
    # Create model
    model = create_model(CONFIG).to(device)
    
    # Count parameters
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"üìä Model: {CONFIG['model_type'].upper()}")
    print(f"   Parameters: {num_params:,}\n")
    
    # Optimizer and scheduler
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay']
    )
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )
    
    # Early stopping
    early_stopping = EarlyStopping(patience=CONFIG['early_stopping_patience'])
    
    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    best_val_acc = 0.0
    best_model_state = None
    
    # Training loop
    for epoch in range(CONFIG['epochs']):
        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, optimizer, device, class_weights
        )
        
        # Validate
        val_loss, val_acc, _, _, _ = evaluate(
            model, val_loader, device, class_weights
        )
        
        # Update scheduler
        scheduler.step(val_loss)
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
        
        # Print progress
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Epoch {epoch+1:3d}/{CONFIG['epochs']} | "
                  f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        
        # Early stopping
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print(f"\n‚èπÔ∏è  Early stopping at epoch {epoch+1}")
            break
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    print(f"\n‚úÖ Training complete!")
    print(f"   Best validation accuracy: {best_val_acc:.4f}")
else:
    print("‚ùå No data available for training!")

## 7. Evaluation & Visualization

Evaluate the trained model and visualize results.

In [None]:
# Evaluate on test set
if len(dataset) > 0:
    print("üìä Evaluating on test set...\n")
    
    test_loss, test_acc, test_preds, test_labels, test_probs = evaluate(
        model, test_loader, device, class_weights
    )
    
    print(f"Test Results:")
    print(f"  Loss:     {test_loss:.4f}")
    print(f"  Accuracy: {test_acc:.4f}")
    
    # Classification report
    print(f"\nüìã Classification Report:\n")
    print(classification_report(
        test_labels, test_preds,
        target_names=['Benign', 'Malware'],
        digits=4
    ))

In [None]:
# Plot training history
if len(dataset) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss plot
    axes[0].plot(epochs, history['train_loss'], 'b-', label='Training Loss', linewidth=2)
    axes[0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    axes[0].legend(fontsize=10)
    axes[0].grid(alpha=0.3)
    
    # Accuracy plot
    axes[1].plot(epochs, history['train_acc'], 'b-', label='Training Accuracy', linewidth=2)
    axes[1].plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy', linewidth=2)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Accuracy', fontsize=12)
    axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
    axes[1].legend(fontsize=10)
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Confusion matrix
if len(dataset) > 0:
    cm = confusion_matrix(test_labels, test_preds)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Benign', 'Malware'],
                yticklabels=['Benign', 'Malware'],
                cbar_kws={'label': 'Count'})
    
    plt.title('Confusion Matrix', fontsize=16, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    
    # Add percentages
    total = np.sum(cm)
    for i in range(2):
        for j in range(2):
            percentage = cm[i, j] / total * 100
            plt.text(j + 0.5, i + 0.7, f'({percentage:.1f}%)',
                    ha='center', va='center', fontsize=10, color='gray')
    
    plt.tight_layout()
    plt.show()
    
    # Print confusion matrix details
    tn, fp, fn, tp = cm.ravel()
    print(f"\nüìä Confusion Matrix Details:")
    print(f"   True Negatives (TN):  {tn} (Correctly identified benign)")
    print(f"   False Positives (FP): {fp} (Benign misclassified as malware)")
    print(f"   False Negatives (FN): {fn} (Malware misclassified as benign) ‚ö†Ô∏è")
    print(f"   True Positives (TP):  {tp} (Correctly identified malware)")

In [None]:
# ROC Curve
if len(dataset) > 0:
    malware_probs = test_probs[:, 1]
    fpr, tpr, thresholds = roc_curve(test_labels, malware_probs)
    roc_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=3,
             label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--',
             label='Random classifier')
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title('Receiver Operating Characteristic (ROC) Curve',
              fontsize=14, fontweight='bold')
    plt.legend(loc="lower right", fontsize=11)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"\nüéØ ROC AUC Score: {roc_auc:.4f}")

In [None]:
# Precision-Recall Curve
if len(dataset) > 0:
    precision, recall, _ = precision_recall_curve(test_labels, malware_probs)
    avg_precision = average_precision_score(test_labels, malware_probs)
    
    plt.figure(figsize=(8, 6))
    plt.plot(recall, precision, color='blue', lw=3,
             label=f'PR curve (AP = {avg_precision:.4f})')
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curve', fontsize=14, fontweight='bold')
    plt.legend(loc="lower left", fontsize=11)
    plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print(f"\nüéØ Average Precision Score: {avg_precision:.4f}")

## üìä Summary

Generate a comprehensive summary of the results.

In [None]:
# Summary statistics
if len(dataset) > 0:
    print("=" * 60)
    print("üìä FINAL SUMMARY")
    print("=" * 60)
    print(f"\nDataset:")
    print(f"  Total samples:    {len(dataset)}")
    print(f"  Benign samples:   {labels.count(0)}")
    print(f"  Malware samples:  {labels.count(1)}")
    print(f"\nModel:")
    print(f"  Architecture:     {CONFIG['model_type'].upper()}")
    print(f"  Parameters:       {num_params:,}")
    print(f"  Hidden channels:  {CONFIG['hidden_channels']}")
    print(f"\nTraining:")
    print(f"  Epochs trained:   {len(history['train_loss'])}")
    print(f"  Best val acc:     {best_val_acc:.4f}")
    print(f"\nTest Performance:")
    print(f"  Accuracy:         {test_acc:.4f}")
    print(f"  ROC AUC:          {roc_auc:.4f}")
    print(f"  Avg Precision:    {avg_precision:.4f}")
    print(f"\nConfusion Matrix:")
    print(f"  True Negatives:   {tn}")
    print(f"  False Positives:  {fp}")
    print(f"  False Negatives:  {fn}")
    print(f"  True Positives:   {tp}")
    print("\n" + "=" * 60)
    print("üéâ Analysis Complete!")
    print("=" * 60)

## üíæ Save Model & Results

Save the trained model and results for later use.

In [None]:
# Save model and results
if len(dataset) > 0:
    # Save model
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': CONFIG,
        'test_accuracy': test_acc,
        'roc_auc': roc_auc
    }, 'malware_gnn_model.pt')
    
    # Save results
    results = {
        'config': CONFIG,
        'history': history,
        'test_accuracy': float(test_acc),
        'test_loss': float(test_loss),
        'roc_auc': float(roc_auc),
        'avg_precision': float(avg_precision),
        'confusion_matrix': cm.tolist()
    }
    
    with open('results.json', 'w') as f:
        json.dump(results, f, indent=2)
    
    print("‚úÖ Model and results saved!")
    print("   - malware_gnn_model.pt")
    print("   - results.json")
    
    # Download files
    print("\nüì• Download files:")
    files.download('malware_gnn_model.pt')
    files.download('results.json')

## üöÄ Next Steps

To improve your model:

1. **Collect more data** - Aim for 1000+ samples per class
2. **Try different architectures** - Change `CONFIG['model_type']` to 'gat'
3. **Add more features** - Modify `extract_node_features()` function
4. **Tune hyperparameters** - Adjust learning rate, hidden channels, dropout
5. **Multi-class classification** - Classify by malware family

---

## ‚ö†Ô∏è Safety Reminder

**Always work with malware in isolated environments!**
- Use VMs with no network access
- Never execute malware samples
- Take regular snapshots

---

**Happy Malware Hunting! üõ°Ô∏èüîç**