<a href="https://colab.research.google.com/github/hululuzhu/chinese-poem-search/blob/main/Chinese_Poem_Search_based_on_GuwenBERT_and_scaNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# This is an end-to-end colab to search for most relevant Chinese tang/song dynasty poems given your string query.

Sample

```
    输入 你看那长江的水从天上来
      1: 李白 [鼓吹曲辞将进酒]
      君不见黄河之水天上来，奔流到海不复回．．．
      2: 马之纯 [新亭其二]
      新亭见说在山头，看见江河衮衮流．．．
      3: 释善果 [偈其五]
      苏州有，常州有，吸尽西江只一口．．．


    输入：忆长安
      1: 宋祁 [农阁]
      ...看云记巫峡，望日省长安。...
      2: 徐凝 [寄白司马]
      ...争遣江州白司马，五年风景忆长安。
      3: 崔涂 [春晚怀进士韦澹]
      ...二年春怅望，不似在长安。
```

Some notes:

*   Based on [guwenBERT](https://huggingface.co/ethanyt/guwenbert-base) (an ancient chinese pre-trained [RoBERTa](https://arxiv.org/abs/1907.11692) language model), [HF Transformers](https://github.com/huggingface/transformers) (model inference), and [Google sanNN](https://github.com/google-research/google-research/tree/master/scann) (approximate nearest neighbor search)
*   We fetch chinese poems from this [chinese-poetry github project](https://github.com/chinese-poetry/chinese-poetry), and divide to sentence pieces
*   Converted to simplified Chinese input using [chinese-converter package](https://github.com/zachary822/chinese-converter), please skip if you prefer traditional Chinese
*   Use the last layer hidden output as embedding to balance quality and memory constraints, [literature](http://jalammar.github.io/illustrated-bert/) recommends last 4 layers but cannot afford memory
*   Note the colab runs successfully with high RAM GPU colab instance (paid class, Tesla P100 GPU with 16G GPU ram, 2 core CPU with 24G ram). If you encourter with OOM issue, reduce the `SAMPLE_SIZE` value will help, or consider to upgrade to paid colab class ($9.99 per month with awesome GPU!!).
*   Colab takes about 1 hour to fully load/transform/initialize, after that, the inference should take <100ms per call.



## Verify GPU is used



*   Embedding lookup is pretty slow for CPU
*   HuggingFace AutoModel does not work with TPU easily



In [25]:
!nvidia-smi -L

GPU 0: Tesla P100-PCIE-16GB (UUID: GPU-72982998-4d8c-0b67-3619-ed640a491dfa)


## Imports

In [2]:
import json
import urllib.request
!pip install -q "tqdm>=4.36.1" > /tmp/na
from tqdm.notebook import tqdm
!pip install chinese-converter > /tmp/na
import chinese_converter
import pickle
import os
import pandas as pd
import numpy as np
import gc
import sys
import re
!pip install -q transformers
from transformers import AutoTokenizer, AutoModel
import torch
!pip install -q scann
import scann
from google.colab import drive

## Fetch Data from github

In [3]:
# https://github.com/chinese-poetry/chinese-poetry
POEM_CONTENT = {
    'tang': {
        'total': 58,
        'pattern': "https://raw.githubusercontent.com/chinese-poetry/chinese-poetry/master/json/poet.tang.{0}.json"
    },
    'song': {
        'total': 255,
        'pattern': "https://raw.githubusercontent.com/chinese-poetry/chinese-poetry/master/json/poet.song.{0}.json"
    }
}

def get_poems(is_test=True, verbose=True):
  df_list = []
  for dynasty in POEM_CONTENT:
    size = 3 if is_test else POEM_CONTENT[dynasty]['total']
    pbar = tqdm(total=size, desc="Dynasty " + dynasty)
    for i in range(size):
      url = POEM_CONTENT[dynasty]['pattern'].format(i * 1000)
      if verbose:
        print(f"download {url} now")
      df_list.append(pd.read_json(url))
      pbar.update(1)
  return pd.concat(df_list)

In [4]:
df = get_poems(is_test=False, verbose=False)

HBox(children=(FloatProgress(value=0.0, description='Dynasty tang', max=58.0, style=ProgressStyle(description_…

HBox(children=(FloatProgress(value=0.0, description='Dynasty song', max=255.0, style=ProgressStyle(description…

## Transform to simplified Chinese and cleaning

In [5]:
df['concat_paragraphs'] = [''.join(map(str, l)) for l in df['paragraphs']]

In [6]:
def convert_schinese(tchinese):
  return chinese_converter.to_simplified(tchinese)

In [7]:
df['s_content'] = df.apply(lambda row: convert_schinese(''.join(row.concat_paragraphs)), axis=1)
df['s_title'] = df.apply(lambda row: convert_schinese(''.join(row.title)), axis=1)
df['s_author'] = df.apply(lambda row: convert_schinese(''.join(row.author)), axis=1)

In [8]:
my_df = df[['s_content', 's_title', 's_author']].copy()

In [9]:
SPLIT_STR = '。|！|？|\t|\n|\r'
converted_sents = []
pbar = tqdm(total=len(df), desc="break up into sentences")
for idx, row in df.iterrows():
  res = re.split(SPLIT_STR, row['s_content'])
  for s in res:
    if s.strip() != "":
      converted_sents.append({
          "s_content": row['s_content'],
          'piece': s,
          's_author': row['s_author'],
          's_title': row['s_title'],
      })
  pbar.update(1)

sent_pd = pd.DataFrame(converted_sents)
sent_pd['piece_len'] = sent_pd.piece.str.len()

HBox(children=(FloatProgress(value=0.0, description='break up into sentences', max=311860.0, style=ProgressSty…

In [10]:
# Empirically set thresholds of 'valid' poem sentence pieces
MAX_SENTENCE_LENGTH = 30
MIN_SENTENCE_LENGTH = 8

In [11]:
clean_sent_pd = sent_pd[sent_pd.piece_len <= MAX_SENTENCE_LENGTH].copy()
clean_sent_pd = clean_sent_pd[clean_sent_pd.piece_len >= MIN_SENTENCE_LENGTH].copy()

In [12]:
my_df = clean_sent_pd.copy()
len(my_df)

1354444

In [13]:
# omit bad chars
OMIT_CHARS = "()（）[]［］●⿰〔〕〖〗［］Ｂ=/…「」x{}《》、”：0123456789○『』"

def trim_author_fn(row):
  return row.s_author[:4]

def trim_title_fn(row):
  trimed_title = row.s_title[:12].replace(" ", "").replace("(", "").replace(")", "")
  return trimed_title

def trim_piece_fn(row, feature):
  trimed_content = row[feature]
  for token in OMIT_CHARS:
    trimed_content = trimed_content.replace(token, "")
  return trimed_content

# Trim the size
my_df['s_author'] = my_df.apply(trim_author_fn, axis=1)
my_df['s_title'] = my_df.apply(trim_title_fn, axis=1)
my_df['s_content'] = my_df.apply(lambda r : trim_piece_fn(r, 's_content'), axis=1)
my_df['piece'] = my_df.apply(lambda r : trim_piece_fn(r, 'piece'), axis=1)

print("before filter:", len(my_df))
my_df = my_df[my_df.s_author != '']
my_df = my_df[my_df.s_title != '']
my_df = my_df[my_df.s_content != '']
my_df = my_df[my_df.piece != '']
my_df = my_df[~my_df.s_content.str.contains("□")] # unrecognized chars, ignore
print("after filter:", len(my_df))

before filter: 1354444
after filter: 1340301


## HuggingFace AutoModel to load guwen bert

In [14]:
# See https://github.com/ethan-yt/guwenbert/blob/main/README_EN.md
tokenizer = AutoTokenizer.from_pretrained("ethanyt/guwenbert-base")
model = AutoModel.from_pretrained("ethanyt/guwenbert-base", output_hidden_states=True)
gpu_model = model.to('cuda')

Some weights of the model checkpoint at ethanyt/guwenbert-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.decoder.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Fetch all sentence embeddings by calling guwenBERT model

In [15]:
# Change to smaller number such as 100k if you see OOM (or colab just restarts)
SAMPLE_SIZE = len(my_df) # 100_000

In [16]:
# use last n layers' sum and average per token, so each string has\
# 768 * n * #token size embeddings.
# literature says last concat last 4 is best, last 1 is a good start here.
LAST_N = 1
last_layers = [-i for i in range(1, LAST_N + 1)]  # -1, -2, ....

def get_bert_embedding(in_strs):
  inputs = tokenizer(in_strs, padding=True, truncation=True,
      return_tensors="pt", max_length=MAX_SENTENCE_LENGTH+2).to('cuda:0')
  with torch.no_grad(): # absolutely required to avoid OOM
    outputs = gpu_model(**inputs)
  last_n_layers = [outputs['hidden_states'][i] for i in last_layers]
  cat_hidden_states = torch.cat(tuple(last_n_layers), dim=-1)
  cat_sentence_embedding = torch.mean(cat_hidden_states, dim=1).squeeze()
  return cat_sentence_embedding

In [17]:
# Fetch embedding for 1.3m sentences

# About 15 mins to fetch embeddings for 1.3M poem sentence pieces
# so roughly 1300k/1200 = 1.1k record inferences per second (Tesla P100), 
# CPU is too low, TPU does not support
batch_size = 2048 # ~max batch size for gpu-ram=16G
all_pieces = list(my_df.piece.values)
all_pieces = all_pieces[:SAMPLE_SIZE]

out_embedding_np = None
pbar = tqdm(total=len(all_pieces), desc="Fetch embeddings")
for i in range(0, len(all_pieces), batch_size):
  start_idx = i
  end_idx = min(start_idx + batch_size, len(all_pieces))
  embeddings = get_bert_embedding(all_pieces[start_idx:end_idx])
  if out_embedding_np is None:
    out_embedding_np = embeddings.detach().cpu().numpy()
  else:
    out_embedding_np = np.append(
        out_embedding_np, embeddings.detach().cpu().numpy(), axis=0)
  pbar.update(end_idx - start_idx)
  # Additional safeguard to avoid OOM
  gc.collect()
  torch.cuda.empty_cache()

HBox(children=(FloatProgress(value=0.0, description='Fetch embeddings', max=1340301.0, style=ProgressStyle(des…

In [18]:
assert out_embedding_np.shape == (len(all_pieces), 768 * LAST_N)

## [optional] Persist embedding for reuse to save time for embedding lookup

In [19]:
# drive.mount('/content/gdrive')
# !mkdir -p /content/gdrive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727
# my_df.to_pickle('/content/gdrive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727/mydf.pickle')
# np.save('/content/gdrive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727/embedding.npy', out_embedding_np)
# !ls -l /content/gdrive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727

# If reload embedding and dataframe from savings
# my_df = pd.read_pickle('drive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727/mydf.pickle')
# out_embedding_np = np.load('drive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727/embedding.npy')

## ScaNN for fast nearest neighbor search

In [20]:
def create_searcher(normalized_embeddings, # size matters, 1/2 size results in 2/3 time
                    num_leaves=1000, # 1/2 size = 2/3 time, recommend ~sqrt(size)
                    num_leaves_to_search=100, # not affect build time
                    reorder_size=20, # not affect build time
                    min_partition_size=50, # not affect build time
                    training_iterations=12, # slightly affect build time
                    neighbor_size=10, # not affect build time
                    search_func="dot_product"): # seems the best option after tried
  """Creates scann searcher. Params briefly tested, comments may be wrong."""
  search_builder = scann.scann_ops_pybind.builder(
      normalized_embeddings, neighbor_size, search_func)
  search_builder = search_builder.tree(
      num_leaves=num_leaves,
      num_leaves_to_search=num_leaves_to_search,
      training_sample_size=normalized_embeddings.shape[0],
      min_partition_size=min_partition_size,
      training_iterations=training_iterations)
  search_builder = search_builder.score_ah(
      2, anisotropic_quantization_threshold=0.2)
  search_builder = search_builder.reorder(reorder_size)
  searcher = search_builder.build()
  return searcher

In [22]:
%%time
print("Est time: ", 2 * (out_embedding_np.shape[0] // 1000), "secs")
normalized_embedding_np = out_embedding_np / np.linalg.norm(out_embedding_np, axis=1)[:, np.newaxis]
scann_searcher = create_searcher(normalized_embedding_np)

Est time:  2680 secs
CPU times: user 29min 25s, sys: 11 s, total: 29min 36s
Wall time: 7min 55s


## Sup method for inference Test

In [64]:
sents = my_df.piece.values
authors = my_df.s_author.values
titles = my_df.s_title.values
contents = my_df.s_content.values

def query_poem_now(input_str,
                   scann_searcher=scann_searcher,
                   size=10):
  """Fetch embedding, ask scann the nearest neighbors and format."""
  inp_embedding = get_bert_embedding(input_str)
  neighbors, distances = scann_searcher.search(
      inp_embedding.detach().cpu().numpy(), final_num_neighbors=size)
  print(f"输入：\x1b[34m{input_str}\x1b[0m")
  id = 1
  for n, d in zip(neighbors, distances):
    content = contents[n]
    if content.find(sents[n]) > 40:
      content = "..." + content[content.find(sents[n]):]
    content = content.replace(sents[n], f"[\x1b[31m{sents[n]}\x1b[0m]")
    if (len(content) > 80):
      content = content[:80] + "..."
    print(f"{id}: {authors[n]} [{titles[n]}]\n{content}")
    id += 1
  print()

## Inference

In [81]:
%%time
for query in ["你看那长江的水从天上来", "忆长安", "月有圆缺", "九歌"]:
  query_poem_now(query, size=3)

输入：[34m你看那长江的水从天上来[0m
1: 李白 [鼓吹曲辞将进酒]
[[31m君不见黄河之水天上来，奔流到海不复回[0m]。君不见高堂明镜悲白发，朝如青丝暮成雪。人生得意须尽欢，莫使金尊空对月。天生我材必有用，千金散尽还复来...
2: 马之纯 [新亭其二]
[[31m新亭见说在山头，看见江河衮衮流[0m]。何事后人轻变改，不教遗址且存留。怜他一代称贤相，说此诸人似楚囚。若使有人来访旧，一番人见一番羞。
3: 释善果 [偈其五]
[[31m苏州有，常州有，吸尽西江只一口[0m]。百八数珠数不尽，须知天长与地久。腾今焕古作嘉祥，一一面南看北斗。

输入：[34m忆长安[0m
1: 宋祁 [农阁]
菌阁俯江干，西南蜀塞宽。[[31m看云记巫峡，望日省长安[0m]。钿崒峰头碧，霞皴荔子丹。比来秋物好，谁伴数凭栏。
2: 徐凝 [寄白司马]
三条九陌花时节，万户千车看牡丹。[[31m争遣江州白司马，五年风景忆长安[0m]。
3: 崔涂 [春晚怀进士韦澹]
故里花应尽，江楼夢尚残。半生吟欲过，一命达何难。特立珪无玷，相思草有兰。[[31m二年春怅望，不似在长安[0m]。

输入：[34m月有圆缺[0m
1: 王义山 [赠心月相士]
...[[31m嗟彼天上月，有圆阙阴晴[0m]。惟有心月月，天者常清明。持此以鉴人，妍媸奚所遁。此月不在天，月在尔方寸。
2: 释智愚 [偈颂二十五首其八]
[[31m一年有十二箇月，每月一度团圆，其余尽是缺[0m]。中间晦明出没，太半有不见者。惟有今宵，分外皎洁。无物堪比伦，教我如何说。
3: 崔萱 [古意]
灼灼叶中花，夏萎春又芳。[[31m明明天上月，蟾缺圆复光[0m]。未如君子情，朝违夕已忘。玉帐枕犹暖，纨扇思何长。愿因西南风，吹上玳瑁牀。娇眠锦衾裏，展转双...

输入：[34m九歌[0m
1: 欧阳修 [江上弹琴]
...[[31m咏歌文王雅，怨刺离骚经[0m]。二典意澹薄，三盘语丁宁。琴声虽可状，琴意谁可听。
2: 范成大 [浮湘行]
...[[31m九歌凄悲不可听，愿赓楚调归和平[0m]。
3: 宋庠 [屈原其二]
[[31m司命湘君各有情，九歌愁苦荐新声[0m]。如何不救沉江祸，枉解堂中许目成。

CPU times: user 51 ms, sys: 1.0