In [None]:
import math
import os
import random
import torch
from d2l import torch as d2l

In [None]:
d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',
                       '319d85e578af0cdc590547f26231e4e31cdf1e42')


#@save
def read_ptb():
    """将PTB数据集加载到文本行的列表中"""
    data_dir = d2l.download_extract('ptb')
    # Readthetrainingset.
    with open(os.path.join(data_dir, 'ptb.train.txt')) as f:
        raw_text = f.read()
    return [line.split() for line in raw_text.split('\n')]


sentences = read_ptb()
f'# sentences数: {len(sentences)}'

In [None]:
sentences[:10]

In [None]:
from NNUtils import NLP

vocab = NLP.Vocab(tokens=sentences, min_freq=10)
print('词典数: ', len(vocab))

## 下采样
- **目的**
    - 过滤 'the'， 'a'，等高频词无用信息
    - 减少语料库，加速训练
- **公式**
    - 每个词被丢弃的概率
$
P\left(w_{i}\right)=\max \left(1-\sqrt{\frac{t}{f\left(w_{i}\right)}}, 0\right)
$

In [None]:
#@save
def subsample(sentences, vocab):
    """
    下采样高频词
    Args:
        sentences (list):
        vocab (vocab):

    Returns:
        subsample_sentence (list), count (dict)
    """
    # 排除未知词元'<unk>'
    sentences = [[token for token in line if vocab[token] != vocab.unk]
                 for line in sentences]
    counter = NLP.count_corpus(sentences)
    num_tokens = sum(counter.values())

    # 如果在下采样期间保留词元，则返回True
    def keep(token):
        return (random.uniform(0, 1) <
                math.sqrt(1e-4 / counter[token] * num_tokens))

    return ([[token for token in line if keep(token)] for line in sentences],
            counter)


sentences_subsampled, counter = subsample(sentences, vocab)

In [None]:
import pandas as pd

orig_counter_dict = {token: freq for token, freq in vocab.token_freqs}
subsample_counter_dict = dict(counter)

df_counter = pd.DataFrame({'orig': orig_counter_dict, 'subsample': subsample_counter_dict})

df_counter.fillna(0, inplace=True)
# df_counter = pd.DataFrame({'a': {1: 9}, 'b':{1:10}})
#     df_counter.loc
#     print(token)
# orig_counter = vocab.token_freqs
# orig_counter
# df_counter

In [None]:

from plotly import graph_objs as go
import plotly.express as px

fig = px.bar(df_counter, x=df_counter.index, y='orig')
fig.show()

In [None]:
def compare_counts(token):
    return (f'"{token}"的数量：'
            f'之前={sum([l.count(token) for l in sentences])}, '
            f'之后={sum([l.count(token) for l in sentences_subsampled])}')


compare_counts('the')
compare_counts('join')

## 将`subsample`映射到idx

In [None]:
corpus = [vocab[line] for line in sentences_subsampled]
corpus[:10]

In [None]:
#@save
def get_centers_and_contexts(corpus, max_window_size):
    """
    返回跳元模型中的中心词和上下文词
    Args:
        corpus (list): 语料数组
        max_window_size (int): 最大上下文词长度
    """

    centers, contexts = [], []
    for line in corpus:
        # 要形成“中心词-上下文词”对，每个句子至少需要有2个词
        if len(line) < 2:
            continue
        centers += line
        for i in range(len(line)):  # 上下文窗口中间i
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, i - window_size),
                                 min(len(line), i + 1 + window_size)))
            # 从上下文词中排除中心词
            indices.remove(i)
            contexts.append([line[idx] for idx in indices])
    return centers, contexts

In [None]:
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('数据集', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('中心词', center, '的上下文词是', context)

In [None]:

all_centers, all_contexts = get_centers_and_contexts(corpus, 5)

"""
f'# “中心词-上下文词对”的数量: {sum([len(contexts) for contexts in all_contexts])}'
"""

all_contexts

## 负采样

In [None]:
t = random.choices(population=[1, 2, 3, 4, 5], weights=[0, 7, 3, 99, 66], k=10)
print(t)

In [None]:
class RandomGenerator:
    """
    根据n个采样权重在{1,2,..., n}中随机抽取
    """

    def __init__(self, sampling_weights):
        self.population = list(range(1, len(sampling_weights) + 1))
        self.sampling_weights = sampling_weights
        self.candidates = []
        self.i = 0

    def draw(self):
        if self.i == len(self.candidates):
            self.candidates = random.choices(self.population, self.sampling_weights, k=1000)
            self.i = 0
        self.i += 1
        return self.candidates[self.i - 1]


In [None]:
generator = RandomGenerator([2, 3, 4])
[generator.draw() for _ in range(10)]

In [None]:
def get_negatives(all_contexts, vocab, counter, K):
    """
    返回负采样的噪声词
    """
    sampling_weights = [counter[vocab.to_tokens(i)] ** 0.75
                        for i in range(1, len(vocab))]
    all_negatives, generator = [], RandomGenerator(sampling_weights)
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            neg = generator.draw()
            # 噪声词不能是上下文词
            if neg not in contexts:
                negatives.append(neg)
        all_negatives.append(negatives)
    return all_negatives


all_negatives = get_negatives(all_contexts, vocab, counter, 5)
all_negatives

In [None]:
#@save
def batchify(data):
    """返回带有负采样的跳元模型的小批量样本"""
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
    masks += [[1] * cur_len + [0] * (max_len - cur_len)]
    labels += [[1] * len(context) + [0] * (max_len - len(context))]

    return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(
        contexts_negatives), torch.tensor(masks), torch.tensor(labels))

In [None]:
x_1 = (1, [2, 2], [3, 3, 3, 3])
x_2 = (1, [2, 2, 2], [3, 3])
batch = batchify((x_1, x_2))

names = ['centers', 'contexts_negatives', 'masks', 'labels']
for name, data in zip(names, batch):
    print(name, '=', data)

In [None]:
#@save
def load_data_ptb(batch_size, max_window_size, num_noise_words):
    """下载PTB数据集，然后将其加载到内存中"""
    num_workers = d2l.get_dataloader_workers()
    sentences = read_ptb()
    vocab = d2l.Vocab(sentences, min_freq=10)
    subsampled, counter = subsample(sentences, vocab)
    corpus = [vocab[line] for line in subsampled]
    all_centers, all_contexts = get_centers_and_contexts(
        corpus, max_window_size)
    all_negatives = get_negatives(
        all_contexts, vocab, counter, num_noise_words)
    return all_centers, all_contexts, all_negatives


class PTBDataset(torch.utils.data.Dataset):
    def __init__(self, centers, contexts, negatives):
        assert len(centers) == len(contexts) == len(negatives)
        self.centers = centers
        self.contexts = contexts
        self.negatives = negatives

    def __getitem__(self, index):
        return (self.centers[index], self.contexts[index],
                self.negatives[index])

    def __len__(self):
        return len(self.centers)


In [None]:
batch_size=32

all_centers, all_contexts, all_negatives = load_data_ptb(batch_size=batch_size, max_window_size=5, num_noise_words=5)

dataset = PTBDataset(all_centers, all_contexts, all_negatives)

data_iter = torch.utils.data.DataLoader(
    dataset, batch_size, shuffle=True,
    collate_fn=batchify, num_workers=4)


In [None]:
# data_iter, vocab = load_data_ptb(512, 5, 5)
for batch in data_iter:
    for name, data in zip(names, batch):
        print(name, 'shape:', data.shape)
        break
    break