# 08: Transformerで様々なタスクを学習

このノートブックでは、Transformerを使って**様々なSequence-to-Sequenceタスク**を学習させます。

## タスク一覧
1. **コピータスク**: 入力をそのまま出力
2. **反転タスク**: 入力を逆順に出力
3. **ソートタスク**: 数字を昇順にソート
4. **加算タスク**: 2つの数字を足し算
5. **簡易翻訳タスク**: おもちゃの言語間翻訳

これらのタスクを通じて、Transformerがどのようにパターンを学習するかを理解します。

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from src.transformer import Transformer, count_parameters

# デバイスの設定
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f"Using device: {device}")

# 共通設定
PAD_IDX = 0
START_IDX = 1
END_IDX = 2

## 共通の学習関数

In [None]:
def train_model(model, data_generator, num_epochs=200, batch_size=64, lr=0.001, verbose=True):
    """
    モデルを学習する共通関数
    
    Args:
        model: Transformerモデル
        data_generator: (src, tgt_input, tgt_output)を返すジェネレータ関数
        num_epochs: エポック数
        batch_size: バッチサイズ
        lr: 学習率
        verbose: 進捗を表示するか
    
    Returns:
        losses, accuracies
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    
    losses = []
    accuracies = []
    
    model.train()
    iterator = range(num_epochs)
    if verbose:
        iterator = tqdm(iterator, desc="Training")
    
    for epoch in iterator:
        src, tgt_in, tgt_out = data_generator(batch_size)
        src = src.to(device)
        tgt_in = tgt_in.to(device)
        tgt_out = tgt_out.to(device)
        
        logits = model(src, tgt_in)
        
        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            tgt_out.reshape(-1)
        )
        
        predictions = logits.argmax(dim=-1)
        mask = tgt_out != PAD_IDX
        correct = ((predictions == tgt_out) & mask).sum().float() / mask.sum().float()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        accuracies.append(correct.item())
        
        if verbose:
            iterator.set_postfix(loss=f"{loss.item():.4f}", acc=f"{correct.item():.4f}")
    
    return losses, accuracies


def plot_training(losses, accuracies, title):
    """学習曲線を描画"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.plot(losses)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title(f'{title} - Loss')
    ax1.grid(True, alpha=0.3)
    
    ax2.plot(accuracies)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title(f'{title} - Accuracy')
    ax2.set_ylim(0, 1.05)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()


def test_model(model, data_generator, num_samples=5, max_len=20):
    """モデルをテストして結果を表示"""
    model.eval()
    src, tgt_in, tgt_out = data_generator(num_samples)
    src = src.to(device)
    
    with torch.no_grad():
        generated = model.greedy_decode(
            src,
            max_len=max_len,
            start_token_id=START_IDX,
            end_token_id=END_IDX,
        )
    
    results = []
    for i in range(num_samples):
        src_tokens = [t for t in src[i].tolist() if t > END_IDX]
        expected = [t for t in tgt_out[i].tolist() if t > END_IDX]
        gen_tokens = [t for t in generated[i].tolist() if t > END_IDX]
        
        match = src_tokens == gen_tokens or expected == gen_tokens
        results.append((src_tokens, expected, gen_tokens, match))
    
    return results

---
## 1. コピータスク

入力シーケンスをそのまま出力する最も基本的なタスク。

```
入力:  [3, 5, 7, 2, 9]
出力:  [3, 5, 7, 2, 9]
```

In [None]:
def generate_copy_data(batch_size, seq_len=8, vocab_size=20):
    """コピータスクのデータ生成"""
    tokens = torch.randint(3, vocab_size, (batch_size, seq_len))
    src = tokens.clone()
    tgt_input = torch.cat([torch.full((batch_size, 1), START_IDX), tokens], dim=1)
    tgt_output = torch.cat([tokens, torch.full((batch_size, 1), END_IDX)], dim=1)
    return src, tgt_input, tgt_output

# テストデータを表示
src, tgt_in, tgt_out = generate_copy_data(1)
print("Copy Task Example:")
print(f"  Source:        {src[0].tolist()}")
print(f"  Target Input:  {tgt_in[0].tolist()}")
print(f"  Target Output: {tgt_out[0].tolist()}")

In [None]:
# コピータスク用モデル
copy_model = Transformer(
    src_vocab_size=20, tgt_vocab_size=20,
    d_model=64, num_heads=4,
    num_encoder_layers=2, num_decoder_layers=2,
    d_ff=256,
).to(device)

print(f"Parameters: {count_parameters(copy_model):,}")

# 学習
losses, accs = train_model(copy_model, generate_copy_data, num_epochs=150)
plot_training(losses, accs, "Copy Task")

In [None]:
# テスト
print("Copy Task Test Results:")
print("=" * 50)
results = test_model(copy_model, generate_copy_data)
for src_t, exp_t, gen_t, match in results:
    mark = "✓" if match else "✗"
    print(f"{mark} Input:    {src_t}")
    print(f"  Output:   {gen_t}")
    print()

---
## 2. 反転タスク

入力シーケンスを逆順に出力するタスク。位置情報の理解が必要。

```
入力:  [3, 5, 7, 2, 9]
出力:  [9, 2, 7, 5, 3]
```

In [None]:
def generate_reverse_data(batch_size, seq_len=8, vocab_size=20):
    """反転タスクのデータ生成"""
    tokens = torch.randint(3, vocab_size, (batch_size, seq_len))
    src = tokens.clone()
    reversed_tokens = tokens.flip(dims=[1])  # 反転
    tgt_input = torch.cat([torch.full((batch_size, 1), START_IDX), reversed_tokens], dim=1)
    tgt_output = torch.cat([reversed_tokens, torch.full((batch_size, 1), END_IDX)], dim=1)
    return src, tgt_input, tgt_output

# テストデータを表示
src, tgt_in, tgt_out = generate_reverse_data(1)
print("Reverse Task Example:")
print(f"  Source:        {src[0].tolist()}")
print(f"  Target Input:  {tgt_in[0].tolist()}")
print(f"  Target Output: {tgt_out[0].tolist()}")

In [None]:
# 反転タスク用モデル
reverse_model = Transformer(
    src_vocab_size=20, tgt_vocab_size=20,
    d_model=64, num_heads=4,
    num_encoder_layers=2, num_decoder_layers=2,
    d_ff=256,
).to(device)

# 学習
losses, accs = train_model(reverse_model, generate_reverse_data, num_epochs=200)
plot_training(losses, accs, "Reverse Task")

In [None]:
# テスト
print("Reverse Task Test Results:")
print("=" * 50)
results = test_model(reverse_model, generate_reverse_data)
for src_t, exp_t, gen_t, match in results:
    mark = "✓" if match else "✗"
    print(f"{mark} Input:    {src_t}")
    print(f"  Expected: {exp_t}")
    print(f"  Output:   {gen_t}")
    print()

---
## 3. ソートタスク

入力の数字を昇順にソートして出力するタスク。より複雑なパターン認識が必要。

```
入力:  [7, 3, 9, 1, 5]
出力:  [1, 3, 5, 7, 9]
```

In [None]:
def generate_sort_data(batch_size, seq_len=6, vocab_size=15):
    """ソートタスクのデータ生成"""
    # 重複なしの数字を生成
    tokens_list = []
    for _ in range(batch_size):
        nums = torch.randperm(vocab_size - 3)[:seq_len] + 3  # 3以上の数字
        tokens_list.append(nums)
    tokens = torch.stack(tokens_list)
    
    src = tokens.clone()
    sorted_tokens, _ = tokens.sort(dim=1)  # ソート
    tgt_input = torch.cat([torch.full((batch_size, 1), START_IDX), sorted_tokens], dim=1)
    tgt_output = torch.cat([sorted_tokens, torch.full((batch_size, 1), END_IDX)], dim=1)
    return src, tgt_input, tgt_output

# テストデータを表示
src, tgt_in, tgt_out = generate_sort_data(1)
print("Sort Task Example:")
print(f"  Source:        {src[0].tolist()}")
print(f"  Target Input:  {tgt_in[0].tolist()}")
print(f"  Target Output: {tgt_out[0].tolist()}")

In [None]:
# ソートタスク用モデル（少し大きめ）
sort_model = Transformer(
    src_vocab_size=15, tgt_vocab_size=15,
    d_model=128, num_heads=4,
    num_encoder_layers=3, num_decoder_layers=3,
    d_ff=512,
).to(device)

print(f"Parameters: {count_parameters(sort_model):,}")

# 学習（ソートは難しいので多めのエポック）
losses, accs = train_model(sort_model, generate_sort_data, num_epochs=500, lr=0.0005)
plot_training(losses, accs, "Sort Task")

In [None]:
# テスト
print("Sort Task Test Results:")
print("=" * 50)
results = test_model(sort_model, generate_sort_data)
for src_t, exp_t, gen_t, match in results:
    mark = "✓" if match else "✗"
    print(f"{mark} Input:    {src_t}")
    print(f"  Expected: {exp_t}")
    print(f"  Output:   {gen_t}")
    print()

---
## 4. 加算タスク

2つの数字を足し算するタスク。数字の各桁をトークンとして扱います。

```
入力:  [1, 2, 3, +, 4, 5, 6]  (123 + 456)
出力:  [5, 7, 9]              (579)
```

In [None]:
# 加算タスク用の特殊トークン
# 0: PAD, 1: START, 2: END, 3: +記号, 4-13: 数字0-9
PLUS_IDX = 3
DIGIT_OFFSET = 4  # 数字0は4, 数字9は13

def num_to_tokens(n, min_digits=1):
    """数字をトークン列に変換"""
    digits = [int(d) + DIGIT_OFFSET for d in str(n)]
    while len(digits) < min_digits:
        digits.insert(0, DIGIT_OFFSET)  # 0でパディング
    return digits

def tokens_to_num(tokens):
    """トークン列を数字に変換"""
    digits = [t - DIGIT_OFFSET for t in tokens if DIGIT_OFFSET <= t <= DIGIT_OFFSET + 9]
    if not digits:
        return 0
    return int(''.join(map(str, digits)))

def generate_addition_data(batch_size, max_digits=3, vocab_size=14):
    """加算タスクのデータ生成"""
    src_list = []
    tgt_in_list = []
    tgt_out_list = []
    
    max_num = 10 ** max_digits - 1
    
    for _ in range(batch_size):
        a = torch.randint(0, max_num + 1, (1,)).item()
        b = torch.randint(0, max_num + 1, (1,)).item()
        result = a + b
        
        # ソース: a + b
        src_tokens = num_to_tokens(a, max_digits) + [PLUS_IDX] + num_to_tokens(b, max_digits)
        
        # ターゲット: result
        result_tokens = num_to_tokens(result, max_digits + 1)  # 桁上がり考慮
        
        src_list.append(torch.tensor(src_tokens))
        tgt_in_list.append(torch.tensor([START_IDX] + result_tokens))
        tgt_out_list.append(torch.tensor(result_tokens + [END_IDX]))
    
    # パディング
    src = nn.utils.rnn.pad_sequence(src_list, batch_first=True, padding_value=PAD_IDX)
    tgt_in = nn.utils.rnn.pad_sequence(tgt_in_list, batch_first=True, padding_value=PAD_IDX)
    tgt_out = nn.utils.rnn.pad_sequence(tgt_out_list, batch_first=True, padding_value=PAD_IDX)
    
    return src, tgt_in, tgt_out

# テストデータを表示
src, tgt_in, tgt_out = generate_addition_data(3, max_digits=2)
print("Addition Task Examples:")
for i in range(3):
    src_t = src[i].tolist()
    # +の位置を見つける
    plus_pos = src_t.index(PLUS_IDX)
    a = tokens_to_num(src_t[:plus_pos])
    b = tokens_to_num(src_t[plus_pos+1:])
    result = tokens_to_num([t for t in tgt_out[i].tolist() if t != END_IDX and t != PAD_IDX])
    print(f"  {a} + {b} = {result}")
    print(f"    Tokens: {src_t} -> {[t for t in tgt_out[i].tolist() if t != END_IDX and t != PAD_IDX]}")

In [None]:
# 加算タスク用モデル
add_model = Transformer(
    src_vocab_size=14, tgt_vocab_size=14,
    d_model=128, num_heads=4,
    num_encoder_layers=3, num_decoder_layers=3,
    d_ff=512,
).to(device)

print(f"Parameters: {count_parameters(add_model):,}")

# 2桁の加算で学習
def gen_add_2digit(batch_size):
    return generate_addition_data(batch_size, max_digits=2)

losses, accs = train_model(add_model, gen_add_2digit, num_epochs=500, lr=0.0005)
plot_training(losses, accs, "Addition Task (2 digits)")

In [None]:
# テスト
print("Addition Task Test Results:")
print("=" * 50)

add_model.eval()
src, tgt_in, tgt_out = generate_addition_data(8, max_digits=2)
src = src.to(device)

with torch.no_grad():
    generated = add_model.greedy_decode(src, max_len=6, start_token_id=START_IDX, end_token_id=END_IDX)

correct_count = 0
for i in range(8):
    src_t = src[i].cpu().tolist()
    plus_pos = src_t.index(PLUS_IDX) if PLUS_IDX in src_t else len(src_t)//2
    a = tokens_to_num(src_t[:plus_pos])
    b = tokens_to_num(src_t[plus_pos+1:])
    expected = a + b
    
    gen_tokens = [t for t in generated[i].tolist() if t not in [START_IDX, END_IDX, PAD_IDX]]
    predicted = tokens_to_num(gen_tokens)
    
    match = expected == predicted
    if match:
        correct_count += 1
    mark = "✓" if match else "✗"
    print(f"{mark} {a:2d} + {b:2d} = {predicted:3d}  (expected: {expected})")

print(f"\nAccuracy: {correct_count}/8 = {correct_count/8:.1%}")

---
## 5. 簡易翻訳タスク

おもちゃの「言語」間での翻訳タスク。

**言語A（数字）→ 言語B（アルファベット相当のトークン）**

```
ルール例:
  3 -> 13 (3+10)
  5 -> 15 (5+10)
  複数トークンの場合は各要素を変換

入力:  [3, 5, 7]
出力:  [13, 15, 17]
```

In [None]:
def generate_translation_data(batch_size, seq_len=6, src_vocab=15, offset=10):
    """
    簡易翻訳タスクのデータ生成
    ルール: 各トークンに offset を足す
    """
    tokens = torch.randint(3, src_vocab, (batch_size, seq_len))
    src = tokens.clone()
    translated = tokens + offset  # 翻訳ルール
    tgt_input = torch.cat([torch.full((batch_size, 1), START_IDX), translated], dim=1)
    tgt_output = torch.cat([translated, torch.full((batch_size, 1), END_IDX)], dim=1)
    return src, tgt_input, tgt_output

# テストデータを表示
src, tgt_in, tgt_out = generate_translation_data(1)
print("Translation Task Example:")
print(f"  Source (Lang A):  {src[0].tolist()}")
print(f"  Target (Lang B):  {[t for t in tgt_out[0].tolist() if t != END_IDX]}")
print(f"  Rule: each token + 10")

In [None]:
# 翻訳タスク用モデル（ソースとターゲットの語彙サイズが異なる）
trans_model = Transformer(
    src_vocab_size=15,   # 言語A: 3-14
    tgt_vocab_size=25,   # 言語B: 13-24
    d_model=64, num_heads=4,
    num_encoder_layers=2, num_decoder_layers=2,
    d_ff=256,
).to(device)

print(f"Parameters: {count_parameters(trans_model):,}")

# 学習
losses, accs = train_model(trans_model, generate_translation_data, num_epochs=150)
plot_training(losses, accs, "Translation Task")

In [None]:
# テスト
print("Translation Task Test Results:")
print("=" * 50)

trans_model.eval()
src, tgt_in, tgt_out = generate_translation_data(5)
src_device = src.to(device)

with torch.no_grad():
    generated = trans_model.greedy_decode(src_device, max_len=10, start_token_id=START_IDX, end_token_id=END_IDX)

for i in range(5):
    src_t = src[i].tolist()
    expected = [t for t in tgt_out[i].tolist() if t != END_IDX]
    gen_t = [t for t in generated[i].tolist() if t not in [START_IDX, END_IDX, PAD_IDX]]
    
    match = expected == gen_t
    mark = "✓" if match else "✗"
    print(f"{mark} Source:   {src_t}")
    print(f"  Expected: {expected}")
    print(f"  Output:   {gen_t}")
    print()

---
## タスク難易度の比較

In [None]:
# 各タスクの最終精度を比較
task_results = {
    'Copy': accs[-1] if 'copy_model' in dir() else 0,
    'Reverse': 0,
    'Sort': 0,
    'Addition': 0,
    'Translation': 0,
}

# 各モデルで短いテストを実行
def quick_test(model, data_gen, n=50):
    model.eval()
    correct = 0
    for _ in range(n):
        results = test_model(model, data_gen, num_samples=1)
        if results[0][3]:  # match
            correct += 1
    return correct / n

print("Testing each model...")
task_results['Copy'] = quick_test(copy_model, generate_copy_data)
task_results['Reverse'] = quick_test(reverse_model, generate_reverse_data)
task_results['Sort'] = quick_test(sort_model, generate_sort_data)
task_results['Translation'] = quick_test(trans_model, generate_translation_data)

# 加算は別途テスト
add_model.eval()
correct = 0
for _ in range(50):
    src, tgt_in, tgt_out = generate_addition_data(1, max_digits=2)
    with torch.no_grad():
        gen = add_model.greedy_decode(src.to(device), max_len=6, start_token_id=START_IDX, end_token_id=END_IDX)
    src_t = src[0].tolist()
    plus_pos = src_t.index(PLUS_IDX) if PLUS_IDX in src_t else len(src_t)//2
    a = tokens_to_num(src_t[:plus_pos])
    b = tokens_to_num(src_t[plus_pos+1:])
    expected = a + b
    predicted = tokens_to_num([t for t in gen[0].tolist() if t not in [START_IDX, END_IDX, PAD_IDX]])
    if expected == predicted:
        correct += 1
task_results['Addition'] = correct / 50

print("\nTask Difficulty Comparison (Test Accuracy):")
print("=" * 40)
for task, acc in sorted(task_results.items(), key=lambda x: -x[1]):
    bar = "█" * int(acc * 20)
    print(f"{task:12s}: {acc:5.1%} {bar}")

In [None]:
# バーチャートで可視化
fig, ax = plt.subplots(figsize=(10, 5))

tasks = list(task_results.keys())
accs = list(task_results.values())
colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(tasks)))

bars = ax.bar(tasks, accs, color=colors)
ax.set_ylabel('Test Accuracy')
ax.set_title('Task Difficulty Comparison')
ax.set_ylim(0, 1.1)
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)

# 値をバーの上に表示
for bar, acc in zip(bars, accs):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02,
            f'{acc:.0%}', ha='center', va='bottom', fontsize=11)

plt.tight_layout()
plt.show()

## まとめ

### タスクごとの特徴

| タスク | 難易度 | 必要な能力 |
|--------|--------|------------|
| **Copy** | 簡単 | 入力をそのまま出力 |
| **Translation** | 簡単 | 単純なルール変換 |
| **Reverse** | 中程度 | 位置情報の理解と逆順生成 |
| **Sort** | 難しい | 全体の比較と順序付け |
| **Addition** | 難しい | 桁上がりの概念、算術演算 |

### 学習のポイント

1. **Transformerの汎用性**: 同じアーキテクチャで多様なタスクを学習可能
2. **タスクの複雑さ**: 単純なパターン変換より、論理的推論が必要なタスクは難しい
3. **学習時間**: 複雑なタスクほど多くのエポックが必要
4. **モデルサイズ**: タスクの複雑さに応じてモデルサイズを調整

### 次のステップ
- 実際の自然言語データ（翻訳、要約など）での学習
- より長いシーケンスへの対応
- Beam Searchによる生成品質の向上