In [1]:
import torch
import numpy as np
from time import time
import json

from datasets import WordNetDataset
from poincare import PoincareDistance, PoincareEmbedding, RiemannianSGD

In [2]:
with open('demo/config.json','r') as f:
    config = json.loads(f.read())

In [3]:
data = WordNetDataset(filename=config['data'])
dataloader = torch.utils.data.DataLoader(data,batch_size=config['batch_size'])

In [4]:
model = PoincareEmbedding(data.n_items)

model.initialize_embedding()

optimizer = RiemannianSGD(model.parameters())

In [None]:
emb_log = []

total_time = 0
for epoch in range(config['n_epochs']):
    epoch_loss = []
    start = time()
    
    if epoch<config['n_burn_in']:
        lr = config['lr']/config['c']
    else:
        lr = config['lr']
    
    for batch in dataloader:
        optimizer.zero_grad()
        
        x,y = batch
        
        preds = model(x,y)
        loss = model.loss(preds)
        
        loss.backward()
        
        optimizer.step(lr=lr)
        
        epoch_loss.append(loss.data.item())
    
    if epoch % 5 == 0:
        emb = model.embedding.weight.data.numpy().copy()
        emb_log.append(emb)
        
    time_per_epoch = time()-start
    total_time += time_per_epoch

    model.log.append(np.mean(epoch_loss))

    estimated_time = (total_time/(epoch+1))*(config['n_epochs']-epoch-1)

    minutes_left = int(estimated_time/60.)

    seconds_left = int(estimated_time-60*minutes_left)

    print('Epoch',epoch+1,'/',config['n_epochs'],'|',
         'loss:',"%.4f" % model.log[-1],'|',
         "time per epoch:","%.2f" % time_per_epoch,'sec.','|',
         'estimated training time:',minutes_left,'min.',seconds_left,'sec.',
          "%.4f" % np.sum(model.embedding.weight.data.numpy()),
         end='\r')

Epoch 6 / 1000 | loss: 3.9247 | time per epoch: 1.57 sec. | estimated training time: 26 min. 12 sec. 0.0122

In [None]:
import plotly.offline as plt
import plotly.graph_objs as go
import plotly.io as pio
from tqdm import tqdm, trange

In [None]:
with open('data/wordnet/mammal_hierarchy.tsv','r') as f:
    edgelist = [line.strip().split('\t') for line in f.readlines()]

In [None]:
loss_log = []
for i,value in enumerate(model.log):
    if i % 5==0:
        loss_log.append(value)

In [None]:
len(loss_log)

In [None]:
len(emb_log)

In [None]:
for i in trange(200):
    render_graph(vis=emb_log[i],loss=loss_log[i],filename='images/'+"0"*(3-len(str(i)))+str(i)+'.png')

In [None]:
def render_graph(vis,loss,filename):

    edge_trace = go.Scatter(
          x=[],
          y=[],
          line=dict(width=0.5,color='#888'),
          hoverinfo='none',
          mode='lines')

    for s0,s1 in edgelist:
        x0, y0 = vis[data.item2id[s0]]
        x1, y1 = vis[data.item2id[s1]]
        edge_trace['x'] += tuple([x0, x1, None])
        edge_trace['y'] += tuple([y0, y1, None])

    node_trace = go.Scatter(
        x=[],
        y=[],
        text=[],
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=False,
            reversescale=True,
            color='#8b9dc3',
            size=2)
        )

    for name in data.items:
        x, y = vis[data.item2id[name]]

        node_trace['x'] += tuple([x])
        node_trace['y'] += tuple([y])

        node_trace['text'] += tuple([name.split('.')[0]])

    display_list = ['placental.n.01',
       'primate.n.02',
       'mammal.n.01',
       'carnivore.n.01',
       'canine.n.02',
       'dog.n.01',
       'pug.n.01',
       'homo_erectus.n.01',
       'homo_sapiens.n.01',
       'terrier.n.01',
       'rodent.n.01',
       'ungulate.n.01',
       'odd-toed_ungulate.n.01',
       'even-toed_ungulate.n.01',
       'monkey.n.01',
       'cow.n.01',
       'welsh_pony.n.01',
       'feline.n.01',
       'cheetah.n.01',
       'mouse.n.01']

    label_trace = go.Scatter(
        x=[],
        y=[],
        mode='text',
        text=[],
        textposition='top center',
        textfont=dict(
            family='sans serif',
            size=13,
            color = "#000000"
        )
      )

    for name in display_list:
        x,y = vis[data.item2id[name]]
        label_trace['x'] += tuple([x])
        label_trace['y'] += tuple([y])
        label_trace['text'] += tuple([name.split('.')[0]])



    fig = go.Figure(data=[edge_trace, node_trace,label_trace],
           layout=go.Layout(
              title='loss:%.2f'%loss,
              width=700,
              height=700,
              titlefont=dict(size=16),
              showlegend=False,
              hovermode='closest',
              margin=dict(b=20,l=5,r=5,t=40),
              xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
              yaxis=dict(showgrid=False, zeroline=False, showticklabels=False)))
    
    pio.write_image(fig, filename)