In [None]:
!pip install -qU bertviz

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/157.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m157.5/157.5 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/139.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.3/139.3 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m83.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m57.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.3/85.3 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from __future__ import annotations

import json
import random
import string
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

In [None]:
ANIMALS = [
  "cat",
  "caracal",
  "capybara",
  "canary",
  "cavy",
  "caiman",
  "cacomistle",
  "caribou",
  "cassowary",
  "caterpillar",
  "dog",
  "dalmatian",
  "dachshund",
  "doberman",
  "duck",
  "dingo",
  "lion",
  "tiger",
  "leopard",
  "cheetah",
  "puma",
  "jaguar",
  "lynx",
  "ocelot",
  "serval",
  "bobcat",
  "cougar",
  "panther",
  "wolf",
  "fox",
  "jackal",
  "coyote",
  "hyena",
  "bear",
  "polar",
  "grizzly",
  "slothbear",
  "panda",
  "koala",
  "kangaroo",
  "wallaby",
  "opossum",
  "wombat",
  "tasmanian",
  "rabbit",
  "hare",
  "mouse",
  "rat",
  "gerbil",
  "hamster",
  "guinea",
  "squirrel",
  "chipmunk",
  "beaver",
  "porcupine",
  "hedgehog",
  "shrew",
  "mole",
  "bat",
  "armadillo",
  "antelope",
  "gazelle",
  "impala",
  "gnu",
  "eland",
  "springbok",
  "deer",
  "moose",
  "elk",
  "reindeer",
  "stag",
  "doe",
  "fawn",
  "buffalo",
  "bison",
  "yak",
  "zebu",
  "cow",
  "bull",
  "ox",
  "calf",
  "sheep",
  "lamb",
  "ram",
  "goat",
  "kid",
  "ibex",
  "chamois",
  "camel",
  "dromedary",
  "llama",
  "alpaca",
  "vicuna",
  "horse",
  "mare",
  "stallion",
  "colt",
  "foal",
  "donkey",
  "mule",
  "zebra",
  "rhinoceros",
  "hippopotamus",
  "pig",
  "boar",
  "hog",
  "swine",
  "babirusa",
  "tapir",
  "elephant",
  "mammoth",
  "mastodon",
  "dugong",
  "manatee",
  "whale",
  "dolphin",
  "porpoise",
  "seal",
  "seaotter",
  "walrus",
  "otter",
  "weasel",
  "ferret",
  "marten",
  "ermine",
  "badger",
  "skunk",
  "wolverine",
  "mongoose",
  "meerkat",
  "civet",
  "genet",
  "fossa",
  "bearcat",
  "platypus",
  "echidna",
  "pangolin",
  "aardvark",
  "aardwolf",
  "okapi",
  "giraffe",
  "monkey",
  "baboon",
  "mandrill",
  "macaque",
  "langur",
  "gibbon",
  "gorilla",
  "chimpanzee",
  "bonobo",
  "orangutan",
  "lemur",
  "tarsier",
  "loris",
  "ayeaye",
  "sloth",
  "anteater",
  "tamandua",
  "kitten",
  "puppy",
  "duckling",
  "gosling",
  "cygnet",
  "eagle",
  "hawk",
  "falcon",
  "osprey",
  "vulture",
  "buzzard",
  "kite",
  "owl",
  "barnowl",
  "tawnyowl",
  "screechowl",
  "snowyowl",
  "parrot",
  "macaw",
  "cockatoo",
  "budgerigar",
  "lovebird",
  "lorikeet",
  "conure",
  "parakeet",
  "kingfisher",
  "woodpecker",
  "toucan",
  "hornbill",
  "cuckoo",
  "cuckooshrike",
  "nightjar",
  "swift",
  "hummingbird",
  "swallow",
  "martin",
  "wren",
  "warbler",
  "thrush",
  "blackbird",
  "starling",
  "mockingbird",
  "finch",
  "canary",
  "sparrow",
  "bunting",
  "lark",
  "pipit",
  "wagtail",
  "robin",
  "chat",
  "wheatear",
  "dipper",
  "nuthatch",
  "treecreeper",
  "tit",
  "chickadee",
  "jay",
  "magpie",
  "crow",
  "raven",
  "rook",
  "jackdaw",
  "chough",
  "shrike",
  "oriole",
  "drongo",
  "bulbul",
  "mina",
  "weaver",
  "whydah",
  "waxbill",
  "munia",
  "manakin",
  "cotinga",
  "antbird",
  "ovenbird",
  "woodcreeper",
  "flycatcher",
  "tyrant",
  "pewee",
  "kingbird",
  "boatbill",
  "motmot",
  "tody",
  "jacamar",
  "puffbird",
  "barbet",
  "toucanet",
  "ani",
  "turaco",
  "hoatzin",
  "bustard",
  "crane",
  "heron",
  "egret",
  "bittern",
  "stork",
  "ibis",
  "spoonbill",
  "flamingo",
  "swan",
  "goose",
  "teal",
  "wigeon",
  "shoveler",
  "pintail",
  "scaup",
  "pochard",
  "canvasback",
  "redhead",
  "goldeneye",
  "merganser",
  "eider",
  "scoter",
  "shelduck",
  "woodduck",
  "mandarin",
  "mallard",
  "gadwall",
  "grebe",
  "loon",
  "penguin",
  "albatross",
  "petrel",
  "shearwater",
  "prion",
  "stormpetrel",
  "fulmar",
  "gannet",
  "booby",
  "cormorant",
  "shag",
  "anhinga",
  "frigatebird",
  "tropicbird",
  "pelican",
  "darter",
  "gull",
  "tern",
  "skimmer",
  "auk",
  "murre",
  "puffin",
  "guillemot",
  "razorbill",
  "dovekie",
  "murrelet",
  "kiwi",
  "emu",
  "rhea",
  "ostrich",
  "tinamou",
  "rail",
  "crake",
  "gallinule",
  "coot",
  "limpkin",
  "buttonquail",
  "plover",
  "lapwing",
  "dotterel",
  "killdeer",
  "oystercatcher",
  "avocet",
  "stilt",
  "phalarope",
  "jacana",
  "sandpiper",
  "snipe",
  "curlew",
  "godwit",
  "dowitcher",
  "stint",
  "ruff",
  "turnstone",
  "knot",
  "pratincole",
  "courser",
  "skua",
  "jaeger",
  "eel",
  "salmon",
  "trout",
  "carp",
  "catfish",
  "cobia",
  "cod",
  "coelacanth",
  "flounder",
  "goby",
  "grouper",
  "guppy",
  "haddock",
  "hake",
  "halibut",
  "koi",
  "mackerel",
  "minnow",
  "perch",
  "pike",
  "pollock",
  "sardine",
  "shad",
  "smelt",
  "snapper",
  "sole",
  "sturgeon",
  "tilapia",
  "tuna",
  "wahoo",
  "zander",
  "anchovy",
  "barracuda",
  "bass",
  "blenny",
  "bluegill",
  "bonito",
  "bream",
  "butterfish",
  "capelin",
  "char",
  "clownfish",
  "drum",
  "grunion",
  "herring",
  "killifish",
  "lamprey",
  "lionfish",
  "loach",
  "molly",
  "mudskipper",
  "needlefish",
  "parrotfish",
  "pompano",
  "scad",
  "sculpin",
  "seahorse",
  "shark",
  "skate",
  "sprat",
  "sucker",
  "sunfish",
  "surgeonfish",
  "tarpon",
  "tetra",
  "trevally",
  "triggerfish",
  "wrasse",
  "tang",
  "abalone",
  "barnacle",
  "clam",
  "cockle",
  "conch",
  "crab",
  "crawfish",
  "krill",
  "limpet",
  "lobster",
  "mussel",
  "nautilus",
  "oyster",
  "periwinkle",
  "prawn",
  "scallop",
  "shrimp",
  "snail",
  "squid",
  "octopus",
  "urchin",
  "worm",
  "beetle",
  "butterfly",
  "caterpillar",
  "dragonfly",
  "earwig",
  "firefly",
  "flea",
  "grasshopper",
  "ladybug",
  "mantis",
  "moth",
  "termite",
  "tick",
  "wasp",
  "weevil",
  "aphid",
  "ant",
  "bee",
  "bug",
  "cricket",
  "damselfly",
  "fly",
  "gnat",
  "hornet",
  "mayfly",
  "mosquito",
  "silverfish",
  "spider",
  "centipede",
  "millipede",
  "scorpion",
  "copepod",
  "isopod",
  "amphipod",
  "woodlouse",
  "horseshoecrab",
  "arachnid",
  "mite",
  "tarantula",
  "fruitfly"
]

In [None]:
FRUITS_VEGGIES = [
  "apple",
  "apricot",
  "avocado",
  "artichoke",
  "banana",
  "bilberry",
  "blackberry",
  "blueberry",
  "boysenberry",
  "breadfruit",
  "cantaloupe",
  "casaba",
  "carambola",
  "cherimoya",
  "cherry",
  "cloudberry",
  "coconut",
  "cranberry",
  "currant",
  "date",
  "elderberry",
  "fig",
  "gooseberry",
  "grape",
  "grapefruit",
  "guava",
  "honeydew",
  "jackfruit",
  "jambul",
  "jujube",
  "kiwi",
  "kumquat",
  "lemon",
  "lime",
  "loquat",
  "lychee",
  "mandarin",
  "mango",
  "mangosteen",
  "melon",
  "mulberry",
  "nectarine",
  "olive",
  "orange",
  "papaya",
  "passionfruit",
  "peach",
  "pear",
  "persimmon",
  "pineapple",
  "plum",
  "pomegranate",
  "pomelo",
  "quince",
  "raspberry",
  "redcurrant",
  "salak",
  "satsuma",
  "starfruit",
  "strawberry",
  "tamarillo",
  "tamarind",
  "tangelo",
  "ugli",
  "watermelon",
  "yuzu",
  "zucchini",
  "carrot",
  "cabbage",
  "cauliflower",
  "cassava",
  "celery",
  "chard",
  "chicory",
  "collard",
  "corn",
  "cress",
  "cucumber",
  "daikon",
  "edamame",
  "eggplant",
  "endive",
  "fennel",
  "garlic",
  "ginger",
  "horseradish",
  "jicama",
  "kale",
  "kohlrabi",
  "leek",
  "lettuce",
  "okra",
  "onion",
  "parsnip",
  "pea",
  "pepper",
  "potato",
  "pumpkin",
  "radish",
  "rutabaga",
  "shallot",
  "spinach",
  "squash",
  "sweetcorn",
  "sweetpotato",
  "tomato",
  "turnip",
  "wasabi",
  "yam",
  "macadamia",
  "pecan",
  "cashew",
  "hazelnut",
  "walnut",
  "almond",
  "brazilnut",
  "chestnut",
  "pistachio",
  "pine",
  "acorn",
  "watercress",
  "caper",
  "cardoon",
  "canna",
  "caraway",
  "carob",
  "camu",
  "camote",
  "canistel",
  "canola",
  "capers",
  "carissa",
  "catjang",
  "cavendish",
  "cayenne",
  "celeriac",
  "chayote",
  "cilantro",
  "clementine",
  "cornsalad",
  "courgette",
  "currant",
  "cushaw",
  "dandelion",
  "dill",
  "durian",
  "endive",
  "escarole",
  "fiddlehead",
  "frisee",
  "gourd",
  "jostaberry",
  "kohlrabi",
  "lablab",
  "luffa",
  "malanga",
  "mangetout",
  "mungbean",
  "navybean",
  "nopale",
  "onionchive",
  "parsley",
  "parsnip",
  "pattypan",
  "peasnap",
  "persimmon",
  "pigeonpea",
  "plantain",
  "pluot",
  "pomegranate",
  "prune",
  "pumpkin",
  "radicchio",
  "rambutan",
  "rapini",
  "rocket",
  "rutabaga",
  "salsify",
  "sapote",
  "scallion",
  "shallot",
  "snowpea",
  "sorrel",
  "soybean",
  "spelt",
  "squash",
  "tamarind",
  "tangelo",
  "tatsoi",
  "tomatillo",
  "tuber",
  "turnip",
  "waterchestnut",
  "watermelon",
  "waxgourd",
  "yambean",
  "yautia",
  "yuca",
  "ziziphus",
  "zucchini",
  "acerola",
  "ackee",
  "ambarella",
  "arugula",
  "asparagus",
  "azuki",
  "bamboo",
  "basil",
  "bean",
  "beet",
  "bellpepper",
  "betel",
  "bokchoy",
  "broccoli",
  "broccolini",
  "brusselsprout",
  "burdock",
  "butternut",
  "calabash",
  "calamansi",
  "canarymelon",
  "cantaloupe",
  "capuli",
  "carambola",
  "carrot",
  "cassava",
  "cauliflower",
  "celery",
  "chamomile",
  "cherry",
  "chickpea",
  "chicory",
  "chives",
  "cilantro",
  "citrus",
  "collards",
  "coriander",
  "courgette",
  "cranberry",
  "cress",
  "cucumber",
  "currant",
  "daikon",
  "dandelion",
  "dates",
  "dragonfruit",
  "durian",
  "eggplant",
  "elderberry",
  "endive",
  "fennel",
  "feijoa",
  "fig",
  "fiddlehead",
  "garbanzo",
  "garlic",
  "ginger",
  "gooseberry",
  "grape",
  "grapefruit",
  "guava",
  "habanero",
  "honeydew",
  "horseradish",
  "iceberg",
  "jackfruit",
  "jalapeno",
  "jicama",
  "jostaberry",
  "jujube",
  "kabocha",
  "kale",
  "kiwi",
  "kohlrabi",
  "kumquat",
  "leek",
  "lemon",
  "lentil",
  "lettuce",
  "licorice",
  "lime",
  "lingonberry",
  "loquat",
  "luffa",
  "lychee",
  "maca",
  "mandarin",
  "mango",
  "mangosteen",
  "marrow",
  "melon",
  "mungbean",
  "mustard",
  "nectarine",
  "okra",
  "olive",
  "onion",
  "orange",
  "oregano",
  "papaya",
  "parsley",
  "parsnip",
  "passionfruit",
  "pea",
  "peach",
  "peanut",
  "pear",
  "pecan",
  "pepper",
  "persimmon",
  "pineapple",
  "pistachio",
  "plum",
  "pomegranate",
  "pomelo",
  "potato",
  "pumpkin",
  "quince",
  "radish",
  "rambutan",
  "rapini",
  "raspberry",
  "redcurrant",
  "rhubarb",
  "rocket",
  "rutabaga",
  "salsify",
  "sapote",
  "scallion",
  "shallot",
  "snappea",
  "sorrel",
  "soybean",
  "spinach",
  "spelt",
  "squash",
  "starfruit",
  "strawberry",
  "sweetcorn",
  "sweetpotato",
  "tamarillo",
  "tamarind",
  "tangelo",
  "tatsoi",
  "tomatillo",
  "tomato",
  "turnip",
  "ugli",
  "watercress",
  "watermelon",
  "waxgourd",
  "yam",
  "yuca",
  "ziziphus",
  "zucchini"
]

In [None]:
# 1. データセット用：動物名+果物・野菜名で計1000種弱
NAMES = ANIMALS + FRUITS_VEGGIES

# 2. 文字のボキャブラリ作成
ALL_CHARS = sorted(set("".join(NAMES)))
SPECIAL_TOKENS = ["<PAD>", "<BOS>", "<EOS>"]
ALL_TOKENS = SPECIAL_TOKENS + ALL_CHARS
VOCAB_SIZE = len(ALL_TOKENS)
CHAR2IDX = {ch: i for i, ch in enumerate(ALL_TOKENS)}
IDX2CHAR = {i: ch for ch, i in CHAR2IDX.items()}
PAD_IDX = CHAR2IDX["<PAD>"]
BOS_IDX = CHAR2IDX["<BOS>"]
EOS_IDX = CHAR2IDX["<EOS>"]

In [None]:
def encode_word(word, max_len):
    tokens = [BOS_IDX] + [CHAR2IDX[c] for c in word] + [EOS_IDX]
    tokens += [PAD_IDX] * (max_len - len(tokens))
    return tokens

def decode_tokens(tokens):
    chars = []
    for idx in tokens:
        if idx == EOS_IDX:
            break
        if idx >= len(IDX2CHAR):
            continue
        ch = IDX2CHAR[idx]
        if ch not in SPECIAL_TOKENS:
            chars.append(ch)
    return "".join(chars)

# 3. PyTorch Dataset
class NameDataset(Dataset):
    def __init__(self, words, max_len):
        self.max_len = max_len
        self.data = []
        for w in words:
            tokens = encode_word(w, max_len)
            self.data.append(tokens)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        tokens = self.data[idx]
        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)
        return x, y

In [None]:
# 4. シンプルな位置エンコーディング
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model > 1:
            pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def forward(self, x):
        return x + self.pe[:x.size(1)]

# 5. miniformer本体（シングルヘッド、1層）
class MiniFormer(nn.Module):
    def __init__(self, vocab_size, d_model=32, max_len=16):
        super().__init__()
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        # シングルヘッドAttention
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.attn_out = nn.Linear(d_model, d_model)
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.max_len = max_len
        self.attn_weights = None  # for visualization

    def forward(self, x, return_attn=False):
        emb = self.embed(x)
        emb = self.pos_enc(emb)
        # Attention
        Q = self.q_linear(emb)
        K = self.k_linear(emb)
        V = self.v_linear(emb)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_model)
        # causal mask
        mask = torch.triu(torch.ones(scores.size(-2), scores.size(-1)), diagonal=1).bool().to(x.device)
        scores = scores.masked_fill(mask, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        attn_out = torch.matmul(attn, V)
        attn_out = self.attn_out(attn_out)
        x1 = self.ln1(emb + attn_out)
        x2 = self.ln2(x1 + self.ffn(x1))
        logits = self.fc_out(x2)
        if return_attn:
            self.attn_weights = attn.detach().cpu().numpy()
            return logits, attn
        return logits

In [None]:
# 6. 訓練ループ
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            logits = model(x)
            loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))
            total_loss += loss.item()
    return total_loss / len(loader)

# 7. 可視化用attention保存関数 (BertViz 形式に近いJSON)
def save_attention(attn_matrix, input_tokens, filename="attn_weights.json"):
    # attn_matrix: [seq_len, seq_len]
    data = {
        "tokens": input_tokens,
        "attentions": attn_matrix.tolist()
    }
    with open(filename, "w") as f:
        json.dump(data, f, indent=2)

In [None]:
%%time

# 設定
max_word_len = max(len(w) for w in NAMES) + 2 # BOS, EOS
batch_size = 16
d_model = 32
#n_epochs = 30
n_epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

# データ分割
random.seed(42)
random.shuffle(NAMES)
split = int(len(NAMES) * 0.8)
train_words = NAMES[:split]
test_words = NAMES[split:]

train_ds = NameDataset(train_words, max_word_len)
test_ds = NameDataset(test_words, max_word_len)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

# モデル
model = MiniFormer(VOCAB_SIZE, d_model, max_word_len).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# 学習
for epoch in range(1, n_epochs+1):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    test_loss = evaluate(model, test_loader, criterion, device)
    if epoch % 5 == 0:
        print(f"Epoch {epoch:2d}: train loss={train_loss:.4f}, test loss={test_loss:.4f}")

Epoch  5: train loss=2.5863, test loss=2.5995
Epoch 10: train loss=2.4199, test loss=2.5167
Epoch 15: train loss=2.3452, test loss=2.4715
Epoch 20: train loss=2.2984, test loss=2.4463
Epoch 25: train loss=2.2558, test loss=2.4402
Epoch 30: train loss=2.2233, test loss=2.4276
Epoch 35: train loss=2.1984, test loss=2.4278
Epoch 40: train loss=2.1648, test loss=2.4174
Epoch 45: train loss=2.1425, test loss=2.4286
Epoch 50: train loss=2.1084, test loss=2.4406
Epoch 55: train loss=2.0990, test loss=2.4260
Epoch 60: train loss=2.0721, test loss=2.4203
Epoch 65: train loss=2.0585, test loss=2.4426
Epoch 70: train loss=2.0401, test loss=2.4364
Epoch 75: train loss=2.0262, test loss=2.4351
Epoch 80: train loss=2.0046, test loss=2.4198
Epoch 85: train loss=1.9933, test loss=2.4377
Epoch 90: train loss=1.9835, test loss=2.4383
Epoch 95: train loss=1.9606, test loss=2.4368
Epoch 100: train loss=1.9568, test loss=2.4564
CPU times: user 21.4 s, sys: 765 ms, total: 22.2 s
Wall time: 24.3 s


In [None]:
sample_word = "flycatcher"
x = torch.tensor([encode_word(sample_word, max_word_len)[:-1]], dtype=torch.long).to(device)
model.eval()
with torch.no_grad():
    logits, attn = model(x, return_attn=True)
    pred_indices = logits.argmax(dim=-1)[0].cpu().numpy()
    print(f"Input: {sample_word}")
    print(f"Predicted: {decode_tokens(pred_indices)}")
    # 可視化用attention保存
    input_tokens = [IDX2CHAR[idx] for idx in x[0].cpu().numpy()]
    save_attention(attn[0], input_tokens, filename="attn_weights.json")
    print("Saved attention weights to attn_weights.json. You can visualize with BertViz or any custom tool.")

Input: flycatcher
Predicted: clo
Saved attention weights to attn_weights.json. You can visualize with BertViz or any custom tool.


In [None]:
import json
import torch
from bertviz import head_view

# attn_weights.jsonを読み込む
with open('attn_weights.json') as f:
    data = json.load(f)

tokens = data['tokens']  # トークン列
attn = data['attentions']  # (seq_len, seq_len) のリスト

# BertViz用に次元調整
attn_tensor = torch.tensor(attn).unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)

# BertVizで可視化
head_view(attention=[attn_tensor], tokens=tokens)

<IPython.core.display.Javascript object>

## Hugging Face transformers

In [None]:
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


# 2. Transformer Decoder（self-attnもcross-attnも返す）
class SimpleTransformerDecoder(nn.Module):
    def __init__(self, d_model=64, max_len=16, nhead=4, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, d_model)
        self.pos_enc = nn.Parameter(self._init_pe(max_len, d_model), requires_grad=False)
        self.layers = nn.ModuleList([
            nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward=128, batch_first=True)
            for _ in range(num_layers)
        ])
        self.max_len = max_len
        self.d_model = d_model
        self.num_layers = num_layers
        self.nhead = nhead

    def _init_pe(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model > 1:
            pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)

    def forward(self, tgt, memory, tgt_key_padding_mask=None, memory_key_padding_mask=None, return_attn=False):
        tgt_emb = self.embedding(tgt) + self.pos_enc[:, :tgt.size(1), :]
        self_attn_weights_layers = []
        cross_attn_weights_layers = []
        output = tgt_emb
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        for layer in self.layers:
            # Self-attn
            tgt2, self_attn_weights = layer.self_attn(
                output, output, output,
                attn_mask=tgt_mask,
                key_padding_mask=tgt_key_padding_mask,
                need_weights=True,
                average_attn_weights=False
            )
            output = output + layer.dropout1(tgt2)
            output = layer.norm1(output)
            # Cross-attn
            tgt2, cross_attn_weights = layer.multihead_attn(
                output, memory, memory,
                key_padding_mask=memory_key_padding_mask,
                need_weights=True,
                average_attn_weights=False
            )
            output = output + layer.dropout2(tgt2)
            output = layer.norm2(output)
            # FFN
            tgt2 = layer.linear2(layer.dropout(layer.activation(layer.linear1(output))))
            output = output + layer.dropout3(tgt2)
            output = layer.norm3(output)
            if return_attn:
                self_attn_weights_layers.append(self_attn_weights.detach().cpu())
                cross_attn_weights_layers.append(cross_attn_weights.detach().cpu())
        if return_attn:
            return output, self_attn_weights_layers, cross_attn_weights_layers
        else:
            return output

# 3. フルTransformer: Encoder + カスタムDecoder
class FullTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, max_len=16, nhead=4, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_enc = nn.Parameter(self._init_pe(max_len, d_model), requires_grad=False)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=128, batch_first=True),
            num_layers=num_layers
        )
        self.decoder = SimpleTransformerDecoder(d_model, max_len, nhead, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.max_len = max_len
        self.d_model = d_model

    def _init_pe(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model > 1:
            pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None, return_attn=False):
        src_emb = self.embedding(src) + self.pos_enc[:, :src.size(1), :]
        memory = self.encoder(src_emb, src_key_padding_mask=src_key_padding_mask)
        if return_attn:
            dec_out, self_attn_layers, cross_attn_layers = self.decoder(
                tgt, memory,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_key_padding_mask,
                return_attn=True
            )
            logits = self.fc_out(dec_out)
            return logits, self_attn_layers, cross_attn_layers
        else:
            dec_out = self.decoder(
                tgt, memory,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=src_key_padding_mask,
                return_attn=False
            )
            logits = self.fc_out(dec_out)
            return logits

In [None]:
# 4. 訓練ループ
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        src_key_padding_mask = (x == PAD_IDX)
        tgt_key_padding_mask = (x == PAD_IDX)
        optimizer.zero_grad()
        logits = model(x, x, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask)
        loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            src_key_padding_mask = (x == PAD_IDX)
            tgt_key_padding_mask = (x == PAD_IDX)
            logits = model(x, x, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask)
            loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))
            total_loss += loss.item()
    return total_loss / len(loader)

def save_attention_bertviz(attn_layers, tokens, filename="self_attn_bertviz.json"):
    """
    attn_layers: [num_layers][nhead, tgt_len, tgt_len]
    tokens: トークン列
    """
    all_layers = []
    for layer in attn_layers:
        layer_heads = []
        for head in layer:
            if isinstance(head, torch.Tensor):
                head = head.cpu().numpy()
            layer_heads.append(head.tolist())
        all_layers.append(layer_heads)
    data = {
        "all": all_layers,
        "tokens": tokens
    }
    with open(filename, "w") as f:
        json.dump(data, f, indent=2)
    print(f"Saved self-attention to {filename} (bertviz format, with tokens)")

In [None]:
%%time

max_word_len = max(len(w) for w in NAMES) + 2
batch_size = 16
d_model = 64
n_epochs = 100
device = "cuda" if torch.cuda.is_available() else "cpu"

random.shuffle(NAMES)
split = int(len(NAMES) * 0.8)
train_words = NAMES[:split]
test_words = NAMES[split:]

train_ds = NameDataset(train_words, max_word_len)
test_ds = NameDataset(test_words, max_word_len)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

model = FullTransformer(VOCAB_SIZE, d_model, max_word_len, nhead=4, num_layers=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

for epoch in range(1, n_epochs + 1):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    test_loss = evaluate(model, test_loader, criterion, device)
    if epoch % 5 == 0:
        print(f"Epoch {epoch:2d}: train loss={train_loss:.4f}, test loss={test_loss:.4f}")

# ---- 正解系列（teacher forcing）でセルフアテンション保存 ----
sample_word = "flycatcher"
x = torch.tensor([encode_word(sample_word, max_word_len)[:-1]], dtype=torch.long).to(device)
tgt = torch.tensor([encode_word(sample_word, max_word_len)[:-1]], dtype=torch.long).to(device)  # 正解系列
model.eval()
with torch.no_grad():
    logits, self_attn_layers, cross_attn_layers = model(
        x, tgt,
        src_key_padding_mask=(x == PAD_IDX),
        tgt_key_padding_mask=(tgt == PAD_IDX),
        return_attn=True
    )
    # self_attn_layers: [num_layers][batch, nhead, tgt_len, tgt_len]
    attn_layers = [
        [self_attn_layers[l][0, h].cpu().numpy() for h in range(self_attn_layers[l].shape[1])]
        for l in range(len(self_attn_layers))
    ]
    tokens = [IDX2CHAR[idx] for idx in tgt[0].cpu().numpy()]
    save_attention_bertviz(attn_layers, tokens, filename="self_attn_bertviz.json")
    print("Saved self-attention weights to self_attn_bertviz.json. You can visualize with BertViz.")

  output = torch._nested_tensor_from_mask(


Epoch  5: train loss=0.4274, test loss=0.3032
Epoch 10: train loss=0.1081, test loss=0.0877
Epoch 15: train loss=0.0651, test loss=0.0935
Epoch 20: train loss=0.0320, test loss=0.0330
Epoch 25: train loss=0.0292, test loss=0.0297
Epoch 30: train loss=0.0149, test loss=0.0267
Epoch 35: train loss=0.0275, test loss=0.0224
Epoch 40: train loss=0.0177, test loss=0.0463
Epoch 45: train loss=0.0081, test loss=0.0313
Epoch 50: train loss=0.0247, test loss=0.0218
Epoch 55: train loss=0.0088, test loss=0.0149
Epoch 60: train loss=0.0454, test loss=0.0698
Epoch 65: train loss=0.0185, test loss=0.0278
Epoch 70: train loss=0.0152, test loss=0.0160
Epoch 75: train loss=0.0108, test loss=0.0421
Epoch 80: train loss=0.0263, test loss=0.0218
Epoch 85: train loss=0.0042, test loss=0.0340
Epoch 90: train loss=0.0121, test loss=0.0033
Epoch 95: train loss=0.0133, test loss=0.0253
Epoch 100: train loss=0.0101, test loss=0.0208
Saved self-attention to self_attn_bertviz.json (bertviz format, with tokens)
Sa

In [None]:
import numpy as np
from bertviz import head_view

with open('self_attn_bertviz.json') as f:
    data = json.load(f)
attn_all = np.array(data['all'])  # (num_layers, num_heads, seq_len, seq_len)
tokens = data['tokens']

# ★ここを修正！
attention = [torch.tensor(attn_all[i]).unsqueeze(0) for i in range(attn_all.shape[0])]

print(f"attn_all.shape: {attn_all.shape}, tokens: {tokens}")
head_view(attention=attention, tokens=tokens)

attn_all.shape: (2, 4, 14, 14), tokens: ['<BOS>', 'f', 'l', 'y', 'c', 'a', 't', 'c', 'h', 'e', 'r', '<EOS>', '<PAD>', '<PAD>']


<IPython.core.display.Javascript object>