<a href="https://colab.research.google.com/github/git-akhtari/graph/blob/main/TemporalNode2vecTgb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import numpy as np
!pip install torch-geometric
from torch_geometric.loader import TemporalDataLoader
!pip install py-tgb
from tgb.utils.utils import get_args
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
import random
from sklearn.linear_model import LogisticRegression
import os
from google.colab import drive
from tgb.utils.utils import set_random_seed, split_by_time, save_results

import sys
import os
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')
file_path = '/content/drive/MyDrive/TGB-TEST/node2vec_walks.py'
file_dir = os.path.dirname(file_path)
sys.path.append(file_dir)
import node2vec_walks


DATA = "tgbl-wiki"
from google.colab import drive
drive.mount('/content/drive')
dataset_path = '/content/drive/MyDrive/TGB-TEST/tgbl_wiki'
from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
dataset = PyGLinkPropPredDataset(name='tgbl-wiki', root=dataset_path)


# Generate Walks
def generate_walks(nx_G, length, p, q):
  G = node2vec_walks.Graph(nx_G, is_directed=False, p=p, q=q)
##G = Graph(nx_G, is_directed=False, p=p, q=q)
  G.preprocess_transition_probs()
  return G.simulate_walks(num_walks=1, walk_length=length)

# def generate_walks(graph, walk_length):
#     walks = []
#     for node in graph.nodes():
#         walk = [node]
#         for _ in range(walk_length - 1):
#             neighbors = list(graph.neighbors(walk[-1]))
#             if len(neighbors) > 0:
#                 walk.append(np.random.choice(neighbors))
#             else:
#                 break
#         walks.append(walk)
#     return walks

# def generate_walks(graph, length, p):
# walks = []
# for node in graph.nodes():
#     walk = [node]
#     while len(walk) < length:
#         current_node = walk[-1]
#         neighbors = list(graph.neighbors(current_node))
#         if not neighbors:
#             break
#         if random.random() < p:
#             walk.append(walk[-2] if len(walk) > 1 else random.choice(neighbors))
#         else:
#             next_node = random.choice(neighbors)
#             walk.append(next_node)
#     walks.append(walk)
# return walks

# Compute PPMI Matrix
def compute_ppmi(walks, window_size, num_nodes):
    cooccurrence_matrix = np.zeros((num_nodes, num_nodes))
    for walk in walks:
        for i, node in enumerate(walk):
            for j in range(max(i - window_size, 0), min(i + window_size + 1, len(walk))):
                if i != j:
                    cooccurrence_matrix[node, walk[j]] += 1
    row_sums = np.sum(cooccurrence_matrix, axis=1, keepdims=True)
    ppmi_matrix = np.log((cooccurrence_matrix * np.sum(row_sums)) / (row_sums @ row_sums.T) + 1e-8)
    ppmi_matrix[ppmi_matrix < 0] = 0
    return torch.tensor(ppmi_matrix, dtype=torch.float32)

def _bcd_step_batch(Yt, Ut, Wp, Wn, gamma, llambda, tau):
    UtU = Ut.T @ Ut
    r = UtU.shape[0]
    A = UtU + (gamma + llambda + 2 * tau) * torch.eye(r, device=Ut.device)
    B = Yt @ Ut.T + gamma * Ut.T + tau * (Wp + Wn).T

    return torch.linalg.solve(A, B).T

def construct_graph(batch):
    import networkx as nx
    G = nx.DiGraph()
    for src, dst in zip(batch.src.cpu().numpy(), batch.dst.cpu().numpy()):
        G.add_edge(src, dst)
    return G

# Initialize embeddings
def initialize_parameters(num_nodes, embedding_dim):
    U = torch.randn(num_nodes, embedding_dim, device=device, requires_grad=False)
    W = torch.randn(num_nodes, embedding_dim, device=device, requires_grad=False)
    return U, W

def initialize_parameters_from_ppmi(PPMI, embedding_dim):
    # Normalize PPMI rows using L1 norm
    row_sums = PPMI.abs().sum(dim=1, keepdim=True)
    PPMI_normalized = PPMI / row_sums
    # Compute variance for each row
    row_variances = PPMI_normalized.var(dim=1)
    # Sort rows based on variance and select the top d rows
    top_indices = torch.argsort(row_variances, descending=True)[:embedding_dim]
    # Initialize U and W with the selected rows
    U = PPMI_normalized[top_indices, :].clone().detach()
    W = PPMI_normalized[top_indices, :].clone().detach()
    return U, W

def save_checkpoint(U, W, filename):
    torch.save({"U": U, "W": W}, filename)
    print(f"Checkpoint saved: {filename}")

def load_checkpoint(filename, device="cpu"):
    if os.path.exists(filename):
        checkpoint = torch.load(filename, map_location=device)
        print(f"Checkpoint loaded: {filename}")
        return checkpoint["U"], checkpoint["W"]
    else:
        print("No checkpoint found, starting from scratch.")
        return None, None

# Training function with BCD
def train_with_bcd(train_loader, PPMI_All, num_nodes, walk_length, window_size, lambda_reg, tau_reg, gamma_reg, U, W, device, save_every=5):
    total_loss = 0

    checkpoint_path = "/content/drive/MyDrive/TGB-TEST/bcd_checkpoint.pth"
    U_loaded, W_loaded = load_checkpoint(checkpoint_path, device=device)
    if U_loaded is not None and W_loaded is not None:
        U, W = U_loaded.to(device), W_loaded.to(device)

    # Initialize batch_idx before the loop
    batch_idx = 0
    for batch in train_loader:
        batch = batch.to(device)

        # Build graph for current batch
        graph = construct_graph(batch)

        # Generate walks and compute PPMI
        walks = generate_walks(graph, walk_length,0.1,0.1)
        PPMI = compute_ppmi(walks, window_size, num_nodes)

        #batch_indices_src = batch.src.cpu().numpy()
        #batch_indices_dst = batch.dst.cpu().numpy()
        #PPMI = PPMI_All[batch_indices_src][:, batch_indices_dst].to(device)

        Wp = torch.cat([torch.zeros(1, W.shape[1], device=device), W[:-1]])
        Wn = torch.cat([W[1:], torch.zeros(1, W.shape[1], device=device)])
        Up = torch.cat([torch.zeros(1, U.shape[1], device=device), U[:-1]])
        Un = torch.cat([U[1:], torch.zeros(1, U.shape[1], device=device)])

        W = _bcd_step_batch(PPMI, U, Wp, Wn, gamma_reg, lambda_reg, tau_reg)
        U = _bcd_step_batch(PPMI, W, Up, Un, gamma_reg, lambda_reg, tau_reg)

        loss = torch.norm(PPMI - U @ W.T, p='fro')  # Frobenius norm
        total_loss += loss.item()
        print('---------------bcd-------------------')
        print(f"Batch {batch_idx + 1}/{len(train_loader)} - Loss: {loss.item()}")
        print("U @ W.T:", U @ W.T)
        if (batch_idx + 1) % save_every == 0:
            save_checkpoint(U, W, checkpoint_path)

        # Increment batch_idx after each iteration
        batch_idx += 1

    save_checkpoint(U, W, checkpoint_path)
    return (total_loss / len(train_loader)), U

def train(embeddings, train_loader, device):
    X, y = [], []

    for batch in train_loader:
        batch = batch.to(device)

        src, pos_dst, t, msg = batch.src, batch.dst, batch.t, batch.msg

        neg_dst = torch.randint(
            min_dst_idx,
            max_dst_idx + 1,
            (src.size(0),),
            dtype=torch.long,
            device=device,
        )

        for i in range(src.size(0)):
            pos_feature_vector = np.concatenate([embeddings[src[i].item()], embeddings[pos_dst[i].item()]])
            X.append(pos_feature_vector)
            y.append(1)

        for i in range(src.size(0)):
            neg_feature_vector = np.concatenate([embeddings[src[i].item()], embeddings[neg_dst[i].item()]])
            X.append(neg_feature_vector)
            y.append(0)

    X = np.array(X)
    y = np.array(y)

    model = LogisticRegression()
    model.fit(X, y)

    print('---------------- Logistic Regression Trained ------------------')
    return model

def val_test(model, embeddings, loader, neg_sampler, evaluator, split_mode):
    perf_list = []

    for pos_batch in loader:
        pos_src, pos_dst, pos_t, pos_msg = (
            pos_batch.src,
            pos_batch.dst,
            pos_batch.t,
            pos_batch.msg,
        )

        neg_batch_list = neg_sampler.query_batch(pos_src, pos_dst, pos_t, split_mode=split_mode)

        for idx, neg_batch in enumerate(neg_batch_list):
            src = pos_src[idx].item()
            pos_dst_node = pos_dst[idx].item()
            neg_dst_nodes = neg_batch

            pos_feature_vector = np.concatenate([embeddings[src], embeddings[pos_dst_node]]).reshape(1, -1)
            #pos_feature_vector = torch.cat([embeddings[src], embeddings[pos_dst_node]]).cpu().numpy().reshape(1, -1)
            y_pred_pos = model.predict_proba(pos_feature_vector)[:, 1]

            neg_feature_vectors = np.array([np.concatenate([embeddings[src], embeddings[neg]]) for neg in neg_dst_nodes])
            #neg_feature_vectors = np.array([torch.cat([embeddings[src], embeddings[neg]]).cpu().numpy() for neg in neg_dst_nodes])
            y_pred_neg = model.predict_proba(neg_feature_vectors)[:, 1]

            input_dict = {
                "y_pred_pos": np.array([y_pred_pos.squeeze()]),
                "y_pred_neg": np.array(y_pred_neg),
                "eval_metric": [metric],
            }
            perf_list.append(evaluator.eval(input_dict)[metric])
            print('-------------val/test---------------------')

    perf_metrics = float(torch.tensor(perf_list).mean())

    return perf_metrics

# Main script
if __name__ == "__main__":
    # Parameters
    DATA = "tgbl-wiki"
    BATCH_SIZE = 200
    embedding_dim = 20
    teta = 1
    window_size = 5
    walk_length = 3
    lambda_reg = 0.1
    tau_reg = 0.1
    gamma_reg = 0.1
    num_epochs = 1
    p = 0.1
    q = 0.1

    # Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load dataset
    #dataset = PyGLinkPropPredDataset(name=DATA, root="datasets")
    train_mask = dataset.train_mask
    val_mask = dataset.val_mask
    test_mask = dataset.test_mask
    data = dataset.get_TemporalData()
    data = data.to(device)
    num_nodes = dataset.num_nodes

    # Ensure to only sample actual destination nodes as negatives.
    min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

    train_data = data[train_mask]
    val_data = data[val_mask]
    test_data = data[test_mask]

    train_loader = TemporalDataLoader(train_data, batch_size=BATCH_SIZE)
    val_loader = TemporalDataLoader(val_data, batch_size=BATCH_SIZE)
    test_loader = TemporalDataLoader(test_data, batch_size=BATCH_SIZE)

    metric = dataset.eval_metric
    evaluator = Evaluator(name=DATA)
    neg_sampler = dataset.negative_sampler

    train_list = split_by_time(train_data)
    # Training loop over T snapshots
    num_snapshots = len(train_list)
    Embeddings_list = []

    for t in range(num_snapshots):  # For each time snapshot
        print(f"Processing snapshot {t + 1}/{num_snapshots}...")
        train_data_t = train_list[t]
        graph = construct_graph(train_data_t)
        walks = generate_walks(graph, walk_length, p, q)
        PPMI = compute_ppmi(walks, window_size, num_nodes).to(device)

        # Initialize embeddings using the ppmi
        U, W = initialize_parameters_from_ppmi(PPMI, embedding_dim)

        # Training loop for embedding
        for epoch in range(num_epochs):
            loss, U = train_with_bcd(train_data_t, PPMI, num_nodes, walk_length, window_size, lambda_reg, tau_reg, gamma_reg, U, W, device)
            print(f"Snapshot {t+1}, Epoch {epoch + 1}/{num_epochs}, Loss: {loss:.4f}")
            Embeddings_list.append(U)

            # KL Divergence
            if t > 0:
                U_t_prob = torch.nn.functional.softmax(U, dim=1)  # U_t normalization
                U_t1_prob = torch.nn.functional.softmax(Embeddings_list[-1], dim=1)  # U_{t-1} normalization
                kl_loss = torch.nn.functional.kl_div(U_t1_prob.log(), U_t_prob, reduction='batchmean')
                loss += beta * kl_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Snapshot {t+1}, Epoch {epoch + 1}/{num_epochs}, Loss: {loss:.4f}")

    Embeddings_list.append(U.clone().detach())

    model = train(Embeddings_list, train_loader, device)

    # validation
    # loading the validation negative samples
    dataset.load_val_ns()
    perf_metric_val = val_test(model, Embeddings_list, val_loader, neg_sampler, evaluator, split_mode="val")
    print(f"\tValidation {metric}: {perf_metric_val: .4f}")

    # test
    # loading the test negative samples
    dataset.load_test_ns()
    perf_metric_test = val_test(model, Embeddings_list, test_loader, neg_sampler, evaluator, split_mode="test")
    print(f"INFO: Test: Evaluation Setting: >>> ONE-VS-MANY <<< ")
    print(f"\tTest: {metric}: {perf_metric_test: .4f}")

