In [1]:
import torch as th
import torch.nn as nn
import geoopt as gt

from net.HyperIM import HyperIM
from util import train, evalu, data

## Params

In [2]:
default_dtype = th.float64
th.set_default_dtype(default_dtype)

if th.cuda.is_available():
    cuda_device = th.device('cuda:0')
    th.cuda.set_device(device=cuda_device)
else:
    raise Exception('No CUDA device found.')
    
data_path = './data/sample/'

# for the sample
label_num = 103
vocab_size = 50000
word_num = 200

if_gru = True # otherwise use rnn
if_log = True # log result

epoch = 1
embed_dim = 10

train_batch_size = 50
test_batch_size = 50
lr = 1e-4

In [3]:
# use pre-trained embed if avalible    
word_embed = th.Tensor(vocab_size, embed_dim)
label_embed = th.Tensor(label_num, embed_dim)

net = HyperIM(word_num, word_embed, label_embed, hidden_size=embed_dim, if_gru=if_gru)
net.to(cuda_device)

loss = nn.BCEWithLogitsLoss()
optim = gt.optim.RiemannianAdam(net.parameters(), lr=lr)

train_data_loader, test_data_loader = data.load_data(data_path, train_batch_size, test_batch_size, word_num)

X_train shape torch.Size([50, 200]) y_train shape torch.Size([50, 103])
train_batch_num 40
X_test shape torch.Size([50, 200]) y_test shape torch.Size([50, 103])
test_batch_num 4


In [4]:
train.train(epoch, net, loss, optim, if_neg_samp=False, train_data_loader=train_data_loader)

train epoch:   0%|          | 0/1 [00:00<?, ?it/s]
batch: 0it [00:00, ?it/s][A
batch: 1it [00:06,  6.18s/it][A
batch: 2it [00:11,  5.92s/it][A
batch: 3it [00:16,  5.76s/it][A
batch: 4it [00:23,  6.05s/it][A
batch: 5it [00:29,  6.00s/it][A
batch: 6it [00:35,  6.02s/it][A
batch: 7it [00:41,  6.01s/it][A
batch: 8it [00:47,  6.05s/it][A
batch: 9it [00:53,  6.05s/it][A
batch: 10it [00:59,  6.00s/it][A
batch: 11it [01:05,  6.00s/it][A
batch: 12it [01:11,  5.96s/it][A
batch: 13it [01:17,  5.99s/it][A
batch: 14it [01:23,  5.96s/it][A
batch: 15it [01:29,  5.94s/it][A
batch: 16it [01:35,  5.91s/it][A
batch: 17it [01:41,  5.96s/it][A
batch: 18it [01:47,  5.95s/it][A
batch: 19it [01:53,  5.96s/it][A
batch: 20it [01:59,  5.98s/it][A
batch: 21it [02:05,  5.95s/it][A
batch: 22it [02:11,  6.00s/it][A
batch: 23it [02:16,  5.79s/it][A
batch: 24it [02:21,  5.51s/it][A
batch: 25it [02:26,  5.37s/it][A
batch: 26it [02:31,  5.28s/it][A
batch: 27it [02:36,  5.17s/it][A
batch: 28it

epoch 1 	loss 26.941045098798682


In [5]:
evalu.evaluate(net, if_log=if_log, test_data_loader=test_data_loader)

evaluating: 4it [00:06,  1.51s/it]

P@1	49.000		P@3	33.500		P@5	27.300
nDCG@1	49.000		nDCG@3	40.526		nDCG@5	44.617



