<a href="https://colab.research.google.com/github/hululuzhu/chinese-poem-search/blob/main/Chinese_Poem_Search_based_on_GuwenBERT_and_ScaNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# This is an end-to-end colab to search for most relevant Chinese tang/song dynasty poems given your string query.

Sample

```
    输入 你看那长江的水从天上来
      1: 李白 [鼓吹曲辞将进酒]
      君不见黄河之水天上来，奔流到海不复回．．．
      2: 马之纯 [新亭其二]
      新亭见说在山头，看见江河衮衮流．．．
      3: 释善果 [偈其五]
      苏州有，常州有，吸尽西江只一口．．．


    输入：忆长安
      1: 宋祁 [农阁]
      ...看云记巫峡，望日省长安。...
      2: 徐凝 [寄白司马]
      ...争遣江州白司马，五年风景忆长安。
      3: 崔涂 [春晚怀进士韦澹]
      ...二年春怅望，不似在长安。
```

Some notes:

*   Based on [guwenBERT](https://huggingface.co/ethanyt/guwenbert-base) (an ancient chinese pre-trained [RoBERTa](https://arxiv.org/abs/1907.11692) language model), [HF Transformers](https://github.com/huggingface/transformers) (model inference), and [Google sanNN](https://github.com/google-research/google-research/tree/master/scann) (approximate nearest neighbor search)
*   We fetch chinese poems from this [chinese-poetry github project](https://github.com/chinese-poetry/chinese-poetry), and divide to sentence pieces
*   Converted to simplified Chinese input using [chinese-converter package](https://github.com/zachary822/chinese-converter), please skip if you prefer traditional Chinese
*   Use the last layer hidden output as embedding to balance quality and memory constraints, [literature](http://jalammar.github.io/illustrated-bert/) recommends last 4 layers but cannot afford memory
*   Note the colab runs successfully with high RAM GPU colab instance (paid class, Tesla P100 GPU with 16G GPU ram, 2 core CPU with 24G ram). If you encourter with OOM issue, reduce the `SAMPLE_SIZE` value will help, or consider to upgrade to paid colab class ($9.99 per month with awesome GPU!!).
*   Colab takes about 1 hour to fully load/transform/initialize, after that, the inference should take <100ms per call.



## Verify GPU is used



*   Embedding lookup is pretty slow for CPU
*   HuggingFace AutoModel does not work with TPU easily



In [None]:
!nvidia-smi -L

GPU 0: Tesla P100-PCIE-16GB (UUID: GPU-68fe9d9b-6ec5-1df8-3117-cd0e19d71493)


## Imports

In [None]:
import json
import urllib.request
!pip install -q "tqdm>=4.36.1" > /tmp/na
from tqdm.notebook import tqdm
!pip install chinese-converter > /tmp/na
import chinese_converter
import pickle
import os
import pandas as pd
import numpy as np
import gc
import sys
import re
!pip install -q transformers
from transformers import AutoTokenizer, AutoModel
import torch
!pip install -q scann
import scann
from google.colab import drive
import tensorflow as tf
from tensorboard.plugins import projector
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


## Fetch Data from github

In [None]:
# https://github.com/chinese-poetry/chinese-poetry, last update 04/18/2023
POEM_CONTENT = {
    'tang': {
        'total': 58,
        'pattern': "https://raw.githubusercontent.com/chinese-poetry/chinese-poetry/master/%E5%85%A8%E5%94%90%E8%AF%97/poet.tang.{0}.json"
    },
    'song': {
        'total': 255,
        'pattern': "https://raw.githubusercontent.com/chinese-poetry/chinese-poetry/master/%E5%85%A8%E5%94%90%E8%AF%97/poet.song.{0}.json"
    }
}

def get_poems(is_test=True, verbose=True):
  df_list = []
  for dynasty in POEM_CONTENT:
    size = 3 if is_test else POEM_CONTENT[dynasty]['total']
    pbar = tqdm(total=size, desc="Dynasty " + dynasty)
    for i in range(size):
      url = POEM_CONTENT[dynasty]['pattern'].format(i * 1000)
      if verbose:
        print(f"download {url} now")
      df_list.append(pd.read_json(url))
      pbar.update(1)
  return pd.concat(df_list)

In [None]:
df = get_poems(is_test=False, verbose=False)

Dynasty tang:   0%|          | 0/58 [00:00<?, ?it/s]

Dynasty song:   0%|          | 0/255 [00:00<?, ?it/s]

## Transform to simplified Chinese and cleaning

In [None]:
df['concat_paragraphs'] = [''.join(map(str, l)) for l in df['paragraphs']]

In [None]:
def convert_schinese(tchinese):
  return chinese_converter.to_simplified(tchinese)

In [None]:
df['s_content'] = df.apply(lambda row: convert_schinese(''.join(row.concat_paragraphs)), axis=1)
df['s_title'] = df.apply(lambda row: convert_schinese(''.join(row.title)), axis=1)
df['s_author'] = df.apply(lambda row: convert_schinese(''.join(row.author)), axis=1)

In [None]:
my_df = df[['s_content', 's_title', 's_author']].copy()

In [None]:
SPLIT_STR = '。|！|？|\t|\n|\r'
converted_sents = []
pbar = tqdm(total=len(df), desc="break up into sentences")
for idx, row in df.iterrows():
  res = re.split(SPLIT_STR, row['s_content'])
  for s in res:
    if s.strip() != "":
      converted_sents.append({
          "s_content": row['s_content'],
          'piece': s,
          's_author': row['s_author'],
          's_title': row['s_title'],
      })
  pbar.update(1)

sent_pd = pd.DataFrame(converted_sents)
sent_pd['piece_len'] = sent_pd.piece.str.len()

break up into sentences:   0%|          | 0/311860 [00:00<?, ?it/s]

In [None]:
# Empirically set thresholds of 'valid' poem sentence pieces
MAX_SENTENCE_LENGTH = 30
MIN_SENTENCE_LENGTH = 8

In [None]:
clean_sent_pd = sent_pd[sent_pd.piece_len <= MAX_SENTENCE_LENGTH].copy()
clean_sent_pd = clean_sent_pd[clean_sent_pd.piece_len >= MIN_SENTENCE_LENGTH].copy()

In [None]:
my_df = clean_sent_pd.copy()
len(my_df)

1354444

In [None]:
# omit bad chars
OMIT_CHARS = "()（）[]［］●⿰〔〕〖〗［］Ｂ=/…「」x{}《》、”：0123456789○『』"

def trim_author_fn(row):
  return row.s_author[:4]

def trim_title_fn(row):
  trimed_title = row.s_title[:12].replace(" ", "").replace("(", "").replace(")", "")
  return trimed_title

def trim_piece_fn(row, feature):
  trimed_content = row[feature]
  for token in OMIT_CHARS:
    trimed_content = trimed_content.replace(token, "")
  return trimed_content

# Trim the size
my_df['s_author'] = my_df.apply(trim_author_fn, axis=1)
my_df['s_title'] = my_df.apply(trim_title_fn, axis=1)
my_df['s_content'] = my_df.apply(lambda r : trim_piece_fn(r, 's_content'), axis=1)
my_df['piece'] = my_df.apply(lambda r : trim_piece_fn(r, 'piece'), axis=1)

print("before filter:", len(my_df))
my_df = my_df[my_df.s_author != '']
my_df = my_df[my_df.s_title != '']
my_df = my_df[my_df.s_content != '']
my_df = my_df[my_df.piece != '']
my_df = my_df[~my_df.s_content.str.contains("□")] # unrecognized chars, ignore
print("after filter:", len(my_df))

before filter: 1354444
after filter: 1340301


## HuggingFace AutoModel to load guwen bert

In [None]:
# See https://github.com/ethan-yt/guwenbert/blob/main/README_EN.md
tokenizer = AutoTokenizer.from_pretrained("ethanyt/guwenbert-base")
model = AutoModel.from_pretrained("ethanyt/guwenbert-base", output_hidden_states=True)
gpu_model = model.to('cuda')

Downloading:   0%|          | 0.00/519 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/93.5k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

Some weights of the model checkpoint at ethanyt/guwenbert-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.decoder.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Fetch all sentence embeddings by calling guwenBERT model

In [None]:
# Change to smaller number such as 100k if you see OOM (or colab just restarts)
SAMPLE_SIZE = len(my_df) # 100_000

In [None]:
# use last n layers' sum and average per token, so each string has\
# 768 * n * #token size embeddings.
# literature says last concat last 4 is best, last 1 is a good start here.
LAST_N = 1
last_layers = [-i for i in range(1, LAST_N + 1)]  # -1, -2, ....

def get_bert_embedding(in_strs):
  inputs = tokenizer(in_strs, padding=True, truncation=True,
      return_tensors="pt", max_length=MAX_SENTENCE_LENGTH+2).to('cuda:0')
  with torch.no_grad(): # absolutely required to avoid OOM
    outputs = gpu_model(**inputs)
  last_n_layers = [outputs['hidden_states'][i] for i in last_layers]
  cat_hidden_states = torch.cat(tuple(last_n_layers), dim=-1)
  cat_sentence_embedding = torch.mean(cat_hidden_states, dim=1).squeeze()
  return cat_sentence_embedding

In [None]:
# Fetch embedding for 1.3m sentences

# About 15 mins to fetch embeddings for 1.3M poem sentence pieces
# so roughly 1300k/1200 = 1.1k record inferences per second (Tesla P100), 
# CPU is too low, TPU does not support
batch_size = 2048 # ~max batch size for gpu-ram=16G
all_pieces = list(my_df.piece.values)
all_pieces = all_pieces[:SAMPLE_SIZE]

out_embedding_np = None
pbar = tqdm(total=len(all_pieces), desc="Fetch embeddings")
for i in range(0, len(all_pieces), batch_size):
  start_idx = i
  end_idx = min(start_idx + batch_size, len(all_pieces))
  embeddings = get_bert_embedding(all_pieces[start_idx:end_idx])
  if out_embedding_np is None:
    out_embedding_np = embeddings.detach().cpu().numpy()
  else:
    out_embedding_np = np.append(
        out_embedding_np, embeddings.detach().cpu().numpy(), axis=0)
  pbar.update(end_idx - start_idx)
  # Additional safeguard to avoid OOM
  gc.collect()
  torch.cuda.empty_cache()

Fetch embeddings:   0%|          | 0/1340301 [00:00<?, ?it/s]

In [None]:
assert out_embedding_np.shape == (len(all_pieces), 768 * LAST_N)

## [optional] Persist embedding for reuse to save time for embedding lookup

In [None]:
# drive.mount('/content/gdrive')
# !mkdir -p /content/gdrive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727
# my_df.to_pickle('/content/gdrive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727/mydf.pickle')
# np.save('/content/gdrive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727/embedding.npy', out_embedding_np)
# !ls -l /content/gdrive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727

# If reload embedding and dataframe from savings
# my_df = pd.read_pickle('drive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727/mydf.pickle')
# out_embedding_np = np.load('drive/MyDrive/ML/Data/ch_poem_search_sent_embed_20210727/embedding.npy')

## Tensorboard to visualize 10k samples

In [None]:
VIZ_SAMPLE = 10000
LOG_DIR = '/logs/embedding_projection/'  # Tensorboard log dir
random_selected = np.zeros(len(my_df), dtype=int)
random_selected[:VIZ_SAMPLE] = 1
np.random.shuffle(random_selected)
random_selected_mask = random_selected.astype(bool)
embedding_samples = out_embedding_np[random_selected_mask]
label_samples = my_df.piece.values[random_selected_mask]
assert embedding_samples.shape[0] == len(label_samples)

In [None]:
# see hack from https://github.com/tensorflow/tensorboard/issues/2471#issuecomment-580423961
def register_embedding(embedding_tensor_name, meta_data_fname, log_dir):
    config = projector.ProjectorConfig()
    embedding = config.embeddings.add()
    embedding.tensor_name = embedding_tensor_name
    embedding.metadata_path = meta_data_fname
    projector.visualize_embeddings(log_dir, config)

def save_labels_tsv(labels, filepath, log_dir):
    with open(os.path.join(log_dir, filepath), 'w') as f:
        for label in labels:
            f.write('{}\n'.format(label))

if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)
META_DATA_FNAME = 'meta.tsv'  # Labels will be stored here
EMBEDDINGS_TENSOR_NAME = 'embeddings'
EMBEDDINGS_FPATH = os.path.join(LOG_DIR, EMBEDDINGS_TENSOR_NAME + '.ckpt')
STEP = 0

register_embedding(EMBEDDINGS_TENSOR_NAME, META_DATA_FNAME, LOG_DIR)
save_labels_tsv(label_samples, META_DATA_FNAME, LOG_DIR)

tensor_embeddings = tf.Variable(embedding_samples, name=EMBEDDINGS_TENSOR_NAME)
saver = tf.compat.v1.train.Saver([tensor_embeddings])  # Must pass list or dict
saver.save(sess=None, global_step=STEP, save_path=EMBEDDINGS_FPATH)



'/logs/embedding_projection/embeddings.ckpt-0'

In [None]:
# After load, pick Projector, UMAP seems to have best results
%tensorboard --logdir {LOG_DIR}

## ScaNN for fast nearest neighbor search

In [None]:
def create_searcher(normalized_embeddings, # size matters, 1/2 size results in 2/3 time
                    num_leaves=1000, # 1/2 size = 2/3 time, recommend ~sqrt(size)
                    num_leaves_to_search=100, # not affect build time
                    reorder_size=20, # not affect build time
                    min_partition_size=50, # not affect build time
                    training_iterations=12, # slightly affect build time
                    neighbor_size=10, # not affect build time
                    search_func="dot_product"): # seems the best option after tried
  """Creates scann searcher. Params briefly tested, comments may be wrong."""
  search_builder = scann.scann_ops_pybind.builder(
      normalized_embeddings, neighbor_size, search_func)
  search_builder = search_builder.tree(
      num_leaves=num_leaves,
      num_leaves_to_search=num_leaves_to_search,
      training_sample_size=normalized_embeddings.shape[0],
      min_partition_size=min_partition_size,
      training_iterations=training_iterations)
  search_builder = search_builder.score_ah(
      2, anisotropic_quantization_threshold=0.2)
  search_builder = search_builder.reorder(reorder_size)
  searcher = search_builder.build()
  return searcher

In [None]:
%%time
print("Est time: ", 2 * (out_embedding_np.shape[0] // 1000), "secs")
normalized_embedding_np = out_embedding_np / np.linalg.norm(out_embedding_np, axis=1)[:, np.newaxis]
scann_searcher = create_searcher(normalized_embedding_np)

Est time:  2680 secs
CPU times: user 31min 15s, sys: 11.3 s, total: 31min 26s
Wall time: 8min 17s


## Sup method for inference Test

In [None]:
sents = my_df.piece.values
authors = my_df.s_author.values
titles = my_df.s_title.values
contents = my_df.s_content.values

def query_poem_now(input_str,
                   scann_searcher=scann_searcher,
                   size=10):
  """Fetch embedding, ask scann the nearest neighbors and format."""
  inp_embedding = get_bert_embedding(input_str)
  neighbors, distances = scann_searcher.search(
      inp_embedding.detach().cpu().numpy(), final_num_neighbors=size)
  print(f"输入：\x1b[34m{input_str}\x1b[0m")
  id = 1
  for n, d in zip(neighbors, distances):
    content = contents[n]
    if content.find(sents[n]) > 40:
      content = "..." + content[content.find(sents[n]):]
    content = content.replace(sents[n], f"[\x1b[31m{sents[n]}\x1b[0m]")
    if (len(content) > 80):
      content = content[:80] + "..."
    print(f"{id}: {authors[n]} [{titles[n]}]\n{content}")
    id += 1
  print()

## Inference

In [None]:
%%time
for query in ["你看那长江的水从天上来", "忆长安", "月有圆缺", "九歌"]:
  query_poem_now(query, size=3)

输入：[34m你看那长江的水从天上来[0m
1: 李白 [鼓吹曲辞将进酒]
[[31m君不见黄河之水天上来，奔流到海不复回[0m]。君不见高堂明镜悲白发，朝如青丝暮成雪。人生得意须尽欢，莫使金尊空对月。天生我材必有用，千金散尽还复来...
2: 马之纯 [新亭其二]
[[31m新亭见说在山头，看见江河衮衮流[0m]。何事后人轻变改，不教遗址且存留。怜他一代称贤相，说此诸人似楚囚。若使有人来访旧，一番人见一番羞。
3: 释善果 [偈其五]
[[31m苏州有，常州有，吸尽西江只一口[0m]。百八数珠数不尽，须知天长与地久。腾今焕古作嘉祥，一一面南看北斗。

输入：[34m忆长安[0m
1: 宋祁 [农阁]
菌阁俯江干，西南蜀塞宽。[[31m看云记巫峡，望日省长安[0m]。钿崒峰头碧，霞皴荔子丹。比来秋物好，谁伴数凭栏。
2: 裴说 [咏鹦鹉]
常贵西山鸟，衔恩在玉堂。语传明主意，衣拂美人香。缓步寻珠网，高飞上画梁。[[31m长安频道乐，何日从君王[0m]。
3: 徐凝 [寄白司马]
三条九陌花时节，万户千车看牡丹。[[31m争遣江州白司马，五年风景忆长安[0m]。

输入：[34m月有圆缺[0m
1: 王义山 [赠心月相士]
...[[31m嗟彼天上月，有圆阙阴晴[0m]。惟有心月月，天者常清明。持此以鉴人，妍媸奚所遁。此月不在天，月在尔方寸。
2: 崔萱 [古意]
灼灼叶中花，夏萎春又芳。[[31m明明天上月，蟾缺圆复光[0m]。未如君子情，朝违夕已忘。玉帐枕犹暖，纨扇思何长。愿因西南风，吹上玳瑁牀。娇眠锦衾裏，展转双...
3: 李曾伯 [题冯司法水月书堂其一]
[[31m水能清亦能浊，月有满亦有亏[0m]。清浊满亏区别，冯君书以知之。

输入：[34m九歌[0m
1: 范成大 [浮湘行]
...[[31m九歌凄悲不可听，愿赓楚调归和平[0m]。
2: 宋庠 [屈原其二]
[[31m司命湘君各有情，九歌愁苦荐新声[0m]。如何不救沉江祸，枉解堂中许目成。
3: 不详 [郊庙歌辞中宗祀昊天乐章]
[[31m九成爰奏，三献式陈[0m]。钦承景福，恭托明禋。

CPU times: user 54.7 ms, sys: 36 µs, total: 54.7 ms
Wall time: 57.9 ms


In [None]:
%%time
for query in ['今日看到美女，惊呆了！', '想骂人', '我要打人',
              '祝贺朋友结婚', '做人要开心', '载歌在谷', '自由而无用', ]:
  query_poem_now(query, size=5)

输入：[34m今日看到美女，惊呆了！[0m
1: 释广闻 [行素长老请自赞]
无狮子教儿之诀，无老猫上树之机。世无所容其拙，人或见谓之慈。知之已熟，画出何爲。[[31m后三十年扬在𡏖𡒁堆头，便令无光怪发现，未免起傍观按剑之疑[0m]。
2: 王大烈 [绝句]
弧矢重悬旧礼仪，郎君又産玉麟儿。[[31m便烦着眼从头看，的似徐卿二子奇[0m]。
3: 陆游 [春雨绝句六首其三]
天公似欲败蚕麰，雨冒南山暮不收。[[31m呆女癡儿那念此，贪看科斗满清沟[0m]。
4: 范纯仁 [安之家庭甘结实三首其]
遮藏霜雪免摧残，绿实初垂未满栏。[[31m想见主人珍惜意，一回出户一回看[0m]。
5: 释宗琏 [赞月庵]
...[[31m只因会春园裏失却眼睛，从此恶名滔滔流遍天下[0m]。高挂虚堂兮如师子全威，一任百怪千妖暗中惊讶。灯禅灯禅，第一不得容易与伊点化。

输入：[34m想骂人[0m
1: 释道济 [与张提点共饮席间作]
每日终朝醉似泥，未尝一日不昏迷。[[31m细君发怒将言骂，道是人间吃酒儿[0m]。莫要管，你休癡，人生能有几多时。杜康曾唱莲花落，刘伶好饮舞啰哩，陶渊明赏菊...
2: 唐士耻 [两溪]
...[[31m叱石而兴真戯耳，可羡人家好弟兄[0m]。西则三洞足仙灵，奇哉石穴声铿铿。瀑泉喷写资照耀，俨如冰蚕之所成。夕阳斜带九华度，何媿芙蓉峥且嵘。鼎湖...
3: 杨万里 [晨炊翫鞭亭二首其二]
[[31m问着无声是阿兄，坐看家贼只吞声[0m]。戮尸大放经纶手，长柄判将锡茂弘。
4: 陈师道 [赠知命]
黑头居士元方弟，不肯作公称法嗣。[[31m外人怪笑那得知，他日灵山亲授记[0m]。学诗初学杜少陵，学书不学王右军。黄尘扶杖笑邻女，白衫骑馿惊市人。静中作业此...
5: 赵必𤩪 [赠相士桂月岩]
昔有道人张岩电，口不言钱似王衍。[[31m却言姓钱不姓张，滑稽玩世名犹香[0m]。昔有岩电今月岩，月眼神舌和天谈。江湖剩结贵人知，虽不爱钱犹爱诗。

输入：[34m我要打人[0m
1: 释师观 [颂古十七首其六]
[[31m张打油，李打油，不打浑身只打头[0m]。今朝有酒今朝醉，明日愁来明日愁。
2: 释师体 [偈颂十八首其一○]
[[31m一拳也是打爷来，未有输赢莫放开[0m]。割舍拍盲穷性命，觜喎鼻塌见全材。
