In [1]:
!pip install -qU bertviz

In [2]:
from __future__ import annotations

import json
import random
import string
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [3]:
ANIMALS = [
  "cat",
  "caracal",
  "capybara",
  "canary",
  "cavy",
  "caiman",
  "cacomistle",
  "caribou",
  "cassowary",
  "caterpillar",
  "dog",
  "dalmatian",
  "dachshund",
  "doberman",
  "duck",
  "dingo",
  "lion",
  "tiger",
  "leopard",
  "cheetah",
  "puma",
  "jaguar",
  "lynx",
  "ocelot",
  "serval",
  "bobcat",
  "cougar",
  "panther",
  "wolf",
  "fox",
  "jackal",
  "coyote",
  "hyena",
  "bear",
  "polar",
  "grizzly",
  "slothbear",
  "panda",
  "koala",
  "kangaroo",
  "wallaby",
  "opossum",
  "wombat",
  "tasmanian",
  "rabbit",
  "hare",
  "mouse",
  "rat",
  "gerbil",
  "hamster",
  "guinea",
  "squirrel",
  "chipmunk",
  "beaver",
  "porcupine",
  "hedgehog",
  "shrew",
  "mole",
  "bat",
  "armadillo",
  "antelope",
  "gazelle",
  "impala",
  "gnu",
  "eland",
  "springbok",
  "deer",
  "moose",
  "elk",
  "reindeer",
  "stag",
  "doe",
  "fawn",
  "buffalo",
  "bison",
  "yak",
  "zebu",
  "cow",
  "bull",
  "ox",
  "calf",
  "sheep",
  "lamb",
  "ram",
  "goat",
  "kid",
  "ibex",
  "chamois",
  "camel",
  "dromedary",
  "llama",
  "alpaca",
  "vicuna",
  "horse",
  "mare",
  "stallion",
  "colt",
  "foal",
  "donkey",
  "mule",
  "zebra",
  "rhinoceros",
  "hippopotamus",
  "pig",
  "boar",
  "hog",
  "swine",
  "babirusa",
  "tapir",
  "elephant",
  "mammoth",
  "mastodon",
  "dugong",
  "manatee",
  "whale",
  "dolphin",
  "porpoise",
  "seal",
  "seaotter",
  "walrus",
  "otter",
  "weasel",
  "ferret",
  "marten",
  "ermine",
  "badger",
  "skunk",
  "wolverine",
  "mongoose",
  "meerkat",
  "civet",
  "genet",
  "fossa",
  "bearcat",
  "platypus",
  "echidna",
  "pangolin",
  "aardvark",
  "aardwolf",
  "okapi",
  "giraffe",
  "monkey",
  "baboon",
  "mandrill",
  "macaque",
  "langur",
  "gibbon",
  "gorilla",
  "chimpanzee",
  "bonobo",
  "orangutan",
  "lemur",
  "tarsier",
  "loris",
  "ayeaye",
  "sloth",
  "anteater",
  "tamandua",
  "kitten",
  "puppy",
  "duckling",
  "gosling",
  "cygnet",
  "eagle",
  "hawk",
  "falcon",
  "osprey",
  "vulture",
  "buzzard",
  "kite",
  "owl",
  "barnowl",
  "tawnyowl",
  "screechowl",
  "snowyowl",
  "parrot",
  "macaw",
  "cockatoo",
  "budgerigar",
  "lovebird",
  "lorikeet",
  "conure",
  "parakeet",
  "kingfisher",
  "woodpecker",
  "toucan",
  "hornbill",
  "cuckoo",
  "cuckooshrike",
  "nightjar",
  "swift",
  "hummingbird",
  "swallow",
  "martin",
  "wren",
  "warbler",
  "thrush",
  "blackbird",
  "starling",
  "mockingbird",
  "finch",
  "canary",
  "sparrow",
  "bunting",
  "lark",
  "pipit",
  "wagtail",
  "robin",
  "chat",
  "wheatear",
  "dipper",
  "nuthatch",
  "treecreeper",
  "tit",
  "chickadee",
  "jay",
  "magpie",
  "crow",
  "raven",
  "rook",
  "jackdaw",
  "chough",
  "shrike",
  "oriole",
  "drongo",
  "bulbul",
  "mina",
  "weaver",
  "whydah",
  "waxbill",
  "munia",
  "manakin",
  "cotinga",
  "antbird",
  "ovenbird",
  "woodcreeper",
  "flycatcher",
  "tyrant",
  "pewee",
  "kingbird",
  "boatbill",
  "motmot",
  "tody",
  "jacamar",
  "puffbird",
  "barbet",
  "toucanet",
  "ani",
  "turaco",
  "hoatzin",
  "bustard",
  "crane",
  "heron",
  "egret",
  "bittern",
  "stork",
  "ibis",
  "spoonbill",
  "flamingo",
  "swan",
  "goose",
  "teal",
  "wigeon",
  "shoveler",
  "pintail",
  "scaup",
  "pochard",
  "canvasback",
  "redhead",
  "goldeneye",
  "merganser",
  "eider",
  "scoter",
  "shelduck",
  "woodduck",
  "mandarin",
  "mallard",
  "gadwall",
  "grebe",
  "loon",
  "penguin",
  "albatross",
  "petrel",
  "shearwater",
  "prion",
  "stormpetrel",
  "fulmar",
  "gannet",
  "booby",
  "cormorant",
  "shag",
  "anhinga",
  "frigatebird",
  "tropicbird",
  "pelican",
  "darter",
  "gull",
  "tern",
  "skimmer",
  "auk",
  "murre",
  "puffin",
  "guillemot",
  "razorbill",
  "dovekie",
  "murrelet",
  "kiwi",
  "emu",
  "rhea",
  "ostrich",
  "tinamou",
  "rail",
  "crake",
  "gallinule",
  "coot",
  "limpkin",
  "buttonquail",
  "plover",
  "lapwing",
  "dotterel",
  "killdeer",
  "oystercatcher",
  "avocet",
  "stilt",
  "phalarope",
  "jacana",
  "sandpiper",
  "snipe",
  "curlew",
  "godwit",
  "dowitcher",
  "stint",
  "ruff",
  "turnstone",
  "knot",
  "pratincole",
  "courser",
  "skua",
  "jaeger",
  "eel",
  "salmon",
  "trout",
  "carp",
  "catfish",
  "cobia",
  "cod",
  "coelacanth",
  "flounder",
  "goby",
  "grouper",
  "guppy",
  "haddock",
  "hake",
  "halibut",
  "koi",
  "mackerel",
  "minnow",
  "perch",
  "pike",
  "pollock",
  "sardine",
  "shad",
  "smelt",
  "snapper",
  "sole",
  "sturgeon",
  "tilapia",
  "tuna",
  "wahoo",
  "zander",
  "anchovy",
  "barracuda",
  "bass",
  "blenny",
  "bluegill",
  "bonito",
  "bream",
  "butterfish",
  "capelin",
  "char",
  "clownfish",
  "drum",
  "grunion",
  "herring",
  "killifish",
  "lamprey",
  "lionfish",
  "loach",
  "molly",
  "mudskipper",
  "needlefish",
  "parrotfish",
  "pompano",
  "scad",
  "sculpin",
  "seahorse",
  "shark",
  "skate",
  "sprat",
  "sucker",
  "sunfish",
  "surgeonfish",
  "tarpon",
  "tetra",
  "trevally",
  "triggerfish",
  "wrasse",
  "tang",
  "abalone",
  "barnacle",
  "clam",
  "cockle",
  "conch",
  "crab",
  "crawfish",
  "krill",
  "limpet",
  "lobster",
  "mussel",
  "nautilus",
  "oyster",
  "periwinkle",
  "prawn",
  "scallop",
  "shrimp",
  "snail",
  "squid",
  "octopus",
  "urchin",
  "worm",
  "beetle",
  "butterfly",
  "caterpillar",
  "dragonfly",
  "earwig",
  "firefly",
  "flea",
  "grasshopper",
  "ladybug",
  "mantis",
  "moth",
  "termite",
  "tick",
  "wasp",
  "weevil",
  "aphid",
  "ant",
  "bee",
  "bug",
  "cricket",
  "damselfly",
  "fly",
  "gnat",
  "hornet",
  "mayfly",
  "mosquito",
  "silverfish",
  "spider",
  "centipede",
  "millipede",
  "scorpion",
  "copepod",
  "isopod",
  "amphipod",
  "woodlouse",
  "horseshoecrab",
  "arachnid",
  "mite",
  "tarantula",
  "fruitfly"
]

In [4]:
FRUITS_VEGGIES = [
  "apple",
  "apricot",
  "avocado",
  "artichoke",
  "banana",
  "bilberry",
  "blackberry",
  "blueberry",
  "boysenberry",
  "breadfruit",
  "cantaloupe",
  "casaba",
  "carambola",
  "cherimoya",
  "cherry",
  "cloudberry",
  "coconut",
  "cranberry",
  "currant",
  "date",
  "elderberry",
  "fig",
  "gooseberry",
  "grape",
  "grapefruit",
  "guava",
  "honeydew",
  "jackfruit",
  "jambul",
  "jujube",
  "kiwi",
  "kumquat",
  "lemon",
  "lime",
  "loquat",
  "lychee",
  "mandarin",
  "mango",
  "mangosteen",
  "melon",
  "mulberry",
  "nectarine",
  "olive",
  "orange",
  "papaya",
  "passionfruit",
  "peach",
  "pear",
  "persimmon",
  "pineapple",
  "plum",
  "pomegranate",
  "pomelo",
  "quince",
  "raspberry",
  "redcurrant",
  "salak",
  "satsuma",
  "starfruit",
  "strawberry",
  "tamarillo",
  "tamarind",
  "tangelo",
  "ugli",
  "watermelon",
  "yuzu",
  "zucchini",
  "carrot",
  "cabbage",
  "cauliflower",
  "cassava",
  "celery",
  "chard",
  "chicory",
  "collard",
  "corn",
  "cress",
  "cucumber",
  "daikon",
  "edamame",
  "eggplant",
  "endive",
  "fennel",
  "garlic",
  "ginger",
  "horseradish",
  "jicama",
  "kale",
  "kohlrabi",
  "leek",
  "lettuce",
  "okra",
  "onion",
  "parsnip",
  "pea",
  "pepper",
  "potato",
  "pumpkin",
  "radish",
  "rutabaga",
  "shallot",
  "spinach",
  "squash",
  "sweetcorn",
  "sweetpotato",
  "tomato",
  "turnip",
  "wasabi",
  "yam",
  "macadamia",
  "pecan",
  "cashew",
  "hazelnut",
  "walnut",
  "almond",
  "brazilnut",
  "chestnut",
  "pistachio",
  "pine",
  "acorn",
  "watercress",
  "caper",
  "cardoon",
  "canna",
  "caraway",
  "carob",
  "camu",
  "camote",
  "canistel",
  "canola",
  "capers",
  "carissa",
  "catjang",
  "cavendish",
  "cayenne",
  "celeriac",
  "chayote",
  "cilantro",
  "clementine",
  "cornsalad",
  "courgette",
  "currant",
  "cushaw",
  "dandelion",
  "dill",
  "durian",
  "endive",
  "escarole",
  "fiddlehead",
  "frisee",
  "gourd",
  "jostaberry",
  "kohlrabi",
  "lablab",
  "luffa",
  "malanga",
  "mangetout",
  "mungbean",
  "navybean",
  "nopale",
  "onionchive",
  "parsley",
  "parsnip",
  "pattypan",
  "peasnap",
  "persimmon",
  "pigeonpea",
  "plantain",
  "pluot",
  "pomegranate",
  "prune",
  "pumpkin",
  "radicchio",
  "rambutan",
  "rapini",
  "rocket",
  "rutabaga",
  "salsify",
  "sapote",
  "scallion",
  "shallot",
  "snowpea",
  "sorrel",
  "soybean",
  "spelt",
  "squash",
  "tamarind",
  "tangelo",
  "tatsoi",
  "tomatillo",
  "tuber",
  "turnip",
  "waterchestnut",
  "watermelon",
  "waxgourd",
  "yambean",
  "yautia",
  "yuca",
  "ziziphus",
  "zucchini",
  "acerola",
  "ackee",
  "ambarella",
  "arugula",
  "asparagus",
  "azuki",
  "bamboo",
  "basil",
  "bean",
  "beet",
  "bellpepper",
  "betel",
  "bokchoy",
  "broccoli",
  "broccolini",
  "brusselsprout",
  "burdock",
  "butternut",
  "calabash",
  "calamansi",
  "canarymelon",
  "cantaloupe",
  "capuli",
  "carambola",
  "carrot",
  "cassava",
  "cauliflower",
  "celery",
  "chamomile",
  "cherry",
  "chickpea",
  "chicory",
  "chives",
  "cilantro",
  "citrus",
  "collards",
  "coriander",
  "courgette",
  "cranberry",
  "cress",
  "cucumber",
  "currant",
  "daikon",
  "dandelion",
  "dates",
  "dragonfruit",
  "durian",
  "eggplant",
  "elderberry",
  "endive",
  "fennel",
  "feijoa",
  "fig",
  "fiddlehead",
  "garbanzo",
  "garlic",
  "ginger",
  "gooseberry",
  "grape",
  "grapefruit",
  "guava",
  "habanero",
  "honeydew",
  "horseradish",
  "iceberg",
  "jackfruit",
  "jalapeno",
  "jicama",
  "jostaberry",
  "jujube",
  "kabocha",
  "kale",
  "kiwi",
  "kohlrabi",
  "kumquat",
  "leek",
  "lemon",
  "lentil",
  "lettuce",
  "licorice",
  "lime",
  "lingonberry",
  "loquat",
  "luffa",
  "lychee",
  "maca",
  "mandarin",
  "mango",
  "mangosteen",
  "marrow",
  "melon",
  "mungbean",
  "mustard",
  "nectarine",
  "okra",
  "olive",
  "onion",
  "orange",
  "oregano",
  "papaya",
  "parsley",
  "parsnip",
  "passionfruit",
  "pea",
  "peach",
  "peanut",
  "pear",
  "pecan",
  "pepper",
  "persimmon",
  "pineapple",
  "pistachio",
  "plum",
  "pomegranate",
  "pomelo",
  "potato",
  "pumpkin",
  "quince",
  "radish",
  "rambutan",
  "rapini",
  "raspberry",
  "redcurrant",
  "rhubarb",
  "rocket",
  "rutabaga",
  "salsify",
  "sapote",
  "scallion",
  "shallot",
  "snappea",
  "sorrel",
  "soybean",
  "spinach",
  "spelt",
  "squash",
  "starfruit",
  "strawberry",
  "sweetcorn",
  "sweetpotato",
  "tamarillo",
  "tamarind",
  "tangelo",
  "tatsoi",
  "tomatillo",
  "tomato",
  "turnip",
  "ugli",
  "watercress",
  "watermelon",
  "waxgourd",
  "yam",
  "yuca",
  "ziziphus",
  "zucchini"
]

In [5]:
NAMES = ANIMALS + FRUITS_VEGGIES

# ----- トークン化 -----
ALL_CHARS = sorted(set("".join(NAMES)))
SPECIAL_TOKENS = ["<PAD>", "<BOS>", "<EOS>"]
ALL_TOKENS = SPECIAL_TOKENS + ALL_CHARS
VOCAB_SIZE = len(ALL_TOKENS)
CHAR2IDX = {ch: i for i, ch in enumerate(ALL_TOKENS)}
IDX2CHAR = {i: ch for ch, i in CHAR2IDX.items()}
PAD_IDX = CHAR2IDX["<PAD>"]
BOS_IDX = CHAR2IDX["<BOS>"]
EOS_IDX = CHAR2IDX["<EOS>"]

In [6]:
def encode_word(word, max_len):
    tokens = [BOS_IDX] + [CHAR2IDX[c] for c in word] + [EOS_IDX]
    tokens += [PAD_IDX] * (max_len - len(tokens))
    return tokens

def decode_tokens(tokens):
    chars = []
    for idx in tokens:
        if idx == EOS_IDX:
            break
        if idx >= len(IDX2CHAR):
            continue
        ch = IDX2CHAR[idx]
        if ch not in SPECIAL_TOKENS:
            chars.append(ch)
    return "".join(chars)

# ----- データセットクラス -----
class NameDataset(Dataset):
    def __init__(self, words, max_len):
        self.max_len = max_len
        self.data = []
        for w in words:
            tokens = encode_word(w, max_len)
            self.data.append(tokens)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        tokens = self.data[idx]
        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)
        return x, y

In [7]:
# ----- 位置エンコーディング -----
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model > 1:
            pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x):
        return x + self.pe[:x.size(1)]

# ----- MiniGPT2本体（Pre-LN） -----
class MiniGPT2(nn.Module):
    def __init__(self, vocab_size, d_model=32, max_len=16):
        super().__init__()
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.q_linear = nn.Linear(d_model, d_model, bias=False)
        self.k_linear = nn.Linear(d_model, d_model, bias=False)
        self.v_linear = nn.Linear(d_model, d_model, bias=False)
        self.attn_out = nn.Linear(d_model, d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model, bias=False),
            nn.ReLU(),
            nn.Linear(d_model, d_model, bias=False)
        )
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.max_len = max_len
        self.attn_weights = None

    def forward(self, x, return_attn=False):
        emb = self.embed(x)
        emb = self.pos_enc(emb)
        # Pre-LN構造：LN→Attention→Add
        attn_in = self.ln1(emb)
        Q = self.q_linear(attn_in)
        K = self.k_linear(attn_in)
        V = self.v_linear(attn_in)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_model)
        mask = torch.triu(torch.ones(scores.size(-2), scores.size(-1)), diagonal=1).bool().to(x.device)
        scores = scores.masked_fill(mask, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn_out = torch.matmul(attn, V)
        attn_out = self.attn_out(attn_out)
        x1 = emb + attn_out
        # Pre-LN→FFN→Add
        ffn_in = self.ln2(x1)
        x2 = x1 + self.ffn(ffn_in)
        logits = torch.matmul(x2, self.embed.weight.t())
        if return_attn:
            self.attn_weights = attn.detach().cpu().numpy()
            return logits, attn
        return logits

    def generate(self, start_tokens, eos_idx, pad_idx, max_gen=None, temperature=1.0):
        self.eval()
        max_gen = max_gen or self.max_len
        tokens = start_tokens.tolist()
        for _ in range(max_gen - len(tokens)):
            inp = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(next(self.parameters()).device)
            logits = self.forward(inp)
            next_token_logits = logits[0, len(tokens)-1] / temperature
            probs = torch.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            if next_token == eos_idx:
                break
            tokens.append(next_token)
        while len(tokens) < self.max_len:
            tokens.append(pad_idx)
        return tokens

In [8]:
# ----- 訓練・評価 -----
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))
            total_loss += loss.item()
    return total_loss / len(loader)

In [9]:
%%time

# ----- 設定 -----
max_word_len = max(len(w) for w in NAMES) + 2 # BOS, EOS
batch_size = 16
d_model = 32
n_epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

# ----- データ分割 -----
random.seed(42)
random.shuffle(NAMES)
split = int(len(NAMES) * 0.8)
train_words = NAMES[:split]
test_words = NAMES[split:]

train_ds = NameDataset(train_words, max_word_len)
test_ds = NameDataset(test_words, max_word_len)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

# ----- モデル -----
model = MiniGPT2(VOCAB_SIZE, d_model, max_word_len).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# ----- 学習ループ -----
for epoch in range(1, n_epochs+1):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    test_loss = evaluate(model, test_loader, criterion, device)
    if epoch % 5 == 0:
        print(f"Epoch {epoch:2d}: train loss={train_loss:.4f}, test loss={test_loss:.4f}")

Epoch  5: train loss=3.2510, test loss=3.1007
Epoch 10: train loss=2.6080, test loss=2.7307
Epoch 15: train loss=2.5008, test loss=2.5772
Epoch 20: train loss=2.4259, test loss=2.5658
Epoch 25: train loss=2.3869, test loss=2.5829
Epoch 30: train loss=2.3695, test loss=2.5618
Epoch 35: train loss=2.3564, test loss=2.5222
Epoch 40: train loss=2.3264, test loss=2.5279
Epoch 45: train loss=2.3214, test loss=2.4982
Epoch 50: train loss=2.2958, test loss=2.5043
Epoch 55: train loss=2.2881, test loss=2.5362
Epoch 60: train loss=2.2765, test loss=2.5286
Epoch 65: train loss=2.2799, test loss=2.5346
Epoch 70: train loss=2.2746, test loss=2.5118
Epoch 75: train loss=2.2542, test loss=2.5262
Epoch 80: train loss=2.2468, test loss=2.5146
Epoch 85: train loss=2.2388, test loss=2.4905
Epoch 90: train loss=2.2196, test loss=2.5354
Epoch 95: train loss=2.2125, test loss=2.4927
Epoch 100: train loss=2.1978, test loss=2.5093
CPU times: user 15.6 s, sys: 511 ms, total: 16.2 s
Wall time: 16.4 s


In [10]:
# ----- 生成例 -----
def sample_generate(prompt, model, max_word_len, temperature=1.0):
    start_tokens = [BOS_IDX] + [CHAR2IDX[c] for c in prompt]
    start_tokens = torch.tensor(start_tokens, dtype=torch.long)
    out_tokens = model.generate(start_tokens, eos_idx=EOS_IDX, pad_idx=PAD_IDX, max_gen=max_word_len, temperature=temperature)
    return decode_tokens(out_tokens[1:])  # BOSを除く

# ↓ 例： "ca" で始まる動物名っぽいものを生成
print("生成例:", sample_generate("ca", model, max_word_len, temperature=0.8))

生成例: categot


In [11]:
print("生成例:", sample_generate("do", model, max_word_len, temperature=0.8))

生成例: docketep


In [12]:
print("生成例:", sample_generate("app", model, max_word_len, temperature=0.8))

生成例: appincho


In [13]:
print("生成例:", sample_generate("sh", model, max_word_len, temperature=0.8))

生成例: she


In [14]:
print("生成例:", sample_generate("ca", model, max_word_len, temperature=0.8))

生成例: capalouis


In [15]:
print("生成例:", sample_generate("ca", model, max_word_len, temperature=0.8))

生成例: cargpen
