# 寻找最近邻 embedding

我的构想是，拿到红楼梦里所有词汇的embedding，然后看哪个词离我们感兴趣的词（比如：林黛玉）最近

In [1]:
import re
import collections

from transformers import BertModel, BertTokenizer
import numpy as np
import torch
import jieba

jieba.load_userdict('./data/user_dict.txt')

Building prefix dict from the default dictionary ...
Loading model from cache /var/folders/0v/110wmd1964s9xk3hg_ty7hnh0000gn/T/jieba.cache
Loading model cost 0.253 seconds.
Prefix dict has been built successfully.


In [2]:
CN_BERT_PATH = './data/bert-base-chinese'
CN_BOOK_PATH = './data/红楼梦.txt'
CN_STOP_WORDS = './data/cn_stopwords.txt'
MIN_FREQ = 100

## 1. 分词

首先做文本预处理，对《红楼梦》做分词。

In [3]:
# 加载停用词
def load_stop_words(stop_words_path):
    with open(stop_words_path, 'r') as f:
        stop_words = f.read()
    return stop_words.split('\n')

In [4]:
# 对《红楼梦》做分词
def preprocess(book_path, cn_stop_words_path):
    with open(book_path, 'r') as f:
        content = f.read()
        
        # 删除 \n \u3000 \u3000
        pattern = re.compile(r'(\n|\u3000|\u3000)', re.IGNORECASE)
        text = pattern.sub('', content)
        
        # 加载中文停用词
        cn_stop_words = load_stop_words(cn_stop_words_path)
        
        # 计算分词
        corpus = [w for w in jieba.lcut(text)
                  if w not in cn_stop_words and len(w) > 1]

        return corpus

In [5]:
corpus = preprocess(book_path=CN_BOOK_PATH,
                    cn_stop_words_path=CN_STOP_WORDS)
len(corpus)

198712

In [6]:
# 词频统计
ctr = collections.Counter(corpus)

# 高频词
ctr.most_common()[:5]

[('宝玉', 3762), ('一个', 1356), ('贾母', 1323), ('凤姐', 1202), ('王夫人', 1020)]

In [7]:
# 低频词
ctr.most_common()[-5:]  

[('抄者', 1), ('游戏笔墨', 1), ('陶情适性', 1), ('曾题', 1), ('更进一竿', 1)]

In [8]:
# 过滤出现次数过低的词
n_corpus = [k for k, v in ctr.items() if v > MIN_FREQ]
len(n_corpus)

225

In [9]:
# 看一下最长的词汇
max_length = max([len(e) for e in n_corpus])
print('max_length:', max_length)

[e for e in n_corpus if len(e) == max_length][:3]

max_length: 4


['下回分解']

## 2. 批量计算 embedding

In [10]:
tokenizer = BertTokenizer.from_pretrained(CN_BERT_PATH)
model = BertModel.from_pretrained(CN_BERT_PATH)

Some weights of the model checkpoint at ./data/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
# 批量计算 embeddings
def get_embeddings(corpus):
    encoded_inputs = tokenizer(corpus,
                               padding='max_length',
                               truncation=True,
                               return_tensors='pt')

    with torch.no_grad():
        outputs = model(**encoded_inputs)
        embeddings = outputs.last_hidden_state.mean(dim=1)
        # embeddings = outputs.last_hidden_state[:, 0, :]

    return embeddings

In [12]:
embeddings = get_embeddings(corpus=n_corpus[:3])
embeddings.shape

torch.Size([3, 768])

In [13]:
embeddings

tensor([[-0.8313,  0.6034,  0.0404,  ...,  0.2570,  0.2097,  0.1184],
        [-0.8494,  0.4742, -0.7889,  ...,  0.5198,  0.0041,  0.0336],
        [-0.7053,  0.1953, -0.2742,  ...,  0.4208,  0.0719, -0.1333]])

## 3. 计算每个词的 embedding

现在，我们可以计算每个词的 embedding 了。

并且把 分词 和 embedding 存成如下形式：

```python
[
    ['甄士隐'， [-0.9199, ... , -0.2456,  0.0457]],
    ['梦幻'， [-0.7423, ... , -0.3301, -0.1229]],
    ...
]
```

In [14]:
word_embeddings = get_embeddings(corpus=n_corpus)
len(word_embeddings)

225

In [15]:
corpus2embeddings = [[k, v] for k, v in zip(n_corpus, [np.array(e) for e in word_embeddings])]
len(corpus2embeddings)

225

## 4. 计算我们关心词汇的近邻 embedding

In [16]:
def nearest_embedding(embedding, embeddings):
    squared_distances = np.sum((embeddings - embedding) ** 2, axis=1)
    nearest_idx = np.argmin(squared_distances)

    return embeddings[nearest_idx], nearest_idx

def nearest_word(word, c2e):
    word_list = [e[0] for e in c2e]
    if word not in word_list:
        raise Exception(f'{word} not in word list.')

    embedding = [v for k, v in c2e if k == word][0]
    embeddings = [v for k, v in c2e if k != word]
    _, idx = nearest_embedding(embedding, embeddings)

    return word_list[idx]

In [17]:
nearest_word(word='黛玉',
             c2e=corpus2embeddings)

'孩子'

In [18]:
nearest_word(word='凤姐',
             c2e=corpus2embeddings)

'湘云'

## 5. 整合成一个类

整合以上功能，写成一个类

In [19]:
class Corpus:

    def __init__(self, model_path, book_path, stopwords_path, min_freq):
        self.model_path = model_path
        self.book_path = book_path
        self.stopwords_path = stopwords_path
        self.min_freq = min_freq

        # 加载模型
        self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.model = BertModel.from_pretrained(model_path)

    def cut(self):
        """分词"""
        corpus = preprocess(self.book_path,
                            self.stopwords_path)
        ctr = collections.Counter(corpus)
        return [k for k, v in ctr.items() if v > self.min_freq]

    def get_embeddings(self, corpus):
        """批量计算 embeddings"""
        encoded_inputs = self.tokenizer(corpus,
                                        padding='max_length',
                                        truncation=True,
                                        return_tensors='pt')

        with torch.no_grad():
            outputs = self.model(**encoded_inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)

        return embeddings
    
    @staticmethod
    def corpus2embeddings(corpus, embeddings):
        """存储 词汇 与 embedding 的映射关系"""
        np_embeddings = [np.array(e) for e in embeddings]
        return [[k, v] for k, v in zip(corpus, np_embeddings)]

    @staticmethod
    def nearest_embedding(embedding, embeddings):
        """计算近邻 embedding"""
        squared_distances = np.sum((embeddings - embedding) ** 2, axis=1)
        nearest_idx = np.argmin(squared_distances)

        return embeddings[nearest_idx], nearest_idx

    def nearest_word(self, word, c2e):
        """计算近邻 词汇"""
        word_list = [e[0] for e in c2e]
        if word not in word_list:
            raise Exception(f'{word} not in word list.')

        embedding = [v for k, v in c2e if k == word][0]
        embeddings = [v for k, v in c2e if k != word]
        _, idx = self.nearest_embedding(embedding, embeddings)
        
        return word_list[idx]

    def test(self):
        corpus = self.cut()
        embeddings = self.get_embeddings(corpus)
        c2e = self.corpus2embeddings(corpus, embeddings)

        word = '宝玉'
        nearest_word = self.nearest_word(word, c2e)
        print('nearest_word:', nearest_word)

In [20]:
c = Corpus(model_path=CN_BERT_PATH,
           book_path=CN_BOOK_PATH,
           stopwords_path=CN_STOP_WORDS,
           min_freq=MIN_FREQ)
c.test()

Some weights of the model checkpoint at ./data/bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


nearest_word: 林之孝
