# Malware Classification with Graph Embeddings

This notebook builds an end-to-end workflow to detect malware from function-call graphs. We rely on the MalNet Tiny dataset distributed via PyTorch Geometric, transform the graphs into vector representations, explore their structure, and finally train both classical ML and GNN models for classification. The dataset is pulled automatically from the [PyTorch Geometric documentation](https://pytorch-geometric.readthedocs.io/en/2.4.0/generated/torch_geometric.datasets.MalNetTiny.html) so no manual downloads are required.


In [None]:
from pathlib import Path
from collections import Counter, defaultdict

import torch
from karateclub import Graph2Vec
from umap.umap_ import UMAP
import networkx as nx
import plotly.express as px
from torch_geometric.datasets import MalNetTiny
from torch_geometric.utils import to_networkx, from_networkx

import pandas as pd
from pycaret.classification import (
    setup, compare_models, finalize_model,
    tune_model, evaluate_model, save_model, load_model)

## 1. Load MalNet Tiny Graphs
Use the `MalNetTiny` dataset helper to download the graphs automatically, and keep a balanced subset of 200 graphs per class for faster experimentation.


In [None]:
DATA_ROOT = Path('data/malnet_tiny')
DATA_SPLIT = 'train'
CLASSES = ['addisplay', 'adware', 'benign', 'downloader', 'trojan']
MAX_GRAPHS_PER_CLASS = 200

In [None]:
%%time

# The MalNet Tiny graphs are downloaded automatically via torch_geometric.datasets.
dataset = MalNetTiny(root=DATA_ROOT.as_posix(), split=DATA_SPLIT)
per_class_counts = defaultdict(int)

targets = []
graphs = []

for data in dataset:
    class_idx = int(data.y)
    class_name = CLASSES[class_idx]
    if per_class_counts[class_name] >= MAX_GRAPHS_PER_CLASS:
        continue

    graph = to_networkx(data, to_undirected=True)
    graph = nx.convert_node_labels_to_integers(graph, label_attribute='old_label')
    graphs.append(graph)
    targets.append(class_name)
    per_class_counts[class_name] += 1

    if len(per_class_counts) == len(CLASSES) and all(
        count >= MAX_GRAPHS_PER_CLASS for count in per_class_counts.values()
    ):
        break

summary = Counter(targets)
f"{len(graphs)} graphs loaded ({summary})"

If you add and want to load the data locally:
```python
PATH_GRAPHS = 'malnet-graphs-tiny' # loocalfolder
CLASSES = ['addisplay', 'adware', 'benign', 'downloader', 'trojan']
MAX_GRAPHS_BY_CLASSE = 200

targets = []
graphs = []

for classe in CLASSES:
    files = Path(PATH_GRAPHS + '/' + classe).glob('*.edgelist')
    for i, file in enumerate(files):
        if i >= MAX_GRAPHS_BY_CLASSE:
            break
        targets.append(classe)
        G = nx.read_edgelist(file)
        G = nx.convert_node_labels_to_integers(G, label_attribute='old_label')
        graphs.append(G)

f"{len(graphs)} graphes charg√©s ({dict(Counter(targets))})"
```

## 2. Graph Embedding
Learn Graph2Vec representations that turn each graph into a dense vector suitable for downstream visualization and classification tasks.


In [None]:
%%time
N_DIMENSIONS = 2

graph2vec = Graph2Vec(dimensions=N_DIMENSIONS)
graph2vec.fit(graphs)
embeddings = graph2vec.get_embedding()
print(embeddings.shape)
embeddings

**Plot the embeddings**

The interactive scatterplot helps verify whether Graph2Vec separates the malware families.


In [None]:
fig = px.scatter(x=embeddings[:, 0], y=embeddings[:, 1], color=targets)
fig.show()

## 3. Dimensionality Reduction
Use UMAP to reduce the high-dimensional embeddings to 2D and 3D views that make cluster structures easier to inspect.


In [None]:
%%time
N_DIMENSIONS = 256

graph2vec = Graph2Vec(dimensions=N_DIMENSIONS)
graph2vec.fit(graphs)
embeddings = graph2vec.get_embedding()

In [None]:
import pandas as pd
df = pd.DataFrame(embeddings)
df['target'] = targets
df.to_csv('malware_emb_500_256.csv', index=None)

**Project embeddings down to 2 dimensions**

In [None]:
%%time
proj_2d = UMAP(n_components=2, init='random', random_state=0).fit_transform(embeddings)
proj_2d

**Plot the 2D embedding**

In [None]:
fig_2d = px.scatter(
    proj_2d, x=0, y=1,
    color=targets
)
fig_2d.show()

**Project embeddings down to 3 dimensions**

In [None]:
proj_3d = UMAP(n_components=3, init='random', random_state=0).fit_transform(embeddings)

**Plot the 3D embedding**

In [None]:
fig_3d = px.scatter_3d(
    proj_3d, x=0, y=1, z=2,
    color=targets
)
fig_3d.update_traces(marker_size=5)
fig_3d.show()

## 4. Classical Classification
Each graph is annotated with a malware family (or the benign label), so we can train a supervised classifier that predicts one of the five categories automatically. We rely on PyCaret to quickly compare algorithms, tune the best one, and persist the winning model for later inference.


**Load the saved embeddings**

Read the CSV file that stores the Graph2Vec representations alongside their labels.

In [None]:
df = pd.read_csv('malware_emb_500_256.csv')
df

**Initialize PyCaret**

In [None]:
setup(df, target="target", fold=3)

**Compare models**
PyCaret benchmarks a wide range of classifiers so we can pick the one that offers the best accuracy for the selected embedding dimensionality.

In [None]:
best_model = compare_models()

**Tune the best model**

In [None]:
best_model_tuned = tune_model(best_model)

**Evaluation**

In [None]:
evaluate_model(best_model)
# addisplay: 0, adware: 1, benign: 2, downloader: 3, trojan: 4

**Save final model**

In [None]:
final_model = finalize_model(best_model)
save_model(final_model, 'ml_model')

In [None]:
best_model = load_model('ml_model')

## 5. GNN-Based Classification
Explore an end-to-end neural approach by training a Graph Convolutional Network (GCN) on the raw graphs instead of relying on precomputed embeddings.


**Convert NetworkX graphs into PyG Data objects**

In [None]:
# Convert NetworkX graphs into PyG Data objects, adding placeholder node features and labels.
def convert_to_pyg(graphs, targets):
    data_list = []
    for i, graph in enumerate(graphs):
        for node in graph.nodes():
            graph.nodes[node]['x'] = [1.0]  # Constant node feature placeholder

        data = from_networkx(graph)
        data.y = torch.tensor([targets[i]], dtype=torch.long)
        data_list.append(data)
    return data_list

class_mapping = {'addisplay': 0, 'adware': 1, 'benign': 2, 'downloader': 3, 'trojan': 4}
encoded_targets = [class_mapping[label] for label in targets]

data_list = convert_to_pyg(graphs, encoded_targets)

**Create train/test splits**

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from torch_geometric.nn import BatchNorm, GCNConv, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

# Split the dataset into train and test partitions
train_data, test_data = train_test_split(
    data_list,
    test_size=int(len(data_list) * 0.3),
    stratify=encoded_targets,
    random_state=42,
)

# Build PyG dataloaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

**Define the GNN architecture**

In [None]:
from torch.nn import Dropout

# Deeper GNN classifier with normalization and dropout regularization
class ImprovedGNNClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3, dropout=0.5):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()

        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.norms.append(BatchNorm(hidden_channels))
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.norms.append(BatchNorm(hidden_channels))

        self.dropout = Dropout(dropout)
        self.fc = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for conv, norm in zip(self.convs, self.norms):
            x = conv(x, edge_index)
            x = norm(x)
            x = F.relu(x)
            x = self.dropout(x)

        x = global_mean_pool(x, batch)
        return self.fc(x)

in_channels = 1
hidden_channels = 256  # Increased hidden size for better capacity
out_channels = len(class_mapping)
num_layers = 8  # Stack more GCN layers
dropout = 0.5

model = ImprovedGNNClassifier(in_channels, hidden_channels, out_channels, num_layers, dropout)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.005, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
def train():
    model.train()
    total_loss = 0.0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / max(1, len(train_loader))

@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0
    preds, labels = [], []
    for batch in loader:
        batch = batch.to(device)
        out = model(batch)
        predictions = out.argmax(dim=1)
        correct += (predictions == batch.y).sum().item()
        preds.extend(predictions.cpu().tolist())
        labels.extend(batch.y.cpu().tolist())
    accuracy = correct / len(loader.dataset)
    return accuracy, preds, labels



**Train the model**

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

num_epochs = 100
for epoch in range(1, num_epochs + 1):
    loss = train()
    train_acc, _, _ = test(train_loader)
    test_acc, _, _ = test(test_loader)
    print(f"Epoch {epoch:02d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}")

    # Step the scheduler based on the latest loss
    scheduler.step(loss)

In [None]:
# Rapport final
_, all_preds, all_labels = test(test_loader)
print("\nClassification Report:")
print(classification_report(all_labels, all_preds, target_names=list(class_mapping.keys())))