Skip to content

Commit

Permalink
smaller performance improvements and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mnick committed Mar 20, 2018
1 parent 056377f commit cfb55cd
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 9 deletions.
2 changes: 1 addition & 1 deletion embed.py
Expand Up @@ -113,7 +113,7 @@ def control(queue, log, types, data, fout, distfn, nepochs, processes):
parser.add_argument('-debug', help='Print debug output', action='store_true', default=False)
opt = parser.parse_args()

th.set_default_tensor_type('torch.DoubleTensor')
th.set_default_tensor_type('torch.FloatTensor')
if opt.debug:
log_level = logging.DEBUG
else:
Expand Down
4 changes: 2 additions & 2 deletions example.sh
Expand Up @@ -13,9 +13,9 @@ export OMP_NUM_THREADS=1
python3 embed.py \
-dim 5 \
-lr 0.3 \
-epochs 200 \
-epochs 300 \
-negs 50 \
-burnin 10 \
-burnin 20 \
-nproc "${NTHREADS}" \
-distfn poincare \
-dset wordnet/mammal_closure.tsv \
Expand Down
10 changes: 6 additions & 4 deletions model.py
Expand Up @@ -23,7 +23,6 @@ def __init__(self, eps=eps):
self.eps = eps

def forward(self, x):
self.save_for_backward(x)
self.z = th.sqrt(x * x - 1)
return th.log(x + self.z)

Expand Down Expand Up @@ -130,7 +129,7 @@ def loss(self, preds, targets, weight=None, size_average=True):

class GraphDataset(Dataset):
_ntries = 10
_dampening = 0.75
_dampening = 1

def __init__(self, idx, objects, nnegs, unigram_size=1e8):
print('Indexing data')
Expand Down Expand Up @@ -174,7 +173,10 @@ def __getitem__(self, i):
t, h, _ = self.idx[i]
negs = set()
ntries = 0
while ntries < self.max_tries and len(negs) < self.nnegs:
nnegs = self.nnegs
if self.burnin:
nnegs *= 0.1
while ntries < self.max_tries and len(negs) < nnegs:
if self.burnin:
n = randint(0, len(self.unigram_table))
n = int(self.unigram_table[n])
Expand All @@ -186,7 +188,7 @@ def __getitem__(self, i):
if len(negs) == 0:
negs.add(t)
ix = [t, h] + list(negs)
while len(ix) < self.nnegs + 2:
while len(ix) < nnegs + 2:
ix.append(ix[randint(2, len(ix))])
return th.LongTensor(ix).view(1, len(ix)), th.zeros(1).long()

Expand Down
2 changes: 1 addition & 1 deletion rsgd.py
Expand Up @@ -9,7 +9,7 @@
import torch as th
from torch.optim.optimizer import Optimizer, required

spten_t = th.sparse.DoubleTensor
spten_t = th.sparse.FloatTensor


def poincare_grad(p, d_p):
Expand Down
2 changes: 1 addition & 1 deletion train.py
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data import DataLoader
import gc

_lr_multiplier = 0.1
_lr_multiplier = 0.01

This comment has been minimized.

Copy link
@zxch3n

zxch3n Apr 28, 2018

in the paper, this value is said to be 0.1. Does 0.01 works better?



def train_mp(model, data, optimizer, opt, log, rank, queue):
Expand Down

0 comments on commit cfb55cd

Please sign in to comment.