<a href="https://colab.research.google.com/github/git-akhtari/graph/blob/main/TemporalNode2vecBcdTgb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
!pip install torch-geometric
from torch_geometric.loader import TemporalDataLoader
!pip install py-tgb
from tgb.utils.utils import get_args
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
import node2vec_walks

# Generate Walks
def generate_walks(nx_G, length, p, q):
  G = node2vec_walks.Graph(nx_G, is_directed=False, p=p, q=q)
##G = Graph(nx_G, is_directed=False, p=p, q=q)
  G.preprocess_transition_probs()
  return G.simulate_walks(num_walks=1, walk_length=length)

# def generate_walks(graph, walk_length):
#     walks = []
#     for node in graph.nodes():
#         walk = [node]
#         for _ in range(walk_length - 1):
#             neighbors = list(graph.neighbors(walk[-1]))
#             if len(neighbors) > 0:
#                 walk.append(np.random.choice(neighbors))
#             else:
#                 break
#         walks.append(walk)
#     return walks

# def generate_walks(graph, length, p):
# walks = []
# for node in graph.nodes():
#     walk = [node]
#     while len(walk) < length:
#         current_node = walk[-1]
#         neighbors = list(graph.neighbors(current_node))
#         if not neighbors:
#             break
#         if random.random() < p:
#             walk.append(walk[-2] if len(walk) > 1 else random.choice(neighbors))
#         else:
#             next_node = random.choice(neighbors)
#             walk.append(next_node)
#     walks.append(walk)
# return walks

# Compute PPMI Matrix
def compute_ppmi(walks, window_size, num_nodes):
    cooccurrence_matrix = np.zeros((num_nodes, num_nodes))
    for walk in walks:
        for i, node in enumerate(walk):
            for j in range(max(i - window_size, 0), min(i + window_size + 1, len(walk))):
                if i != j:
                    cooccurrence_matrix[node, walk[j]] += 1
    row_sums = np.sum(cooccurrence_matrix, axis=1, keepdims=True)
    ppmi_matrix = np.log((cooccurrence_matrix * np.sum(row_sums)) / (row_sums @ row_sums.T) + 1e-8)
    ppmi_matrix[ppmi_matrix < 0] = 0
    return torch.tensor(ppmi_matrix, dtype=torch.float32)

def _bcd_step(Yt, Ut, Wp, Wn, gamma, llambda, tau, idx):
    UtU = Ut.T @ Ut
    r = UtU.shape[0]

    A = UtU + (gamma + llambda + 2 * tau) * torch.eye(r, device=Ut.device)
    B = Yt @ Ut + gamma * Ut[idx, :] + tau * (Wp + Wn)

    return torch.linalg.solve(A, B.T).T

def construct_graph_from_batch(batch):
    import networkx as nx
    G = nx.DiGraph()
    for src, dst in zip(batch.src.cpu().numpy(), batch.dst.cpu().numpy()):
        G.add_edge(src, dst)
    return G

# Initialize embeddings
def initialize_parameters(num_nodes, embedding_dim):
    U = torch.randn(num_nodes, embedding_dim, device=device, requires_grad=False)
    W = torch.randn(num_nodes, embedding_dim, device=device, requires_grad=False)
    return U, W

# Training function with BCD
def train_with_bcd(train_loader, num_nodes, walk_length, window_size, lambda_reg, tau_reg, gamma_reg, U, W, device):
    total_loss = 0

    for batch in train_loader:
        batch = batch.to(device)

        # Build graph for current batch
        graph = construct_graph_from_batch(batch)

        # Generate walks and compute PPMI
        walks = generate_walks(graph, walk_length,0.1,0.1)
        PPMI = compute_ppmi(walks, window_size, num_nodes)

        # Block Coordinate Descent (BCD)
        for idx in range(num_nodes):
            if idx > 0:
                Wp = W[idx - 1, :]
                Up = U[idx - 1, :]
            else:
                Wp = torch.zeros_like(W[idx, :])
                Up = torch.zeros_like(U[idx, :])

            if idx < num_nodes - 1:
                Wn = W[idx + 1, :]
                Un = U[idx + 1, :]
            else:
                Wn = torch.zeros_like(W[idx, :])
                Un = torch.zeros_like(U[idx, :])

            # Update embeddings for current node
            W[idx, :] = _bcd_step(PPMI[idx, :], U, Wp, Wn, gamma_reg, lambda_reg, tau_reg, idx)
            U[idx, :] = _bcd_step(PPMI[idx, :], W, Up, Un, gamma_reg, lambda_reg, tau_reg, idx)

        # Compute loss (optional)
        loss = torch.norm(PPMI - U @ W.T, p='fro')  # Frobenius norm
        total_loss += loss.item()

    return total_loss / len(train_loader)


# Main script
if __name__ == "__main__":
    # Parameters
    DATA = "tgbl-wiki"
    BATCH_SIZE = 200
    embedding_dim = 128
    teta = 1
    window_size = 5
    walk_length = 10
    lambda_reg = 0.1
    tau_reg = 0.1
    gamma_reg = 0.1
    num_epochs = 5

    # Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load dataset
    dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
    train_mask = dataset.train_mask
    val_mask = dataset.val_mask
    test_mask = dataset.test_mask
    data = dataset.get_TemporalData()
    data = data.to(device)

    train_data = data[train_mask]
    val_data = data[val_mask]
    test_data = data[test_mask]

    train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
    val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
    test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)

    # Initialize embeddings
    num_nodes = dataset.num_nodes
    U, W = initialize_parameters(num_nodes, embedding_dim)

    # Training loop
    for epoch in range(num_epochs):
        loss = train_with_bcd(train_loader, num_nodes, walk_length, window_size, lambda_reg, tau_reg, gamma_reg, U, W, device)
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss:.4f}")