In [1]:
%pip install transformers torch

You should consider upgrading via the '/Users/danilkladnitsky/.pyenv/versions/3.10.4/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [30]:
from transformers import GPT2LMHeadModel, AutoTokenizer
import torch
import re

class ChineseSentenceGenerator:
    def __init__(self, model_path: str, device: str = None):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = GPT2LMHeadModel.from_pretrained(model_path)

        self.model.eval()
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def generate(self, word: str, max_length: int = 60, num_return_sequences: int = 1) -> list:
        prompt = f"ËØ∑Áî®ËØçËØ≠‚Äú{word}‚ÄùÈÄ†Âè•Ôºö"
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=max_length,
                do_sample=True,
                top_k=50,
                top_p=0.9,
                temperature=0.4,
                num_return_sequences=num_return_sequences,
                repetition_penalty=1.2,
                pad_token_id=self.tokenizer.pad_token_id,
                eos_token_id=self.tokenizer.eos_token_id
            )

        print(self.tokenizer.pad_token_id, self.tokenizer.eos_token_id)

        results = []
        for output in outputs:
            decoded = self.tokenizer.decode(output, skip_special_tokens=True)
            print(decoded)
            sentence = decoded.replace(prompt, "").replace(" ", "").strip()
            sentence = re.split(r"[„ÄÇÔºÅÔºü]", sentence)[0] + "„ÄÇ"  # stop at first punctuation
            results.append(sentence)

        return results
    
generator = ChineseSentenceGenerator("../models/hsk1-gpt2-jieba")

results = generator.generate("‰π¶", max_length=60, num_return_sequences=3)
for i, sent in enumerate(results):
    print(f"üìù Sentence {i+1}: {sent}")

0 None
ËØ∑ Áî® ËØçËØ≠ ‚Äú ‰π¶ ‚Äù ÈÄ†Âè• : Êàë ÁöÑ ‰π¶ ‰∏∫‰ªÄ‰πà ‰∏ç Âú® ÂõûÊù• „ÄÇ ‰Ω†‰ª¨ ‰∫Ü „ÄÇ „ÄÇ ‚Äù ‚Äú ! ‚Äù ‚Äú ‰Ω† Âæà Â•Ω „ÄÇ „ÄÇ Âë¢ ? „ÄÇ , ÊòéÂ§© ‰∏ÄÊ†∑ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ
ËØ∑ Áî® ËØçËØ≠ ‚Äú ‰π¶ ‚Äù ÈÄ†Âè• : Êàë ÁöÑ ‰π¶ Êúâ ‰∏Ä‰∫õ ÊúãÂèã „ÄÇ „ÄÇ ‰Ω†‰ª¨ „ÄÇ „ÄÇ „ÄÇ Âêó ? ? ÂØπ , Êàë Âú® ‰∏Ä‰∏™ ÊúãÂèãÂÆ∂ „ÄÇ „ÄÇ ‚Äù „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ ÂØπ „ÄÇ „ÄÇ „ÄÇ „ÄÇ
ËØ∑ Áî® ËØçËØ≠ ‚Äú ‰π¶ ‚Äù ÈÄ†Âè• : Êàë ÁöÑ ‰π¶ ‰∏∫‰ªÄ‰πà ‰∏ç Âú® ÂõûÊù• „ÄÇ „ÄÇ ? „ÄÇ „ÄÇ ‚Äù ‚Äú ‰Ω† ÊòØ ‰ªñ‰ª¨ ÁöÑ „ÄÇ Âêó ? ‚Äù ÊòØ Êàë ÁöÑ „ÄÇ , ‰Ω† ÁöÑ „ÄÇ „ÄÇ ‚Äù „ÄÇ ‰Ω†‰ª¨ ËÆ§ËØÜ „ÄÇ „ÄÇ „ÄÇ „ÄÇ ‚Äù „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ „ÄÇ ‚Äù Êàë „ÄÇ „ÄÇ „ÄÇ ‚Äù
üìù Sentence 1: ËØ∑Áî®ËØçËØ≠‚Äú‰π¶‚ÄùÈÄ†Âè•:ÊàëÁöÑ‰π¶‰∏∫‰ªÄ‰πà‰∏çÂú®ÂõûÊù•„ÄÇ
üìù Sentence 2: ËØ∑Áî®ËØçËØ≠‚Äú‰π¶‚ÄùÈÄ†Âè•:ÊàëÁöÑ‰π¶Êúâ‰∏Ä‰∫õÊúãÂèã„ÄÇ
üìù Sentence 3: ËØ∑Áî®ËØçËØ≠‚Äú‰π¶‚ÄùÈÄ†Âè•:ÊàëÁöÑ‰π¶‰∏∫‰ªÄ‰πà‰∏çÂú®ÂõûÊ