In [1]:
import numpy as np
import torch, pdb
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
from itertools import ifilter
from IPython.core.debugger import set_trace
from random import randint

In [2]:
BATCH_SIZE = 200000
class Word2Vec(nn.Module):
    def __init__(self, vocab_size, hid_dim, pretrained=None):
        super(Word2Vec, self).__init__()
        self.hid_dim = hid_dim
        
        #these are by intent to learn separate embedding matrices, we return word_emb
        self.word_emb = nn.Embedding(vocab_size, hid_dim)
        if pretrained is not None:
            self.word_emb.weight.data.copy_(pretrained)
        self.context_emb = nn.Embedding(vocab_size, hid_dim)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, wrd, cntxt):
        wrd_vec = self.word_emb(wrd) # N * 1 * D
        cntxt_vec = self.context_emb(cntxt) # N * 5 * D
        res = torch.bmm(wrd_vec, cntxt_vec.view(BATCH_SIZE, self.hid_dim, -1))
        res = self.sigmoid(res) # N * 1 * 5
        res = res.squeeze(1) # for each mini-batch we have a probability score for the 5 contexts
        return res
        
def process_data(min_freq, flName=None, lines=None):
    """
    flName: file to read which contain raw data
    min_freq: required minimum frequency to be considered in vocabulary
    returns: vocab, vocab_size, word2index, index2word
    """
    vocab, index2word, word2index = {}, {}, {}
    if flName is not None:
        with open(flName) as fp:
            lines = fp.readlines()
            
    for line in lines:
        wrds = line.split(" ") #a very basic tokenizer that only splits by space and no stemming or cleaning
        for w in wrds:
            if w not in vocab:
                vocab[w] = 1
            else:
                vocab[w] += 1

    for wrd in vocab:
        if vocab[wrd] >= min_freq:
            index2word[len(index2word)] = wrd
            word2index[wrd] = len(index2word) - 1
        else:
            vocab[wrd] = 0
    vocab_size = len(index2word)
    return vocab, vocab_size, word2index, index2word

def negative_sampling_tbl(vocab, vocab_size, idx2word):
    total_cn = 0
    for wrd in vocab:
        total_cn += pow(vocab[wrd],0.75)
        
    tbl_size, wrd_idx = int(1e8), 0
    table = torch.LongTensor(tbl_size) # defaults to a column vector with only 1 dimension
    wrd_prob = pow(vocab[idx2word[wrd_idx]], 0.75)/total_cn
    
    for i in range(0, tbl_size):
        table[i] = wrd_idx
        if i/tbl_size > wrd_prob:
            wrd_idx += 1
            wrd_prob += pow(vocab[idx2word[wrd_idx]], 0.75)/total_cn
        if wrd_idx >= vocab_size:
            wrd_idx -= 1
    
    return table
        
# return the sample context the first one being the true word and other being negative
def sample_context(table, neg_cn, cntxt):
    cntxts, i = [], 0
    cntxts.append(cntxt)
    while i < neg_cn:
        ind = randint(0, len(table) - 1)
        neg_ctx = table[ind]
        if neg_ctx != cntxt:
            cntxts.append(neg_ctx)
            i += 1
    return cntxts


def train_pair(wrd_idx, cntxts, labels, mdl, criterion, optimizer):
    """
        wrd_idx: is the input word which is predicting the context
        cntxts: contains 1 positive word idx's and remaining negative words idx's forming the context
    """
    preds = mdl(wrd_idx, cntxts)
    loss = criterion(preds, labels)    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.data[0]
    
def train_model(mdl, lines, table=None, neg_exmpl=10, win_size=5):
    print('Training..')
    if table is None:
        table = negative_sampling_tbl(vocab, vocab_size, index2word)
    print('Data processing complete..')
    # by default it will give a float tensor, the first one is the positive, remaining are negative
    labels = Variable(torch.zeros(BATCH_SIZE, 1 + neg_exmpl))
    labels[:, 0] = 1
    
    if torch.cuda.is_available():
        labels = labels.cuda()
        table = table.cuda()
        mdl.cuda()

    criterion = nn.BCELoss()
    optimizer = torch.optim.SGD(mdl.parameters(), lr=0.025)
    
    track_loss, batch_wrd_idx, batch_cntxts = [], [], []
    for k, l in enumerate(lines):
        l = l.strip()
        wrds = l.split(" ")        
        #print("line:", k)
        for i, wrd in enumerate(wrds):
            #print("wrd:", i)
            if wrd in word2index:
                # target word which is predicting its context                
                wrd_idx = word2index[wrd]
                
                # for training word i searching for other words j in vicinity as its context
                # for ppdb its singleton, for story its multiple                
                for j in range(max(0, i - win_size), min(len(wrds), i + win_size)):
                    #print("cntxt:", j)
                    cntxt_wrd = l[j] #it's guarenteed to be a valid index
                    if j != i and cntxt_wrd in word2index:
                        cntxt_idx = word2index[cntxt_wrd]
                        cntxts = sample_context(table, neg_exmpl, cntxt_idx)
                        
                        batch_wrd_idx.append(wrd_idx)
                        batch_cntxts.append(cntxts)
                        
                        if len(batch_wrd_idx) == BATCH_SIZE:
                            print("line:", i, j, k)
                            var_wrd_idx = Variable(torch.LongTensor(batch_wrd_idx)).unsqueeze(1)
                            var_cntxts = Variable(torch.LongTensor(batch_cntxts))

                            if torch.cuda.is_available():
                                var_wrd_idx = var_wrd_idx.cuda()
                                var_cntxts = var_cntxts.cuda()
                            
                            lossval = train_pair(var_wrd_idx, var_cntxts, labels, mdl, criterion, optimizer)
                            print('loss:', lossval)
                            track_loss.append(lossval)
                            batch_wrd_idx[:], batch_cntxts[:] = [], []
                            
    return sum(track_loss)/len(track_loss)

In [3]:
def get_sim(wrd, k, mat, word2index):
    if wrd not in word2index:
        return None
    vec = mat[word2index[wrd], :].unsqueeze(1)
    othrs = torch.mm(mat, vec)
    othrs, ind = torch.sort(othrs, 0, descending=True)
    topk = ind[:k]
    for i in range(topk.size()[0]):
        print(index2word[topk[i][0]])    

In [4]:
def get_score(wrd1, wrd2, mat):
    if wrd1 not in word2index or wrd2 not in word2index:
        return 0.0
    vec1 = mat[word2index[wrd1]]
    vec2 = mat[word2index[wrd2]]
    return torch.dot(vec2, vec1)

In [5]:
with open("ppdb-2.0-xl-lexical", "r") as fp:
    lines = fp.readlines()

In [6]:
pairs = []
for l in lines:
    dt = l.split("|||")
    score = float(dt[3].split(" ")[1].split("=")[1])
    if score < 3.7:
        continue
    wrd1, wrd2 = dt[1], dt[2]
    wrd1, wrd2 = wrd1.strip(), wrd2.strip()
    if ".pdf" not in wrd1:
        pairs.append(wrd1 + " " + wrd2)

In [7]:
vocab, vocab_size, word2index, index2word = process_data(1, lines=pairs)

In [8]:
len(vocab)

65256

In [9]:
vocab_size

65256

In [None]:
mdl = Word2Vec(vocab_size, 300)
mdl.load_state_dict(torch.load('./mdl_preglove_300d.pth'))
w2vmat = torch.nn.functional.normalize(mdl.word_emb.weight.data.cpu())

In [39]:
mdl.word_emb.weight.data.cpu()[word2index['the'], :]  - pretrained_weight[word2index['the'], :]


1.00000e-02 *
 -0.3246
 -0.2606
  0.0572
  0.1338
  0.2410
 -0.1656
  0.0540
  0.1766
  0.3145
  0.4376
  0.0207
 -0.0702
 -0.1342
  0.0599
 -0.0723
 -0.5322
  0.3899
  0.4913
 -0.2826
 -0.4947
  0.2850
  0.1287
 -0.3109
  0.4558
 -0.3521
 -0.1082
  0.1826
 -0.1000
 -0.4710
 -0.7047
  0.2641
  0.3973
  0.3069
 -1.0578
  0.1623
  0.4239
 -0.3341
  0.1033
 -0.0524
 -0.4932
  0.2178
  0.2346
 -0.1333
 -0.0506
 -0.1475
  0.3552
  0.4031
 -0.6859
  0.1984
  0.3102
 -0.3771
  0.1117
  0.1921
 -0.5297
  0.0224
  0.4189
 -0.1848
  0.0406
 -0.6146
  0.2082
  0.3086
 -0.5745
 -0.1329
  0.0960
  0.0523
 -0.3201
  0.2758
 -0.2357
 -0.0274
  0.3173
  0.1876
  0.1278
 -0.0896
  0.3192
  0.5792
 -0.2303
 -0.5949
  0.2732
  0.5696
 -0.3904
  0.3003
  0.0063
 -0.2350
  0.0146
  0.3733
 -0.0701
  0.1105
  0.2877
  0.2795
  0.0150
 -0.9621
  0.1570
  0.2720
 -0.3763
  0.1755
  0.1365
 -0.3675
  0.0715
  0.2293
 -0.1106
  0.2469
 -0.3127
  0.5633
  0.2982
 -1.0295
 -0.2881
  0.3117
 -0.0708
  0.0244
  0.

torch.Size([300])

In [23]:
get_sim('old', 10)

old
man
whose
woman
boy
father
another
home
a
who


In [24]:
get_sim('hi', 10)

hi
ho
ai
tu
se
na
yo
ti
brasil
ya


In [25]:
get_sim('sleep', 10)

sleep
breathing
sleeping
pain
dying
awake
breath
patient
sick
waking


In [27]:
get_sim('young', 10)

young
keenly
blimp
draconian
koruna
keep
coexistence
2003-2006
bst
lowlife


In [28]:
get_sim('sleep', 10)

sleep
squeamish
rayon
elites
deselect
tucked
partenaire
snakes
co-occurring
india


In [29]:
get_sim('hi', 10)

hi
foreign-currency
min
24.6
hijacks
stadiums
outset
zia
toil
my-


In [30]:
get_sim('old', 10)

old
netanyahu
anti-poverty
bias-free
persecutes
marvin
bap
spanner
cancellable
benevolent
