In [10]:
import torch
from torch.utils.cpp_extension import load

# This will compile and load the C++ extension.
skipgram_ops = load(name="skipgram_ops", sources=["skipgram_ops.cpp"], verbose=True)



Using /nfs/nfs2/home/gogandhi/.cache/torch_extensions/py39_cu124 as PyTorch extensions root...
No modifications detected for re-loaded extension module skipgram_ops_v2, skipping build step...
Loading extension module skipgram_ops_v2...


In [13]:
import time
import random
import numpy as np
import networkx as nx

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import lfr
import sys
import cProfile, pstats

In [14]:
##############################################
# Block 1: Setup and Random Walk Generation  #
##############################################

def generate_random_walks(G, num_walks=10, walk_length=10):
    """
    Generate random walks over the graph G.
    Note: For larger graphs, consider using more efficient vectorized implementations or libraries like 
    StellarGraph which implement efficient random walk sampling.
    """
    walks = []
    nodes = list(G.nodes())
    for _ in range(num_walks):
        random.shuffle(nodes)
        for node in nodes:
            walk = [node]
            for i in range(walk_length - 1):
                cur = walk[-1]
                neighbors = list(G.neighbors(cur))
                if len(neighbors) == 0:
                    break
                walk.append(random.choice(neighbors))
            walks.append(walk)
    return walks

##############################################
# Block 2: Generating SkipGram Pairs         #
##############################################

def generate_skipgram_pairs(walks, window_size=2):
    """
    For each walk, generate (center, context) pairs using a sliding window.
    Consider profiling this block using line_profiler if walks or window_size are large.
    """
    pairs = []
    for walk in walks:
        L = len(walk)
        for i, center in enumerate(walk):
            start = max(0, i - window_size)
            end = min(L, i + window_size + 1)
            for j in range(start, end):
                if i == j:
                    continue
                pairs.append((center, walk[j]))
    return pairs

##############################################
# Block 3: Dataset and DataLoader            #
##############################################

class SkipGramDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        center, context = self.pairs[idx]
        return center, context

In [None]:
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, sim_type="dot"):
        """
        Parameters:
          vocab_size: int – number of nodes.
          embedding_dim: int – embedding dimensionality.
          sim_type: str – similarity measure; options: "dot", "euclidean", "cosine".
        """
        super(SkipGramModel, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        
        # Map similarity measure to an integer for the C++ extension.
        sim_type_lower = sim_type.lower()
        if sim_type_lower == "dot":
            self.sim_type = 0
        elif sim_type_lower == "euclidean":
            self.sim_type = 1
        elif sim_type_lower == "cosine":
            self.sim_type = 2
        else:
            raise ValueError("Unknown similarity type. Choose from 'dot', 'euclidean', or 'cosine'.")
        
        self.center_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.context_embeddings = nn.Embedding(vocab_size, embedding_dim)
        initrange = 0.5 / embedding_dim
        self.center_embeddings.weight.data.uniform_(-initrange, initrange)
        self.context_embeddings.weight.data.zero_()
        
        # Beta is used to scale the Euclidean distance.
        self.beta = nn.Parameter(torch.tensor(1.0))
    
    def forward(self, center_idxs, context_idxs, negative_idxs):
        """
        center_idxs: (B,) tensor of center nodes.
        context_idxs: (B,) tensor of positive context nodes.
        negative_idxs: (B, K) tensor of negative samples.
        """
        center_vecs = self.center_embeddings(center_idxs)         # (B, D)
        pos_context_vecs = self.context_embeddings(context_idxs)    # (B, D)
        neg_context_vecs = self.context_embeddings(negative_idxs)   # (B, K, D)
        
        # Call the compiled C++ extension.
        pos_scores, neg_scores = skipgram_ops.skipgram_forward(
            center_vecs, pos_context_vecs, neg_context_vecs, self.sim_type, self.beta.item()
        )
        return pos_scores, neg_scores

def skipgram_loss(pos_scores, neg_scores):
    """
    Compute loss:
      - For a positive pair: L_pos = -log(sigmoid(score))
      - For negative samples: L_neg = -sum(log(sigmoid(-score)))
    Using F.logsigmoid for numerical stability.
    """
    loss_pos = -F.logsigmoid(pos_scores)
    loss_neg = -torch.sum(F.logsigmoid(-neg_scores), dim=1)
    loss = loss_pos + loss_neg
    return loss.mean()

In [17]:
from torch.utils.data import DataLoader, Dataset
import time
import random

# (Assume you have implemented functions to generate random walks and skip-gram pairs, and a Dataset class.)
# For illustration, here’s a minimal dummy dataset:
class SkipGramDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.pairs[idx]
##############################################
# Block 5: Training Loop                     #
##############################################

def train_model(G, num_walks = 10, walk_length = 80, 
                window_size = 10, embedding_dim = 128, 
                negative_samples = 5, num_epochs = 1, sim_type = "dot"):
    # Options: "dot", "euclidean", "cosine"
    
    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)   
    
    # Generate random walks
    start_time = time.time()
    walks = generate_random_walks(G, num_walks, walk_length)
    print("Number of walks:", len(walks))
    print("Example walk:", walks[0])

    pairs = generate_skipgram_pairs(walks, window_size)
    print("Number of positive pairs:", len(pairs))
    print("Example pairs:", pairs[:5])
    print("Skipgram walks and pairs generated in {:.4f} seconds.".format(time.time() - start_time))

    # Create dataset and dataloader
    dataset = SkipGramDataset(pairs)
    # For larger datasets, consider setting num_workers > 0 for parallel data loading.
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)

    # Hyperparameters
    vocab_size = G.number_of_nodes()   # Assuming node IDs are 0-indexed.    

    # Initialize model and optimizer
    model = SkipGramModel(vocab_size, embedding_dim, sim_type=sim_type).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print("\nStarting training...")
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        epoch_start = time.time()
        for center_idxs, context_idxs in dataloader:
            center_idxs = center_idxs.to(device)
            context_idxs = context_idxs.to(device)
            batch_size = center_idxs.size(0)
            # Use device variable to ensure compatibility even if the device is CPU.
            negative_idxs = torch.randint(0, vocab_size, (batch_size, negative_samples), device=device)
            optimizer.zero_grad()
            pos_scores, neg_scores = model(center_idxs, context_idxs, negative_idxs)
            loss = skipgram_loss(pos_scores, neg_scores)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * batch_size
        avg_loss = epoch_loss / len(dataset)
        print(f"Epoch {epoch+1}/{num_epochs} completed in {time.time() - epoch_start:.2f}s, Loss: {avg_loss:.4f}")

    # Save or return the trained model and embeddings as needed.
    return model, device

def evaluate_model(model, device):
    model.eval()
    with torch.no_grad():
        embeddings = model.center_embeddings.weight.cpu().numpy()
    print("Embeddings shape:", embeddings.shape)
    # print("First 5 embeddings:")
    # print(embeddings[:5])
    return embeddings

In [20]:
##############################################
# Main Function: Run the Training and Eval   #
##############################################

import lfr
def create_network(params={
    "N": 1000,     # number of nodes
    "k": 6,        # average degree
    "maxk": 25,    # maximum degree
    "minc": 25,    # minimum community size
    "maxc": 250,   # maximum community size
    "tau": 2,      # degree exponent
    "tau2": 1.5,   # community size exponent
    "mu": 0.01,     # mixing rate
}):
    ng = lfr.NetworkGenerator()
    data = ng.generate(**params)
    net = data["net"]                  # scipy.csr_sparse matrix
    community_table = data["community_table"]  # pandas DataFrame
    seed = data["seed"]                # Seed value
    return net, community_table, seed


# Create a sample graph (Karate Club)
#G = nx.karate_club_graph()
# Generate the network and the associated community table
A, community_labels, _ = create_network()
G = nx.from_scipy_sparse_array(A)

profiler = cProfile.Profile()
profiler.enable()
model, device = train_model(G, num_walks = 10, walk_length = 80, 
                            window_size = 10, embedding_dim = 128, 
                            negative_samples = 5, num_epochs = 1, sim_type = "euclidean" )

embeddings = evaluate_model(model, device)

profiler.disable()

Using device: cuda
Number of walks: 10000
Example walk: [548, 991, 548, 991, 229, 991, 126, 955, 991, 842, 266, 844, 616, 203, 616, 999, 280, 999, 395, 999, 324, 999, 678, 571, 678, 999, 756, 689, 860, 817, 860, 671, 860, 991, 860, 844, 860, 548, 765, 810, 765, 643, 879, 707, 777, 647, 777, 266, 844, 274, 842, 868, 810, 161, 868, 633, 880, 406, 817, 842, 604, 860, 441, 274, 441, 860, 731, 991, 765, 703, 996, 396, 897, 396, 924, 355, 924, 766, 589, 35]
Number of positive pairs: 14900000
Example pairs: [(548, 991), (548, 548), (548, 991), (548, 229), (548, 991)]
Skipgram walks and pairs generated in 8.3594 seconds.

Starting training...
Epoch 1/1 completed in 971.65s, Loss: 1.2678
Embeddings shape: (1000, 128)


In [21]:
stats = pstats.Stats(profiler).sort_stats('tottime')
stats.print_stats()

         188664265 function calls (183309504 primitive calls) in 980.483 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   232813  311.819    0.001  311.819    0.001 {method 'run_backward' of 'torch._C._EngineBase' objects}
   232813   79.033    0.000   79.033    0.000 {built-in method skipgram_ops_v2.skipgram_forward}
   698439   35.964    0.000   35.964    0.000 {built-in method torch.embedding}
   698442   30.413    0.000   30.413    0.000 {built-in method torch.tensor}
   232813   26.757    0.000   62.875    0.000 /tmp/ipykernel_1789051/280362172.py:54(skipgram_loss)
   465629   21.848    0.000   21.848    0.000 {method 'to' of 'torch._C.TensorBase' objects}
   698440   21.340    0.000   21.340    0.000 {built-in method torch._ops.profiler._record_function_enter_new}
  1396880   20.704    0.000   20.704    0.000 {method 'item' of 'torch._C.TensorBase' objects}
        1   19.764   19.764  980.016  980.016 /tmp/ipykerne

<pstats.Stats at 0x7b8c3e327490>