# Transformer をゼロから実装

Google Colab で GPU を使って学習します。

**GPU を有効にする:** ランタイム → ランタイムのタイプを変更 → GPU

In [None]:
# GPU確認
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Self-Attention の実装

In [None]:
import torch
import torch.nn as nn
import math


class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention
    
    Q (Query):  「何を探しているか」
    K (Key):    「何を持っているか」
    V (Value):  「実際の情報」
    
    Attention(Q, K, V) = softmax(QK^T / √d) × V
    """

    def __init__(self, embed_size: int, num_heads: int):
        super().__init__()
        assert embed_size % num_heads == 0

        self.embed_size = embed_size
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads

        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        self.out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim)

        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))

        attention_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)

        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_size)
        return self.out(output)

## 2. Transformer モデル

In [None]:
class PositionalEncoding(nn.Module):
    """位置エンコーディング: 順番情報を追加"""

    def __init__(self, embed_size: int, max_len: int = 5000):
        super().__init__()

        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len, :]


class FeedForward(nn.Module):
    """Feed Forward Network"""

    def __init__(self, embed_size: int, hidden_size: int = None):
        super().__init__()
        hidden_size = hidden_size or embed_size * 4
        self.net = nn.Sequential(
            nn.Linear(embed_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, embed_size),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    """Attention → Add & Norm → FeedForward → Add & Norm"""

    def __init__(self, embed_size: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadAttention(embed_size, num_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.ff = FeedForward(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out = self.attention(x, mask)
        x = self.norm1(x + self.dropout(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.dropout(ff_out))
        return x


class Transformer(nn.Module):
    """文章生成用 Transformer (GPT風)"""

    def __init__(
        self,
        vocab_size: int,
        embed_size: int = 128,
        num_heads: int = 4,
        num_layers: int = 4,
        max_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.embed_size = embed_size
        self.token_embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = PositionalEncoding(embed_size, max_len)
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_size, num_heads, dropout)
            for _ in range(num_layers)
        ])
        self.ln_final = nn.LayerNorm(embed_size)
        self.output = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = self.token_embedding(x)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        for block in self.blocks:
            x = block(x, mask)
        x = self.ln_final(x)
        return self.output(x)

    @staticmethod
    def create_causal_mask(seq_len: int, device):
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
        return mask.unsqueeze(0).unsqueeze(0)

## 3. データセット

In [None]:
from torch.utils.data import Dataset, DataLoader


class TextDataset(Dataset):
    def __init__(self, text: str, seq_length: int = 128):
        self.seq_length = seq_length
        self.chars = sorted(set(text))
        self.vocab_size = len(self.chars)
        self.char_to_idx = {ch: i for i, ch in enumerate(self.chars)}
        self.idx_to_char = {i: ch for i, ch in enumerate(self.chars)}
        self.data = torch.tensor([self.char_to_idx[ch] for ch in text])

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        x = self.data[idx : idx + self.seq_length]
        y = self.data[idx + 1 : idx + self.seq_length + 1]
        return x, y

    def encode(self, text: str) -> torch.Tensor:
        return torch.tensor([self.char_to_idx[ch] for ch in text])

    def decode(self, indices: torch.Tensor) -> str:
        return "".join(self.idx_to_char[i.item()] for i in indices)

## 4. 学習データの準備

シェイクスピアのテキストをダウンロード

In [None]:
import urllib.request

# シェイクスピア全集をダウンロード
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
urllib.request.urlretrieve(url, "shakespeare.txt")

with open("shakespeare.txt", "r") as f:
    text = f.read()

print(f"テキスト長: {len(text):,} 文字")
print("\n--- 冒頭 ---")
print(text[:500])

## 5. 学習

In [None]:
# ハイパーパラメータ
BATCH_SIZE = 64
SEQ_LENGTH = 128
EMBED_SIZE = 128
NUM_HEADS = 4
NUM_LAYERS = 4
EPOCHS = 10
LR = 3e-4

# デバイス
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"デバイス: {device}")

# データ
dataset = TextDataset(text, SEQ_LENGTH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print(f"語彙数: {dataset.vocab_size}")
print(f"バッチ数: {len(dataloader)}/エポック")

# モデル
model = Transformer(
    vocab_size=dataset.vocab_size,
    embed_size=EMBED_SIZE,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
).to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"パラメータ数: {num_params:,}")

In [None]:
# 学習ループ
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
    total_loss = 0
    for batch_idx, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)
        mask = Transformer.create_causal_mask(x.size(1), device)

        logits = model(x, mask)
        loss = criterion(logits.view(-1, dataset.vocab_size), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f"\rEpoch {epoch+1}: {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}", end="")

    avg_loss = total_loss / len(dataloader)
    print(f"\rEpoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.4f}")

## 6. 文章生成

In [None]:
def generate(model, dataset, start_text="The ", length=500, temperature=0.8):
    model.eval()
    tokens = dataset.encode(start_text).unsqueeze(0).to(device)
    generated = start_text

    with torch.no_grad():
        for _ in range(length):
            mask = Transformer.create_causal_mask(tokens.size(1), device)
            logits = model(tokens, mask)
            next_logits = logits[0, -1, :] / temperature
            probs = torch.softmax(next_logits, dim=0)
            next_token = torch.multinomial(probs, 1)
            generated += dataset.idx_to_char[next_token.item()]
            tokens = torch.cat([tokens, next_token.unsqueeze(0)], dim=1)
            if tokens.size(1) > 512:
                tokens = tokens[:, -512:]

    return generated


# 生成テスト
print(generate(model, dataset, "First Citizen:", length=500))

In [None]:
# 別の開始文字で試す
print(generate(model, dataset, "ROMEO:", length=500))