# Transformer Reference Notebook

이 노트북은 기본 Transformer 구조를 단계별로 다시 정리한 참고용 버전입니다. 기존 실습 노트북과 나란히 비교하면서 구성 요소와 학습 루프를 살펴볼 수 있습니다.

## 진행 순서
1. 필수 라이브러리와 설정 값 정의
2. SentencePiece 기반 토크나이저 & 데이터 파이프라인
3. 마스킹, 포지셔널 인코딩, 어텐션 블록
4. Encoder / Decoder / 전체 Transformer 구성
5. 학습 루프와 생성(Greedy Decoding) 예시

In [26]:
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple

import pandas as pd
import sentencepiece as spm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

SEED = 42
torch.manual_seed(SEED)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


### 설정 값 (Config)
필요한 경로 및 하이퍼파라미터를 한 곳에서 관리합니다. SentencePiece 학습 시 사용할 텍스트 파일도 미리 지정합니다.

In [29]:
@dataclass
class Config:
    dataset_csv: str = "/mnt/nas/jayden_code/Codeit_Practice/Part3_01_자연어처리_실습/train.csv"
    sp_input_txt: str = "/mnt/nas/jayden_code/Codeit_Practice/Part3_02_TransFormer_실습/train.txt"
    sp_model_prefix: str = "Part3_02_TransFormer_실습/model/spm_reference"
    vocab_size: int = 2000
    character_coverage: float = 0.9995
    max_length: int = 40
    num_layers: int = 2
    d_model: int = 128
    num_heads: int = 4
    d_ff: int = 256
    dropout: float = 0.1
    pad_id: int = 3
    pad_piece: str = "<pad>"
    bos_id: int = 1
    eos_id: int = 2

CFG = Config()
Path(CFG.sp_model_prefix).parent.mkdir(parents=True, exist_ok=True)
print(CFG)

Config(dataset_csv='/mnt/nas/jayden_code/Codeit_Practice/Part3_01_자연어처리_실습/train.csv', sp_input_txt='/mnt/nas/jayden_code/Codeit_Practice/Part3_02_TransFormer_실습/train.txt', sp_model_prefix='Part3_02_TransFormer_실습/model/spm_reference', vocab_size=2000, character_coverage=0.9995, max_length=40, num_layers=2, d_model=128, num_heads=4, d_ff=256, dropout=0.1, pad_id=3, pad_piece='<pad>', bos_id=1, eos_id=2)


## SentencePiece & 데이터 로딩
SentencePiece 모델이 없다면 `train_sentencepiece`를 먼저 실행합니다. 이미 학습된 모델이 있다면 바로 불러오면 됩니다.

In [31]:
def train_sentencepiece(cfg: Config):
    sp_input = Path(cfg.sp_input_txt)
    if not sp_input.exists():
        df = pd.read_csv(cfg.dataset_csv)
        with sp_input.open('w', encoding='utf-8') as f:
            for text in df['HS01']:
                f.write(str(text).strip() + '\n')
    spm.SentencePieceTrainer.Train(
        input=str(sp_input),
        model_prefix=cfg.sp_model_prefix,
        vocab_size=cfg.vocab_size,
        character_coverage=cfg.character_coverage,
        model_type='unigram',
        bos_id=cfg.bos_id,
        eos_id=cfg.eos_id,
        pad_id=cfg.pad_id,
        pad_piece=cfg.pad_piece
    )


def load_sentencepiece(cfg: Config):
    model_path = Path(cfg.sp_model_prefix + '.model')
    if not model_path.exists():
        print('SentencePiece 모델이 없어 새로 학습합니다...')
        train_sentencepiece(cfg)
    sp = spm.SentencePieceProcessor()
    sp.load(str(model_path))
    return sp


sp = load_sentencepiece(CFG)
print('vocab size:', sp.get_piece_size())

vocab size: 2000


### 토큰화 & 데이터셋
문장을 `[BOS] + 토큰 + [EOS]` 형태로 만들고, 최대 길이에 맞게 패딩합니다.

In [32]:
def encode_sentence(sp, text: str, cfg: Config) -> torch.Tensor:
    tokens = sp.encode(text, out_type=int)
    tokens = tokens[: cfg.max_length - 2]
    tokens = [cfg.bos_id] + tokens + [cfg.eos_id]
    pad_len = cfg.max_length - len(tokens)
    if pad_len > 0:
        tokens.extend([cfg.pad_id] * pad_len)
    else:
        tokens = tokens[: cfg.max_length]
    return torch.tensor(tokens, dtype=torch.long)


class CounselingDataset(Dataset):
    def __init__(self, cfg: Config, sp):
        df = pd.read_csv(cfg.dataset_csv)
        self.pairs = list(zip(df['HS01'].astype(str), df['SS01'].astype(str)))
        self.sp = sp
        self.cfg = cfg

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        encoder_input = encode_sentence(self.sp, src, self.cfg)
        decoder_input = encode_sentence(self.sp, tgt, self.cfg)
        decoder_target = torch.roll(decoder_input, shifts=-1, dims=0)
        decoder_target[-1] = self.cfg.pad_id
        return encoder_input, decoder_input, decoder_target


dataset = CounselingDataset(CFG, sp)
print('dataset size:', len(dataset))

dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
encoder_sample, decoder_sample, target_sample = next(iter(dataloader))
print(encoder_sample.shape, decoder_sample.shape, target_sample.shape)

dataset size: 51628
torch.Size([8, 40]) torch.Size([8, 40]) torch.Size([8, 40])


## 마스킹 & 포지셔널 인코딩

In [33]:
def create_padding_mask(seq: torch.Tensor, pad_id: int) -> torch.Tensor:
    return (seq == pad_id).unsqueeze(1).unsqueeze(2)


def create_look_ahead_mask(size: int, device=None) -> torch.Tensor:
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
    mask = torch.triu(torch.ones((size, size), device=device), diagonal=1).bool()
    return mask


def combine_masks(tgt: torch.Tensor, pad_mask: torch.Tensor) -> torch.Tensor:
    seq_len = tgt.size(1)
    look_ahead = create_look_ahead_mask(seq_len, tgt.device)
    look_ahead = look_ahead.unsqueeze(0).unsqueeze(0)
    return pad_mask | look_ahead

In [34]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:, : x.size(1)]

## 어텐션 블록

In [35]:
def scaled_dot_product_attention(q, k, v, mask=None):
    dk = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dk)
    if mask is not None:
        scores = scores.masked_fill(mask == 1, float('-inf'))
    attn = torch.softmax(scores, dim=-1)
    output = torch.matmul(attn, v)
    return output, attn

In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.depth = d_model // num_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.dense = nn.Linear(d_model, d_model)

    def split_heads(self, x):
        bsz, seq_len, d_model = x.shape
        x = x.view(bsz, seq_len, self.num_heads, self.depth)
        return x.transpose(1, 2)

    def forward(self, v, k, q, mask):
        q = self.split_heads(self.wq(q))
        k = self.split_heads(self.wk(k))
        v = self.split_heads(self.wv(v))

        if mask is not None and mask.dim() == 3:
            mask_ = mask.unsqueeze(1)
        else:
            mask_ = mask

        scaled_attention, attn = scaled_dot_product_attention(q, k, v, mask_)
        scaled_attention = scaled_attention.transpose(1, 2).contiguous()
        concat_attention = scaled_attention.view(scaled_attention.size(0), -1, self.num_heads * self.depth)
        return self.dense(concat_attention), attn

In [37]:
class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        return self.linear2(x)

## Encoder / Decoder 레이어

In [38]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_out, _ = self.mha(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_out))
        ffn_out = self.ffn(x)
        return self.norm2(x + self.dropout2(ffn_out))

In [39]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, x, enc_output, look_ahead_mask, padding_mask):
        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)
        x = self.norm1(x + self.dropout1(attn1))

        attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, x, padding_mask)
        x = self.norm2(x + self.dropout2(attn2))

        ffn_out = self.ffn(x)
        x = self.norm3(x + self.dropout3(ffn_out))
        return x, attn_weights_block1, attn_weights_block2

## Encoder / Decoder / Transformer

In [40]:
class Encoder(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.embedding = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_encoding = PositionalEncoding(cfg.d_model)
        self.layers = nn.ModuleList([
            EncoderLayer(cfg.d_model, cfg.num_heads, cfg.d_ff, cfg.dropout)
            for _ in range(cfg.num_layers)
        ])
        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, x, mask):
        x = self.embedding(x) * math.sqrt(x.size(-1))
        x = self.pos_encoding(x)
        x = self.dropout(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x

In [41]:
class Decoder(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.embedding = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_encoding = PositionalEncoding(cfg.d_model)
        self.layers = nn.ModuleList([
            DecoderLayer(cfg.d_model, cfg.num_heads, cfg.d_ff, cfg.dropout)
            for _ in range(cfg.num_layers)
        ])
        self.dropout = nn.Dropout(cfg.dropout)

    def forward(self, x, enc_output, look_ahead_mask, padding_mask):
        x = self.embedding(x) * math.sqrt(x.size(-1))
        x = self.pos_encoding(x)
        x = self.dropout(x)
        attn_weights = {}
        for idx, layer in enumerate(self.layers):
            x, block1, block2 = layer(x, enc_output, look_ahead_mask, padding_mask)
            attn_weights[f'decoder_layer{idx+1}_block1'] = block1
            attn_weights[f'decoder_layer{idx+1}_block2'] = block2
        return x, attn_weights

In [42]:
class Transformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.encoder = Encoder(cfg)
        self.decoder = Decoder(cfg)
        self.final_layer = nn.Linear(cfg.d_model, cfg.vocab_size)

    def forward(self, enc_inp, dec_inp, enc_padding_mask, look_ahead_mask, dec_padding_mask):
        enc_output = self.encoder(enc_inp, enc_padding_mask)
        dec_output, attention_weights = self.decoder(dec_inp, enc_output, look_ahead_mask, dec_padding_mask)
        logits = self.final_layer(dec_output)
        return logits, attention_weights

    @torch.no_grad()
    def greedy_decode(self, src, max_len):
        self.eval()
        enc_padding_mask = create_padding_mask(src, self.cfg.pad_id)
        enc_output = self.encoder(src, enc_padding_mask)

        dec_input = torch.full((src.size(0), 1), self.cfg.bos_id, dtype=torch.long, device=src.device)
        for _ in range(max_len - 1):
            dec_padding_mask = create_padding_mask(dec_input, self.cfg.pad_id)
            combined_mask = combine_masks(dec_input, dec_padding_mask)
            predictions, _ = self.decoder(dec_input, enc_output, combined_mask, enc_padding_mask)
            logits = self.final_layer(predictions[:, -1:, :])
            next_token = torch.argmax(logits, dim=-1)
            dec_input = torch.cat([dec_input, next_token], dim=1)
            if (next_token == self.cfg.eos_id).all():
                break
        return dec_input

## 학습 루프 Helper

In [43]:
def build_model(cfg: Config) -> Transformer:
    model = Transformer(cfg)
    return model.to(device)


def loss_function(predictions, targets, pad_id):
    predictions = predictions.view(-1, predictions.size(-1))
    targets = targets.view(-1)
    criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
    return criterion(predictions, targets)


def train_one_epoch(model, dataloader, optimizer, cfg: Config):
    model.train()
    total_loss = 0
    for enc_inp, dec_inp, dec_target in dataloader:
        enc_inp, dec_inp, dec_target = enc_inp.to(device), dec_inp.to(device), dec_target.to(device)
        enc_padding_mask = create_padding_mask(enc_inp, cfg.pad_id)
        dec_padding_mask = create_padding_mask(dec_inp, cfg.pad_id)
        combined_mask = combine_masks(dec_inp, dec_padding_mask)

        logits, _ = model(enc_inp, dec_inp, enc_padding_mask, combined_mask, enc_padding_mask)
        loss = loss_function(logits, dec_target, cfg.pad_id)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

## 디코딩 예시
학습된 체크포인트가 있다면 불러온 뒤 `greedy_decode`로 간단히 결과를 볼 수 있습니다. 아래 예시는 랜덤 초기화 모델이므로 의미 있는 문장을 만들지는 않습니다.

In [44]:
model = build_model(CFG)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, betas=(0.9, 0.98), eps=1e-9)

enc_inp, dec_inp, dec_target = next(iter(dataloader))
enc_inp, dec_inp = enc_inp.to(device), dec_inp.to(device)
enc_mask = create_padding_mask(enc_inp, CFG.pad_id)
dec_mask = create_padding_mask(dec_inp, CFG.pad_id)
combined = combine_masks(dec_inp, dec_mask)
logits, attn = model(enc_inp, dec_inp, enc_mask, combined, enc_mask)
print('logits shape:', logits.shape)

sample_src = enc_inp[:1]
output_tokens = model.greedy_decode(sample_src, CFG.max_length)
print('generated token ids:', output_tokens.squeeze(0).tolist())

logits shape: torch.Size([8, 40, 2000])
generated token ids: [1, 1026, 1572, 750, 1106, 621, 1785, 1617, 1898, 22, 1728, 1975, 1208, 52, 519, 951, 1672, 923, 1254, 161, 1203, 814, 28, 1844, 1264, 77, 369, 175, 1059, 969, 1037, 1297, 351, 1036, 1078, 1718, 1413, 1782, 1161, 34]
