In [18]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
import math
import time

# Timing Decorater Function

In [19]:
def timing_decorator(func):
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"행렬 연산 {func.__name__} 실행 시간: {end_time - start_time:.4f}초")
        return result
    return wrapper


# Self-Head Attention (Scaled Dot Attention)

In [40]:
class selfAttention(nn.Module):
    def __init__(self, embed_dim, attention_dim):
        super().__init__()
        # embedding dim 입력을 받아서, attention dim 으로 변환
        self.embed_size = embed_dim
        self.attention_dim = attention_dim

        # Q, K, V 행렬 변환 레이어.
        self.W_q = nn.Linear(embed_dim, attention_dim)
        self.W_k = nn.Linear(embed_dim, attention_dim)
        self.W_v = nn.Linear(embed_dim, attention_dim)
    
    # 데코레이터 사용해서 행렬 계산 시간 측정.
    @timing_decorator
    def forward(self, x):
        # Q, K, V 행렬 계산
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # Q, K 행렬의 내적 계산 <- arguument (-2, -1) 은 tensor의 뒤 두 차원을 뒤집어서 계산 / batch 고려
        attention_score = Q @ K.transpose(-2, -1)
        # 내적 결과를 정규화
        attention_score = attention_score / math.sqrt(self.attention_dim)
        # 정규화된 내적 결과에 대한 소프트맥스 함수 적용, 행 적용.
        attention_score = F.softmax(attention_score, dim=-1)

        # 소프트맥스 결과와 V 행렬의 곱 계산
        attention_score = attention_score @ V
        return attention_score


# claude 3.7 최적화 버전

In [76]:
class selfAttention(nn.Module):
    def __init__(self, embed_dim, attention_dim, dropout_rate=0.1):
        super().__init__()
        # embedding dim 입력을 받아서, attention dim 으로 변환
        self.embed_size = embed_dim
        self.attention_dim = attention_dim
        self.scale = math.sqrt(attention_dim)
        
        # 1. 행렬 연산 통합 - 하나의 선형 레이어로 Q, K, V 동시 계산
        # 기존에는 Q, K, V를 각각 접근하여 계산해, 메모리 접근이 비효율적. -> 하나의 선형 레이어로 계산
        # 큰 레이어로 계산하고 나누는 형식으로 사용.
        self.qkv_proj = nn.Linear(embed_dim, 3 * attention_dim)
        
        # 2. 출력 프로젝션 추가 (원래 차원으로 복원)
        self.out_proj = nn.Linear(attention_dim, embed_dim)
        
        # 3. 드롭아웃 추가 (정규화 효과)
        self.dropout = nn.Dropout(dropout_rate)
    
    @timing_decorator
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # 4. 통합된 QKV 계산 및 분리
        qkv = self.qkv_proj(x).chunk(3, dim=-1)
        Q, K, V = qkv[0], qkv[1], qkv[2]
        
        # 5. 행렬 곱셈 최적화 (@ 연산자 사용)
        attention_score = (Q @ K.transpose(-2, -1)) / self.scale
        
        # 6. 마스킹 지원 추가
        if mask is not None:
            attention_score = attention_score.masked_fill(mask == 0, -1e9)
        
        # 7. 소프트맥스 및 드롭아웃
        attention_weights = self.dropout(F.softmax(attention_score, dim=-1))
        
        # 8. 가중치 적용 및 출력 프로젝션
        output = attention_weights @ V
        output = self.out_proj(output)
        
        return output

In [35]:
model = selfAttention(embed_dim=10000, attention_dim=100)
model(torch.randn(10000, 10000))

tensor([[ 0.0151, -0.0001, -0.0090,  ...,  0.0043, -0.0032,  0.0113],
        [ 0.0144, -0.0036, -0.0035,  ...,  0.0020,  0.0011,  0.0138],
        [ 0.0213, -0.0009, -0.0066,  ...,  0.0021, -0.0019,  0.0137],
        ...,
        [ 0.0176, -0.0011, -0.0092,  ...,  0.0004, -0.0031,  0.0173],
        [ 0.0164, -0.0026, -0.0069,  ...,  0.0043, -0.0064,  0.0152],
        [ 0.0153, -0.0026,  0.0012,  ...,  0.0024, -0.0047,  0.0076]],
       grad_fn=<MmBackward0>)

# Multi-Head Attention

In [41]:
class multiHeadAttention(nn.Module):
    def __init__(self, embed_dim, head_num):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.attention_dim = embed_dim // head_num
        self.head_num = head_num
        
        self.heads = nn.ModuleList([selfAttention(self.embed_dim, self.attention_dim) for _ in range(head_num)])

    @timing_decorator
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)
        


In [42]:
model = multiHeadAttention(embed_dim=10000, head_num=4)
model(torch.randn(10000, 10000))

행렬 연산 forward 실행 시간: 1.7030초
행렬 연산 forward 실행 시간: 1.8105초
행렬 연산 forward 실행 시간: 1.8116초
행렬 연산 forward 실행 시간: 1.8113초
행렬 연산 forward 실행 시간: 7.1626초


tensor([[ 0.0019, -0.0148, -0.0056,  ..., -0.0120, -0.0061, -0.0060],
        [ 0.0140, -0.0135, -0.0082,  ..., -0.0125, -0.0089, -0.0052],
        [ 0.0053, -0.0112, -0.0054,  ..., -0.0130, -0.0123, -0.0062],
        ...,
        [ 0.0062, -0.0128, -0.0038,  ..., -0.0123, -0.0078, -0.0055],
        [ 0.0053, -0.0093, -0.0071,  ..., -0.0128, -0.0146, -0.0077],
        [ 0.0093, -0.0145, -0.0035,  ..., -0.0134, -0.0059, -0.0104]],
       grad_fn=<CatBackward0>)

In [44]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.fc1 = nn.Linear(embed_dim, embed_dim)
        self.fc2 = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

model = FeedForward(embed_dim=100)
model(torch.randn(100, 100))


tensor([[ 0.3624,  0.3854,  0.0270,  ...,  0.7548, -0.0454,  0.1277],
        [-0.0234,  0.0858, -0.2067,  ...,  0.1077,  0.1831,  0.1361],
        [ 0.1950,  0.2922, -0.0998,  ...,  0.0696,  0.0856,  0.0088],
        ...,
        [-0.0776,  0.4409,  0.2175,  ..., -0.1747, -0.0894,  0.0300],
        [ 0.0319,  0.3808, -0.3594,  ..., -0.1441,  0.4355, -0.0837],
        [ 0.1455,  0.0110,  0.0431,  ...,  0.0055,  0.1723, -0.0053]],
       grad_fn=<AddmmBackward0>)

In [72]:
class PositionalEncoding(nn.Module):
    @timing_decorator
    def __init__(self, embed_dim, max_len=5000):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_len = max_len
        
        self.pe = torch.zeros(max_len, embed_dim)
        for pos in range(max_len):
            for i in range(embed_dim):
                if i % 2 == 0:
                    self.pe[pos, i] = math.sin(pos / (10000 ** (i/embed_dim)))
                else:
                    self.pe[pos, i] = math.cos(pos / (10000 ** ((i-1)/embed_dim)))
    def forward(self, x):
        return x + self.pe[:x.size(1), :]
    

In [73]:
embed_dim = 1000
model = PositionalEncoding(embed_dim=embed_dim)
model(torch.randn(embed_dim, embed_dim))

행렬 연산 __init__ 실행 시간: 12.9882초


tensor([[ 1.5417,  1.8187,  0.8034,  ..., -0.1625,  0.4761,  0.5411],
        [ 1.9665,  1.3288,  2.1589,  ..., -0.7545, -0.5838,  0.7737],
        [ 1.1568,  1.0134,  0.8806,  ..., -0.2044, -0.1583,  1.6796],
        ...,
        [-1.3900,  0.1110, -0.7185,  ...,  1.5557,  0.6896,  0.6653],
        [ 1.1296,  1.4143, -0.0995,  ...,  2.5482, -0.4045, -0.0067],
        [-1.0297,  2.8238, -0.1389,  ...,  1.6958,  0.6954,  0.0750]])

In [74]:
class PositionalEncoding(nn.Module):
    @timing_decorator
    def __init__(self, embed_dim, max_len=5000):
        super().__init__()
        
        # 2. 벡터화 연산으로 변경 (중첩 for문 제거)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        
        pe = torch.zeros(max_len, embed_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # 3. 등록된 버퍼로 변환 (모델 저장 시 함께 저장됨)
        self.register_buffer('pe', pe.unsqueeze(0))
        
        # 4. 임베딩 차원 저장 (디버깅 및 문서화 목적)
        self.embed_dim = embed_dim
    
    def forward(self, x):
        # 5. 입력 시퀀스 길이에 맞게 위치 인코딩 잘라서 사용
        # x 형태: (batch_size, seq_len, embed_dim)
        return x + self.pe[:, :x.size(1), :]

In [75]:
embed_dim = 1000
model = PositionalEncoding(embed_dim=embed_dim)
model(torch.randn(embed_dim, embed_dim))

행렬 연산 __init__ 실행 시간: 0.0059초


tensor([[[-0.7073,  0.8616,  0.9822,  ..., -0.7980,  1.0234,  2.5579],
         [ 0.7515,  2.0104,  1.0040,  ...,  2.6168,  0.8642,  0.1621],
         [ 0.0559,  0.9550, -0.1554,  ...,  1.8842, -0.9640,  1.5610],
         ...,
         [-1.6885, -0.9095, -1.2558,  ...,  0.6043,  0.0119,  0.6931],
         [ 1.2091,  0.1611, -0.5896,  ...,  1.2426,  0.7322,  1.8975],
         [ 0.8219,  1.6510,  1.0216,  ...,  1.0777,  1.2984, -0.1192]]])