### A Demo About Temporal Walk Matrix Maintaining
In this demo, we will show how to explictly and implicitly maintain different temporal walk matrices at single interaction level and batch interaction level. 

The element of a k-hop temporal walk matrix is.
$$
A_{u,v}^{(k)}(t) = \sum_{W\in M_{u,v}^{k}(t)} s(W),
$$
where $M_{u,v}^k(t)$ is the set of all k-step temporal walks from u to v, and $s(\cdot)$ is the score function.

Denoting a temporal walk as $W=[(w_0,t_0),(w_1,t_1),...,(w_k,t_k)]$ and current time as $t$, we consider the following two types of temporal walk matrices.
- Sum Matirx: its score function is $s(W)=\prod_{i=0}^{k} \text{exp}(-\lambda(t-t_i))$.
- Norm Matrix: its score function is $s(W)=\prod_{i=0}^{k-1} \frac{\text{exp}(-\lambda(t_{i}-t_{i+1}))}{\sum_{(\{w',w\},t')\in \mathcal{E}_{w_i,t_i}} \text{exp}(-\lambda(t_i-t'))}$.

The sum matrix corresponds to the matrix of TPNet and the norm matrix correspond to the matrix of CAWN.

#### Basic Utils
Basic utils including a function to genearte random temporal graphs and a function to compute the temporal walk matrices by brute force.

In [4]:
import numpy as np
from functools import reduce
import sys
import math
import torch

def generate_graph(node_num, edge_num):
    node_list = [i for i in range(node_num)]
    previous_time = 0
    src_node_ids = []
    dst_node_ids = []
    node_interact_times = []
    for i in range(edge_num):
        u = np.random.choice(node_list)
        v = np.random.choice(node_list[:u] + node_list[u + 1:])
        t = previous_time + np.random.randint(1, 5)
        src_node_ids.append(u)
        dst_node_ids.append(v)
        node_interact_times.append(t)
        previous_time = t
    return np.array(src_node_ids), np.array(dst_node_ids), np.array(node_interact_times)

def get_matrix_by_brute_force(src_node_ids, dst_node_ids, interact_times, matrix_type,num_layer,lam,node_num):
    """
    given a temporal graph G(t), generate the temporal walk matrices at t+1 by brute force.
    """
    adj_node = [[] for i in range(node_num)]
    adj_time = [[] for i in range(node_num)]
    for i in range(len(src_node_ids)):
        u, v, t = src_node_ids[i], dst_node_ids[i], interact_times[i]
        adj_node[u].append(v)
        adj_time[u].append(t)
        adj_node[v].append(u)
        adj_time[v].append(t)

    matrices = [torch.zeros((node_num, node_num)) for i in range(num_layer + 1)]
    last_time = interact_times[-1] + 1

    def dfs(now_node, now_time, node_list, time_list, score_list):
        pos = np.searchsorted(adj_time[now_node], now_time, 'left')
        if pos > 0 and len(node_list) <= num_layer:
            normalize_weight = np.sum(np.exp(-lam * (now_time - np.array(adj_time[now_node][:pos]))))
            for i in range(pos):
                next_node, next_time = adj_node[now_node][i], adj_time[now_node][i]
                if matrix_type == 'norm':
                    weight = np.exp(-lam * (now_time - next_time)) / normalize_weight
                elif matrix_type == 'sum':
                    weight = np.exp(-lam * (last_time - next_time))
                else:
                    raise ValueError("Not Implemented Matrix Type")
                dfs(next_node, next_time, node_list + [next_node], time_list + [next_time], score_list + [weight])

        u, v, hop = node_list[0], node_list[-1], len(node_list) - 1
        matrices[hop][u, v] += reduce(lambda a, b: a * b, score_list)

    for i in range(node_num):
        dfs(i, last_time, [i], [last_time], [1])

    return matrices

#### Matrix Updating Function for Single Interaction Updating
Given a temporal graph $G(t)$, if different interactions have different timestamps, the following two functions compute the corresponding temporal walk matrices at $t+1$ by updating the matrices incrementally, one interaction at a time.

In [5]:
def get_norm_matrix(src_node_ids, dst_node_ids, interact_times, num_layer, node_num, lam, use_projection, dimension):
    """
    given a temporal graph G(t), generate the norm temporal walk matrices at t+1.
    """
    if use_projection:
        matrices = [torch.normal(0, 1 / math.sqrt(dimension), (node_num, dimension))]
        matrices = matrices + [torch.zeros(node_num, dimension) for i in range(num_layer)]
    else:
        matrices = [torch.eye(node_num)] + [torch.zeros((node_num, node_num)) for i in range(num_layer)]

    degree = torch.zeros(node_num)
    previous_time = 0
    for i in range(len(src_node_ids)):
        u, v, t = src_node_ids[i], dst_node_ids[i], interact_times[i]
        # move current timestamp to t
        degree = degree * np.exp(-lam * (t - previous_time))
        # add interaction
        for j in range(num_layer, 0, -1):
            matrices[j][u] = (matrices[j][u] * degree[u] + matrices[j - 1][v]) / (degree[u] + 1)
            matrices[j][v] = (matrices[j][v] * degree[v] + matrices[j - 1][u]) / (degree[v] + 1)
        degree[u] = degree[u] + 1
        degree[v] = degree[v] + 1
        previous_time = t
    
    # esitmate the matrix by inner product
    if use_projection:
        matrices = [matrices[i] @ matrices[0].T for i in range(num_layer + 1)]
    return matrices

def get_sum_matrix(src_node_ids, dst_node_ids, interact_times, num_layer, node_num, lam, use_projection, dimension):
    """
    given a temporal graph G(t), generate the sum temporal walk matrices at t+1.
    """
    if use_projection:
        matrices = [torch.normal(0, 1 / math.sqrt(dimension), (node_num, dimension))]
        matrices = matrices + [torch.zeros(node_num, dimension) for i in range(num_layer)]
    else:
        matrices = [torch.eye(node_num)] + [torch.zeros((node_num, node_num)) for i in range(num_layer)]

    previous_time = 0
    for i in range(len(src_node_ids)):
        u, v, t = src_node_ids[i], dst_node_ids[i], interact_times[i]
        for j in range(num_layer, 0, -1):
            matrices[j] = matrices[j] * np.power(np.exp(-lam * (t - previous_time)), j)
        for j in range(num_layer, 0, -1):
            matrices[j][u] = matrices[j][u] + matrices[j - 1][v]
            matrices[j][v] = matrices[j][v] + matrices[j - 1][u]
        previous_time = t
    # move time to previous_time + 1
    for j in range(num_layer, 0, -1):
        matrices[j] = matrices[j] * np.power(np.exp(-lam * 1), j)
    
    # esitmate the matrix by inner product
    if use_projection:
        matrices = [matrices[i] @ matrices[0].T for i in range(num_layer + 1)]
    return matrices

#### Matrix Updating Function for Batch Interaction Updating
Given a temporal graph $G(t)$, the following two functions compute the corresponding temporal walk matrices at $t+1$ by updating the matrices incrementally, one batch of interactions at a time, if the following two conditions are satisfied
- The timestamps of interactions in the previous batch are smaller than those in the current batch
- Only using the interactions in current batch will not produce a temporal walk of length larger than 1



In [6]:
def get_norm_matrix_by_batch_updating(src_node_ids, dst_node_ids, interact_times, num_layer, node_num, lam,
                                      use_projection, dimension, batch_size, device):
    """
    given a temporal graph G(t), generate the norm temporal walk matrices at t+1.
    """
    if use_projection:
        matrices = [torch.normal(0, 1 / math.sqrt(dimension), (node_num, dimension)).to(device)]
        matrices = matrices + [torch.zeros(node_num, dimension).to(device) for i in range(num_layer)]
    else:
        matrices = [torch.eye(node_num).to(device)] + [torch.zeros((node_num, node_num)).to(device) for i in
                                                       range(num_layer)]

    degree = torch.zeros(node_num).to(device)
    previous_time = 0
    for l in range(0, len(src_node_ids), batch_size):
        r = min(l + batch_size, len(src_node_ids))
        batch_src_node_ids = src_node_ids[l:r]
        batch_dst_node_ids = dst_node_ids[l:r]
        batch_node_interact_times = interact_times[l:r]
        next_time = batch_node_interact_times[-1]
        # move current timestamp to next_time
        degree = degree * np.exp(-lam * (next_time - previous_time))

        # add interaction
        concat_target_nodes = np.concatenate([batch_src_node_ids, batch_dst_node_ids])
        concat_source_nodes = np.concatenate([batch_dst_node_ids, batch_src_node_ids])
        link_weight = np.exp(-lam * (next_time - np.tile(batch_node_interact_times, 2)))
        link_weight = torch.from_numpy(link_weight).to(device=device, dtype=torch.float32)
        delta_degree = torch.zeros_like(degree)
        delta_degree.scatter_add_(dim=0, src=link_weight, index=torch.from_numpy(concat_target_nodes).to(device))
        link_weight = (link_weight / (degree[concat_target_nodes] + delta_degree[concat_target_nodes]))

        for j in range(num_layer, 0, -1):
            message = link_weight[:, None] * (-matrices[j][concat_target_nodes] + matrices[j - 1][concat_source_nodes])
            matrices[j].scatter_add_(dim=0, src=message,
                                     index=torch.from_numpy(concat_target_nodes)[:, None].to(device).
                                     expand(-1, matrices[j].shape[1]))
        degree = degree + delta_degree
        previous_time = next_time

    # esitmate the matrix by inner product
    if use_projection:
        matrices = [matrices[i] @ matrices[0].T for i in range(num_layer + 1)]
    # move matrices to cpu
    matrices = [matrices[i].cpu() for i in range(num_layer + 1)]
    return matrices


def get_sum_matrix_by_batch_updating(src_node_ids, dst_node_ids, interact_times, num_layer, node_num, lam,
                                     use_projection, dimension, batch_size, device):
    """
    given a temporal graph G(t), generate the sum temporal walk matrices at t+1.
    """
    if use_projection:
        matrices = [torch.normal(0, 1 / math.sqrt(dimension), (node_num, dimension)).to(device)]
        matrices = matrices + [torch.zeros(node_num, dimension).to(device) for i in range(num_layer)]
    else:
        matrices = [torch.eye(node_num).to(device)] + [torch.zeros((node_num, node_num)).to(device) for i in
                                                       range(num_layer)]

    previous_time = 0
    for l in range(0, len(src_node_ids), batch_size):
        r = min(l + batch_size, len(src_node_ids))
        batch_src_node_ids = src_node_ids[l:r]
        batch_dst_node_ids = dst_node_ids[l:r]
        batch_node_interact_times = interact_times[l:r]
        next_time = batch_node_interact_times[-1]
        # move current timestamp to next_time
        for j in range(num_layer, 0, -1):
            matrices[j] = matrices[j] * np.power(np.exp(-lam * (next_time - previous_time)), j)

        concat_target_nodes = np.concatenate([batch_src_node_ids, batch_dst_node_ids])
        concat_source_nodes = np.concatenate([batch_dst_node_ids, batch_src_node_ids])
        link_weight = np.exp(-lam * (next_time - np.tile(batch_node_interact_times, 2)))
        link_weight = torch.from_numpy(link_weight).to(device=device, dtype=torch.float32)
        for j in range(num_layer, 0, -1):
            matrices[j].scatter_add_(dim=0, src=matrices[j - 1][concat_source_nodes] * link_weight[:, None],
                                     index=torch.from_numpy(concat_target_nodes)[:, None].to(device).
                                     expand(-1, matrices[j].shape[1]))
        previous_time = next_time
    # move time to previous_time + 1
    for j in range(num_layer, 0, -1):
        matrices[j] = matrices[j] * np.power(np.exp(-lam * 1), j)

    # esitmate the matrix by inner product
    if use_projection:
        matrices = [matrices[i] @ matrices[0].T for i in range(num_layer + 1)]
    # move matrices to cpu
    matrices = [matrices[i].cpu() for i in range(num_layer + 1)]
    return matrices

#### Unify All Things Together
In this part, we shows that the imcremental updating mechanism can generate the same temporal walk matrices as the brute force method.

We first show the correctness of the updating mechnaism for single interactoin.

In [7]:
# set relevant hyperparameters
node_num = 100
edge_num = 500
lam = 0.0001
num_layer = 3
dimension = 50
# generate a random graph
src_node_ids,dst_node_ids,node_interact_times = generate_graph(node_num=node_num,edge_num=edge_num)
# get matrices by different methods
sum_matrices_by_brute_force = get_matrix_by_brute_force(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,interact_times=node_interact_times,
                                         matrix_type='sum',lam=lam,num_layer=num_layer,node_num=node_num)
norm_matrices_by_brute_force = get_matrix_by_brute_force(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,interact_times=node_interact_times,
                                         matrix_type='norm',lam=lam,num_layer=num_layer,node_num=node_num)
sum_matrices_by_single_update = get_sum_matrix(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,interact_times=node_interact_times,lam=lam,
                                               num_layer=num_layer,node_num=node_num,use_projection=False,dimension=0)
norm_matrices_by_single_update = get_norm_matrix(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,interact_times=node_interact_times,lam=lam,
                                               num_layer=num_layer,node_num=node_num,use_projection=False,dimension=0)

for i in range(num_layer+1):
    assert torch.allclose(sum_matrices_by_single_update[i],sum_matrices_by_brute_force[i],rtol=1e-5,atol=1e-5),\
        f"{i}\n{sum_matrices_by_brute_force[i]}\n{sum_matrices_by_single_update[i]}"
    assert torch.allclose(norm_matrices_by_single_update[i],norm_matrices_by_brute_force[i],rtol=1e-5,atol=1e-5),\
    f"{i}\n{norm_matrices_by_brute_force[i]}\n{norm_matrices_by_single_update[i]}"

# implicitly maintain the temporal walk matrices by random projections
projected_sum_matrices_by_single_update = get_sum_matrix(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,interact_times=node_interact_times,lam=lam,
                                               num_layer=num_layer,node_num=node_num,use_projection=True,dimension=dimension)
projected_norm_matrices_by_single_update = get_norm_matrix(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,interact_times=node_interact_times,
                                                           lam=lam,num_layer=num_layer,node_num=node_num,use_projection=True,dimension=dimension)

def get_error_ratio(estimated_matrix,ground_truth_matrix):
    """
    The the estimated temporal walk matrix A'_{u,v}^{(k)} is caculated by <h_u^{(k)},h_v^{(0)}> 
    in the above functions, where h_u^{(k)} is the projection of the A_u^{(k)}.
    In this function, we compute $\frac{|<h_u^{(k)},h_v^{(0)}>-<A_u^{(k)},A_v^{(0)}>|}{0.5*(||A_u^{(k)}||_2^2+||A_v^{(0)}||_2^2)}$, which
    correspond $\epsion$ in theorem 2 of the original paper
    """
    delta_matrix = torch.abs(estimated_matrix-ground_truth_matrix)
    epsilon = delta_matrix / (0.5*(torch.sum(ground_truth_matrix**2,dim=1)[:,None]+1))
    return torch.mean(epsilon)
    
for i in range(num_layer+1):
    ratio1 = get_error_ratio(projected_norm_matrices_by_single_update[i],norm_matrices_by_brute_force[i])
    ratio2 = get_error_ratio(projected_sum_matrices_by_single_update[i],sum_matrices_by_brute_force[i])
    assert ratio1 < 0.2, f"norm{i}: {ratio1}"
    assert ratio2 < 0.2, f"sum{i}: {ratio2}"

We then show the correctness of the updating mechnaism for batch interactoin.

In [10]:
# set relevant hyperparameters
node_num = 100
edge_num = 500
lam = 0.0001
num_layer = 3
dimension = 50
batch_size = 10
device = 'cuda:0'
# generate a random graph
src_node_ids,dst_node_ids,node_interact_times = generate_graph(node_num=node_num,edge_num=edge_num)
# change the timestamps to satisfy the conditions of batch updating mechanism
assert (edge_num // batch_size) * batch_size == edge_num
node_interact_times = np.repeat(np.arange(1, edge_num//batch_size+1),batch_size)

# get matrices by different methods
sum_matrices_by_brute_force = get_matrix_by_brute_force(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,interact_times=node_interact_times,
                                         matrix_type='sum',lam=lam,num_layer=num_layer,node_num=node_num)
norm_matrices_by_brute_force = get_matrix_by_brute_force(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,interact_times=node_interact_times,
                                         matrix_type='norm',lam=lam,num_layer=num_layer,node_num=node_num)    
sum_matrices_by_batch_update = get_sum_matrix_by_batch_updating(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,
                                                                interact_times=node_interact_times,lam=lam,num_layer=num_layer,
                                                                node_num=node_num,use_projection=False,dimension=0,
                                                               batch_size=batch_size,device=device)
norm_matrices_by_batch_update = get_norm_matrix_by_batch_updating(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,
                                                                   interact_times=node_interact_times,lam=lam,num_layer=num_layer,
                                                                   node_num=node_num,use_projection=False,dimension=0,
                                                                  batch_size=batch_size,device=device)

for i in range(num_layer+1):
    assert torch.allclose(sum_matrices_by_batch_update[i],sum_matrices_by_brute_force[i],rtol=1e-5,atol=1e-5),\
        f"{i}\n{sum_matrices_by_brute_force[i]}\n{sum_matrices_by_batch_update[i]}"
    assert torch.allclose(norm_matrices_by_batch_update[i],norm_matrices_by_brute_force[i],rtol=1e-5,atol=1e-5),\
    f"{i}\n{norm_matrices_by_brute_force[i]}\n{norm_matrices_by_batch_update[i]}"

# implicitly maintain the temporal walk matrices by random projections
projected_sum_matrices_by_batch_update = get_sum_matrix_by_batch_updating(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,
                                                                interact_times=node_interact_times,lam=lam,num_layer=num_layer,
                                                                node_num=node_num,use_projection=True,dimension=dimension,
                                                               batch_size=batch_size,device=device)
projected_norm_matrices_by_batch_update = get_norm_matrix_by_batch_updating(src_node_ids=src_node_ids,dst_node_ids=dst_node_ids,
                                                                   interact_times=node_interact_times,lam=lam,num_layer=num_layer,
                                                                   node_num=node_num,use_projection=True,dimension=dimension,
                                                                  batch_size=batch_size,device=device)
    
for i in range(num_layer+1):
    ratio1 = get_error_ratio(projected_norm_matrices_by_batch_update[i],norm_matrices_by_brute_force[i])
    ratio2 = get_error_ratio(projected_sum_matrices_by_batch_update[i],sum_matrices_by_brute_force[i])
    assert ratio1 < 0.2, f"norm{i}: {ratio1}"
    assert ratio2 < 0.2, f"sum{i}: {ratio2}"