<a href="https://colab.research.google.com/github/hululuzhu/chinese-ai-writing-share/blob/main/Inference_T5_Finetune_Chinese_Couplet_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inference for models trained from [T5 chinese couplet training colab](https://github.com/hululuzhu/chinese-ai-writing-share/blob/main/Mengzi_T5_Finetune_Chinese_Couplet_V1.ipynb)
- Download my saved models at [drive link](https://drive.google.com/drive/folders/1bQb_nrHHLkDYj09P2rrX7PSvHD8h3cTx?usp=sharing)

## Load package and previously trained models

In [1]:
# Quite install simple T5 package
!pip install -q simplet5 &> /dev/null

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [13]:
!mkdir -p my_t5/finetuned
!cp /content/drive/MyDrive/ML/Models/t5-couplet/simplet5-epoch-1-train-loss-3.3605/* my_t5/finetuned

In [14]:
!ls my_t5/finetuned -l

total 967956
-rw------- 1 root root       706 Feb  7 07:05 config.json
-rw------- 1 root root 990438349 Feb  7 07:05 pytorch_model.bin
-rw------- 1 root root      1786 Feb  7 07:05 special_tokens_map.json
-rw------- 1 root root    725135 Feb  7 07:05 spiece.model
-rw------- 1 root root      1961 Feb  7 07:05 tokenizer_config.json


In [15]:
import torch
from simplet5 import SimpleT5
from transformers import T5Tokenizer, T5ForConditionalGeneration

class MengziSimpleT5(SimpleT5):
  def __init__(self) -> None:
    super().__init__()
    self.device = torch.device("cuda")

  def load_my_model(self, use_gpu: bool = True):
    self.tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")
    self.model = T5ForConditionalGeneration.from_pretrained("my_t5/finetuned")

In [16]:
model = MengziSimpleT5()
model.load_my_model()
model.model = model.model.to('cuda')

COUPLET_PROMPOT = '对联：'
MAX_SEQ_LEN = 32
MAX_OUT_TOKENS = MAX_SEQ_LEN

def predict_now(in_str, model=model):
  model.model = model.model.to('cuda')
  in_request = f"{COUPLET_PROMPOT}{in_str[:MAX_SEQ_LEN]}"
  print(in_str, model.predict(
      in_request,
      max_length=MAX_OUT_TOKENS,
      num_beams=1,
      top_p=1.0,
      top_k=1,
      do_sample=False)) # topp, num_beams ...

## Inference now
- Note we turned off sampling to see determistic results for comparison

In [17]:
print("Epoch 2:\n")
for pre in ['欢天喜地度佳节', '不待鸣钟已汗颜，重来试手竟何艰',
            '当年欲跃龙门去，今日真披马革还', '载歌在谷']:
  predict_now(pre)

Epoch 2:

欢天喜地度佳节 ['笑语欢歌迎新春']
不待鸣钟已汗颜，重来试手竟何艰 ['只缘落日已秋鬓,再起伤心又何妨']
当年欲跃龙门去，今日真披马革还 ['此日重登虎榜来,他年再驾龙舟归']
载歌在谷 ['把酒临风']
