In [140]:
import torch
import torch.nn as nn
import torch.nn.functional as fn
from torch.autograd import Variable
import numpy as np

In [141]:
from tqdm import tqdm_notebook

In [237]:
# Multi-Component word embeddings

class MultiComp(nn.Module):
    def __init__(self, word_embeddings_to_comp, n_comp=3):
        super(MultiComp, self).__init__()
        self.n_comp = n_comp
        n_words, embedding_size = word_embeddings_to_comp.shape
        # expand to duplicate embeddings n_comp times
        expanded_word_embeddings = np.expand_dims(word_embeddings_to_comp, 1)
        comp_embeddings = np.tile(expanded_word_embeddings, [1, n_comp, 1])
        
        # add some noise to components (1/3 of its standard deviations)
        stds = word_embeddings_to_comp.std(axis=1, keepdims=True).reshape([n_words, 1, 1])
        comp_embeddings += np.random.randn(n_words, n_comp, embedding_size)  * stds  / 10
        comp_embeddings = comp_embeddings.astype(np.float32)
        
        # create variable to use autograd
        self.words_comps = nn.Parameter(torch.from_numpy(comp_embeddings))
        
        # weight matrices for attention (times 2 because concat context and comp)
        weights = np.random.randn(n_words, embedding_size * 2, 1)
        # xavier
        weights = weights * np.sqrt(2 / (embedding_size + n_comp))
        weights = weights.astype(np.float32)
        self.att_w = nn.Parameter(torch.from_numpy(weights))
        self.att_b = nn.Parameter(torch.zeros(n_words, n_comp, 1))
                
    def forward(self, context_embeddigs, word_n):
        # pick word_embeddings and linear layer weigts
        w_comps = self.words_comps[word_n]
        att_w = self.att_w[word_n]
        att_b = self.att_b[word_n]
        # sum the context across words dim 
        cont_sum = torch.mean(context_embeddigs, 0, keepdim=True)
        cont_sum_repeated = cont_sum.repeat(self.n_comp, 1)
        att_input = torch.cat([cont_sum_repeated, w_comps], dim=1)
        att = torch.matmul(att_input, att_w) + att_b
        att = fn.softmax(att, 0)
        comps_sum = torch.sum(w_comps * att, 0)
        dot_prod = torch.matmul(comps_sum, cont_sum.squeeze())/(torch.norm(comps_sum, p=1) * torch.norm(cont_sum, p=1))
        return dot_prod, att

# Test simple

In [238]:
w_emb = np.random.randn(3, 100)
net = MultiComp(w_emb)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)


n_samples = 10
n_context = 20 # number of words in the context
emb_dim = 100
for sample in np.random.randn(n_samples, n_context, emb_dim):
    # Prepare sample with Variable wrap
    sample = Variable(torch.from_numpy(sample.astype(np.float32)))
    net.zero_grad()
    dot_prod, att = net.forward(sample, 0)
    loss = -dot_prod
    loss.backward()
    opt.step()
    print(loss)

Variable containing:
1.00000e-05 *
  4.7743
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-04 *
  4.3998
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-03 *
 -1.5495
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-04 *
 -2.4649
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-03 *
 -1.7173
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-04 *
  4.7251
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-04 *
 -1.9560
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-03 *
 -2.3268
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-04 *
  9.6727
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-03 *
  1.9331
[torch.FloatTensor of size 1]



# Create dataset

In [7]:
ls -lh ../data/my_data/

total 13G
-rw-r--r-- 1 fogside fogside  8,1G янв 21 18:08 [0m[00mbig_one_file.txt[0m
-rw-r--r-- 1 fogside fogside   40M янв 21 18:22 [00mdict.txt[0m
drwxr-xr-x 3 fogside fogside  4,0K янв 21 18:05 [01;34mlibru[0m/
-rw-r--r-- 1 fogside fogside  3,0M янв 17 17:42 [00mmain_contexts_and_test.txt[0m
-rw-r--r-- 1 fogside fogside  623M янв 17 17:43 [00mmain_wiki_and_contexts.txt[0m
-rw-r--r-- 1 fogside fogside  620M янв 14 19:34 [00mmain_words_wiki_normalized_no_punct.txt[0m
-rw-r--r-- 1 fogside fogside  732M янв 14 18:56 [00mmain_words_wiki.txt[0m
-rw-r--r-- 1 fogside fogside 1019M окт 20 00:10 [00mruwiki_00.txt[0m
-rw-r--r-- 1 fogside fogside  1,1G янв 13 15:06 [00mruwiki_tokenized.txt[0m
drwxrwxr-x 4 fogside fogside  4,0K янв 19 18:03 [01;34mНКРЯ[0m/


In [8]:
ls -lh ../models/

total 3,0G
-rw-r--r-- 1 fogside fogside 1,3G дек  8 17:42 [0m[00mfast_text_model.bin[0m
-rw-r--r-- 1 fogside fogside 587M дек  8 17:42 [00mfast_text_model.vec[0m
-rw-r--r-- 1 fogside fogside 923M янв 22 04:51 [00mmodel_big_one.bin[0m
-rw-r--r-- 1 fogside fogside 171M янв 22 04:51 [00mmodel_big_one.vec[0m


Variable containing:
 125.4748
[torch.FloatTensor of size 1]

In [10]:
from pymystem3 import Mystem
stemmer = Mystem()

In [21]:
def get_all_indexes(lst, word):
    res = []
    i = 0
    while(True):
        try:
            i = lst.index(word, i)
            res.append(i)
            i+=1
        except:
            break
    return res

In [26]:
from tqdm import tqdm

In [58]:
def make_dataset(word, window):
    N = 1669868
    w = stemmer.lemmatize(word)[0]
    counter = 0

    with open("../data/my_data/big_one_file.txt", 'r') as bigf,\
    open("../data/my_data/{}_out.txt".format(word), 'a') as fout:
        for i in tqdm(range(N)):
            line = bigf.readline().split()
            if w in line:
                idxs = get_all_indexes(line, w)
                for i in idxs:
                    counter+=1
                    # each line is a group of neighbour words with length = 3*window
                    start = max(0, i-1-window) # if 0 is max then all before main word will be selected
                    fout.write(" ".join(line[start:i-1])+" "+" ".join(line[i:i+window])+'\n')
    return counter

In [60]:
make_dataset(word='замок', window=10)

100%|██████████| 1669868/1669868 [01:11<00:00, 23263.62it/s] 


111462

In [9]:
from gensim.models import KeyedVectors
wv = KeyedVectors.load_word2vec_format("../models/model_big_one.vec", binary=False)

In [62]:
with open("../data/my_data/{}_out.txt".format('замок'), 'r') as f:
    lines = f.readlines()

In [136]:
def generate_batch(lines, context_max_len):
    for line in lines:
        line = line.split()
        embedd = []
        for i, w in enumerate(line[:context_max_len]):
            try:
                embedd.append(wv[w])
            except KeyError:
                continue
        yield np.array(embedd)

In [267]:
batch_gen = generate_batch(lines, context_max_len=20)
w_emb = wv['замок'].reshape((1,100))
net = MultiComp(w_emb, 5)
opt = torch.optim.Adam(net.parameters(), lr=1e-4)


n_samples = len(lines)
n_context = 20 # number of words in the context
emb_dim = 100
pbar = tqdm_notebook(batch_gen, total = n_samples)
atts = list()
for n, sample in enumerate(pbar):
    # Prepare sample with Variable wrap
    sample = Variable(torch.from_numpy(sample.astype(np.float32)))
    net.zero_grad()
    dot_prod, att = net.forward(sample, 0)
    atts.append(att.data.numpy())
    loss = -dot_prod 
    loss.backward()
    opt.step()
    if n % 100 == 99:
        pbar.set_description("loss {:.3f}".format(float(loss.data.numpy())))

KeyboardInterrupt: 

In [260]:
atts = np.array(atts)

In [266]:
atts[-100:]

array([[[3.2565333e-07],
        [3.5325311e-07],
        [9.2793636e-07],
        [3.3731990e-07],
        [9.9999803e-01]],

       [[3.2582886e-07],
        [3.5345360e-07],
        [9.2848694e-07],
        [3.3750106e-07],
        [9.9999803e-01]],

       [[3.2596904e-07],
        [3.5360364e-07],
        [9.2888547e-07],
        [3.3764400e-07],
        [9.9999803e-01]],

       [[3.2609154e-07],
        [3.5373654e-07],
        [9.2932498e-07],
        [3.3778022e-07],
        [9.9999803e-01]],

       [[3.2602375e-07],
        [3.5367412e-07],
        [9.2918140e-07],
        [3.3773449e-07],
        [9.9999803e-01]],

       [[3.2581207e-07],
        [3.5344888e-07],
        [9.2853651e-07],
        [3.3752229e-07],
        [9.9999803e-01]],

       [[3.2538449e-07],
        [3.5298774e-07],
        [9.2705977e-07],
        [3.3707613e-07],
        [9.9999803e-01]],

       [[3.2515527e-07],
        [3.5273197e-07],
        [9.2621582e-07],
        [3.3681425e-07],
        [9.

In [256]:
wv.similar_by_vector(net.words_comps[0].data.numpy()[1], 302)

[('либуша', 0.49330535531044006),
 ('мечькин', 0.48986151814460754),
 ('альхен', 0.48868250846862793),
 ('бонтон', 0.4864254891872406),
 ('бертольди', 0.48547130823135376),
 ('тришатов', 0.4854549765586853),
 ('маллинер', 0.4826701879501343),
 ('полчасика', 0.4812876284122467),
 ('часок', 0.4810723066329956),
 ('еспер', 0.4767700433731079),
 ('зеленуда', 0.4758117198944092),
 ('часик', 0.47476765513420105),
 ('мамзель', 0.4705950617790222),
 ('пользительно', 0.46866458654403687),
 ('даровщинка', 0.46850210428237915),
 ('марихен', 0.4684235453605652),
 ('польди', 0.46806660294532776),
 ('пересаливать', 0.46789810061454773),
 ('алексевна', 0.4678904414176941),
 ('пунтило', 0.4661265015602112),
 ('крыжовенный', 0.46360862255096436),
 ('мараскин', 0.46271851658821106),
 ('фреди', 0.4613215923309326),
 ('максинька', 0.46104151010513306),
 ('дюбона', 0.4606403708457947),
 ('вотренный', 0.45866337418556213),
 ('ленивица', 0.45859938859939575),
 ('подразнить', 0.4584534168243408),
 ('покушать'

In [247]:
att

Variable containing:
 8.6015e-09
 1.0000e+00
 6.8154e-09
 7.6624e-09
 1.0283e-08
[torch.FloatTensor of size 5x1]