In [1]:
%pip install transformers torch

You should consider upgrading via the '/Users/danilkladnitsky/.pyenv/versions/3.10.4/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [10]:
from transformers import GPT2LMHeadModel, AutoTokenizer
import torch
import re

class ChineseSentenceGenerator:
    def __init__(self, model_path: str, device: str = None):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = GPT2LMHeadModel.from_pretrained(model_path)

        self.model.eval()
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def generate(self, word: str, max_length: int = 60, num_return_sequences: int = 1) -> list:
        prompt = f"ËØ∑Áî®ËØçËØ≠‚Äú{word}‚ÄùÈÄ†Âè•Ôºö"
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=max_length,
                do_sample=True,
                top_k=50,
                top_p=0.9,
                temperature=0.9,
                num_return_sequences=num_return_sequences,
                repetition_penalty=1.3,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        print(self.tokenizer.pad_token_id, self.tokenizer.eos_token_id)

        results = []
        for output in outputs:
            decoded = self.tokenizer.decode(output, skip_special_tokens=True)
            print(decoded)
            sentence = decoded.replace(prompt, "").replace(" ", "").strip()
            sentence = re.split(r"[„ÄÇÔºÅÔºü]", sentence)[0] + "„ÄÇ"  # stop at first punctuation
            results.append(sentence)

        return results
    
generator = ChineseSentenceGenerator("../models/hsk1-gpt2-jieba")

results = generator.generate("Â≠¶Áîü", max_length=60, num_return_sequences=3)
for i, sent in enumerate(results):
    print(f"üìù Sentence {i+1}: {sent}")

0 None
ËØ∑ Áî® ËØçËØ≠ ‚Äú Â≠¶Áîü ‚Äù ÈÄ†Âè• : ÂæàÂ§ö Â≠¶Áîü Âéª ‰∫Ü „ÄÇ Ëøô Êú¨‰π¶ „ÄÇ „ÄÇ „ÄÇ „ÄÇ ‰∏Ä‰∏™ ÂÜô‰Ωú ÂÆ¢Ê∞î „ÄÇ „ÄÇ ! ‰Ω† „ÄÇ „ÄÇ „ÄÇ ‚Äù ‰∏Ä‰∏™ ? „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ Êàë „ÄÇ ! „ÄÇ „ÄÇ ÁöÑ Èõ® „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ Êú¨ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ ‚Äù „ÄÇ
ËØ∑ Áî® ËØçËØ≠ ‚Äú Â≠¶Áîü ‚Äù ÈÄ†Âè• : ÂæàÂ§ö Â≠¶Áîü ‰π∞ ‰∫Ü Ëøô Êú¨‰π¶ „ÄÇ „ÄÇ „ÄÇ „ÄÇ ! ‰∏Ä‰∏™ „ÄÇ Êúâ Âæà‰ºö „ÄÇ „ÄÇ ‰∫∫Âéª „ÄÇ „ÄÇ „ÄÇ „ÄÇ Âêó ? ? ‚Äù ÊòØ ÁöÑ , ÂØπ Âêó ? , ‰∏çÊòØ ÊòüÊúüÂÖ≠ „ÄÇ „ÄÇ Êàë ÁöÑ . Ë∞¢Ë∞¢ „ÄÇ ‚Äù „ÄÇ „ÄÇ ‚Äù „ÄÇ ‚Äù „ÄÇ „ÄÇ ‚Äù „ÄÇ
ËØ∑ Áî® ËØçËØ≠ ‚Äú Â≠¶Áîü ‚Äù ÈÄ†Âè• : Êúâ‰∏™ Â≠¶Áîü ÊÉ≥ËßÅ ‰Ω† „ÄÇ ‰∫Ü ‰∏Ä‰∏™ ‰∫∫ ! ? „ÄÇ „ÄÇ „ÄÇ ‚Äù ‰∏Ä‰∏™ Âéª ‰∏Ä‰∏ã „ÄÇ „ÄÇ „ÄÇ ÊòØ ÁöÑ , Êàë ‰∏çÊòØ ‰Ω† ÁöÑ ÊúãÂèã „ÄÇ „ÄÇ Êàë Âæà È´òÂÖ¥ „ÄÇ Êù• „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ
üìù Sentence 1: ËØ∑Áî®ËØçËØ≠‚ÄúÂ≠¶Áîü‚ÄùÈÄ†Âè•:ÂæàÂ§öÂ≠¶ÁîüÂéª‰∫Ü„ÄÇ
üìù Sentence 2: ËØ∑Áî®ËØçËØ≠‚ÄúÂ≠¶Áîü‚ÄùÈÄ†Âè•:ÂæàÂ§öÂ≠¶Áîü‰π∞‰∫ÜËøôÊú¨‰π¶„ÄÇ
üìù Sentence 3: ËØ∑Áî®ËØçËØ≠‚ÄúÂ≠¶Áîü‚ÄùÈÄ†Âè•:Êúâ‰∏™Â≠¶ÁîüÊÉ≥