In [1]:
%pip install transformers torch

You should consider upgrading via the '/Users/danilkladnitsky/.pyenv/versions/3.10.4/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [14]:
import time
from collections import defaultdict

def compare_models(models, word="喜欢", max_length=60, num_return_sequences=3):
    from transformers import GPT2LMHeadModel, AutoTokenizer
    import torch
    import re

    results = defaultdict(list)
    
    for model_path in models:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = GPT2LMHeadModel.from_pretrained(model_path).eval()
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)

        prompt = f"请用词语“{word}”造句："
        inputs = tokenizer(prompt, return_tensors="pt").to(device)

        start_time = time.time()
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_length=max_length,
                do_sample=True,
                top_k=50,
                top_p=0.9,
                temperature=0.4,
                num_return_sequences=num_return_sequences,
                repetition_penalty=1.2,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        elapsed = time.time() - start_time

        decoded_outputs = []
        keyword_count = 0
        lengths = []

        for output in outputs:
            decoded = tokenizer.decode(output, skip_special_tokens=True)
            sentence = decoded.replace(prompt, "").replace(" ", "").strip()
            sentence = re.split(r"[。！？]", sentence)[0] + "。"
            decoded_outputs.append(sentence)
            if word in sentence:
                keyword_count += 1
            lengths.append(len(sentence))

        results[model_path] = {
            "samples": decoded_outputs,
            "keyword_coverage": keyword_count / num_return_sequences,
            "avg_length": sum(lengths) / len(lengths),
            "inference_time_sec": elapsed,
        }

    return results

In [32]:
from pprint import pprint

models = [
    "../models/grid_search/run_0/checkpoint-312",
    "../models/grid_search/run_1/checkpoint-312",
    "../models/grid_search/run_2/checkpoint-312",
]

comparison = compare_models(models, word="喜欢", num_return_sequences=3)
pprint(comparison)

defaultdict(<class 'list'>,
            {'../models/grid_search/run_0/checkpoint-312': {'avg_length': 15.666666666666666,
                                                            'inference_time_sec': 0.831632137298584,
                                                            'keyword_coverage': 1.0,
                                                            'samples': ['请用词语“喜欢”造句:她是我的。',
                                                                        '请用词语“喜欢”造句:她有谁。',
                                                                        '请用词语“喜欢”造句:你是我的。']},
             '../models/grid_search/run_1/checkpoint-312': {'avg_length': 17.333333333333332,
                                                            'inference_time_sec': 0.8303139209747314,
                                                            'keyword_coverage': 1.0,
                                                            'samples': ['请用词语“喜欢”造句:我高兴了。',
                                          

In [41]:
import torch
import torch.nn.functional as F

def compute_perplexity(model, tokenizer, sentences: list, max_length: int = 512) -> float:
    """
    Compute perplexity of a language model on a list of sentences.

    Args:
        model: A language model (e.g., GPT2LMHeadModel)
        tokenizer: Corresponding tokenizer
        sentences (list): A list of strings to evaluate
        max_length (int): Maximum length to truncate input to
    
    Returns:
        float: Perplexity score (lower is better)
    """
    model.eval()
    device = next(model.parameters()).device

    total_loss = 0.0
    total_tokens = 0

    for sent in sentences:
        inputs = tokenizer(sent, return_tensors="pt", truncation=True, max_length=max_length)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            output = model(**inputs, labels=inputs["input_ids"])
            loss = output.loss

        num_tokens = inputs["input_ids"].numel()
        total_loss += loss.item() * num_tokens
        total_tokens += num_tokens

    if total_tokens == 0:
        return float("inf")  # avoid division by zero

    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    return perplexity

In [48]:
import json
import torch
from transformers import GPT2LMHeadModel, AutoTokenizer

def compute_perplexity(model, tokenizer, sentences: list, max_length: int = 512) -> float:
    model.eval()
    device = next(model.parameters()).device

    total_loss = 0.0
    total_tokens = 0
    skipped = 0

    for sent in sentences:
        inputs = tokenizer(sent, return_tensors="pt", truncation=True, max_length=max_length)
        if inputs["input_ids"].numel() == 0:
            skipped += 1
            continue

        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            output = model(**inputs, labels=inputs["input_ids"])
            loss = output.loss

        num_tokens = inputs["input_ids"].numel()
        total_loss += loss.item() * num_tokens
        total_tokens += num_tokens

    if total_tokens == 0:
        print("⚠️ No tokens were processed. All inputs may be empty or not tokenizable by this model.")
        return float("nan")

    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    return perplexity
# Paths
models = [
    "../models/grid_search/run_0/checkpoint-312",
    "../models/grid_search/run_1/checkpoint-312",
    "../models/grid_search/run_2/checkpoint-312"
]

hsk_dataset_path_json = "../datasets/hsk1-dataset.json"

# Load evaluation sentences
# Load line-delimited JSON (NDJSON-style)
with open(hsk_dataset_path_json, "r", encoding="utf-8") as f:
    data = [json.loads(line) for line in f if line.strip()]

# Expecting format: [{ "prompt": ..., "completion": ... }]
eval_sentences = [entry["completion"].strip() for entry in data if "completion" in entry]

# Evaluate perplexity per model
for model_path in models:
    print(f"Evaluating: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = GPT2LMHeadModel.from_pretrained(model_path).to("cuda" if torch.cuda.is_available() else "cpu")

    ppl = compute_perplexity(model, tokenizer, eval_sentences)
    print(f"🔍 Perplexity for {model_path}: {ppl:.2f}\n")

Evaluating: ../models/grid_search/run_0/checkpoint-312
🔍 Perplexity for ../models/grid_search/run_0/checkpoint-312: nan

Evaluating: ../models/grid_search/run_1/checkpoint-312
🔍 Perplexity for ../models/grid_search/run_1/checkpoint-312: nan

Evaluating: ../models/grid_search/run_2/checkpoint-312
🔍 Perplexity for ../models/grid_search/run_2/checkpoint-312: nan



In [25]:
import jieba
import re

def calculate_hsk_coverage(sentences: list, target_word: str, hsk_vocab: set) -> dict:
    """
    Check HSK word coverage and target word presence in generated Chinese sentences.
    Uses jieba for tokenization and ignores punctuation.

    Args:
        sentences (list): List of generated Chinese sentences.
        target_word (str): The word that must appear in each sentence.
        hsk_vocab (set): Set of allowed HSK words (characters or full words).
    
    Returns:
        dict: Summary statistics and details per sentence.
    """
    punctuation_pattern = r"[，。！？、,.!?；;：“”\"'（）()【】\[\]《》<>]"

    results = []
    in_vocab_sentences = 0
    target_present_count = 0

    for sentence in sentences:
        sentence = sentence.strip().replace(" ", "")
        sentence_clean = re.sub(punctuation_pattern, "", sentence)

        tokens = list(jieba.cut(sentence_clean))
        in_vocab = all(token in hsk_vocab for token in tokens)
        contains_target = target_word in sentence_clean

        if in_vocab:
            in_vocab_sentences += 1
        if contains_target:
            target_present_count += 1

        results.append({
            "original_sentence": sentence,
            "cleaned_sentence": sentence_clean,
            "tokens": tokens,
            "all_in_vocab": in_vocab,
            "contains_target": contains_target,
            "unknown_tokens": [t for t in tokens if t not in hsk_vocab]
        })

    return {
        "total": len(sentences),
        "target_word": target_word,
        "target_word_coverage": target_present_count / len(sentences),
        "full_vocab_coverage": in_vocab_sentences / len(sentences),
        "details": results
    }

In [42]:
from pprint import pprint

def load_hsk_vocab(filepath: str) -> set:
    """
    Load HSK vocabulary from a file and return it as a set of words.

    Args:
        filepath (str): Path to the file containing HSK words (one per line).
    
    Returns:
        set: Set of unique HSK words or characters.
    """
    with open(filepath, "r", encoding="utf-8") as f:
        vocab = {line.strip() for line in f if line.strip()}
    return vocab

hsk_vocab = load_hsk_vocab("../datasets/vocabulary/hsk1.txt")

print(f"Loaded {len(hsk_vocab)} HSK words.")
print(list(hsk_vocab)[:10])  # Show first 10 entries

sentences = [
    "我非常喜欢吃苹果。",
    "他喜欢看电影。",
    "小明喜欢踢足球。",
    "老师喜欢认真听讲的学生。",
    "她喜欢安静的环境。",
    "我爱吃披萨。"
]

report = calculate_hsk_coverage(sentences, target_word="喜欢", hsk_vocab=hsk_vocab)

pprint(report)

Loaded 151 HSK words.
['妈妈', '叫', '出租车', '前面', '的', '住', '几', '本', '名字', '吗']


TypeError: calculate_hsk_coverage() got an unexpected keyword argument 'target_word'

In [39]:
import jieba
import torch
import re
from transformers import GPT2LMHeadModel, AutoTokenizer

def remove_punctuation(text: str) -> str:
    return re.sub(r"[，。！？、,.!?；;：“”\"'（）()【】\[\]《》<>]", "", text)

def evaluate_models_on_hsk(models: list, hsk_vocab: set, max_length: int = 50) -> list:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    skip_tokens = {"用", "词语", "造句", "请", ":"}
    summaries = []

    for model_path in models:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = GPT2LMHeadModel.from_pretrained(model_path).to(device).eval()

        generated_sentences = []
        hsk_words = list(hsk_vocab)
        target_words = []

        for word in hsk_words:
            prompt = f"请用词语“{word}”造句："
            inputs = tokenizer(prompt, return_tensors="pt").to(device)

            with torch.no_grad():
                output = model.generate(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],
                    max_length=max_length,
                    do_sample=True,
                    top_k=50,
                    top_p=0.9,
                    temperature=0.7,
                    repetition_penalty=1.2,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )

            decoded = tokenizer.decode(output[0], skip_special_tokens=True)
            sentence = decoded.replace(prompt, "").replace(" ", "").strip()
            sentence = re.split(r"[。！？]", sentence)[0] + "。"  # stop at first major punctuation
            generated_sentences.append(sentence)
            target_words.append(word)

        # Calculate coverage
        total = len(generated_sentences)
        hsk_compliant_count = 0
        target_hit_count = 0
        total_unknown_tokens = 0
        unknown_token_set = set()

        for sentence, target_word in zip(generated_sentences, target_words):
            cleaned = remove_punctuation(sentence).replace(" ", "")
            cleaned = ''.join([ch for ch in cleaned if ch not in skip_tokens])

            tokens = list(jieba.cut(cleaned))
            unknown_tokens = [t for t in tokens if any(c not in hsk_vocab for c in t)]

            # ✅ Check if all characters (from all tokens) are in HSK vocab
            all_chars = [char for token in tokens for char in token]
            if all(char in hsk_vocab for char in all_chars):
                hsk_compliant_count += 1

            if target_word in cleaned:
                target_hit_count += 1

            total_unknown_tokens += len(unknown_tokens)
            unknown_token_set.update(unknown_tokens)

        summaries.append({
            "model_path": model_path,
            "target_word_coverage": target_hit_count / total,
            "full_vocab_coverage": hsk_compliant_count / total,
            "avg_unknown_tokens_per_sentence": total_unknown_tokens / total,
            "unknown_tokens": sorted(list(unknown_token_set))
        })

    return summaries

In [40]:
# Load HSK vocab
from pprint import pprint

hsk_vocab = load_hsk_vocab("../datasets/vocabulary/hsk1.txt")

# Define models
models = [
    "../models/grid_search/run_0/checkpoint-312",
    "../models/grid_search/run_1/checkpoint-312",
    "../models/grid_search/run_2/checkpoint-312"
]

# cropped_set
hsk_vocab = hsk_vocab

# Evaluate
reports = evaluate_models_on_hsk(models, hsk_vocab)

pprint(reports)

[{'avg_unknown_tokens_per_sentence': 3.3311258278145695,
  'full_vocab_coverage': 0.0,
  'model_path': '../models/grid_search/run_0/checkpoint-312',
  'target_word_coverage': 0.9867549668874173,
  'unknown_tokens': ['一点儿',
                     '一起',
                     '上午',
                     '下午',
                     '下雨',
                     '东西',
                     '中午',
                     '中国',
                     '为什么',
                     '人造',
                     '什么',
                     '今天',
                     '儿子',
                     '先生',
                     '再见',
                     '出去',
                     '出租车',
                     '分钟',
                     '前面',
                     '北京',
                     '医生',
                     '医院',
                     '句',
                     '吃饭',
                     '同学',
                     '名字',
                     '后面',
                     '哪儿',
                     '哪里',
                    