In [10]:
import timeit
import torch
from tqdm import tqdm
import numpy as np
import logging
from hype.sn import Embedding 
from hype import train
from hype.graph import load_edge_list, eval_reconstruction
from hype.rsgd import RiemannianSGD
from hype.Poincare import PoincareManifold
from hype.Halfspace import HalfspaceManifold
import sys
import json
import torch.multiprocessing as mp
from hype.graph_dataset import BatchedDataset

device = torch.device('cpu')
torch.set_default_tensor_type('torch.DoubleTensor')
#torch.set_default_tensor_type('torch.HalfTensor')
torch.manual_seed(42)
np.random.seed(42)

In [11]:
## parameters; these are global in the notebook!
opt_maxnorm = 500000; opt_debug = False; opt_epochs = 1000; 
opt_dim = 2; opt_com_n = 1;
opt_negs = 50; opt_batchsize = 10; opt_eval_each = 20;
opt_sparse = True; opt_ndproc = 4;  opt_burnin = 20;
opt_dampening = 0.75; opt_neg_multiplier = 1.0; 
opt_burnin_multiplier = 0.01; opt_lr = 0.3;
# opt_manifold = "Poincare"
opt_manifold = "Halfspace"
#######################################

In [12]:
MANIFOLDS = {
    'Poincare': PoincareManifold,
    'Halfspace': HalfspaceManifold,
}

### Initializing logging and data loading

In [13]:
log_level = logging.DEBUG if opt_debug else logging.INFO
log = logging.getLogger('MCF')
logging.basicConfig(level=log_level, format='%(message)s', stream=sys.stdout)
log.info('Using edge list dataloader')
idx, objects, weights = load_edge_list("wordnet/mammal_closure.csv", False) 

Using edge list dataloader


### Initializing model

In [14]:
def init_model(manifold, idx, objects, weights, sparse=True):
    conf = []
    model_name = '%s_dim%d%com_n'
    mname = model_name % (opt_manifold, opt_dim, opt_com_n)
    data = BatchedDataset(idx, objects, weights, opt_negs, opt_batchsize,
        opt_ndproc, opt_burnin > 0, opt_dampening)
    model = Embedding(len(data.objects), opt_dim, manifold, sparse=sparse, com_n=opt_com_n)
    data.objects = objects
    return model, data, mname, conf

def adj_matrix(data):
  adj = {}
  for inputs, _ in data:
    for row in inputs:
        x = row[0].item()
        y = row[1].item()
        if x in adj:
            adj[x].add(y)
        else:
            adj[x] = {y}
  return adj

### Training

In [15]:
def data_loader_lr(data, epoch, progress = False):
  data.burnin = False 
  lr = opt_lr
  if epoch < opt_burnin:
    data.burnin = True
    lr = opt_lr * train._lr_multiplier
  loader_iter = tqdm(data) if progress else data
  return loader_iter, lr

In [16]:
def train(device, model, data, optimizer, progress=False):  
  epoch_loss = torch.Tensor(len(data))
  LOSS = np.zeros(opt_epochs)
  
  for epoch in range(opt_epochs):
    largest_weight_emb = round(torch.abs(model.lt.weight.data).max().item(), 6)
    print(largest_weight_emb, "is the largest absolute weight in the embedding")
    
    epoch_loss.fill_(0)
    t_start = timeit.default_timer()
    # handling burnin, get loader_iter and learning rate
    loader_iter, lr = data_loader_lr(data, epoch, progress = progress)
    
    for i_batch, (inputs, targets) in enumerate(loader_iter):
      elapsed = timeit.default_timer() - t_start
      inputs = inputs.to(device); targets = targets.to(device)      
      optimizer.zero_grad()
      preds = model(inputs)
      loss = model.loss(preds, targets, size_average=True)
      loss.backward()
      optimizer.step(lr=lr)
      epoch_loss[i_batch] = loss.cpu().item()
      
    LOSS[epoch] = torch.mean(epoch_loss).item()
    # since only one thread is used:
    log.info('json_stats: {' f'"epoch": {epoch}, ' \
    f'"elapsed": {elapsed}, ' f'"loss": {LOSS[epoch]}, ' '}')
  return

# Training embedding

In [None]:
# setup model
# manifold = PoincareManifold(
#     debug=opt_debug, max_norm=opt_maxnorm, com_n=opt_com_n)
manifold = MANIFOLDS[opt_manifold](
        debug=opt_debug, max_norm=opt_maxnorm, com_n=opt_com_n)
model, data, model_name, conf = init_model(
    manifold, idx, objects, weights, sparse=opt_sparse)
data.neg_multiplier = opt_neg_multiplier
train._lr_multiplier = opt_burnin_multiplier
model = model.to(device)
print('the total dimension', model.lt.weight.data.size(-1), 'com_n', opt_com_n)

# setup optimizer
optimizer = RiemannianSGD(model.optim_params(manifold), lr= opt_lr)
# get adjacency matrix
adj = adj_matrix(data) 
# begin training
start_time = timeit.default_timer()
train(device, model, data, optimizer, progress=False )
print("Total training time is:", timeit.default_timer() - start_time)

>>>>>> The size of embedding: Embedding(1180, 2, sparse=True)
the total dimension 2 com_n 1
1.0001 is the largest absolute weight in the embedding
json_stats: {"epoch": 0, "elapsed": 1.0876862667500973, "loss": 3.930980864846105, }
1.008234 is the largest absolute weight in the embedding
json_stats: {"epoch": 1, "elapsed": 0.9176706080324948, "loss": 3.928893942473242, }
1.015815 is the largest absolute weight in the embedding
json_stats: {"epoch": 2, "elapsed": 0.9252970539964736, "loss": 3.9267565433052654, }
1.02358 is the largest absolute weight in the embedding
json_stats: {"epoch": 3, "elapsed": 0.9352685082703829, "loss": 3.924634707879188, }
1.031369 is the largest absolute weight in the embedding
json_stats: {"epoch": 4, "elapsed": 0.9831414399668574, "loss": 3.922485272402206, }
1.039123 is the largest absolute weight in the embedding
json_stats: {"epoch": 5, "elapsed": 0.9427397740073502, "loss": 3.9203554630891952, }
1.047044 is the largest absolute weight in the embedding


11.667213 is the largest absolute weight in the embedding
json_stats: {"epoch": 57, "elapsed": 0.9711592691019177, "loss": 2.5088971888323837, }
12.664697 is the largest absolute weight in the embedding
json_stats: {"epoch": 58, "elapsed": 0.966441951226443, "loss": 2.4915726436848473, }
13.782805 is the largest absolute weight in the embedding
json_stats: {"epoch": 59, "elapsed": 0.9803574127145112, "loss": 2.485219297448103, }
14.632995 is the largest absolute weight in the embedding
json_stats: {"epoch": 60, "elapsed": 0.9207399049773812, "loss": 2.473143853820206, }
15.960282 is the largest absolute weight in the embedding
json_stats: {"epoch": 61, "elapsed": 0.9877834278158844, "loss": 2.458539430789335, }
16.97113 is the largest absolute weight in the embedding
json_stats: {"epoch": 62, "elapsed": 0.9857727992348373, "loss": 2.4414890897527632, }
18.093299 is the largest absolute weight in the embedding
json_stats: {"epoch": 63, "elapsed": 0.9070373508147895, "loss": 2.4438569049

107.444035 is the largest absolute weight in the embedding
json_stats: {"epoch": 114, "elapsed": 0.97931591514498, "loss": 1.9869619424308222, }
108.918797 is the largest absolute weight in the embedding
json_stats: {"epoch": 115, "elapsed": 0.9541259459219873, "loss": 1.9694199458738817, }
110.779343 is the largest absolute weight in the embedding
json_stats: {"epoch": 116, "elapsed": 0.9028250090777874, "loss": 1.945453033968791, }
111.752875 is the largest absolute weight in the embedding
json_stats: {"epoch": 117, "elapsed": 0.9595736251212656, "loss": 1.953721267195683, }
112.653467 is the largest absolute weight in the embedding
json_stats: {"epoch": 118, "elapsed": 0.9552928740158677, "loss": 1.9314068152475647, }
115.716554 is the largest absolute weight in the embedding
json_stats: {"epoch": 119, "elapsed": 0.9560549128800631, "loss": 1.9344059483147655, }
117.025476 is the largest absolute weight in the embedding
json_stats: {"epoch": 120, "elapsed": 0.9413367039524019, "loss

226.235408 is the largest absolute weight in the embedding
json_stats: {"epoch": 170, "elapsed": 0.9312850870192051, "loss": 1.18146835012265, }
227.839101 is the largest absolute weight in the embedding
json_stats: {"epoch": 171, "elapsed": 0.9656376680359244, "loss": 1.1546922751470299, }
228.972145 is the largest absolute weight in the embedding
json_stats: {"epoch": 172, "elapsed": 0.9344279118813574, "loss": 1.1308344367146455, }
229.876182 is the largest absolute weight in the embedding
json_stats: {"epoch": 173, "elapsed": 0.9295264570973814, "loss": 1.1289467176964707, }
232.901229 is the largest absolute weight in the embedding
json_stats: {"epoch": 174, "elapsed": 0.9120643553324044, "loss": 1.1101533373645476, }
234.565041 is the largest absolute weight in the embedding
json_stats: {"epoch": 175, "elapsed": 0.9344237470068038, "loss": 1.0692102542619817, }
235.544297 is the largest absolute weight in the embedding
json_stats: {"epoch": 176, "elapsed": 0.9658458190970123, "lo

425.799821 is the largest absolute weight in the embedding
json_stats: {"epoch": 226, "elapsed": 0.9615197749808431, "loss": 0.45315033870508675, }
426.867591 is the largest absolute weight in the embedding
json_stats: {"epoch": 227, "elapsed": 0.9291023351252079, "loss": 0.46693713686679633, }
427.91958 is the largest absolute weight in the embedding
json_stats: {"epoch": 228, "elapsed": 0.925714819226414, "loss": 0.4526859364719992, }
428.315533 is the largest absolute weight in the embedding
json_stats: {"epoch": 229, "elapsed": 0.9283575308509171, "loss": 0.4439630273911217, }
428.514579 is the largest absolute weight in the embedding
json_stats: {"epoch": 230, "elapsed": 0.9700395078398287, "loss": 0.4620232678245277, }
429.286747 is the largest absolute weight in the embedding
json_stats: {"epoch": 231, "elapsed": 0.9593754620291293, "loss": 0.4405315834484256, }
429.410652 is the largest absolute weight in the embedding
json_stats: {"epoch": 232, "elapsed": 0.9432393028400838, "

# Evaluate embedding

In [45]:
meanrank, maprank = eval_reconstruction(
    adj, model.lt.weight.data.clone(), manifold.distance, workers=opt_ndproc)
sqnorms = manifold.pnorm(model.lt.weight.data.clone())
log.info(
        'json_stats final test: \n{' 
        f'"sqnorm_min": {round(sqnorms.min().item(),6)}, '
        f'"sqnorm_avg": {round(sqnorms.mean().item(),6)}, '
        f'"sqnorm_max": {round(sqnorms.max().item(),6)}, \n'
        f'"mean_rank": {round(meanrank,6)}, '
        f'"map": {round(maprank,6)}, '
        '}'
    )
print(model.lt.weight.data[0])

json_stats final test: 
{"sqnorm_min": 0.052636, "sqnorm_avg": 0.997684, "sqnorm_max": 1.0, 
"mean_rank": 1.973547, "map": 0.818664, }
tensor([-0.3037, -0.9528])
