In [1]:
import sys
sys.path.append("..")

from torch.utils.data import DataLoader
import numpy as np
import torch as t

In [2]:
from loss.sgns import SGNSLoss
from datasets.newsgroups import NewsgroupsDataset
from model import Lda2vec
from train import get_args

In [3]:
args = get_args()
args.load_dataset = 'dataset.pth'
args.dataset_dir = '../data/'
args.toy = False
args.batch_size = 2

In [4]:
dataset = NewsgroupsDataset(args)
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

In [5]:
model = Lda2vec(len(dataset.term_freq_dict), len(dataset.files), args)
sgns = SGNSLoss(dataset, model.word_embeds, args.device)

In [6]:
(center, doc_id), target = iter(dataloader).next()

In [7]:
print(f'Target: {target}')
print(f'Center: {center}')
print(f'Doc Id: {doc_id}')

Target: tensor([[1],
        [2]])
Center: tensor([[0],
        [0]])
Doc Id: tensor([[1750],
        [1750]])


In [59]:
context = model((center, doc_id))
sgns(context, model.word_embeds(target)).item()

-33.19685363769531

In [16]:
losses = [sgns(context, model.word_embeds(target)).item() for i in range(10000)]
print(np.mean(losses))

RuntimeError: invalid argument 1: must be >= 0 and <= 1 at /pytorch/aten/src/TH/THRandom.cpp:320

In [None]:
"""
TESTING FORWARD PASS
"""
#target = model.word_embeds(target)
context, target = context.squeeze(), target.squeeze()  # batch_size x embed_size

print(f'context: {context.size()}')
print(f'target: {target.size()}')

# compute non-sampled portion
dots = (context * target).sum(-1)  # batch_size
print(f'dots: {dots.size()}')
log_targets = t.nn.functional.logsigmoid(dots)
print(f'log_targets: {log_targets.size()}')

# END NONSAMPLES PORTION WITH tensor(batch_size)
print()
print(log_targets)

In [None]:
# Sampled portions
samples = sgns.get_unigram_samples(2) # NUM_SAMPLES x EMBEDDING SIZE
print(f'samples: {samples.size()}')

log_samples = []
for row in samples:
    dot = (t.neg(context) * row).sum(-1) # batch_size
    log_sample = t.log(t.sigmoid(dot).clamp(min=1e-5, max=1-1e-5)) # batch_size
    log_samples.append(log_sample)
    print(log_sample)

print()
log_samples = t.stack(log_samples).sum(0) # batch_size
print()
print(log_samples)
# END SAMPLES PORTION WITH tensor(batch_size)

In [None]:
t.add(log_targets, log_samples).mean()

In [None]:
sgns.unigram_table.draw(2).to(sgns.device)