In [57]:
import collections
import math
import random
import torch
import torch.utils.data as Data
import torch.nn.functional as F

## 数据处理过程

In [2]:
with open('ptb.train.txt', 'r') as f:
    lines = f.readlines()
    # st表示一行句子
    raw_dataset = [st.split() for st in lines]

In [3]:
len(raw_dataset) # 42068行

42068

In [4]:
raw_dataset[0] # 1行为一句

['aer',
 'banknote',
 'berlitz',
 'calloway',
 'centrust',
 'cluett',
 'fromstein',
 'gitano',
 'guterman',
 'hydro-quebec',
 'ipo',
 'kia',
 'memotec',
 'mlx',
 'nahb',
 'punts',
 'rake',
 'regatta',
 'rubens',
 'sim',
 'snack-food',
 'ssangyong',
 'swapo',
 'wachter']

In [5]:
# tk是token的缩写
counter = collections.Counter([tk for st in raw_dataset for tk in st]) # 统计单词的频率
counter = dict(filter(lambda m: m[1] >= 5, counter.items())) # 去掉频率低于5的单词

In [6]:
idx_to_token = [tk for tk, _ in counter.items()]
token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}
dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]
           for st in raw_dataset] # 通过token_to_idx将文章翻译为id
num_tokens = sum([len(st) for st in dataset])
num_tokens # 文章所含单词数量

887100

In [7]:
len(idx_to_token) # 文章使用的不同单词数

9858

In [8]:
def discard(idx):
    """二次采样"""
    return random.uniform(0, 1) < 1 - math.sqrt(
         1e-4 / counter[idx_to_token[idx]] * num_tokens)

In [9]:
subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]
'# tokens: %d' % sum([len(st) for st in subsampled_dataset]) # 二次采样后的总字数

'# tokens: 376023'

In [10]:
def compare_counts(token):
    return '# %s: before=%d, after=%d' % (token, sum(
        [st.count(token_to_idx[token]) for st in dataset]), sum(
        [st.count(token_to_idx[token]) for st in subsampled_dataset]))

compare_counts('the') # 可以看出高频词去除较多

'# the: before=50770, after=2181'

In [11]:
compare_counts('join') # 可以看出低频词去除不多

'# join: before=45, after=45'

In [12]:
def get_centers_and_contexts(dataset_, max_window_size):
    """中心词与contexts词提取"""
    centers, contexts = [], []
    for st in dataset_:
        if len(st) < 2:  # 每个句子至少要有2个词才可能组成一对“中心词-背景词”
            continue
        centers += st
        for center_i in range(len(st)):
            window_size = random.randint(1, max_window_size) # 随机窗口大小,不超过最大窗口数
            indices = list(range(max(0, center_i - window_size),
                                 min(len(st), center_i + 1 + window_size)))
            indices.remove(center_i)  # 将中心词排除在背景词之外
            contexts.append([st[idx] for idx in indices])
    return centers, contexts

In [13]:
tiny_dataset = [list(range(17)), list(range(17, 30))]
print('dataset:', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 5)): # 最大窗口为5
    print('center', center, 'has contexts', context)

dataset: [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]]
center 0 has contexts [1, 2, 3]
center 1 has contexts [0, 2, 3, 4, 5]
center 2 has contexts [0, 1, 3, 4, 5]
center 3 has contexts [0, 1, 2, 4, 5, 6, 7]
center 4 has contexts [0, 1, 2, 3, 5, 6, 7, 8, 9]
center 5 has contexts [1, 2, 3, 4, 6, 7, 8, 9]
center 6 has contexts [1, 2, 3, 4, 5, 7, 8, 9, 10, 11]
center 7 has contexts [2, 3, 4, 5, 6, 8, 9, 10, 11, 12]
center 8 has contexts [3, 4, 5, 6, 7, 9, 10, 11, 12, 13]
center 9 has contexts [7, 8, 10, 11]
center 10 has contexts [9, 11]
center 11 has contexts [6, 7, 8, 9, 10, 12, 13, 14, 15, 16]
center 12 has contexts [8, 9, 10, 11, 13, 14, 15, 16]
center 13 has contexts [8, 9, 10, 11, 12, 14, 15, 16]
center 14 has contexts [9, 10, 11, 12, 13, 15, 16]
center 15 has contexts [13, 14, 16]
center 16 has contexts [12, 13, 14, 15]
center 17 has contexts [18, 19]
center 18 has contexts [17, 19, 20]
center 19 has contexts [17, 

In [14]:
all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)

In [15]:
all_centers # 所有的中心词

[0,
 1,
 3,
 4,
 6,
 9,
 11,
 12,
 13,
 1,
 16,
 18,
 19,
 20,
 22,
 4,
 24,
 16,
 25,
 26,
 27,
 28,
 30,
 11,
 32,
 34,
 35,
 36,
 37,
 41,
 42,
 43,
 44,
 45,
 46,
 49,
 50,
 10,
 52,
 53,
 55,
 3,
 57,
 58,
 36,
 59,
 15,
 60,
 61,
 62,
 64,
 65,
 66,
 67,
 71,
 72,
 57,
 77,
 78,
 41,
 80,
 81,
 82,
 42,
 43,
 86,
 87,
 54,
 90,
 91,
 92,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 4,
 104,
 107,
 3,
 56,
 109,
 110,
 111,
 17,
 36,
 112,
 114,
 115,
 36,
 118,
 121,
 122,
 57,
 124,
 51,
 125,
 128,
 17,
 41,
 80,
 117,
 130,
 127,
 132,
 133,
 136,
 137,
 138,
 139,
 48,
 140,
 141,
 142,
 143,
 57,
 48,
 140,
 146,
 147,
 148,
 149,
 150,
 36,
 38,
 152,
 153,
 154,
 43,
 157,
 158,
 159,
 160,
 161,
 144,
 41,
 80,
 62,
 43,
 50,
 165,
 123,
 166,
 167,
 7,
 168,
 169,
 54,
 170,
 171,
 176,
 177,
 48,
 180,
 49,
 181,
 182,
 48,
 184,
 57,
 186,
 187,
 188,
 190,
 177,
 141,
 47,
 182,
 48,
 49,
 50,
 51,
 191,
 192,
 155,
 193,
 194,
 196,
 113,
 36,
 51,
 124,
 198,
 199,
 73,
 2

In [16]:
print(len(all_centers), len(set(all_centers)))

375090 9857


In [17]:
all_contexts # 中心词对应的contexts词

[[1, 3, 4],
 [0, 3, 4, 6],
 [0, 1, 4, 6, 9, 11, 12],
 [1, 3, 6, 9],
 [0, 1, 3, 4, 9, 11, 12, 13],
 [6, 11],
 [4, 6, 9, 12, 13],
 [6, 9, 11, 13],
 [11, 12],
 [16, 18, 19],
 [1, 18, 19, 20],
 [1, 16, 19, 20],
 [16, 18, 20],
 [19],
 [4, 24],
 [22, 24],
 [22, 4, 16, 25],
 [4, 24, 25, 26],
 [22, 4, 24, 16, 26, 27, 28, 30],
 [4, 24, 16, 25, 27, 28, 30, 11],
 [24, 16, 25, 26, 28, 30, 11, 32],
 [25, 26, 27, 30, 11, 32],
 [26, 27, 28, 11, 32, 34],
 [28, 30, 32, 34],
 [11, 34],
 [32],
 [36, 37, 41, 42],
 [35, 37],
 [36, 41],
 [35, 36, 37, 42, 43, 44, 45],
 [36, 37, 41, 43, 44, 45],
 [41, 42, 44, 45],
 [43, 45],
 [44, 46],
 [42, 43, 44, 45, 49, 50, 10, 52],
 [42, 43, 44, 45, 46, 50, 10, 52, 53, 55],
 [46, 49, 10, 52],
 [49, 50, 52, 53],
 [10, 53],
 [46, 49, 50, 10, 52, 55, 3, 57, 58],
 [50, 10, 52, 53, 3, 57, 58],
 [10, 52, 53, 55, 57, 58],
 [3, 58],
 [53, 55, 3, 57],
 [59, 15, 60],
 [36, 15, 60, 61, 62],
 [36, 59, 60, 61, 62, 64, 65],
 [59, 15, 61, 62],
 [36, 59, 15, 60, 62, 64, 65, 66],
 [36, 5

In [18]:
def get_negatives(all_contexts_, sampling_weights_, K):
    """负采样"""
    all_negatives_, neg_candidates, i = [], [], 0 # neg_candidates为备选负样本的词
    population = list(range(len(sampling_weights)))
    for contexts in all_contexts_:
        negatives = []
        while len(negatives) < len(contexts) * K:
            if i == len(neg_candidates):
                i, neg_candidates = 0, random.choices(population, sampling_weights_, k=int(1e5)) # 根据概率sampling_weights从population中选择int(1e5)个词
            neg, i = neg_candidates[i], i + 1
            if neg not in set(contexts): # 噪音词不能是背景词
                negatives.append(neg)
        all_negatives_.append(negatives)
    return all_negatives_

sampling_weights = [counter[w]**0.75 for w in idx_to_token] # 借用了二次采样的思想
all_negatives = get_negatives(all_contexts, sampling_weights, 5)

In [19]:
all_negatives # 噪声词

[[4459,
  2,
  9365,
  1941,
  1282,
  9290,
  5,
  1352,
  3611,
  3938,
  5285,
  927,
  6128,
  735,
  1064],
 [9267,
  161,
  75,
  2704,
  5002,
  7033,
  1169,
  79,
  1824,
  3638,
  2225,
  1539,
  5134,
  3604,
  3284,
  1014,
  435,
  1406,
  211,
  105],
 [5172,
  2831,
  1478,
  2627,
  293,
  762,
  2440,
  1018,
  3159,
  3536,
  23,
  1168,
  73,
  9741,
  4816,
  406,
  7819,
  44,
  1312,
  135,
  312,
  2467,
  1450,
  7,
  7,
  129,
  4375,
  1390,
  2665,
  2406,
  860,
  2702,
  906,
  4057,
  4981],
 [17,
  650,
  1668,
  243,
  246,
  88,
  4326,
  1855,
  448,
  103,
  3361,
  265,
  1571,
  4956,
  156,
  362,
  290,
  17,
  1287,
  1775],
 [3174,
  291,
  4903,
  476,
  103,
  6438,
  1829,
  391,
  2487,
  8593,
  1576,
  7135,
  3482,
  207,
  4063,
  621,
  4790,
  4977,
  3821,
  749,
  53,
  685,
  1980,
  839,
  7311,
  919,
  2531,
  83,
  4377,
  2083,
  3988,
  1076,
  1658,
  2402,
  436,
  849,
  1379,
  185,
  204,
  2230],
 [3868, 63, 2797, 961, 4

In [20]:
def batchify(text):
    """对数据进行处理,使之长度一致,并通过masks,labels加以区分"""
    max_len = max(len(c) + len(n) for _, c, n in text)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center_, context_, negative_ in text:
        cur_len = len(context_) + len(negative_)
        centers += [center_]
        contexts_negatives += [context_ + negative_ + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)] # 元素为填充时,masks中的值取值为0,否则取值为1
        labels += [[1] * len(context_) + [0] * (max_len - len(context_))] # 背景词取值为1,其他词(噪音词和填充)取值为0
    return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),
            torch.tensor(masks), torch.tensor(labels))

In [21]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, centers, contexts, negatives):
        assert len(centers) == len(contexts) == len(negatives)
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives

    def __getitem__(self, index):
        return self.centers[index], self.contexts[index], self.negatives[index]

    def __len__(self):
        return len(self.centers)

In [22]:
batch_size = 512

dataset1 = MyDataset(all_centers,
                    all_contexts,
                    all_negatives)

data_iter = Data.DataLoader(dataset1, batch_size, shuffle=True,
                            collate_fn=batchify)
for batch in data_iter:
    for name, data in zip(['centers', 'contexts_negatives', 'masks',
                           'labels'], batch):
        print(name, 'shape:', data.shape) # 最终的数据格式
    break

centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])


In [23]:
from SkipGram_train import train

In [24]:
embed_size, embed_dimension = len(idx_to_token), 100
last_net = train(data_iter, embed_size, embed_dimension, 0.01, 3)

train on cpu
epoch 1, loss 0.47, time 98.90s
epoch 2, loss 0.41, time 97.65s
epoch 3, loss 0.36, time 98.81s


In [26]:
last_net.u_embeddings

Embedding(9858, 100, sparse=True)

In [29]:
last_net.u_embeddings.weight.data.shape

torch.Size([9858, 100])

In [60]:
def get_similar_tokens(query_token, k, net):
    """计算词与词之间的余弦相似度"""
    W = net.weight.data
    X = torch.unsqueeze(W[token_to_idx[query_token]], 0).repeat_interleave(embed_size, dim=0) # 通过扩充,使W,X形状相同
    cos = F.cosine_similarity(W, X) # 计算W,X的余弦相似度
    _, topk = torch.topk(cos, k=k+1)
    topk = topk.cpu().numpy()
    for i in topk[1:]:  # 排除输入词
        print('cosine sim=%.3f: %s' % (cos[i], (idx_to_token[i])))

get_similar_tokens('chip', 10, last_net.u_embeddings)


cosine sim=0.547: microprocessor
cosine sim=0.495: occasion
cosine sim=0.457: chips
cosine sim=0.457: kodak
cosine sim=0.455: intel
cosine sim=0.454: grains
cosine sim=0.438: risc
cosine sim=0.434: microprocessors
cosine sim=0.430: tricky
cosine sim=0.429: hewlett-packard
