In [None]:
import os
import torch
import gc

from src.model import XRDict
from src.data import XRDictDataset, device, data_processing
from src.train import train

from src.vocab import Vocab
import torch
import json

In [None]:
vocab = Vocab('./embeddings/vec_inuse.json')
model = XRDict(ckpt_path='checkpoints/mlm_tlm_xnli15_1024.pth', vocab=vocab)

In [None]:
bpe = json.load(open('./data/train_bpe.json', 'r', encoding='utf-8'))
train_data, valid_data, test_data = data_processing(bpe, model.dico.word2id, vocab.word2id)

train_dataset = XRDictDataset(train_data)
valid_dataset = XRDictDataset(valid_data)
test_dataset = XRDictDataset(test_data)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
train(model, 1, (train_dataset, valid_dataset, train_dataset), optimizer, batch_size=2)

In [None]:
sentences_bpe = [d['definitions'] for d in bpe]
n_w = len([w for w in ' '.join(sentences_bpe).split()])
n_oov = len([w for w in ' '.join(sentences_bpe).split() if w not in model.dico.word2id])
print('Number of out-of-vocab words: %s/%s' % (n_oov, n_w))

In [None]:
sentences_bpe = [(('</s> %s </s>' % sent.strip()).split()) for sent in sentences_bpe]

bs = len(sentences_bpe)
slen = max([len(sent) for sent in sentences_bpe])

word_ids = torch.LongTensor(slen, bs).fill_(model.params.pad_index).to(device)
for i in range(len(sentences_bpe)):
    sent = torch.LongTensor([model.dico.index(w) for w in sentences_bpe[i]])
    word_ids[:len(sent), i] = sent

lengths = torch.LongTensor([len(sent) for sent in sentences_bpe]).to(device)

In [None]:
model.to(device)
score = model(x=word_ids, lengths=lengths, causal=False)