In [1]:
corpus = """
我今天去学校上课。
我今天去吃饭。
你今天去哪里？
我不想去上课。
你不想去学校吗？
"""

In [7]:

import re

def tokenize_text_to_sentences(text: str):
    sents = re.split(r"[。？！\n]", text)
    sents = [s.strip() for s in sents if s.strip()]
    return sents

def tokenize_sentence_char_level(sent: str):
    sent = re.sub(r"\s+", "", sent)
    return list(sent)

def add_sentence_tokens(tokens, n):
    return ["<s>"] * (n - 1) + tokens + ["</s>"]

In [8]:
sents = tokenize_text_to_sentences(corpus)
print(sents)
for s in sents:
    print(s, "->", tokenize_sentence_char_level(s))

['我今天去学校上课', '我今天去吃饭', '你今天去哪里', '我不想去上课', '你不想去学校吗']
我今天去学校上课 -> ['我', '今', '天', '去', '学', '校', '上', '课']
我今天去吃饭 -> ['我', '今', '天', '去', '吃', '饭']
你今天去哪里 -> ['你', '今', '天', '去', '哪', '里']
我不想去上课 -> ['我', '不', '想', '去', '上', '课']
你不想去学校吗 -> ['你', '不', '想', '去', '学', '校', '吗']


In [40]:
from collections import defaultdict

def build_ngram_counts(sentences, n:int):
    context_counts = defaultdict(int)
    next_counts = defaultdict(int)

    for sent in sentences:
        tokens = tokenize_sentence_char_level(sent)
        tokens = add_sentence_tokens(tokens, n)
        
        for i in range (n - 1, len(tokens)):
            context = tuple(tokens[i - (n-1):i])
            nxt = tokens[i]
            context_counts[context] += 1
            next_counts[(context, nxt)] += 1

    return context_counts, next_counts

In [41]:
n = 3  # trigram
context_counts, next_counts = build_ngram_counts(sents, n)

print("context example:", list(context_counts.items())[:5])
print("next example:", list(next_counts.items())[:5])

context example: [(('<s>', '<s>'), 5), (('<s>', '我'), 3), (('我', '今'), 2), (('今', '天'), 3), (('天', '去'), 3)]
next example: [((('<s>', '<s>'), '我'), 3), ((('<s>', '我'), '今'), 2), ((('我', '今'), '天'), 2), ((('今', '天'), '去'), 3), ((('天', '去'), '学'), 1)]


In [32]:
import random

def sample_next_token(context, next_counts):
    candidates = []
    weights = []

    for (ctx, nxt), c in next_counts.items():
        if ctx == context:
            candidates.append(nxt)
            weights.append(c)

    if not candidates:
        return None

    return random.choices(candidates, weights = weights, k=1)[0]

In [34]:
def generate_text(n, next_counts, max_len = 50):
    context = tuple(["<s>"] * (n - 1))
    output = []

    for _ in range(max_len):
        nxt = sample_next_token(context, next_counts)
        if nxt is None:
            break
        if nxt == "</s>":
            break

        output.append(nxt)
        context = tuple(list(context[1:]) + [nxt])

    return "".join(output)

In [35]:
for _ in range(10):
    print(generate_text(n, context_counts, next_counts))

你今天去学校上课
我今天去哪里
你今天去哪里
我今天去吃饭
我不想去学校上课
我今天去哪里
你不想去上课
我不想去学校上课
我今天去吃饭
我今天去哪里
