<a href="https://colab.research.google.com/github/hululuzhu/chinese-ai-writing-share/blob/main/Inference_T5_Finetune_Chinese_Couplet_and_Poem_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inference for models trained from [T5 chinese couplet colab](https://github.com/hululuzhu/chinese-ai-writing-share/blob/main/Mengzi_T5_Finetune_Chinese_Couplet_V1.ipynb) and [T5 chinese Poem colab](https://github.com/hululuzhu/chinese-ai-writing-share/blob/main/WIP_Mengzi_T5_Finetune_Chinese_Poem_Writing_V1.ipynb)
- Download my saved models at [couplet model link](https://drive.google.com/drive/folders/1bQb_nrHHLkDYj09P2rrX7PSvHD8h3cTx?usp=sharing) and [poem model link](https://drive.google.com/drive/folders/1ZymaSbOcwlslD5tuUIk_9__C2dUJK_UY?usp=sharing)

## Load package and previously trained models

In [1]:
# Quite install simple T5 package
!pip install -q simplet5 &> /dev/null

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
!mkdir -p my_t5/couplet
!mkdir -p my_t5/poem
!cp /content/drive/MyDrive/ML/Models/t5-couplet/simplet5-epoch-2-train-loss-3.126/* my_t5/couplet
!cp /content/drive/MyDrive/ML/Models/t5-poem/simplet5-epoch-0-train-loss-4.6328/* my_t5/poem

In [4]:
import torch
from simplet5 import SimpleT5
from transformers import T5Tokenizer, T5ForConditionalGeneration

class MengziSimpleT5(SimpleT5):
  def __init__(self) -> None:
    super().__init__()
    self.device = torch.device("cuda")

  def load_my_model(self, local_path, use_gpu: bool = True):
    self.tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")
    self.model = T5ForConditionalGeneration.from_pretrained(local_path)

Global seed set to 42


In [5]:
couplet_model = MengziSimpleT5()
couplet_model.load_my_model(local_path='my_t5/couplet')
couplet_model.model = couplet_model.model.to('cuda')

COUPLET_PROMPOT = '对联：'
MAX_SEQ_LEN = 32
MAX_OUT_TOKENS = MAX_SEQ_LEN

def couplet(in_str, model=couplet_model):
  model.model = model.model.to('cuda')
  in_request = f"{COUPLET_PROMPOT}{in_str[:MAX_SEQ_LEN]}"
  print(f"上： {in_str}\n下：", model.predict(
      in_request,
      max_length=MAX_OUT_TOKENS,
      num_beams=1,
      top_p=1.0,
      top_k=1,
      do_sample=False)[0])

Downloading:   0%|          | 0.00/725k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/659 [00:00<?, ?B/s]

In [8]:
AUTHOR_PROMPT = "模仿："
TITLE_PROMPT = "作诗："
EOS_TOKEN = '</s>'

poem_model = MengziSimpleT5()
poem_model.load_my_model(local_path='my_t5/poem')
poem_model.model = couplet_model.model.to('cuda')
MAX_AUTHOR_CHAR = 4
MAX_TITLE_CHAR = 12
MIN_CONTENT_CHAR = 10
MAX_CONTENT_CHAR = 64

def poem(title_str, opt_author=None, model=poem_model):
  model.model = model.model.to('cuda')
  if opt_author:
    in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR] + EOS_TOKEN + AUTHOR_PROMPT + opt_author[:MAX_AUTHOR_CHAR]
  else:
    in_request = TITLE_PROMPT + title_str[:MAX_TITLE_CHAR]
  print(f"标题： {in_request}\n诗歌：", model.predict(
      in_request,
      max_length=MAX_CONTENT_CHAR,
      num_beams=1,
      top_p=1.0,
      top_k=1,
      do_sample=False)[0])

## Inference now
- Note we turned off sampling to see determistic results for comparison

In [9]:
# epoch 3 after 6 hours, looks good enough
for pre in ['欢天喜地度佳节', '不待鸣钟已汗颜，重来试手竟何艰',
            '当年欲跃龙门去，今日真披马革还', '载歌在谷',
            '北国风光，千里冰封，万里雪飘',
            '独立寒秋，湘江北去，橘子洲头']:
  couplet(pre)

上： 欢天喜地度佳节
下： 笑语欢歌迎新春
上： 不待鸣钟已汗颜，重来试手竟何艰
下： 何须击鼓犹昂首?再起杀心应有功
上： 当年欲跃龙门去，今日真披马革还
下： 今日欲乘虎势来,明朝又见马蹄飞
上： 载歌在谷
下： 对酒当歌
上： 北国风光，千里冰封，万里雪飘
下： 南疆气象,一城春暖?八方客来
上： 独立寒秋，湘江北去，橘子洲头
下： 孤眠冷月,玉笛西来?琵琶指间


In [12]:
# Epoch 1 after 2 hours, looks pretty bad
for title in ['秋思', '百度', "湾区春日之谜"]:
  poem(title)

标题： 作诗：秋思
诗歌： 春思
标题： 作诗：百度
诗歌： 三更
标题： 作诗：湾区春日之谜
诗歌： 江上秋波为鉴
