<a href="https://colab.research.google.com/github/hululuzhu/chinese-ai-writing-share/blob/main/training/t5_finetune/2023_T5_Finetune_Chinese_Poem_Writing_V1_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# T5 写诗
- 设计：Pretrained T5 + “写诗 prompt” fine-tuning
  - 对比我的 [transformer training from scratch](https://github.com/hululuzhu/chinese-ai-writing-share/blob/main/%E4%B8%AD%E6%96%87%E5%86%99%E8%AF%97Transformer_Source_Code_Share_V1.ipynb)
  - 想要加入作者作为可选输入
    - 每个文章分两次输入，一次作者名字，一次“None”名字（通用）
- 数据：[诗歌github](https://github.com/chinese-poetry/chinese-poetry)
- 相关内容
  - [Huggingface](https://huggingface.co/)
  - LangZhou Chinese [MengZi T5 pretrained Model](https://huggingface.co/Langboat/mengzi-t5-base) and [paper](https://arxiv.org/pdf/2110.06696.pdf)
  - [SimpleT5 by Shivanandroy](https://github.com/Shivanandroy/simpleT5) (on top of pytorch and pytorch lightning) and [his awesome medium article](https://medium.com/geekculture/simplet5-train-t5-models-in-just-3-lines-of-code-by-shivanand-roy-2021-354df5ae46ba)
- 进度
  - 02/2023, improve source text to make it shorter and more concise
    - Enforce the alignment of text size in poems
  - New model to be trained by 03/2023

## Load Data

In [None]:
!nvidia-smi

In [None]:
IS_TEST_FLOW = False  #@param {type: "boolean"}

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import json
import urllib.request
import pandas as pd
!pip install -q "tqdm>=4.36.1" > /tmp/na
from tqdm.notebook import tqdm
!pip install -q chinese-converter > /tmp/na
import chinese_converter  # 繁体到简体需要
import pickle
import os
import pandas as pd
import numpy as np

In [None]:
# https://github.com/chinese-poetry/chinese-poetry
POEM_CONTENT = {
    'tang': {
        'total': 58,
        'pattern': "https://raw.githubusercontent.com/chinese-poetry/chinese-poetry/master/json/poet.tang.{0}.json"
    },
    'song': {
        'total': 255,
        'pattern': "https://raw.githubusercontent.com/chinese-poetry/chinese-poetry/master/json/poet.song.{0}.json"
    }
}

def get_poems(is_test=True, verbose=True):
  df_list = []
  for dynasty in POEM_CONTENT:
    size = 3 if is_test else POEM_CONTENT[dynasty]['total']
    pbar = tqdm(total=size, desc="Dynasty " + dynasty)
    for i in range(size):
      url = POEM_CONTENT[dynasty]['pattern'].format(i * 1000)
      if verbose:
        print(f"download {url} now")
      df_list.append(pd.read_json(url))
      pbar.update(1)
  return pd.concat(df_list)

In [None]:
df = get_poems(is_test=IS_TEST_FLOW, verbose=False)
df['concat_paragraphs'] = [''.join(map(str, l)) for l in df['paragraphs']]
df = df[['author', 'title', 'concat_paragraphs']]

def convert_schinese(tchinese):
  return chinese_converter.to_simplified(tchinese)

df['s_content'] = df.apply(lambda row: convert_schinese(''.join(row.concat_paragraphs)), axis=1)
df['s_title'] = df.apply(lambda row: convert_schinese(''.join(row.title)), axis=1)
df['s_author'] = df.apply(lambda row: convert_schinese(''.join(row.author)), axis=1)

my_df = df
print("my_df size", len(my_df))

In [None]:
MAX_AUTHOR_CHAR = 4
MAX_TITLE_CHAR = 12
MIN_CONTENT_CHAR = 20
MAX_CONTENT_CHAR = 64
BAD_TOKENS = " ()[]《》（）□{}abcdefgxyz一"

def trim_author_fn(row):
  return row.s_author[:MAX_AUTHOR_CHAR]

def trim_title_fn(row):
  trimed_title = row.s_title[:MAX_TITLE_CHAR]
  for b in BAD_TOKENS:
    trimed_title = trimed_title.replace(b, "")
  return trimed_title

def trim_content_fn(row):
  trimed_content = row.s_content[:MAX_CONTENT_CHAR]
  # # End with a period to avoid partial ending to confuse model
  for b in BAD_TOKENS:
    trimed_content = trimed_content.replace(b, "")
  last_period = trimed_content.rfind("。")
  return trimed_content[:last_period+1]
  # return trimed_content

# Trim the size, a soft copy to avoid the view/copy conflict warning
my_df['s_author_trim'] = my_df.copy().apply(trim_author_fn, axis=1)
my_df['s_title_trim'] = my_df.copy().apply(trim_title_fn, axis=1)
my_df['s_content_trim'] = my_df.copy().apply(trim_content_fn, axis=1)

print("my_df size", len(my_df))

In [None]:
# Title cannot be empty
empty_title_mask = (my_df['s_title_trim'].str.len() == 0)
too_short_cotent_mask = (my_df['s_content_trim'].str.len() <= MIN_CONTENT_CHAR)
invalid_mask = (('无正文' == my_df['s_content_trim']) | ('无正文' == my_df['s_author_trim']))
too_short_mask =  empty_title_mask | too_short_cotent_mask | invalid_mask
# filtered_my_df = my_df.loc[too_short_mask]
# filtered_my_df

my_df = my_df.loc[~too_short_mask][[
  's_author_trim', 's_title_trim', 's_content_trim']]
print("my_df size", len(my_df))

In [None]:
import re
result_dict = {
    's_author_trim': [],
    's_title_trim': [],
    's_content_trim': [],
}
for i, row in my_df.iterrows():
  c = row['s_content_trim']
  snippets = list(re.split('，|。|？', c))
  lens = [len(s) for s in snippets if s.strip() != '']
  if max(lens) != min(lens) or max(lens) not in [5, 7]:
    continue
  result_dict['s_author_trim'].append(row['s_author_trim'])
  result_dict['s_title_trim'].append(row['s_title_trim'])
  result_dict['s_content_trim'].append(row['s_content_trim'])
# print("get rid of ", sum(bad_items))
my_df = pd.DataFrame(data=result_dict)
print("left", len(my_df))

In [None]:
my_df.sample(100)

In [None]:
AUTHOR_PROMPT = "模仿："
TITLE_PROMPT = "作诗："
EOS_TOKEN = '</s>'
def build_dataset_df(df, include_author=True):
  dfc = df.copy()
  if include_author:
    dfc['source_text'] = TITLE_PROMPT + df['s_title_trim'] + EOS_TOKEN + AUTHOR_PROMPT + df['s_author_trim']
  else:
    dfc['source_text'] = TITLE_PROMPT + df['s_title_trim']
  dfc['target_text'] = df['s_content_trim']
  dfc = dfc[['source_text', 'target_text']]
  return dfc

In [None]:
df_author_title_content = build_dataset_df(my_df, True)
df_author_title_content[100:105]

In [None]:
df_title_content = build_dataset_df(my_df, False)
df_title_content[100:105]

In [None]:
merged_df = pd.concat([df_author_title_content, df_title_content])

In [None]:
merged_df = merged_df.sample(frac=1.)
merged_df

## Modeling

In [None]:
# Quiet install simple T5 package
!pip install -q simplet5 &> /dev/null

In [None]:
import torch
from simplet5 import SimpleT5
from transformers import T5Tokenizer, T5ForConditionalGeneration

In [None]:
torch.cuda.empty_cache() 

In [None]:
class MengziSimpleT5(SimpleT5):
  def __init__(self) -> None:
    super().__init__()
    self.device = torch.device("cuda")

  def load_my_model(self, use_gpu: bool = True):
    self.tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")
    self.model = T5ForConditionalGeneration.from_pretrained("Langboat/mengzi-t5-base")

In [None]:
model = MengziSimpleT5()
model.load_my_model()
model.model = model.model.to('cuda')

In [None]:
model.tokenizer("桥形通汉上，峰势接云危。</s>烟霞交隐映，花鸟自参差。")

In [None]:
model.tokenizer.decode([1012, 955, 406, 921, 23, 3, 1440, 2180, 799, 355, 4008, 4, 1, 1448, 4152, 690, 3934, 4990, 3, 17544, 178, 2572, 769, 4, 1])

In [None]:
from sklearn.model_selection import train_test_split
merged_df = merged_df.sample(frac=1) # Shuffle
train_df, eval_df = train_test_split(merged_df, test_size=0.02)

In [None]:
print("train", len(train_df), "eval", len(eval_df))

In [None]:
model.train(train_df=train_df,
            eval_df=eval_df, 
            source_max_token_len=(len(TITLE_PROMPT) + MAX_TITLE_CHAR +  1 + len(AUTHOR_PROMPT) + MAX_AUTHOR_CHAR),
            target_max_token_len=MAX_CONTENT_CHAR, 
            batch_size=128,
            max_epochs=3,
            use_gpu=True,
            outputdir="/content/drive/MyDrive/ML/Models/t5-poem-v2")