In [1]:
from bs4 import BeautifulSoup
import logging
import re
import requests

In [2]:
class CountryLoader:
    # Class to load and process United Nations Member States

    def __init__(self):
        self.base_url = 'https://www.un.org/en/about-us/member-states'

        # Set up logger to log messages for various events and errors
        logging.basicConfig(level=logging.INFO) # Log messages with a security level of INFO or higher
        self.logger = logging.getLogger(__name__)

    def fetch_and_parse(self, url: str) -> BeautifulSoup:
        # Fetch content from the URL and return parsed BeautifulSoup object
        try:
            headers = {
                'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/132.0.0.0 Safari/537.36'
            }
            response = requests.get(url, headers=headers, timeout=10)
            response.raise_for_status()
            return BeautifulSoup(response.content, 'lxml')

        except requests.RequestException as e:
            self.logger.error(f'Failed to fetch URL {url}: {str(e)}')
            raise

    def extract_countries(self, soup: BeautifulSoup): # -> List[str]
        # Extract country names from parsed HTML
        countries = []

        try:
            # 'mb-2' is a unique CSS class, not present elsewhere in the HTML
            # This div contains the countries inside 'col-md-12' divs
            # Names are contained in h2 elements with class 'mb-0'
            block = soup.find('div', class_='mb-2') # Works as of 22nd January 2025

            if block is None:
                self.logger.error(f'No div with class "mb-2" found, URL {url} structure has likely been changed"')

            for country in block.find_all('h2', class_='mb-0'): # Works as of 22nd Janaury 2025
                name = country.text.strip()
                if name:
                    countries.append(name)

            # Check if the countries list is populated
            if not countries:
                self.logger.warning('No country names found in "mb-2" block')

        except Exception as e:
            self.logger.error(f'An error occured while extracting country names: {str(e)}')
            raise

        return countries

    def clean_country_name(self, name: str) -> str:
        # Standardizing country names for clarity and consistency, where official names are more commonly referred to by other names in international contexts.
        name_mapping = {
            'Democratic People\'s Republic of Korea': 'North Korea',
            'Democratic Republic of the Congo': 'DR Congo',
            'Lao People’s Democratic Republic': 'Laos',
            'Republic of Korea': 'South Korea',
            'Republic of Moldova': 'Moldova',
            'Russian Federation': 'Russia',
            'Syrian Arab Republic': 'Syria',
            'United Kingdom of Great Britain and Northern Ireland': 'United Kingdom',
            'United Republic of Tanzania': 'Tanzania',
        }

        if name in name_mapping:
            name = name_mapping[name] # Will change this later to generalize to lower case and other stuff

        # Some country names have official designations in brackets
        # Other country names like Venezuela have designations after a comma
        # Remove these since they aren't relevant in the game

        # Match sequence of characters in paranthesis and remove it
        name = re.sub(r'\([^)]*\)', '', name)

        # Remove everything after (and incluing) a comma
        name = re.sub(r',.*', '', name)

        return name.strip()

    def load_countries(self): # -> List[str]
        # Main method to load and process country data
        try:
            soup = self.fetch_and_parse(self.base_url)
            countries = self.extract_countries(soup)
            countries = [self.clean_country_name(country) for country in countries]

            return countries

        except Exception as e:
            self.logger.error(f'Failed to load countries: {str(e)}')
            raise

In [3]:
def load_country_data(): # -> List[str]
    # Wrapper function to create a class, load country data and return a list of countries
    loader = CountryLoader()
    return loader.load_countries()

# Test the loader
try:
    countries = load_country_data()
    print(f'Successfully extracted {len(countries)} countries')
    print('\n'.join(sorted(countries)))

except Exception as e:
    print(f'Error during testing: {str(e)}')

Successfully extracted 192 countries
Afghanistan
Albania
Algeria
Andorra
Angola
Antigua and Barbuda
Argentina
Armenia
Australia
Austria
Azerbaijan
Bahamas
Bahrain
Bangladesh
Barbados
Belarus
Belgium
Belize
Benin
Bhutan
Bolivia
Bosnia and Herzegovina
Botswana
Brazil
Brunei Darussalam
Bulgaria
Burkina Faso
Burundi
Cabo Verde
Cambodia
Cameroon
Canada
Central African Republic
Chad
Chile
China
Colombia
Comoros
Congo
Costa Rica
Croatia
Cuba
Cyprus
Czechia
Côte D'Ivoire
DR Congo
Denmark
Djibouti
Dominica
Dominican Republic
Ecuador
Egypt
El Salvador
Equatorial Guinea
Eritrea
Estonia
Eswatini
Ethiopia
Fiji
Finland
France
Gabon
Gambia
Georgia
Germany
Ghana
Greece
Grenada
Guatemala
Guinea
Guinea Bissau
Guyana
Haiti
Honduras
Hungary
Iceland
India
Indonesia
Iran
Iraq
Ireland
Israel
Italy
Jamaica
Japan
Jordan
Kazakhstan
Kenya
Kiribati
Kuwait
Kyrgyzstan
Laos
Latvia
Lebanon
Lesotho
Liberia
Libya
Liechtenstein
Lithuania
Luxembourg
Madagascar
Malawi
Malaysia
Maldives
Mali
Malta
Marshall Islands
Mauritan

In [4]:
# Helper: Build a directed grpah from the given list of names as per the rules of the game
def build_graph(names):
    G = nx.DiGraph()
    # Normalize names (stripping, and converting the entire name to lower case)
    names_norm = [name.strip().lower() for name in names]
    for name in names_norm:
        G.add_node(name)
    for a in names_norm:
        for b in names_norm:
            if a != b and a[-1] == b[0]:
                G.add_edge(a, b)
    return G

In [5]:
import hdbscan
import networkx as nx
import numpy as np
from pyvis.network import Network
import random
import string
import torch
from torch_geometric.nn import GCNConv, GAE
from torch_geometric.utils import from_networkx
import torch.nn.functional as F
import torch.optim as optim

########################################
# 1. Build the Atlas Directed Graph
########################################

# Create a directed graph using NetworkX.
G_nx = nx.DiGraph()
G_nx.add_nodes_from(countries)

# Define the Atlas game rule: an edge from A to B if the last letter of A equals the first letter of B.
def can_move(a, b):
    return a[-1].lower() == b[0].lower()

# Add edges following the rule.
for a in countries:
    for b in countries:
        if a != b and can_move(a, b):
            G_nx.add_edge(a, b)

print('Nodes in graph:', list(G_nx.nodes()))
print('Edges in graph:', list(G_nx.edges()))

########################################
# 2. Create Node Features (52-dimensional)
########################################

# One-hot encode the first and last letters of each country.
alphabet = list(string.ascii_lowercase)

def one_hot_first_last(country):
    '''
    Returns a 52-dim one-hot encoded vector:
      - First 26 entries for the first letter.
      - Last 26 entries for the last letter.
    '''
    country = country.lower()
    vec = [0] * 52
    first_letter = country[0]
    last_letter = country[-1]

    if first_letter in alphabet:
        idx = alphabet.index(first_letter)
        vec[idx] = 1
    if last_letter in alphabet:
        idx = alphabet.index(last_letter)
        vec[26 + idx] = 1  # Offset by 26 for the last letter.
    return vec

# Build the feature matrix.
features = np.array([one_hot_first_last(country) for country in countries], dtype=np.float32)
print('\nFeatures for each country:')
for country, feat in zip(countries, features):
    print(f'{country}: {feat}')

########################################
# 3. Convert to PyTorch Geometric Data Object
########################################

# Convert the NetworkX graph into a PyTorch Geometric data object.
data = from_networkx(G_nx)
# Assign the computed 52-dimensional features.
data.x = torch.tensor(features)

########################################
# 4. Define and Train a GNN (Graph Autoencoder)
########################################

# Define a simple GCN-based encoder.
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

# Set dimensions.
in_channels = data.num_features  # 52 from our one-hot encoding.
hidden_channels = 8              # Hidden dimension (adjustable).
latent_dim = 4                   # Dimension of the embedding space.

# Create the Graph Autoencoder (GAE) model.
encoder = GCNEncoder(in_channels, hidden_channels, latent_dim)
model = GAE(encoder)

# Set device.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
data = data.to(device)

# Set up the optimizer.
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model.
model.train()
epochs = 200
for epoch in range(epochs):
    optimizer.zero_grad()
    z = model.encode(data.x, data.edge_index)  # Compute embeddings.
    loss = model.recon_loss(z, data.edge_index)  # Reconstruction loss using inner product decoder.
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

########################################
# 5. Extract Embeddings and Cluster with HDBSCAN
########################################

model.eval()
with torch.no_grad():
    z = model.encode(data.x, data.edge_index)

# Convert embeddings to a NumPy array.
embeddings = z.cpu().detach().numpy()
print('\nLearned Embeddings:\n', embeddings)

# Cluster the embeddings using HDBSCAN.
clusterer = hdbscan.HDBSCAN(min_cluster_size=2)
cluster_labels = clusterer.fit_predict(embeddings)
print('\nHDBSCAN Cluster Labels:', cluster_labels)

# Map cluster labels back to the original country names.
cluster_mapping = {country: cluster_labels[i] for i, country in enumerate(countries)}
for country, label in cluster_mapping.items():
    print(f'{country}: Cluster {label}')

########################################
# 6. Visualize Communities Using Pyvis
########################################

# Create a Pyvis network.
net = Network(height='800px', width='100%', directed=True)
net.set_options("""
    var options = {
        "nodes": {
            "font": {
                "size": 12
            }
        },
        "edges": {
            "arrows": {
                "to": {
                    "enabled": true,
                    "scaleFactor": 0.5
                }
            },
            "smooth": {
                "type": "continuous",
                "forceDirection": "none"
            }
        },
        "physics": {
            "forceAtlas2Based": {
                "gravitationalConstant": -100,
                "centralGravity": 0.01,
                "springLength": 200,
                "springConstant": 0.08,
                "damping": 0.4,
                "avoidOverlap": 1
            },
            "minVelocity": 23,
            "solver": "forceAtlas2Based"
        }
    }
""")

# Create a color mapping for clusters.
# For noise (cluster -1), use black.
unique_clusters = np.unique(cluster_labels)
color_map = {}
for cl in unique_clusters:
    if cl == -1:
        color_map[cl] = '#000000'
    else:
        # Generate a random hex color.
        color_map[cl] = '#{:06x}'.format(random.randint(0, 0xFFFFFF))

# Add nodes with the cluster color.
for node in G_nx.nodes():
    cl = cluster_mapping[node]
    net.add_node(node, label=node, title=f"Cluster {cl}", color=color_map[cl])

# Add edges.
for source, target in G_nx.edges():
    net.add_edge(source, target)

# Show the interactive visualization.
fileName = 'gnn_communities.html'
net.show(fileName, notebook=False)

Nodes in graph: ['Afghanistan', 'Albania', 'Algeria', 'Andorra', 'Angola', 'Antigua and Barbuda', 'Argentina', 'Armenia', 'Australia', 'Austria', 'Azerbaijan', 'Bahamas', 'Bahrain', 'Bangladesh', 'Barbados', 'Belarus', 'Belgium', 'Belize', 'Benin', 'Bhutan', 'Bolivia', 'Bosnia and Herzegovina', 'Botswana', 'Brazil', 'Brunei Darussalam', 'Bulgaria', 'Burkina Faso', 'Burundi', 'Cabo Verde', 'Cambodia', 'Cameroon', 'Canada', 'Central African Republic', 'Chad', 'Chile', 'China', 'Colombia', 'Comoros', 'Congo', 'Costa Rica', "Côte D'Ivoire", 'Croatia', 'Cuba', 'Cyprus', 'Czechia', 'North Korea', 'DR Congo', 'Denmark', 'Djibouti', 'Dominica', 'Dominican Republic', 'Ecuador', 'Egypt', 'El Salvador', 'Equatorial Guinea', 'Eritrea', 'Estonia', 'Eswatini', 'Ethiopia', 'Fiji', 'Finland', 'France', 'Gabon', 'Gambia', 'Georgia', 'Germany', 'Ghana', 'Greece', 'Grenada', 'Guatemala', 'Guinea', 'Guinea Bissau', 'Guyana', 'Haiti', 'Honduras', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Iran', 'Iraq', 



In [6]:
from google.colab import files
files.download('gnn_communities.html')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [7]:
import networkx as nx
import numpy as np
import random
from sklearn.metrics import roc_auc_score
import string
import torch
from torch_geometric.nn import GCNConv, Node2Vec
from torch_geometric.utils import from_networkx, negative_sampling
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from typing import Tuple, List, Dict

# Set seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def relabel_with_attributes(G):
    # Convert node labels to integers while preserving original labels as attributes
    mapping = {node: idx for idx, node in enumerate(G.nodes())}
    G_int = nx.DiGraph()

    # Add nodes with attributes
    for old_label, new_label in mapping.items():
        G_int.add_node(new_label, name=old_label)

    # Add edges using new labels
    for u, v in G.edges():
        G_int.add_edge(mapping[u], mapping[v])

    return G_int, mapping

def create_node_features(G):
    # Create node features
    num_nodes = G.number_of_nodes()

    # One-hot encoding for first and last letters (52 features)
    letter_features = np.zeros((num_nodes, 52))

    # Graph structural features (5 features)
    structural_features = np.zeros((num_nodes, 5))

    # Word length features (1 feature)
    length_features = np.zeros((num_nodes, 1))

    for node in G.nodes():
        name = G.nodes[node]['name'].lower()

        # Letter features
        if name[0] in string.ascii_lowercase:
            letter_features[node, ord(name[0]) - ord('a')] = 1
        if name[-1] in string.ascii_lowercase:
            letter_features[node, 26 + (ord(name[-1]) - ord('a'))] = 1

        # Structural features
        structural_features[node, 0] = G.in_degree(node) / num_nodes # Normalized in-degree
        structural_features[node, 1] = G.out_degree(node) / num_nodes # Normalized out-degree
        structural_features[node, 2] = nx.clustering(G.to_undirected(), node) # Clustering coefficient

        # PageRank and betweenness centrality
        pr = nx.pagerank(G)
        bc = nx.betweenness_centrality(G)
        structural_features[node, 3] = pr[node]
        structural_features[node, 4] = bc[node]

        # Word length (normalized)
        length_features[node, 0] = len(name) / 15 # Normalize by typical max length

    # Combine all features
    features = np.hstack([letter_features, structural_features, length_features])
    return torch.FloatTensor(features)

def train_test_split_edges(G, test_frac: float = 0.2):
    # Split edges ensuring graph remains connected and proper negative sampling
    edges = list(G.edges())
    n_test = max(1, int(len(edges) * test_frac))

    # Try splitting until we get a connected training graph
    while True:
        random.shuffle(edges)
        test_edges = edges[:n_test]
        train_edges = edges[n_test:]

        G_train = G.copy()
        G_train.remove_edges_from(test_edges)

        if nx.is_weakly_connected(G_train):
            break

    # Sample hard negative edges (realistic but invalid moves)
    n_nodes = G.number_of_nodes()
    neg_edges = []
    existing_edges = set(G.edges())
    nodes = list(G.nodes())

    # Create negative edges that look plausible but are invalid
    while len(neg_edges) < len(test_edges):
        source = random.choice(nodes)
        source_name = G.nodes[source]['name']

        # Find target that starts with a letter different from source's last letter
        while True:
            target = random.choice(nodes)
            if target != source:
                target_name = G.nodes[target]['name']
                if source_name[-1] != target_name[0] and (source, target) not in existing_edges:
                    neg_edges.append((source, target))
                    break

    return G_train, test_edges, neg_edges

class GNNEncoder(nn.Module):
    # GNN model for link prediction
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int):
        super(GNNEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, out_channels)

        self.batch_norm1 = nn.BatchNorm1d(hidden_channels)
        self.batch_norm2 = nn.BatchNorm1d(hidden_channels)

    def forward(self, x, edge_index):
        # First layer
        x = self.conv1(x, edge_index)
        x = self.batch_norm1(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.3, training=self.training)

        # Second layer
        x = self.conv2(x, edge_index)
        x = self.batch_norm2(x)
        x = F.relu(x)
        x = F.dropout(x, p=0.3, training=self.training)

        # Third layer
        x = self.conv3(x, edge_index)
        return x

def margin_loss(pos_score, neg_score, margin=0.5):
    # Margin loss for link prediction
    return F.relu(margin - pos_score + neg_score).mean()

def train_node2vec(G_train, embedding_dim=128):
    # Train Node2Vec model with improved parameters
    data = from_networkx(G_train)

    model = Node2Vec(
        data.edge_index,
        embedding_dim=embedding_dim,
        walk_length=20,  # Longer walks
        context_size=10,  # Larger context
        walks_per_node=20,  # More walks
        p=0.2,  # Favor DFS-like exploration
        q=2.0,  # Favor local structure
        sparse=True
    ).to(device)

    optimizer = torch.optim.SparseAdam(model.parameters(), lr=0.01)
    loader = model.loader(batch_size=128, shuffle=True, num_workers=4)

    model.train()
    total_loss = 0
    for epoch in range(200):  # More epochs
        epoch_loss = 0
        for pos_rw, neg_rw in loader:
            optimizer.zero_grad()
            loss = model.loss(pos_rw.to(device), neg_rw.to(device))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        total_loss = epoch_loss / len(loader)
        if epoch % 10 == 0:
            print(f'Node2Vec Epoch: {epoch:02d}, Loss: {total_loss:.4f}')

    return model

def evaluate_model(embeddings, test_edges, neg_edges):
    '''
    Evaluate link prediction performance using ROC AUC.

    Args:
        embeddings: Node embeddings from model
        test_edges: Positive test edges
        neg_edges: Negative test edges
    '''
    scores = []
    labels = []

    # Compute scores for positive edges
    for src, dst in test_edges:
        score = float((embeddings[src] * embeddings[dst]).sum())
        scores.append(score)
        labels.append(1)

    # Compute scores for negative edges
    for src, dst in neg_edges:
        score = float((embeddings[src] * embeddings[dst]).sum())
        scores.append(score)
        labels.append(0)

    return roc_auc_score(labels, scores)

def analyze_predictions(model, data, test_edges, neg_edges, mapping):
    '''
    Analyze model predictions in detail.

    Args:
        model: Trained model (GNN)
        data: PyG data object
        test_edges: Positive test edges
        neg_edges: Negative test edges
        mapping: Node index to name mapping
    '''
    model.eval()
    reverse_mapping = {idx: name for name, idx in mapping.items()}

    with torch.no_grad():
        embeddings = model(data.x, data.edge_index)

        # Analyze positive edges
        pos_predictions = []
        for src, dst in test_edges:
            score = float((embeddings[src] * embeddings[dst]).sum())
            src_name = reverse_mapping[src]
            dst_name = reverse_mapping[dst]
            pos_predictions.append({
                'src': src_name,
                'dst': dst_name,
                'score': score,
                'correct': src_name[-1].lower() == dst_name[0].lower()
            })

        # Analyze negative edges
        neg_predictions = []
        for src, dst in neg_edges:
            score = float((embeddings[src] * embeddings[dst]).sum())
            src_name = reverse_mapping[src]
            dst_name = reverse_mapping[dst]
            neg_predictions.append({
                'src': src_name,
                'dst': dst_name,
                'score': score,
                'correct': src_name[-1].lower() != dst_name[0].lower()
            })

        # Sort by scores
        pos_predictions.sort(key=lambda x: x['score'], reverse=True)
        neg_predictions.sort(key=lambda x: x['score'])

        print('\nTop 5 Highest Scoring Valid Moves:')
        for pred in pos_predictions[:5]:
            print(f"{pred['src']} -> {pred['dst']}: {pred['score']:.4f}")

        print('\nTop 5 Lowest Scoring Invalid Moves:')
        for pred in neg_predictions[:5]:
            print(f"{pred['src']} -> {pred['dst']}: {pred['score']:.4f}")

        # Calculate metrics
        pos_correct = sum(1 for p in pos_predictions if p['correct'])
        neg_correct = sum(1 for p in neg_predictions if p['correct'])

        print(f"\nAccuracy on valid moves: {pos_correct/len(pos_predictions):.4f}")
        print(f"Accuracy on invalid moves: {neg_correct/len(neg_predictions):.4f}")

        # Score distributions
        pos_scores = [p['score'] for p in pos_predictions]
        neg_scores = [p['score'] for p in neg_predictions]

        print(f'\nScore Statistics:')
        print(f'Valid moves - Mean: {np.mean(pos_scores):.4f}, Std: {np.std(pos_scores):.4f}')
        print(f'Invalid moves - Mean: {np.mean(neg_scores):.4f}, Std: {np.std(neg_scores):.4f}')

        return {
            'pos_predictions': pos_predictions,
            'neg_predictions': neg_predictions,
            'metrics': {
                'pos_accuracy': pos_correct/len(pos_predictions),
                'neg_accuracy': neg_correct/len(neg_predictions),
                'pos_mean': np.mean(pos_scores),
                'neg_mean': np.mean(neg_scores)
            }
        }

def main():
    G = build_graph(countries)
    G_int, mapping = relabel_with_attributes(G)

    # Split edges
    G_train, test_edges, neg_edges = train_test_split_edges(G_int)

    print('Graph Statistics:')
    print(f'Nodes: {G.number_of_nodes()}')
    print(f'Total Edges: {G.number_of_edges()}')
    print(f'Training Edges: {len(G_train.edges())}')
    print(f'Test Edges: {len(test_edges)}')

    # Train Node2Vec
    print('\nTraining Node2Vec...')
    n2v_model = train_node2vec(G_train)
    n2v_embeddings = n2v_model()
    n2v_auc = evaluate_model(n2v_embeddings, test_edges, neg_edges)
    print(f'Node2Vec ROC AUC: {n2v_auc:.4f}')

    # Train GNN
    print('\nTraining GNN...')
    data = from_networkx(G_train)
    data.x = create_node_features(G_train).to(device)
    data.edge_index = data.edge_index.to(device)

    gnn_model = GNNEncoder(in_channels=58, hidden_channels=128, out_channels=64).to(device)
    optimizer = torch.optim.Adam(gnn_model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=10)

    best_auc = 0
    patience = 20
    no_improve = 0

    for epoch in tqdm(range(150)):
        gnn_model.train()
        optimizer.zero_grad()

        z = gnn_model(data.x, data.edge_index)

        # Get positive edges
        pos_edge_index = data.edge_index

        # Generate hard negative edges
        neg_edge_index = negative_sampling(
            edge_index=pos_edge_index,
            num_nodes=data.num_nodes,
            num_neg_samples=pos_edge_index.size(1)
        )

        # Compute scores
        pos_score = (z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=1)
        neg_score = (z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=1)

        # Combined loss with margin
        loss = margin_loss(pos_score, neg_score)

        loss.backward()
        optimizer.step()

        # Evaluate every 10 epochs
        if epoch % 10 == 0:
            gnn_model.eval()
            with torch.no_grad():
                z = gnn_model(data.x, data.edge_index)
                auc = evaluate_model(z, test_edges, neg_edges)
                print(f'GNN Epoch {epoch:03d}, Loss: {loss.item():.4f}, AUC: {auc:.4f}')

                scheduler.step(auc)

                if auc > best_auc:
                    best_auc = auc
                    no_improve = 0
                else:
                    no_improve += 1

                if no_improve >= patience:
                    print('Early stopping!')
                    break

    print(f'\nFinal Results:')
    print(f'Node2Vec AUC: {n2v_auc:.4f}')
    print(f'GNN AUC: {best_auc:.4f}')

    # Analyze predictions
    print('\nAnalyzing GNN Predictions...')
    gnn_analysis = analyze_predictions(gnn_model, data, test_edges, neg_edges, mapping)

if __name__ == "__main__":
    main()

Graph Statistics:
Nodes: 192
Total Edges: 2001
Training Edges: 1601
Test Edges: 400

Training Node2Vec...




Node2Vec Epoch: 00, Loss: 9.5942
Node2Vec Epoch: 10, Loss: 5.3078
Node2Vec Epoch: 20, Loss: 3.9629
Node2Vec Epoch: 30, Loss: 3.3052
Node2Vec Epoch: 40, Loss: 2.8333
Node2Vec Epoch: 50, Loss: 2.5292
Node2Vec Epoch: 60, Loss: 2.3453
Node2Vec Epoch: 70, Loss: 2.2168
Node2Vec Epoch: 80, Loss: 2.1266
Node2Vec Epoch: 90, Loss: 2.0564
Node2Vec Epoch: 100, Loss: 2.0305
Node2Vec Epoch: 110, Loss: 1.9442
Node2Vec Epoch: 120, Loss: 1.9640
Node2Vec Epoch: 130, Loss: 1.8989
Node2Vec Epoch: 140, Loss: 1.8641
Node2Vec Epoch: 150, Loss: 1.8507
Node2Vec Epoch: 160, Loss: 1.8094
Node2Vec Epoch: 170, Loss: 1.8145
Node2Vec Epoch: 180, Loss: 1.7647
Node2Vec Epoch: 190, Loss: 1.7825
Node2Vec ROC AUC: 0.6637

Training GNN...


  7%|▋         | 11/150 [00:00<00:02, 53.44it/s]

GNN Epoch 000, Loss: 4.5730, AUC: 0.7587
GNN Epoch 010, Loss: 1.3956, AUC: 0.7635


 21%|██        | 31/150 [00:00<00:02, 58.52it/s]

GNN Epoch 020, Loss: 1.0382, AUC: 0.7692
GNN Epoch 030, Loss: 0.8104, AUC: 0.7661


 34%|███▍      | 51/150 [00:00<00:01, 58.36it/s]

GNN Epoch 040, Loss: 0.7492, AUC: 0.7662
GNN Epoch 050, Loss: 0.5519, AUC: 0.7694


 47%|████▋     | 71/150 [00:01<00:01, 58.58it/s]

GNN Epoch 060, Loss: 0.4507, AUC: 0.7726
GNN Epoch 070, Loss: 0.4338, AUC: 0.7730


 61%|██████    | 91/150 [00:01<00:01, 55.72it/s]

GNN Epoch 080, Loss: 0.3786, AUC: 0.7736
GNN Epoch 090, Loss: 0.3029, AUC: 0.7746


 74%|███████▍  | 111/150 [00:01<00:00, 56.84it/s]

GNN Epoch 100, Loss: 0.3416, AUC: 0.7888
GNN Epoch 110, Loss: 0.3011, AUC: 0.7855


 87%|████████▋ | 131/150 [00:02<00:00, 58.73it/s]

GNN Epoch 120, Loss: 0.3081, AUC: 0.7815
GNN Epoch 130, Loss: 0.2448, AUC: 0.7789


100%|██████████| 150/150 [00:02<00:00, 58.72it/s]


GNN Epoch 140, Loss: 0.2490, AUC: 0.7817

Final Results:
Node2Vec AUC: 0.6637
GNN AUC: 0.7888

Analyzing GNN Predictions...

Top 5 Highest Scoring Valid Moves:
algeria -> angola: 13.5792
austria -> afghanistan: 12.9691
angola -> albania: 12.9390
angola -> antigua and barbuda: 12.8098
antigua and barbuda -> azerbaijan: 12.6619

Top 5 Lowest Scoring Invalid Moves:
burkina faso -> bolivia: -0.7259
netherlands -> papua new guinea: -0.6581
germany -> afghanistan: -0.5496
norway -> panama: -0.4401
bosnia and herzegovina -> peru: -0.3487

Accuracy on valid moves: 1.0000
Accuracy on invalid moves: 1.0000

Score Statistics:
Valid moves - Mean: 2.1814, Std: 2.3763
Invalid moves - Mean: 0.8663, Std: 0.7194
