In [None]:
!pip install torch torchvision torchaudio
!pip install torch-geometric
!pip install infomap
!pip install ace_tools
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.0.0+cu118.html


Looking in links: https://data.pyg.org/whl/torch-2.0.0+cu118.html


In [None]:
import networkx as nx
from infomap import Infomap
import community as community_louvain
import matplotlib.pyplot as plt
from collections import defaultdict

from IPython.display import display

import json
import pandas as pd
import torch
from torch_geometric.utils import negative_sampling
from itertools import combinations

import torch.nn.functional as F
import numpy as np
from sklearn.metrics import roc_auc_score, average_precision_score
import os
import torch
import networkx as nx
from torch_geometric.utils import from_networkx, add_self_loops
from torch_geometric.data import Data
from collections import defaultdict
import random
import networkx as nx
import torch
from torch_geometric.utils import negative_sampling



## **Import grafo ottenuto dal primo modulo**

In [None]:
import json
file_path = '/kaggle/input/graph_data.json'
with open(file_path, "r") as f:
    data = json.load(f)

## **Preprocessing**

Il grafo importato si presenta nella forma: Soggetto,Verbo,Oggetto; dunque va preprocessato per creare un grafo NetworkX (G) in modo da poter lavorare con grafi.

Vengono utilizzati due algoritmi per la rilevazione delle comunità:

- **Infomap**: Un algoritmo basato sulla teoria dell'informazione.

- **Louvain**: Un algoritmo di ottimizzazione della modularità.

Le comunità vengono estratte e salvate in due dizionari: **infomap_communities** e **louvain_communities**.

In [None]:
def preprocess_graph_data(data):
    G = nx.MultiGraph()

    for entry in data:
        nodo1 = entry["nodo1"]["id"]
        nodo2 = entry["nodo2"]["id"]

        if nodo1 != nodo2:
            G.add_edge(nodo1, nodo2)

    return G

# Community: Infomap
def infomap_community_detection(G):
    im = Infomap()
    for u, v in G.edges():
        im.add_link(u, v)
    im.run()

    return {node_id: module_id for node_id, module_id in im.modules}

# Community detection: Louvain
def louvain_community_detection(G):
    return community_louvain.best_partition(G)

G = preprocess_graph_data(data)
infomap_communities = infomap_community_detection(G)
louvain_communities = louvain_community_detection(G)

# Abbiamo output del tipo: {node_id: community_id}

In [None]:
print(G)

MultiGraph with 2701 nodes and 18996 edges


## **Calcolo le statistiche delle community estaratte**

Vengono calcolate statistiche come il numero di nodi, archi, densità, grado medio e centralità per ogni comunità.

In [None]:
def compute_community_stats(G, community_dict):
    community_nodes = defaultdict(list)
    for node, comm_id in community_dict.items():
        community_nodes[comm_id].append(node)

    community_stats = []
    for comm_id, nodes in community_nodes.items():
        subgraph = G.subgraph(nodes)
        num_nodes = subgraph.number_of_nodes()
        num_edges = subgraph.number_of_edges()
        density = nx.density(subgraph)
        avg_degree = np.mean([d for _, d in subgraph.degree()]) if num_nodes > 0 else 0
        centrality = np.mean(list(nx.pagerank(subgraph).values())) if num_nodes > 1 else 0

        community_stats.append({
            "community": comm_id,
            "num_nodi": num_nodes,
            "edges": num_edges,
            "density": density,
            "avg_degree": avg_degree,
            "centrality": centrality
        })

    return pd.DataFrame(community_stats)


infomap_df = compute_community_stats(G, infomap_communities)
louvain_df = compute_community_stats(G, louvain_communities)

In [None]:
infomap_df

Unnamed: 0,community,num_nodi,edges,density,avg_degree,centrality
0,25,18,44,0.287582,4.888889,0.055556
1,60,11,22,0.400000,4.000000,0.090909
2,65,8,18,0.642857,4.500000,0.125000
3,1,230,2230,0.084678,19.391304,0.004348
4,5,37,204,0.306306,11.027027,0.027027
...,...,...,...,...,...,...
347,157,4,6,1.000000,3.000000,0.250000
348,292,2,2,2.000000,2.000000,0.500000
349,193,3,4,1.333333,2.666667,0.333333
350,271,2,2,2.000000,2.000000,0.500000


In [None]:
louvain_df

Unnamed: 0,community,num_nodi,edges,density,avg_degree,centrality
0,28,162,554,0.042481,6.839506,0.006173
1,1,397,1858,0.023637,9.360202,0.002519
2,4,247,1800,0.059248,14.574899,0.004049
3,27,145,568,0.054406,7.834483,0.006897
4,8,112,292,0.046976,5.214286,0.008929
...,...,...,...,...,...,...
143,68,2,2,2.000000,2.000000,0.500000
144,69,2,2,2.000000,2.000000,0.500000
145,70,3,4,1.333333,2.666667,0.333333
146,75,2,2,2.000000,2.000000,0.500000


## **Selzione delle community piu significative**
Vengono selezionate le comunità più significative basandosi su:

### **1. Densità**: Comunità con alta densità


In [None]:
def select_diverse_dense_communities(df, num_top=15, diversity_threshold=0.6):
    df_filtered = df[(df["num_nodi"] > 20) & (df["edges"] > 0)]

    df_sorted = df_filtered.nlargest(num_top * 5, "density")

    selected = []
    selected_communities = set()

    for _, row in df_sorted.iterrows():
        row_nodes = row["num_nodi"]

        # Verifica la diversità rispetto ai grafi già selezionati
        is_diverse = True
        for existing in selected:
            if abs(row_nodes - existing["num_nodi"]) <= diversity_threshold * max(row_nodes, existing["num_nodi"]):
                is_diverse = False
                break

        if is_diverse:
            row_dict = row.to_dict()
            row_dict["criterion"] = "largest_density"
            selected.append(row_dict)
            selected_communities.add(row["community"])

        if len(selected) >= num_top:
            break

    return pd.DataFrame(selected)

### **2. Dimensione**: Comunità con molti nodi

In [None]:
def select_diverse_largest_communities(df, num_selected=15, diversity_threshold=0.6):
    df_filtered = df[(df["num_nodi"] > 20) & (df["edges"] > 0)]

    df_sorted = df_filtered.sort_values(by="num_nodi", ascending=False)

    selected = []
    selected_communities = set()

    for _, row in df_sorted.iterrows():
        if len(selected) >= num_selected:
            break

        # Verifica la diversità rispetto ai grafi già selezionati
        is_diverse = True
        for existing in selected:
            if abs(row["num_nodi"] - existing["num_nodi"]) <= diversity_threshold * max(row["num_nodi"], existing["num_nodi"]):
                is_diverse = False
                break

        if is_diverse:
            row_dict = row.to_dict()
            row_dict["criterion"] = "largest_size"
            selected.append(row_dict)
            selected_communities.add(row["community"])

    return pd.DataFrame(selected)

### **3. Centralità**: Comunità con nodi centrali



In [None]:
def select_diverse_central_graphs(df, top_n=15, diversity_threshold=0.6):
    df_filtered = df[(df["num_nodi"] > 20) & (df["edges"] > 0)]

    top_graphs = df_filtered.nlargest(top_n * 5, "centrality")

    selected = []
    selected_communities = set()

    for _, row in top_graphs.iterrows():
        community = row["community"]

        # Verifica la diversità rispetto ai grafi già selezionati
        is_diverse = True
        for existing in selected:
            if abs(row["num_nodi"] - existing["num_nodi"]) <= diversity_threshold * max(row["num_nodi"], existing["num_nodi"]):
                is_diverse = False
                break

        if is_diverse:
            row_dict = row.to_dict()
            row_dict["criterion"] = "largest_centrality"
            selected.append(row_dict)
            selected_communities.add(community)

        if len(selected) == top_n:
            break

    return pd.DataFrame(selected)

Si estraggono i grafi con le statistiche descritte

In [None]:
def combine_graph_selections(dense_df, large_df, central_df):
    dense_df['selection_order'] = range(len(dense_df))
    large_df['selection_order'] = range(len(large_df))
    central_df['selection_order'] = range(len(central_df))

    combined = pd.concat([dense_df, large_df, central_df])

    combined = combined.sort_values(['community', 'selection_order'])
    combined_unique = combined.drop_duplicates(subset=['community', 'criterion'])
    combined_unique = combined_unique.drop('selection_order', axis=1)

    return combined_unique.reset_index(drop=True)

# INFOMAP
dense_communities_infomap = select_diverse_dense_communities(infomap_df, num_top=15)
largest_communities_infomap = select_diverse_largest_communities(infomap_df, num_selected=15)
central_graphs_infomap = select_diverse_central_graphs(infomap_df, top_n=15)

final_selection_infomap = combine_graph_selections(dense_communities_infomap, largest_communities_infomap, central_graphs_infomap)
selected_communities_infomap = final_selection_infomap.drop_duplicates(subset='community')


# LOUVAIN
dense_communities_louvain = select_diverse_dense_communities(louvain_df, num_top=15)
largest_communities_louvain = select_diverse_largest_communities(louvain_df, num_selected=15)
central_graphs_louvain = select_diverse_central_graphs(louvain_df, top_n=15)

final_selection_louvain = combine_graph_selections(dense_communities_louvain, largest_communities_louvain, central_graphs_louvain)
selected_communities_louvain = final_selection_louvain.drop_duplicates(subset='community')

In [None]:
selected_communities_infomap

Unnamed: 0,community,num_nodi,edges,density,avg_degree,centrality,criterion
0,1.0,230.0,2230.0,0.084678,19.391304,0.004348,largest_size
3,3.0,75.0,382.0,0.137658,10.186667,0.013333,largest_size
4,6.0,59.0,240.0,0.140269,8.135593,0.016949,largest_centrality
5,9.0,31.0,172.0,0.369892,11.096774,0.032258,largest_density
6,19.0,29.0,74.0,0.182266,5.103448,0.034483,largest_size
7,32.0,21.0,52.0,0.247619,4.952381,0.047619,largest_centrality


In [None]:
selected_communities_louvain

Unnamed: 0,community,num_nodi,edges,density,avg_degree,centrality,criterion
0,1.0,397.0,1858.0,0.023637,9.360202,0.002519,largest_size
1,4.0,247.0,1800.0,0.059248,14.574899,0.004049,largest_density
2,27.0,145.0,568.0,0.054406,7.834483,0.006897,largest_size
3,28.0,162.0,554.0,0.042481,6.839506,0.006173,largest_centrality
4,36.0,59.0,186.0,0.108708,6.305085,0.016949,largest_density


### **4.Grafi meno densi, centrali e piccoli**: Vengono anche selezionate comunità meno dense e meno centrali per un confronto.



In [None]:
def select_diverse_communities(df, num_per_criterion=3):
    # Filtra per rimuovere comunità con pochi nodi o senza archi
    df_filtered = df[(df["num_nodi"] > 25) & (df["edges"] > 0)]

    selected = []
    selected_communities = set()

    def add_communities(selection_df, criterion):
        count = 0
        for _, row in selection_df.iterrows():
            if row["community"] not in selected_communities:
                row_dict = row.to_dict()
                row_dict["criterion"] = criterion
                selected.append(row_dict)
                selected_communities.add(row["community"])
                count += 1
            if count >= num_per_criterion:
                break

    if len(df_filtered) >= num_per_criterion:
        add_communities(df_filtered.nsmallest(num_per_criterion, "num_nodi"), "smallest_size")
        add_communities(df_filtered.nsmallest(num_per_criterion, "density"), "smallest_density")
        add_communities(df_filtered.nsmallest(num_per_criterion, "centrality"), "smallest_centrality")

    return pd.DataFrame(selected)

selected_communities_infomap_small = select_diverse_communities(infomap_df, num_per_criterion=3)
selected_communities_infomap = pd.concat([selected_communities_infomap, selected_communities_infomap_small])
selected_communities_infomap = selected_communities_infomap.drop_duplicates(subset='community')

selected_communities_louvain_small = select_diverse_communities(louvain_df, num_per_criterion=3)
selected_communities_louvain = pd.concat([selected_communities_louvain, selected_communities_louvain_small])
selected_communities_louvain = selected_communities_louvain.drop_duplicates(subset='community')

In [None]:
selected_communities_infomap

Unnamed: 0,community,num_nodi,edges,density,avg_degree,centrality,criterion
0,1.0,230.0,2230.0,0.084678,19.391304,0.004348,largest_size
3,3.0,75.0,382.0,0.137658,10.186667,0.013333,largest_size
4,6.0,59.0,240.0,0.140269,8.135593,0.016949,largest_centrality
5,9.0,31.0,172.0,0.369892,11.096774,0.032258,largest_density
6,19.0,29.0,74.0,0.182266,5.103448,0.034483,largest_size
7,32.0,21.0,52.0,0.247619,4.952381,0.047619,largest_centrality
0,16.0,27.0,68.0,0.193732,5.037037,0.037037,smallest_size
2,22.0,29.0,64.0,0.157635,4.413793,0.034483,smallest_size
3,2.0,230.0,1176.0,0.044655,10.226087,0.004348,smallest_density
5,8.0,65.0,246.0,0.118269,7.569231,0.015385,smallest_density


In [None]:
selected_communities_louvain

Unnamed: 0,community,num_nodi,edges,density,avg_degree,centrality,criterion
0,1.0,397.0,1858.0,0.023637,9.360202,0.002519,largest_size
1,4.0,247.0,1800.0,0.059248,14.574899,0.004049,largest_density
2,27.0,145.0,568.0,0.054406,7.834483,0.006897,largest_size
3,28.0,162.0,554.0,0.042481,6.839506,0.006173,largest_centrality
4,36.0,59.0,186.0,0.108708,6.305085,0.016949,largest_density
1,33.0,111.0,306.0,0.050123,5.513514,0.009009,smallest_size
2,8.0,112.0,292.0,0.046976,5.214286,0.008929,smallest_size
4,11.0,182.0,590.0,0.035821,6.483516,0.005495,smallest_density
5,7.0,246.0,1108.0,0.036768,9.00813,0.004065,smallest_density


## **Preparazione dei dati per le GNN**

- I grafi delle comunità selezionate vengono convertiti in formati adatti per l'addestramento delle GNN.

- Vengono aggiunte feature strutturali come grado, PageRank e coefficiente di clustering


Mi salvo per ogni ID del grafo le caratteristiche associate ad esse :(community_subgraph_infomap, community_subgraph_louvain )

In [None]:
def get_community_subgraph(G, community_dict, target_community):

    community_nodes = [node for node, comm_id in community_dict.items() if comm_id == target_community]
    subgraph = G.subgraph(community_nodes)

    return subgraph


community_subgraph_infomap = []
community_subgraph_louvain = []

for _, community_id in selected_communities_infomap.iterrows():
    subgraph = get_community_subgraph(G, infomap_communities, community_id["community"])
    community_subgraph_infomap.append((community_id["community"], subgraph))

for _, community_id in selected_communities_louvain.iterrows():
    subgraph = get_community_subgraph(G, louvain_communities, community_id["community"])
    community_subgraph_louvain.append((community_id["community"], subgraph))


In [None]:
id_community, subgraph = community_subgraph_infomap[0]

print(f"Community ID: {id_community}")
print("Edges:", len(list(subgraph.edges())))


Community ID: 1.0
Edges: 2230


### **1. Analizziamo i grafi estratti per prepararli alle GNN**



In [None]:
def extract_subgraph_info(data, subgraph_edges):
    subgraph_info = []

    edge_pairs = {(n1, n2) for n1, n2 in subgraph_edges}
    edge_pairs.update({(n2, n1) for n1, n2 in subgraph_edges})

    for entry in data:
        nodo1_id = entry["nodo1"]["id"]
        nodo2_id = entry["nodo2"]["id"]

        if (nodo1_id, nodo2_id) in edge_pairs:
            subgraph_info.append(entry)

    return subgraph_info

def process_community_subgraphs(data, community_subgraphs):
    processed_subgraphs = []

    for community_id, subgraph in community_subgraphs:
        subgraph_edges = list(subgraph.edges())
        real_subgraph = extract_subgraph_info(data, subgraph_edges)
        processed_subgraphs.append((community_id, real_subgraph))

    return processed_subgraphs

subgraph_infomap = process_community_subgraphs(data, community_subgraph_infomap)
subgraph_louvain = process_community_subgraphs(data, community_subgraph_louvain)

# IN subgraph_infomap HO LA LISTA DEGLI ID_COMMUNITY E IL GRAFO ASSOCIATO [(1.0,[{'nodo1': {'id': 2, 'labels': [{'categoria_id': 2, '"entita_id"': 52}]}, 'nodo2': {'id': 3, 'labels': [{'categoria_i

def check_multiple_edges(subgraph_info):
    """
    Verifica e conta i collegamenti multipli tra coppie di nodi.
    """
    edge_count = {}
    for entry in subgraph_info:
        edge = (entry["nodo1"]["id"], entry["nodo2"]["id"])
        edge_count[edge] = edge_count.get(edge, 0) + 1

    return {edge: count for edge, count in edge_count.items() if count > 1}

# verifica:
for community_id, subgraph_info in subgraph_infomap[:3]:
    multiple_edges = check_multiple_edges(subgraph_info)
    print(f"\nCommunity {community_id}:")
    print(f"Totale collegamenti: {len(subgraph_info)}")
    print(f"Coppie di nodi con collegamenti multipli: {len(multiple_edges)}")
    if multiple_edges:
        print("Esempi di collegamenti multipli:")
        for edge, count in list(multiple_edges.items())[:3]:
            print(f"Nodi {edge}: {count} collegamenti")


Community 1.0:
Totale collegamenti: 2230
Coppie di nodi con collegamenti multipli: 567
Esempi di collegamenti multipli:
Nodi (5, 340): 28 collegamenti
Nodi (5, 670): 14 collegamenti
Nodi (5, 32): 4 collegamenti

Community 3.0:
Totale collegamenti: 382
Coppie di nodi con collegamenti multipli: 115
Esempi di collegamenti multipli:
Nodi (17, 97): 2 collegamenti
Nodi (17, 189): 4 collegamenti
Nodi (19, 17): 4 collegamenti

Community 6.0:
Totale collegamenti: 240
Coppie di nodi con collegamenti multipli: 90
Esempi di collegamenti multipli:
Nodi (55, 208): 4 collegamenti
Nodi (55, 566): 2 collegamenti
Nodi (55, 1240): 2 collegamenti


In [None]:
for _, community_id in selected_communities_infomap.iterrows():
    subgraph = get_community_subgraph(G, infomap_communities, community_id["community"])
    community_subgraph_infomap.append((community_id["community"], subgraph))
    #Save the subgraph
    nx.write_graphml(subgraph, f"/kaggle/working/{int(community_id['community'])}_{community_id['criterion']}.graphml")

### **2. Conversione Graph in Data per le GNN**

In [None]:
def prepare_data(graph_data):
    edges = []
    node_features = defaultdict(list)
    all_nodes = set()

    for row in graph_data:
        n1 = row["nodo1"]["id"]
        n2 = row["nodo2"]["id"]
        edges.append((n1, n2))
        all_nodes.update([n1, n2])

        for nodo, node_id in [("nodo1", n1), ("nodo2", n2)]:
            for categoria in row[nodo].get("labels", []):
                categoria_id = categoria.get("categoria_id")
                entita_id = categoria.get("entita_id")
                if categoria_id is not None and entita_id is not None:
                    node_features[node_id].append((categoria_id, entita_id))

    category_to_idx = {}
    entita_to_idx = {}

    def get_or_assign_id(mapping, key):
        if key not in mapping:
            mapping[key] = len(mapping)
        return mapping[key]

    for pairs in node_features.values():
        for categoria_id, entita_id in pairs:
            get_or_assign_id(category_to_idx, categoria_id)
            get_or_assign_id(entita_to_idx, entita_id)

    node_to_idx = {node_id: idx for idx, node_id in enumerate(all_nodes)}
    edges = [(node_to_idx[n1], node_to_idx[n2]) for n1, n2 in edges]

    num_nodes = len(node_to_idx)
    num_categories = len(category_to_idx)
    num_entities = len(entita_to_idx)
    num_features = num_categories + num_entities

    if num_features == 0:
        num_features = 1
        x = torch.rand((num_nodes, num_features)) * 0.1
    else:
        x = torch.zeros((num_nodes, num_features))
        for node_id, pairs in node_features.items():
            if node_id in node_to_idx:
                node_idx = node_to_idx[node_id]
                for categoria_id, entita_id in pairs:
                    cat_idx = category_to_idx[categoria_id]
                    ent_idx = entita_to_idx[entita_id]
                    x[node_idx, cat_idx] = 1
                    x[node_idx, num_categories + ent_idx] = 1
        x = F.normalize(x + 1e-8, p=2, dim=1)

    G = nx.Graph()
    G.add_edges_from(edges)

    degree_dict = dict(G.degree())
    degree_values = torch.tensor([degree_dict.get(node_id, 0) for node_id in node_to_idx], dtype=torch.float).view(-1, 1)

    pagerank_dict = nx.pagerank(G)
    pagerank_values = torch.tensor([pagerank_dict.get(node_id, 0) for node_id in node_to_idx], dtype=torch.float).view(-1, 1)

    clustering_dict = nx.clustering(G)
    clustering_values = torch.tensor([clustering_dict.get(node_id, 0) for node_id in node_to_idx], dtype=torch.float).view(-1, 1)

    structural_features = torch.cat([degree_values, pagerank_values, clustering_values], dim=1)
    structural_features = F.normalize(structural_features, p=2, dim=0)

    x = torch.cat([x, structural_features], dim=1)

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)

    return Data(x=x, edge_index=edge_index, node_to_idx=node_to_idx)

def process_subgraphs(subgraph_infomap):
    prepared_graphs = {}

    for community_id, graph_data in subgraph_infomap:
        prepared_graphs[community_id] = prepare_data(graph_data)

    return prepared_graphs

gnn_graphs_infomap = process_subgraphs(subgraph_infomap)
gnn_graphs_louvain = process_subgraphs(subgraph_louvain)

# ORA ABBIAMO x=[N, num_features + 3]
# num_features è il numero di categorie ed entità,
# 3 sono le nuove feature (Grado, PageRank, Clustering Coefficient).


Salviamo i i grafi generati

In [None]:
import torch

torch.save(gnn_graphs_infomap, "/kaggle/working/gnn_graphs_infomap.pt")
torch.save(gnn_graphs_louvain, "/kaggle/working/gnn_graphs_louvain.pt")

# **Link Prediction con GNN**

Vengono definiti tre modelli GNN:

- **GCN** (Graph Convolutional Network)
- **GAT** (Graph Attention Network)
- **GraphSAGE** (Graph Sample and Aggregate)

Viene implementato il task di link prediction, dove il modello deve prevedere se esiste un arco tra due nodi.

Vengono utilizzate diverse strategie di negative sampling:

- **Uniforme**: Campionamento casuale di nodi non connessi.
- **Hard**: Campionamento di nodi vicini ma non connessi.
- **Centralità**: Campionamento di nodi con alta centralità.
- **Ibrida**: Combinazione delle strategie precedenti.



## **Negative Sampling**

In [None]:
from torch_geometric.utils import negative_sampling

def advanced_negative_sampling(data, num_neg_samples=None, strategy="hard", strategy_weights=None):
    """
    Genera esempi negativi per link prediction con 4 strategie:
    1. Uniforme (casuale)
    2. Hard Negative (campiona nodi difficili)
    3. Basato su Centralità (sceglie nodi centrali)
    4. Hybrid (combinazione pesata delle tre strategie)
    """
    num_nodes = data.x.shape[0]
    G = nx.Graph()
    G.add_edges_from(data.edge_index.t().tolist())

    if num_neg_samples is None:
        num_neg_samples = data.edge_index.shape[1] // 2

    neg_edges = []

    if strategy == "uniform":
        neg_edges = negative_sampling(data.edge_index, num_nodes=num_nodes, num_neg_samples=num_neg_samples)

    elif strategy == "hard":
        hard_neg_edges = []
        nodes = list(G.nodes())

        while len(hard_neg_edges) < num_neg_samples:
            node = random.choice(nodes)
            neighbors = set(G.neighbors(node))

            if len(neighbors) < 2:
                continue

            for _ in range(5):
                neg_node = random.choice(nodes)
                if neg_node not in neighbors and neg_node != node:
                    hard_neg_edges.append((node, neg_node))
                    break

        neg_edges = torch.tensor(hard_neg_edges, dtype=torch.long).t().contiguous()

    elif strategy == "centrality":
        pagerank = nx.pagerank(G)
        sorted_nodes = sorted(pagerank.keys(), key=lambda x: pagerank[x], reverse=True)
        neg_edges_list = []
        nodes = list(G.nodes())

        while len(neg_edges_list) < num_neg_samples:
            high_centrality_node = random.choice(sorted_nodes[:len(sorted_nodes)//5])
            random_node = random.choice(nodes)

            if not G.has_edge(high_centrality_node, random_node):
                neg_edges_list.append((high_centrality_node, random_node))

        neg_edges = torch.tensor(neg_edges_list, dtype=torch.long).t().contiguous()

    elif strategy == "hybrid":
        if strategy_weights is None:
            raise ValueError("Per la strategia 'hybrid', è necessario fornire 'strategy_weights'.")

        total_weight = sum(strategy_weights.values())
        strategy_weights = {k: v / total_weight for k, v in strategy_weights.items()}

        num_uniform = int(num_neg_samples * strategy_weights.get("uniform", 0))
        num_hard = int(num_neg_samples * strategy_weights.get("hard", 0))
        num_centrality = int(num_neg_samples * strategy_weights.get("centrality", 0))

        if num_uniform > 0:
            uniform_neg_edges = negative_sampling(data.edge_index, num_nodes=num_nodes, num_neg_samples=num_uniform).t().tolist()
            neg_edges.extend(uniform_neg_edges)

        if num_hard > 0:
            hard_neg_edges = []
            nodes = list(G.nodes())

            while len(hard_neg_edges) < num_hard:
                node = random.choice(nodes)
                neighbors = set(G.neighbors(node))

                if len(neighbors) < 2:
                    continue

                for _ in range(5):
                    neg_node = random.choice(nodes)
                    if neg_node not in neighbors and neg_node != node:
                        hard_neg_edges.append((node, neg_node))
                        break

            neg_edges.extend(hard_neg_edges)

        if num_centrality > 0:
            pagerank = nx.pagerank(G)
            sorted_nodes = sorted(pagerank.keys(), key=lambda x: pagerank[x], reverse=True)
            neg_edges_list = []
            nodes = list(G.nodes())

            while len(neg_edges_list) < num_centrality:
                high_centrality_node = random.choice(sorted_nodes[:len(sorted_nodes)//5])
                random_node = random.choice(nodes)

                if not G.has_edge(high_centrality_node, random_node):
                    neg_edges_list.append((high_centrality_node, random_node))

            neg_edges.extend(neg_edges_list)

        neg_edges = torch.tensor(neg_edges, dtype=torch.long).t().contiguous()

    else:
        raise ValueError("Strategia non valida. Scegli tra 'uniform', 'hard', 'centrality' o 'hybrid'.")

    return neg_edges


In [None]:
def create_train_test_edges(data, test_ratio=0.1, min_edges=2, neg_sampling_strategy="uniform", strategy_weights=None):
    """
    Divide gli archi in training e test set con strategia di sampling configurabile.
    """
    device = data.edge_index.device
    num_edges = data.edge_index.shape[1]
    num_test = max(int(test_ratio * num_edges), min_edges)

    perm = torch.randperm(num_edges, device=device)
    test_edges = data.edge_index[:, perm[:num_test]]
    train_edges = data.edge_index[:, perm[num_test:]]

    neg_edges = advanced_negative_sampling(
        data,
        num_neg_samples=num_test,
        strategy=neg_sampling_strategy,
        strategy_weights=strategy_weights if neg_sampling_strategy == "hybrid" else None
    )

    neg_edges = neg_edges.to(device)

    return train_edges.long(), test_edges.long(), neg_edges.long()

## **Definizione delle GNN**

In [None]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GATConv
from torch_geometric.nn import SAGEConv

class GCN(torch.nn.Module):
    def __init__(self, in_feats, hidden_feats):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_feats, hidden_feats)
        self.conv2 = GCNConv(hidden_feats, hidden_feats)

    def forward(self, x, edge_index):
        x = torch.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

class GAT(torch.nn.Module):
    def __init__(self, in_feats, hidden_feats, heads=2):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_feats, hidden_feats, heads=heads, concat=True)
        self.conv2 = GATConv(hidden_feats * heads, hidden_feats, heads=1, concat=False)

    def forward(self, x, edge_index):
        x = torch.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_feats, hidden_feats):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(in_feats, hidden_feats)
        self.conv2 = SAGEConv(hidden_feats, hidden_feats)

    def forward(self, x, edge_index):
        x = torch.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

## **Addestramento delle GNN per Link Prediction**

I modelli vengono addestrati e valutati su diverse comunità, utilizzando le metriche AUC (Area Under Curve) e AP (Average Precision) per misurare le prestazioni.

In [None]:
def link_predictor(embeddings, edge_index):
    """Dot product tra coppie di nodi per predire il link."""
    src, dst = edge_index
    src = src.long()
    dst = dst.long()
    return (embeddings[src] * embeddings[dst]).sum(dim=1)

# directory per salvare i modelli finali
os.makedirs("trained_models", exist_ok=True)

def train_link_prediction(graphs, models, epochs=100, lr=0.01, neg_sampling_strategy="hard", strategy_weights=None, community_type=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = {}

    for model_name, Model in models.items():
        print(f"\nTraining {model_name} with negative sampling: {neg_sampling_strategy}")

        data_example = next(iter(graphs.values()))
        model = Model(in_feats=data_example.x.shape[1], hidden_feats=32).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)

        auc_scores = []
        ap_scores = []

        for epoch in range(epochs):
            model.train()
            optimizer.zero_grad()
            total_loss = torch.tensor(0.0, device=device)  # Inizializziamo come tensor
            num_graphs = 0

            curr_weights = strategy_weights if neg_sampling_strategy == "hybrid" else None

            for graph_id, data in graphs.items():
                try:
                    data = data.to(device)
                    train_edges, test_edges, neg_edges = create_train_test_edges(
                        data,
                        neg_sampling_strategy=neg_sampling_strategy,
                        strategy_weights=curr_weights
                    )

                    if test_edges.size(1) < 2 or neg_edges.size(1) < 2:
                        print(f"Skipping Graph {graph_id} - Test set too small")
                        continue

                    embeddings = model(data.x, train_edges)
                    pos_pred = link_predictor(embeddings, train_edges)
                    neg_pred = link_predictor(embeddings, neg_edges)

                    labels = torch.cat([torch.ones(pos_pred.size(0)), torch.zeros(neg_pred.size(0))]).to(device)
                    pred = torch.cat([pos_pred, neg_pred])

                    loss = F.binary_cross_entropy_with_logits(pred, labels)
                    total_loss += loss
                    num_graphs += 1

                except Exception as e:
                    print(f"Errore nell'elaborazione di Graph {graph_id}: {str(e)}")
                    continue

            if num_graphs > 0:
                total_loss /= num_graphs
                total_loss.backward()
                optimizer.step()

            if epoch % 10 == 0:
                print(f"Epoch {epoch}, Avg Loss: {total_loss.item()}")

        # Valutazione
        model.eval()
        for graph_id, data in graphs.items():
            try:
                data = data.to(device)
                train_edges, test_edges, neg_edges = create_train_test_edges(
                    data,
                    neg_sampling_strategy=neg_sampling_strategy,
                    strategy_weights=strategy_weights if neg_sampling_strategy == "hybrid" else None
                )

                with torch.no_grad():
                    embeddings = model(data.x, train_edges)
                    pos_test_pred = link_predictor(embeddings, test_edges)
                    neg_test_pred = link_predictor(embeddings, neg_edges)

                    test_labels = torch.cat([torch.ones(pos_test_pred.size(0)),
                                             torch.zeros(neg_test_pred.size(0))]).cpu().numpy()
                    test_pred = torch.cat([pos_test_pred, neg_test_pred]).cpu().numpy()

                    if len(test_labels) < 2 or len(set(test_labels)) < 2:
                        print(f"Attenzione: Test set di Graph {graph_id} non valido per la valutazione")
                        continue

                    auc = roc_auc_score(test_labels, test_pred)
                    ap = average_precision_score(test_labels, test_pred)

                    auc_scores.append(auc)
                    ap_scores.append(ap)

            except Exception as e:
                print(f"Errore nella valutazione di Graph {graph_id}: {str(e)}")
                continue

        # Calcoliamo le metriche
        avg_auc = np.nanmean(auc_scores) if auc_scores else np.nan
        avg_ap = np.nanmean(ap_scores) if ap_scores else np.nan
        print(f"Final results - {model_name} | Strategy {neg_sampling_strategy} - AUC: {avg_auc:.4f}, AP: {avg_ap:.4f}")

        model_path = f"trained_models/{model_name}_strategy_{neg_sampling_strategy}_{community_type}.pth"
        torch.save(model.state_dict(), model_path)

        results[model_name] = {
            "AUC": avg_auc,
            "AP": avg_ap
        }

    return results

## **Salvataggio dei risultati**
I risultati delle diverse strategie e modelli vengono salvati e visualizzati in una tabella.



In [None]:
results_summary = []

models = {
    "GCN": GCN,
    "GAT": GAT,
    "GraphSAGE": GraphSAGE
}

weights = {"uniform": 0.3, "hard": 0.4, "centrality": 0.3}

for strategy in ["uniform", "hard", "centrality", "hybrid"]:
    print(f"\n====== Training for Strategy: {strategy} ======")

    # Add community_type parameter for Infomap
    results_infomap = train_link_prediction(
        gnn_graphs_infomap,
        models,
        neg_sampling_strategy=strategy,
        strategy_weights=weights if strategy == "hybrid" else None,
        community_type="infomap"
    )

    # Add community_type parameter for Louvain
    results_louvain = train_link_prediction(
        gnn_graphs_louvain,
        models,
        neg_sampling_strategy=strategy,
        strategy_weights=weights if strategy == "hybrid" else None,
        community_type="louvain"
    )

    for model in models.keys():
            results_summary.append({
                "Strategy": strategy,
                "Model": model,
                "AUC (Infomap)": results_infomap[model]["AUC"],
                "AP (Infomap)": results_infomap[model]["AP"],
                "AUC (Louvain)": results_louvain[model]["AUC"],
                "AP (Louvain)": results_louvain[model]["AP"]
            })
results_df = pd.DataFrame(results_summary)



Training GCN with negative sampling: uniform
Epoch 0, Avg Loss: 0.6763001680374146
Epoch 10, Avg Loss: 0.343402236700058
Epoch 20, Avg Loss: 0.27392566204071045
Epoch 30, Avg Loss: 0.24475303292274475
Epoch 40, Avg Loss: 0.272144079208374
Epoch 50, Avg Loss: 0.2511940896511078
Epoch 60, Avg Loss: 0.25781551003456116
Epoch 70, Avg Loss: 0.24915817379951477
Epoch 80, Avg Loss: 0.2502090632915497
Epoch 90, Avg Loss: 0.25756195187568665
Final results - GCN | Strategy uniform - AUC: 0.8299, AP: 0.8681

Training GAT with negative sampling: uniform
Epoch 0, Avg Loss: 0.6913192272186279
Epoch 10, Avg Loss: 0.39373549818992615
Epoch 20, Avg Loss: 0.3292187750339508
Epoch 30, Avg Loss: 0.32082661986351013
Epoch 40, Avg Loss: 0.32144269347190857
Epoch 50, Avg Loss: 0.3195691704750061
Epoch 60, Avg Loss: 0.319161057472229
Epoch 70, Avg Loss: 0.3195987939834595
Epoch 80, Avg Loss: 0.31944945454597473
Epoch 90, Avg Loss: 0.31933513283729553
Final results - GAT | Strategy uniform - AUC: 0.5718, AP:

In [None]:
from IPython.display import display
display(results_df)

Unnamed: 0,Strategy,Model,AUC (Infomap),AP (Infomap),AUC (Louvain),AP (Louvain)
0,uniform,GCN,0.829931,0.868115,0.781706,0.827711
1,uniform,GAT,0.571808,0.592459,0.522153,0.542354
2,uniform,GraphSAGE,0.656063,0.722856,0.702457,0.75051
3,hard,GCN,0.851896,0.884866,0.800488,0.841821
4,hard,GAT,0.497449,0.524565,0.517514,0.544273
5,hard,GraphSAGE,0.662,0.733666,0.717043,0.770561
6,centrality,GCN,0.786554,0.828669,0.691834,0.763269
7,centrality,GAT,0.497558,0.55713,0.520303,0.538181
8,centrality,GraphSAGE,0.602434,0.663054,0.642618,0.692835
9,hybrid,GCN,0.772823,0.848463,0.733016,0.780996
