In [1]:
import os
import sys
from pathlib import Path
sys.path.append(os.path.join(Path().resolve(), '..'))
import numpy as np
from datasets import ptb
import dezero.layers as L
import dezero.functions as F
from dezero.optimizers import Adam
from dezero import Variable, Model
from dezero import DataLoader
import matplotlib.pyplot as plt
from utils import *
from data import PTB
from tqdm import tqdm
from datasets.ptb import load_data
import dezero

window_size = 5
hidden_size = 100
batch_size = 100
max_eopch = 10
eps = 1e-8
trainset = PTB(window_size=window_size)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
corpus = trainset.corpus
vocab_size = len(trainset.word_to_id)

class CBOW(Model):
    def __init__(self, vocab_size, hidden_dim=10):
        super().__init__()
        self.embed_in = L.EmbedID(vocab_size, hidden_dim)
        self.embed_out = L.EmbedID(vocab_size, hidden_dim)
        self.sampler = UnigramSampler(corpus)

    def forward(self, x):
        return self.embed_in(x).sum(axis=1)

    def embed(self, y):
        return self.embed_out(y)


model = CBOW(vocab_size, hidden_size)
model.load_weights('ptb_cbow.npz')

for word_vecs in model.embed_in.params():
    break
word_vecs = word_vecs.data    

corpus, w2i, i2w = load_data('train')

In [2]:
# # 教科書サンプルコードの学習済みパラメータを使うとき
# import pickle
# with open('cbow_params.pkl', 'rb') as f:
#     data = pickle.load(f)

# i2w = data['id_to_word']
# w2i = data['word_to_id']
# word_vecs = data['word_vecs']

In [3]:
querys = ['you', 'year', 'car', 'toyota']

In [4]:
def word_similality(w1, w2, w2i, i2w, word_vecs):
    w1_vec = word_vecs[w2i[w1]]
    w2_vec = word_vecs[w2i[w2]]
    return cos_similality(w1_vec, w2_vec)

In [5]:
word_similality('toyota', 'mazda', w2i, i2w, word_vecs)

0.29477829334851646

In [6]:
for query in querys:
    most_similar(query, w2i, i2w, word_vecs, top = 10)


 [query] you
we        : 0.74113
i         : 0.69585
they      : 0.68333
your      : 0.47319
he        : 0.42945
meals     : 0.41782
pasadena  : 0.39748
she       : 0.39077
scott     : 0.38869
analysts  : 0.38582

 [query] year
month     : 0.84547
week      : 0.84042
summer    : 0.59059
spring    : 0.55247
decade    : 0.54155
day       : 0.52436
piece     : 0.48625
time      : 0.47846
years     : 0.44805
thing     : 0.44672

 [query] car
cars      : 0.42419
auto      : 0.41969
across-the-board: 0.39437
letter    : 0.38786
rank      : 0.38754
diamond   : 0.38159
industrials: 0.38104
denying   : 0.37294
cd        : 0.37269
adjusting : 0.37182

 [query] toyota
westinghouse: 0.44242
loyal     : 0.41437
homes     : 0.40086
broadcasting: 0.39378
rivals    : 0.37883
esb       : 0.37727
grid      : 0.36735
slashed   : 0.36282
soar      : 0.35967
carat     : 0.35959


In [7]:
analogy('king', 'man', 'queen', w2i, i2w, word_vecs, top = 10)


[analogy] king:man = queen:?
 crackdown: 4.488529468804121
 recital: 3.8430350805558104
 honduras: 3.784316218826454
 worm: 3.7298828831012605
 nonsense: 3.72230297685637
 editorial-page: 3.6959209671899513
 prosecutorial: 3.5527883765871513
 birds: 3.4896121164097624
 brushed: 3.4335280281204605
 heroes: 3.411219536672392


In [8]:
analogy('take', 'took', 'go', w2i, i2w, word_vecs, top = 10)


[analogy] take:took = go:?
 pricings: 3.9008207009172287
 recalled: 3.6740769961302897
 summoned: 3.5325145002719305
 marched: 3.4863272358378175
 daffynition: 3.4397122134079043
 naked: 3.2857247589495024
 dominates: 3.2813056524275788
 berlitz: 3.268546680106323
 regard: 3.188953639198097
 dies: 3.170695775335239


In [9]:
analogy('car', 'cars', 'child', w2i, i2w, word_vecs, top = 10)


[analogy] car:cars = child:?
 recital: 4.272431678515458
 struggles: 3.937740357708284
 belts: 3.687268434990096
 intentionally: 3.6112456075460524
 photographs: 3.574466829437898
 queen: 3.5142402952209117
 sr.: 3.4302826600119474
 walks: 3.4057440374337418
 palmer: 3.365731852564985
 borders: 3.361399037777532
