In [1]:
import torch
import torch.nn.functional as F
import time

In [2]:
def attention(X, W_q, W_k, W_v):
    # X: [b, n, d_model]
    # W_q, W_k, W_v: [d_model, d]
    Q = X @ W_q              # [b, n, d]
    K = X @ W_k              # [b, n, d]
    V = X @ W_v              # [b, n, d]
    s = 1 / torch.sqrt(torch.tensor(Q.shape[-1], dtype=torch.float32))
    A = (Q @ K.transpose(-2, -1)) * s   # [b, n, n]
    return F.softmax(A, dim=-1) @ V     # [b, n, d]

In [3]:
def diff_attn(X, W_q, W_k, W_v, lam):
    # X: [b, n, d_model]
    # W_q, W_k: [d_model, 2*d] -> split 후 각각 [b, n, d]
    # W_v: [d_model, 2*d] -> 출력 [b, n, 2*d]
    Q = X @ W_q  # [b, n, 2*d]
    K = X @ W_k  # [b, n, 2*d]
    
    # split: 마지막 차원을 절반으로 분할
    d = Q.shape[-1] // 2
    Q1, Q2 = Q[..., :d], Q[..., d:]
    K1, K2 = K[..., :d], K[..., d:]
    V = X @ W_v             # [b, n, 2*d]
    
    s = 1 / torch.sqrt(torch.tensor(d, dtype=torch.float32))
    A1 = (Q1 @ K1.transpose(-2, -1)) * s  # [b, n, n]
    A2 = (Q2 @ K2.transpose(-2, -1)) * s  # [b, n, n]
    
    # 두 attention score의 softmax를 계산하고 λ 배 조합한 후 V와 곱함
    return (F.softmax(A1, dim=-1) - lam * F.softmax(A2, dim=-1)) @ V  # [b, n, 2*d]

In [8]:
def model_init():
    # 파라미터 설정
    batch_size = 2
    n = 10
    d_model = 32
    d = 16  # standard attention의 출력 차원 (W_q, W_k, W_v: [d_model, d])
    
    return batch_size, n, d_model, d

In [9]:
# 테스트 코드
def test_transformer():
    torch.manual_seed(0)
    batch_size, n, d_model, d = model_init()
    
    # 일반 Transformer용 가중치: [d_model, d]
    W_q_std = torch.randn(d_model, d)
    W_k_std = torch.randn(d_model, d)
    W_v_std = torch.randn(d_model, d)
    
    # 입력 X: [batch_size, n, d_model]
    X = torch.randn(batch_size, n, d_model)
    
    # 일반 Attention 테스트
    start = time.time()
    out_std = attention(X, W_q_std, W_k_std, W_v_std)
    elapsed_std = time.time() - start
    print("일반 Attention 출력 shape:", out_std.shape)
    print("일반 Attention 실행 시간:", elapsed_std)

In [13]:
# 테스트 코드
def test_diff():
    torch.manual_seed(0)
    batch_size, n, d_model, d = model_init()
    
    # Diff Transformer용 가중치:
    # W_q, W_k: [d_model, 2*d] (출력 후 두 부분으로 분할)
    # W_v: [d_model, 2*d]
    W_q_diff = torch.randn(d_model, d)
    W_k_diff = torch.randn(d_model, d)
    W_v_diff = torch.randn(d_model, d)
    
    # 입력 X: [batch_size, n, d_model]
    X = torch.randn(batch_size, n, d_model)
    
    # Diff Attention 테스트 (λ = 0.5 예시)
    lam = 0.5
    start = time.time()
    out_diff = diff_attn(X, W_q_diff, W_k_diff, W_v_diff, lam)
    elapsed_diff = time.time() - start
    print("Diff Attention 출력 shape:", out_diff.shape)
    print("Diff Attention 실행 시간:", elapsed_diff)

In [14]:
test_transformer()

일반 Attention 출력 shape: torch.Size([2, 10, 16])
일반 Attention 실행 시간: 0.0008032321929931641


In [15]:
test_diff()

Diff Attention 출력 shape: torch.Size([2, 10, 16])
Diff Attention 실행 시간: 0.0007767677307128906


In [9]:
# 테스트 코드
def simple_test_transformer():
    torch.manual_seed(0)
    # 파라미터 설정
    batch_size = 2
    n = 10
    d_model = 32
    d = 16  # standard attention의 출력 차원 (W_q, W_k, W_v: [d_model, d])
    
    # 일반 Transformer용 가중치: [d_model, d]
    W_q_std = torch.randn(d_model, d)
    W_k_std = torch.randn(d_model, d)
    W_v_std = torch.randn(d_model, d)
    
    # 입력 X: [batch_size, n, d_model]
    X = torch.randn(batch_size, n, d_model)
    
    # 일반 Attention 테스트
    out_std = attention(X, W_q_std, W_k_std, W_v_std)

In [13]:
# 테스트 코드
def simple_test_diff():
    torch.manual_seed(0)
    # 파라미터 설정
    batch_size = 2
    n = 10
    d_model = 32
    d = 16  # standard attention의 출력 차원 (W_q, W_k, W_v: [d_model, d])
    
    # Diff Transformer용 가중치:
    # W_q, W_k: [d_model, 2*d] (출력 후 두 부분으로 분할)
    # W_v: [d_model, 2*d]
    W_q_diff = torch.randn(d_model, d)
    W_k_diff = torch.randn(d_model, d)
    W_v_diff = torch.randn(d_model, d)
    
    # 입력 X: [batch_size, n, d_model]
    X = torch.randn(batch_size, n, d_model)
    
    # Diff Attention 테스트 (λ = 0.5 예시)
    lam = 0.5
    out_diff = diff_attn(X, W_q_diff, W_k_diff, W_v_diff, lam)

In [11]:
%%timeit

simple_test_transformer()

727 μs ± 3.41 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [14]:
%%timeit

simple_test_diff()

802 μs ± 5.14 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
