<a href="https://colab.research.google.com/github/hululuzhu/chinese-ai-writing-share/blob/main/Inference_T5_Finetune_Chinese_Couplet_V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inference for models trained from [T5 chinese couplet training colab](https://github.com/hululuzhu/chinese-ai-writing-share/blob/main/Mengzi_T5_Finetune_Chinese_Couplet_V1.ipynb)

## Load package and previously trained models

In [1]:
# Quite install simple T5 package
!pip install -q simplet5 &> /dev/null

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!mkdir -p my_t5/finetune_epoch1
!cp /content/drive/MyDrive/ML/Models/t5-couplet/simplet5-epoch-0-train-loss-3.8253/* my_t5/finetune_epoch1

In [4]:
!ls my_t5/finetune_epoch1 -l

total 967956
-rw------- 1 root root       706 Feb  7 04:29 config.json
-rw------- 1 root root 990438349 Feb  7 04:29 pytorch_model.bin
-rw------- 1 root root      1786 Feb  7 04:29 special_tokens_map.json
-rw------- 1 root root    725135 Feb  7 04:29 spiece.model
-rw------- 1 root root      1961 Feb  7 04:29 tokenizer_config.json


In [7]:
import torch
from simplet5 import SimpleT5
from transformers import T5Tokenizer, T5ForConditionalGeneration

class MengziSimpleT5(SimpleT5):
  def __init__(self) -> None:
    super().__init__()
    self.device = torch.device("cuda")

  def load_my_model(self, use_gpu: bool = True):
    self.tokenizer = T5Tokenizer.from_pretrained("Langboat/mengzi-t5-base")
    self.model = T5ForConditionalGeneration.from_pretrained("my_t5/finetune_epoch1")

In [16]:
model = MengziSimpleT5()
model.load_my_model()
model.model = model.model.to('cuda')

COUPLET_PROMPOT = '对联：'
MAX_SEQ_LEN = 32
MAX_OUT_TOKENS = MAX_SEQ_LEN

def predict_now(in_str, model=model):
  model.model = model.model.to('cuda')
  in_request = f"{COUPLET_PROMPOT}{in_str[:MAX_SEQ_LEN]}"
  print(in_str, model.predict(
      in_request,
      max_length=min(MAX_OUT_TOKENS, len(in_request) - len(COUPLET_PROMPOT)),
      num_beams=1,
      top_p=1.0,
      top_k=1,
      do_sample=False)) # topp, num_beams ...

## Inference now
- Note we turned off sampling to see determistic results for comparison

In [20]:
print("Epoch 1:\n")
for pre in ['欢天喜地度佳节', '不待鸣钟已汗颜，重来试手竟何艰',
            '当年欲跃龙门去，今日真披马革还', '载歌在谷']:
  predict_now(pre)

Epoch 1:

欢天喜地度佳节 ['喜地欢天迎新春']
不待鸣钟已汗颜，重来试手竟何艰 ['但凭杯酒长精神,一醉忘年犹未归']
当年欲跃龙门去，今日真披马革还 ['此日将登虎榜来,他年再驾马蹄归']
载歌在谷 ['有酒于']
