<a href="https://colab.research.google.com/github/nickprock/influencer/blob/master/notebook/HITS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch

In [None]:
!pip install git+https://github.com/nickprock/influencer.git

In [None]:
import influencer
influencer.__version__

In [None]:
!pip install --upgrade jax jaxlib

In [None]:
from influencer.centrality import hits as npHITS
from influencer.torch_centrality import hits as torchHITS

In [None]:
torch.cuda.is_available()

In [None]:
# lazy_cerntrality version

import jax.numpy as jnp
from jax import jit

def jhits(adjMatrix, p: int = 100):
    n = adjMatrix.shape[0]
    
    a = jnp.ones([1,n])
    h = jnp.ones([1,n])
    
    pa=a
    
    authority = {}
    hub = {}
    
    for k in range(1,p):
        h1 = jnp.dot(adjMatrix, pa.T)/jnp.linalg.norm(jnp.dot(adjMatrix, pa.T))
        a1 = jnp.dot(adjMatrix.T, h1)/jnp.linalg.norm(jnp.dot(adjMatrix.T , h1))
    
        h = jnp.vstack((h,jnp.dot(adjMatrix, a[k-1,:].T)/jnp.linalg.norm(jnp.dot(adjMatrix, a[k-1,:].T))))
        a = jnp.vstack((a,jnp.dot(adjMatrix.T, h[k,:].T)/jnp.linalg.norm(jnp.dot(adjMatrix.T, h[k,:].T))))
    
        pa = a1.T
        
    for i in range(n):
        authority[str(i)] = a[-1,i]
        hub[str(i)] = h[-1,i]
    
    return hub, authority, h, a

In [None]:
jit_jhits = jit(jhits)

In [None]:
import time

In [None]:
np.random.seed(42)

num_nodes = [x for x in range(500,15000, 500)]
time_np = []
time_torch = []
time_torch_cpu = []
time_jnp = []

In [None]:
for N in num_nodes:
  adjM = np.random.rand(N, N)
  adjM[adjM>0.5]=1
  adjM[adjM<=0.5]=0
  start_time1 = time.time()
  _, _,_,_ = npHITS(adjM, p=10)
  exe_time1 = time.time() - start_time1
  MT = torch.from_numpy(adjM).float().to(0)
  start_time2 = time.time()
  _,_,_,_ = torchHITS(MT, p=10)
  exe_time2 = time.time() - start_time2
  MT_cpu = torch.from_numpy(adjM).float()
  start_time3 = time.time()
  _,_,_,_ = torchHITS(MT_cpu, p=10, device='cpu')
  exe_time3 = time.time() - start_time3
  start_time4 = time.time()
  _, _,_,_ = jhits(adjM, p=10)
  exe_time4 = time.time() - start_time4
  time_np.append(exe_time1)
  time_torch.append(exe_time2)
  time_torch_cpu.append(exe_time3)
  time_jnp.append(exe_time4)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure(figsize=(18,10))
plt.plot(num_nodes,time_np, 'bo')
plt.plot(num_nodes,time_torch, 'ro')
plt.plot(num_nodes,time_torch_cpu, 'go')
plt.plot(num_nodes,time_jnp, 'ko')
plt.xlabel("nodes")
plt.ylabel("seconds")
plt.title("HITS algorithm execution time")
plt.legend(["numpy", "torch", "torch_CPU", "JAX"])
plt.show()

In [None]:
plt.figure(figsize=(18,10))
plt.plot(num_nodes,time_np, 'bo')
plt.plot(num_nodes,time_torch, 'ro')
plt.plot(num_nodes,time_torch_cpu, 'go')
plt.xlabel("nodes")
plt.ylabel("seconds")
plt.title("HITS algorithm execution time")
plt.legend(["numpy", "torch", "torch_CPU"])
plt.show()