Ce script prépare un dataset PyTorch Geometric à partir de fichiers `.json` représentant des graphes CFG, en extrayant automatiquement des features avancées par nœud (structurelles, instructions, appels API) et en les associant aux labels d'un fichier.  
Il génère aussi un modèle de features sauvegardé (`feature_model.pkl`) et les graphes encodés au format `.pt` dans un dossier `processed`.

Remarque: Le fichier train contenant les labels est un csv et le fichier test un excel il faut lancer le read_csv (et ajouter le séparateur ;) en read_excel pour passer de l'un à l'autre

In [None]:
import os
import re
import json
import random
import numpy as np
import pandas as pd
import networkx as nx
import torch
import pickle
from collections import defaultdict, Counter
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch_geometric.data import Data, Dataset
from torch_geometric.utils import from_networkx
import warnings
warnings.filterwarnings('ignore')

DIGRAPH_DIR = 'folder_test_set'  
LABELS_PATH = 'data/test_set_metadata_to_predict.xlsx'  
OUTPUT_DIR = 'pt_output/test' 
SAMPLE_SIZE = 5000  

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, 'processed'), exist_ok=True)

def load_digraph(file_path):
    """Charge un fichier digraph et le convertit en graphe NetworkX"""
    try:
        with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
            content = f.read()
        
        G = nx.DiGraph()
        
        lines = content.split('\n')
        for line in lines:
            line = line.strip()
            if not line or line.startswith('Digraph') or line.startswith('{') or line.startswith('}'):
                continue
            
            if '->' in line: 
                parts = line.split('->')
                source = parts[0].strip().strip('"')
                target = parts[1].strip().strip('"')
                G.add_edge(source, target)
            elif '[label' in line:  
                node_match = re.match(r'"([^"]+)"\s*\[label\s*=\s*"([^"]+)"\]', line)
                if node_match:
                    node_id = node_match.group(1)
                    label = node_match.group(2)
                    G.add_node(node_id, label=label)
        
        return G
    except Exception as e:
        print(f"Erreur lors du chargement de {file_path}: {str(e)}")
        return None

def analyze_instruction_stats(digraph_dir, max_files=SAMPLE_SIZE):
    """Analyse les statistiques des instructions dans les fichiers digraph"""
    all_files = [os.path.join(digraph_dir, f) for f in os.listdir(digraph_dir) if f.endswith('.json')]
    
    if len(all_files) > max_files:
        sample_files = random.sample(all_files, max_files)
    else:
        sample_files = all_files
    
    instruction_types = defaultdict(int)
    registers_used = defaultdict(int)
    api_calls = defaultdict(int)
    memory_patterns = defaultdict(int)
    control_flow_patterns = defaultdict(int)
    suspicious_insts = defaultdict(int)
    
    total_nodes = 0
    total_valid_nodes = 0
    
    for file_path in tqdm(sample_files, desc="Analyse des instructions"):
        G = load_digraph(file_path)
        if not G:
            continue
            
        total_nodes += len(G.nodes)
        
        for node_id, attrs in G.nodes(data=True):
            if 'label' not in attrs:
                continue
            
            label = attrs['label']
            total_valid_nodes += 1
            
            if "INST" in label:
                instruction = label.split("INST : ")[1].strip().lower()
                
                if " " in instruction:
                    inst_type = instruction.split(" ")[0].lower()
                    instruction_types[inst_type] += 1
                
                if '[' in instruction and ']' in instruction:
                    memory_patterns['memory_access'] += 1
                    
                    if 'mov' in instruction and '[' in instruction:
                        if instruction.split('[')[0].strip().endswith('mov'):
                            memory_patterns['memory_read'] += 1
                        else:
                            memory_patterns['memory_write'] += 1
                
                common_registers = ['eax', 'ebx', 'ecx', 'edx', 'esi', 'edi', 'ebp', 'esp', 
                                  'rax', 'rbx', 'rcx', 'rdx', 'rsi', 'rdi', 'rbp', 'rsp',
                                  'ah', 'al', 'bh', 'bl', 'ch', 'cl', 'dh', 'dl']
                for reg in common_registers:
                    if re.search(r'\b' + reg + r'\b', instruction):
                        registers_used[reg] += 1
                
                if 'xor' in instruction and 'eax, eax' in instruction:
                    suspicious_insts['xor_eax_eax'] += 1  
                if 'push' in instruction and ('offset' in instruction or 'str' in instruction):
                    suspicious_insts['push_string'] += 1 
                if 'call' in instruction and '[' in instruction:
                    suspicious_insts['indirect_call'] += 1  
                
                if instruction.startswith('call'):
                    api_name = instruction[4:].strip()
                    if not api_name.startswith('0x'):  
                        api_calls[api_name] += 1
            
            elif "RET" in label:
                control_flow_patterns['ret'] += 1
            elif "JMP" in label:
                control_flow_patterns['jmp'] += 1
            elif "JCC" in label:
                control_flow_patterns['conditional_jmp'] += 1
            elif "CALL" in label:
                control_flow_patterns['call'] += 1
    
    return {
        'instruction_types': dict(instruction_types),
        'registers_used': dict(registers_used),
        'api_calls': dict(api_calls),
        'memory_patterns': dict(memory_patterns),
        'control_flow_patterns': dict(control_flow_patterns),
        'suspicious_insts': dict(suspicious_insts),
        'total_nodes': total_nodes,
        'total_valid_nodes': total_valid_nodes
    }

def create_advanced_feature_model(stats, inst_threshold=0.5, reg_threshold=0.5, api_min_count=10):
    """Crée un modèle de features avancé basé sur les statistiques"""
    feature_model = {
        'instruction_types': [],
        'registers': [],
        'apis': [],
        'memory_patterns': ['memory_access', 'memory_read', 'memory_write'],
        'control_flow': ['is_ret', 'is_jmp', 'is_conditional_jmp', 'is_call'],
        'suspicious_patterns': ['xor_self', 'push_string', 'indirect_call', 'stack_manipulation'],
        'structural': ['in_degree', 'out_degree', 'is_entry', 'is_exit', 'is_critical_node'],
        'custom_patterns': ['loop_pattern', 'call_ret_sequence', 'api_call_sequence']
    }
    
    total_valid_nodes = stats['total_valid_nodes']
    
    for inst, count in stats['instruction_types'].items():
        if count / total_valid_nodes * 100 >= inst_threshold:
            feature_model['instruction_types'].append(inst)
    
    for reg, count in stats['registers_used'].items():
        if count / total_valid_nodes * 100 >= reg_threshold:
            feature_model['registers'].append(reg)
    
    for api, count in stats['api_calls'].items():
        if count >= api_min_count:
            clean_api = re.sub(r'[^a-zA-Z0-9_]', '', api.lower())
            if clean_api:
                feature_model['apis'].append(clean_api)
    
    total_features = sum(len(category) for category in feature_model.values())
    
    print(f"Modèle de features avancé créé avec {total_features} features au total:")
    for category, features in feature_model.items():
        print(f"  {category}: {len(features)} features")
    
    return feature_model

def extract_advanced_features(G, feature_model):
    """
    Extrait des features avancées pour chaque nœud du graphe
    avec gestion des dimensions inconsistantes
    """
    node_features = {}
   
    max_dim = 0
    for category in feature_model.values():
        max_dim += len(category)
   
    centrality = nx.betweenness_centrality(G, k=min(100, len(G)))
    degrees = dict(G.degree())
    in_degrees = dict(G.in_degree())
    out_degrees = dict(G.out_degree())
   
    entry_nodes = {n for n, d in G.in_degree() if d == 0}
    exit_nodes = {n for n, d in G.out_degree() if d == 0}
   
    critical_nodes = {n for n, c in centrality.items() if c > 0.1}
   
    for node, attrs in G.nodes(data=True):
        try:
            features = np.zeros(max_dim, dtype=np.float32)
            feature_idx = 0
           
            in_degree = in_degrees.get(node, 0)
            out_degree = out_degrees.get(node, 0)
            is_entry = 1 if node in entry_nodes else 0
            is_exit = 1 if node in exit_nodes else 0
            is_critical = 1 if node in critical_nodes else 0
           
            structural_features = [in_degree, out_degree, is_entry, is_exit, is_critical]
            for i, feat in enumerate(structural_features):
                if feature_idx + i < max_dim:
                    features[feature_idx + i] = feat
            feature_idx += len(structural_features)
           
            if 'label' in attrs:
                label = attrs['label'].lower()
               
                feature_sets = [
                    ('instruction_types', feature_model['instruction_types']),
                    ('registers', feature_model['registers']),
                    ('apis', feature_model['apis']),
                ]
               
                for set_name, feature_set in feature_sets:
                    if feature_idx + len(feature_set) <= max_dim:
                        feature_idx += len(feature_set)
           
            features = features[:max_dim]
           
            node_features[node] = features
           
        except Exception as e:
            print(f"Erreur lors de l'extraction de features pour le nœud {node}: {str(e)}")
            node_features[node] = np.zeros(max_dim, dtype=np.float32)
   
    return node_features

def calculate_graph_features(G, node_features):
    """
    Calcule des features globales pour le graphe entier
    """
    if not node_features:
        return np.zeros(1, dtype=np.float32)
    
    avg_features = np.mean(list(node_features.values()), axis=0)
    
    graph_features = {
        'n_nodes': len(G),
        'n_edges': len(G.edges()),
        'density': nx.density(G),
        'avg_in_degree': sum(d for _, d in G.in_degree()) / max(1, len(G)),
        'avg_out_degree': sum(d for _, d in G.out_degree()) / max(1, len(G)),
        'max_in_degree': max((d for _, d in G.in_degree()), default=0),
        'max_out_degree': max((d for _, d in G.out_degree()), default=0),
        'n_entry_points': sum(1 for _, d in G.in_degree() if d == 0),
        'n_exit_points': sum(1 for _, d in G.out_degree() if d == 0),
        'avg_path_length': nx.average_shortest_path_length(G) if nx.is_strongly_connected(G) and len(G) > 1 else 0,
        'n_connected_components': nx.number_weakly_connected_components(G),
        'largest_component_size': max(len(c) for c in nx.weakly_connected_components(G)) if len(G) > 0 else 0,
    }
    
    graph_features_vector = np.array(list(graph_features.values()), dtype=np.float32)
    
    combined_features = np.concatenate([avg_features, graph_features_vector])
    
    return combined_features

def convert_to_pytorch_geometric(G, node_features, label=None):
    """
    Convertit un graphe en format PyTorch Geometric avec gestion des dimensions inconsistantes
    """
    node_mapping = {node: i for i, node in enumerate(G.nodes())}
   
    edge_index = []
    for source, target in G.edges():
        if source in node_mapping and target in node_mapping:
            edge_index.append([node_mapping[source], node_mapping[target]])
   
    if not edge_index:
        edge_index = np.zeros((2, 0), dtype=np.int64)
    else:
        edge_index = np.array(edge_index, dtype=np.int64).T
   
    feature_dims = [feat.shape[0] for feat in node_features.values()]
    if len(set(feature_dims)) > 1:
        most_common_dim = max(set(feature_dims), key=feature_dims.count)
       
        for node in node_features:
            if node_features[node].shape[0] != most_common_dim:
                if node_features[node].shape[0] < most_common_dim:
                    node_features[node] = np.pad(
                        node_features[node],
                        (0, most_common_dim - node_features[node].shape[0]),
                        'constant'
                    )
                else:
                    node_features[node] = node_features[node][:most_common_dim]
   
    feature_dim = list(node_features.values())[0].shape[0]
    x = np.zeros((len(node_mapping), feature_dim), dtype=np.float32)
    for node, idx in node_mapping.items():
        if node in node_features:
            x[idx] = node_features[node]
   
    y = None
    if label is not None:
        y = np.array(label, dtype=np.float32)
   
    data = Data(
        x=torch.tensor(x, dtype=torch.float),
        edge_index=torch.tensor(edge_index, dtype=torch.long),
        y=torch.tensor(y, dtype=torch.float) if y is not None else None
    )
   
    return data

class CFGDataset(Dataset):
    def __init__(self, digraph_dir, labels_path, feature_model, transform=None, pre_transform=None, root=None):
        """
        Dataset pour les graphes de flux de contrôle
        """
        self.digraph_dir = digraph_dir
        self.labels_path = labels_path
        self.feature_model = feature_model
        
        self.labels_df = pd.read_excel(labels_path)
        
        self.behaviors = [col for col in self.labels_df.columns 
                         if col not in ['name', 'num_antivirus_malicious', 'first_submission_date', 'suggested_threat_label']]
        
        self.files = [f for f in os.listdir(digraph_dir) if f.endswith('.json') and 
                      os.path.splitext(f)[0] in self.labels_df['name'].values]
        
        self.file_labels = {}
        for file in self.files:
            file_id = os.path.splitext(file)[0]
            if file_id in self.labels_df['name'].values:
                self.file_labels[file] = self.labels_df.loc[self.labels_df['name'] == file_id, self.behaviors].values[0]
        
        super(CFGDataset, self).__init__(root, transform, pre_transform)
    
    @property
    def raw_file_names(self):
        return self.files
    
    @property
    def processed_file_names(self):
        return [f'data_{i}.pt' for i in range(len(self.files))]
    
    def download(self):
        pass
    
    def process(self):
        for i, file in enumerate(tqdm(self.files, desc="Traitement des graphes")):
            G = load_digraph(os.path.join(self.digraph_dir, file))
            if G is None or len(G) == 0:
                continue
            
            try:
                node_features = extract_advanced_features(G, self.feature_model)
                
                labels = self.file_labels.get(file, np.zeros(len(self.behaviors)))
                
                data = convert_to_pytorch_geometric(G, node_features, labels)
                
                data.file_name = file
                
                torch.save(data, os.path.join(self.processed_dir, f'data_{i}.pt'))
            except Exception as e:
                print(f"Erreur lors du traitement de {file}: {str(e)}")
    
    def len(self):
        return len(self.files)
    
    def get(self, idx):
        data = torch.load(os.path.join(self.processed_dir, f'data_{idx}.pt'))
        return data

def prepare_cfg_dataset(digraph_dir, labels_path, output_dir, sample_size=5000):
    """
    Prépare un dataset sur des graphes CFG
    """
    print(f"Analyse des statistiques sur un échantillon de {sample_size} fichiers...")
    stats = analyze_instruction_stats(digraph_dir, max_files=sample_size)
    
    print("\nCréation du modèle de features...")
    feature_model = create_advanced_feature_model(stats, inst_threshold=1.0, reg_threshold=1.0, api_min_count=10)

    with open(os.path.join(output_dir, 'feature_model.pkl'), 'wb') as f:
        pickle.dump(feature_model, f)
    
    print("\nCréation du dataset...")
    dataset = CFGDataset(
        digraph_dir=digraph_dir,
        labels_path=labels_path,
        feature_model=feature_model,
        root=output_dir
    )
    
    print(f"Dataset créé avec {len(dataset)} graphes")
    

    with open(os.path.join(output_dir, 'stats.json'), 'w') as f:
        stats_json = {k: (dict(v) if isinstance(v, Counter) else v) for k, v in stats.items()}
        json.dump(stats_json, f, indent=2)
    
    return dataset, feature_model

if __name__ == "__main__":
    print("Préparation du dataset")
    dataset, feature_model = prepare_cfg_dataset(
        digraph_dir=DIGRAPH_DIR,
        labels_path=LABELS_PATH,
        output_dir=OUTPUT_DIR,
        sample_size=SAMPLE_SIZE
    )
    print("Préparation terminée!")

In [None]:
import torch

pt_file_path = "pt_output/data_2.pt"
data = torch.load(pt_file_path, weights_only=False)

print("Attributs internes (data.__dict__):")
for key, value in data.__dict__.items():
    if hasattr(value, 'shape'):
        print(f"{key}: shape = {value.shape}")
    else:
        print(f"{key}: {value}")


In [None]:
print(data._store['file_name'])

---




Ce script traite par batch les fichiers `.pt` contenant des graphes encodés (format `torch_geometric.data.Data`) afin d’en extraire des statistiques globales, notamment :
- Des métriques sur les embeddings des nœuds (`x_mean`, `x_var`, `x_l2`, etc.),
- Des propriétés structurelles (`in/out degree`, densité, présence de boucles, etc.).

Les résultats sont stockés en plusieurs fichiers CSV (`_features_batch_XX.csv`), avec détection automatique des fichiers déjà extraits pour éviter les doublons.




In [None]:
import torch
import os
import pandas as pd
import numpy as np
import networkx as nx
from tqdm import tqdm
from glob import glob
from torch_geometric.utils import to_networkx, contains_self_loops

pt_folder = "pt_output"
csv_path = "data/test.csv"
output_csv_base = "data/test_features_batch"
batch_size = 2500

existing_graph_ids = set()
existing_csvs = glob(f"{output_csv_base}_*.csv")
for csv_file in existing_csvs:
    try:
        df = pd.read_csv(csv_file, usecols=["graph_id"])
        existing_graph_ids.update(df["graph_id"].dropna().astype(str).tolist())
    except Exception as e:
        print(f"Erreur en lisant {csv_file} : {e}")

print(f"Graphs déjà traités : {len(existing_graph_ids)}")

all_files = sorted([f for f in os.listdir(pt_folder) if f.endswith(".pt")])
num_batches = (len(all_files) + batch_size - 1) // batch_size

for batch_idx in range(num_batches):
    output_csv_path = f"{output_csv_base}_{batch_idx + 1}.csv"
    if os.path.exists(output_csv_path):
        print(f"Batch {batch_idx + 1} déjà existant, on skip...")
        continue

    rows = []
    batch_files = all_files[batch_idx * batch_size : (batch_idx + 1) * batch_size]
    print(f"\n Traitement batch {batch_idx + 1}/{num_batches} ({len(batch_files)} fichiers)...")

    for fname in tqdm(batch_files):
        try:
            data = torch.load(os.path.join(pt_folder, fname), weights_only=False)
            file_name = getattr(data, 'file_name', None)
            if file_name is None:
                continue

            graph_id = file_name.replace(".json", "")
            if graph_id in existing_graph_ids:
                continue 

            x = data.x
            x_l2 = x.norm(p=2, dim=1)

            G_nx = to_networkx(data, to_undirected=False)
            in_deg_mean = np.mean([d for _, d in G_nx.in_degree()])
            out_deg_mean = np.mean([d for _, d in G_nx.out_degree()])
            num_components = nx.number_strongly_connected_components(G_nx)
            density = nx.density(G_nx)
            has_loops = contains_self_loops(data.edge_index)

            row = {
                "graph_id": graph_id,
                "feature_dim": x.shape[1],
                "x_mean_all": float(x.mean().item()),
                "x_var_all": float(x.var().item()),
                "x_abs_mean": float(x.abs().mean().item()),
                "x_max": float(x.max().item()),
                "x_min": float(x.min().item()),
                "x_l2_mean": float(x_l2.mean().item()),
                "x_l2_max": float(x_l2.max().item()),
                "x_l2_min": float(x_l2.min().item()),
                "x_l2_std": float(x_l2.std().item()),
                "x_l2_ratio_max_mean": float(x_l2.max().item()) / (x_l2.mean().item() + 1e-6),
                "num_nodes_l2_above_1": int((x_l2 > 1.0).sum().item()),
                "num_nodes_l2_below_0.1": int((x_l2 < 0.1).sum().item()),
                "num_nodes": data.num_nodes,
                "num_edges": data.num_edges,
                "in_degree_mean": in_deg_mean,
                "out_degree_mean": out_deg_mean,
                "has_self_loops": int(has_loops),
                "num_connected_components": num_components,
                "density": density,
            }
            rows.append(row)
            existing_graph_ids.add(graph_id) 

        except Exception as e:
            print(f"Erreur avec {fname} : {e}")

    features_df = pd.DataFrame(rows)
    features_df.to_csv(output_csv_path, index=False)
    print(f"Batch {batch_idx + 1} sauvegardé dans {output_csv_path}")
