In [17]:
import torch
from torch import nn

# 랜덤 시드 고정
torch.manual_seed(0)
B, T, d_x, d_h, L = 2, 5, 3, 4, 1
x = torch.randn(B, T, d_x)
h0 = torch.randn(1, B, d_h)
c0 = torch.randn(1, B, d_h)

In [None]:
# LSTM
lstm = nn.LSTM(input_size=d_x, hidden_size=d_h, num_layers=L,
               bias=True, batch_first=True, bidirectional=False)
# 입력: x of shape [batch, seq_len, d_x]
# 초기 hidden/cell: (h0, c0) each of shape [L, batch, d_h]
out, (hn, cn) = lstm(x, (h0, c0))
# out: [batch, seq_len, d_h];  hn, cn: [L, batch, d_h]

# GRU
gru = nn.GRU(input_size=d_x, hidden_size=d_h, num_layers=L,
             bias=True, batch_first=True, bidirectional=False)
# out: [batch, seq_len, d_h];  hn: [L, batch, d_h]


In [19]:
class MyLSTMCell(nn.Module):
    def __init__(self, d_x, d_h):
        super().__init__()
        self.W = nn.Linear(d_x + d_h, 4 * d_h, bias=False)

    def forward(self, x_t, hc_prev):
        h_prev, c_prev = hc_prev
        concat = torch.cat([x_t, h_prev], dim=-1)  # [B, d_x + d_h]
        f,i,ctilde,o = self.W(concat).chunk(4, dim=-1)
        f = torch.sigmoid(f); i = torch.sigmoid(i)
        ctilde = torch.tanh(ctilde); o = torch.sigmoid(o)
        c_t = f * c_prev + i * ctilde
        h_t = o * torch.tanh(c_t)
        return h_t, c_t

class MyLSTM(nn.Module):
    def __init__(self, d_x, d_h):
        super().__init__()
        self.cell = MyLSTMCell(d_x, d_h)

    def forward(self, x, hc0):
        # x: [B, T, d_x], hc0: ([B,d_h],[B,d_h])
        h, c = hc0
        outputs = []
        for t in range(x.size(1)):
            h, c = self.cell(x[:,t], (h, c))
            outputs.append(h.unsqueeze(1))
        return torch.cat(outputs, dim=1), (h, c)


In [20]:
mylstm = MyLSTM(d_x, d_h)

lstm = nn.LSTM(d_x, d_h, num_layers=1, bias=False, batch_first=True)
lstm.weight_ih_l0.data[:] = torch.cat([mylstm.cell.W.weight[:d_h],
                                       mylstm.cell.W.weight[d_h:2*d_h],
                                       mylstm.cell.W.weight[2*d_h:3*d_h],
                                       mylstm.cell.W.weight[3*d_h:]], dim=0)
lstm.weight_hh_l0.data[:] = torch.zeros_like(lstm.weight_hh_l0)  # myLSTMCell은 hh weight 없음

out1, (hn1, cn1) = mylstm(x, (h0.squeeze(0), c0.squeeze(0)))
out2, (hn2, cn2) = lstm(x, (h0, c0))

print("Max abs diff:", (out1 - out2).abs().max().item())

RuntimeError: The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 1.  Target sizes: [16, 3].  Tensor sizes: [16, 7]

In [1]:
import torch

def heinsen_associative_scan_log(log_coeffs, log_values):
    # log_coeffs: [B, T, D], log_values: [B, T+1, D]
    # 1) 누적합 계산 후 앞에 0 패딩
    a_star = log_coeffs.cumsum(dim=1)                              # [B, T, D]
    zero = torch.zeros_like(a_star[:, :1, :])                      # [B, 1, D]
    a_star_padded = torch.cat([zero, a_star], dim=1)               # [B, T+1, D]

    # 2) Hillis–Steele prefix-scan의 로그-누적
    log_h0_plus_b_star = (log_values - a_star_padded).logcumsumexp(dim=1)

    # 3) 최종 결과 복원
    log_h = a_star_padded + log_h0_plus_b_star
    return log_h.exp()



In [2]:
import triton
import triton.language as tl

@triton.jit
def row_scan_kernel(
    x_ptr, y_ptr, length_ptr, stride_ptr,
    BLOCK: tl.constexpr
):
    row_id = tl.program_id(0)
    length = tl.load(length_ptr)
    stride = tl.load(stride_ptr)
    offs = tl.arange(0, BLOCK)

    # 입력/출력 포인터
    base_ptr = x_ptr + row_id * stride
    mask = offs < length
    data = tl.load(base_ptr + offs, mask=mask)

    # Hillis–Steele
    step = 1
    while step < BLOCK:
        prev = tl.load(base_ptr + offs - step, mask=offs >= step)
        data = tl.where(offs >= step, data + prev, data)
        step *= 2

    tl.store(y_ptr + row_id * stride + offs, data, mask=mask)

In [3]:
def triton_scan(log_values: torch.Tensor, block_size: int = 128) -> torch.Tensor:
    B, Lp1, D = log_values.shape
    flat = log_values.permute(0,2,1).reshape(-1, Lp1).contiguous()
    out_flat = torch.empty_like(flat)

    length_ptr = torch.tensor([Lp1], device=flat.device, dtype=torch.int32)
    stride_ptr = torch.tensor([Lp1], device=flat.device, dtype=torch.int32)

    grid = (flat.shape[0],)
    row_scan_kernel[grid](
        flat, out_flat, length_ptr, stride_ptr,
        BLOCK=block_size, num_warps=4
    )

    return out_flat.view(B, D, Lp1).permute(0,2,1).contiguous()

In [8]:
from torch.utils.benchmark import Timer

# 데이터 준비
device = 'cuda'
B, T, D = 1, 4, 8
log_coeffs = torch.randn(B, T, D, device=device).abs()
log_values = torch.randn(B, T+1, D, device=device).abs()

# Warmup
torch_one = heinsen_associative_scan_log(log_coeffs, log_values)
triton_one = triton_scan(log_values)

torch_one.shape, triton_one.shape

(torch.Size([1, 5, 8]), torch.Size([1, 5, 8]))

In [10]:
torch_one[0][0]

tensor([3.2234, 2.7486, 1.6426, 3.3325, 1.6021, 3.9480, 2.9908, 1.9364],
       device='cuda:0')

In [11]:
triton_one[0][0]

tensor([1.1704, 1.0111, 0.4963, 1.2037, 0.4713, 1.3732, 1.0955, 0.6608],
       device='cuda:0')

In [12]:
# 타이머 설정
t1 = Timer(
    stmt="heinsen_associative_scan_log(log_coeffs, log_values)",
    globals=globals()
)
t2 = Timer(
    stmt="triton_scan(log_values)",
    globals=globals()
)

# 벤치마크 (50회)
result1 = t1.timeit(1000)
result2 = t2.timeit(1000)
result1 = result1.mean
result2 = result2.mean 
print(f"Heinsen scan: {result1:.4f} ms")
print(f"Triton scan: {result2:.4f} ms")
print(f"Speedup: {result1/result2:.2f}x")

Heinsen scan: 0.0001 ms
Triton scan: 0.0001 ms
Speedup: 0.64x


In [13]:
# 큰 데이터에 대해 테스트
B, T, D = 16, 512, 256
log_coeffs = torch.randn(B, T, D, device=device).abs()
log_values = torch.randn(B, T+1, D, device=device).abs()

# Warmup
_ = heinsen_associative_scan_log(log_coeffs, log_values)
_ = triton_scan(log_values)

# 타이머 설정
t1 = Timer(
    stmt="heinsen_associative_scan_log(log_coeffs, log_values)",
    globals=globals()
)
t2 = Timer(
    stmt="triton_scan(log_values)",
    globals=globals()
)

# 벤치마크 (50회)
result1 = t1.timeit(1000)
result2 = t2.timeit(1000)
result1 = result1.mean
result2 = result2.mean 
print(f"Heinsen scan: {result1:.4f} ms")
print(f"Triton scan: {result2:.4f} ms")
print(f"Speedup: {result1/result2:.2f}x")

Heinsen scan: 0.0005 ms
Triton scan: 0.0002 ms
Speedup: 3.50x


In [14]:
B, T, D = 16, 1024, 256
log_coeffs = torch.randn(B, T, D, device=device).abs()
log_values = torch.randn(B, T+1, D, device=device).abs()

# Warmup
_ = heinsen_associative_scan_log(log_coeffs, log_values)
_ = triton_scan(log_values)

# 타이머 설정
t1 = Timer(
    stmt="heinsen_associative_scan_log(log_coeffs, log_values)",
    globals=globals()
)
t2 = Timer(
    stmt="triton_scan(log_values)",
    globals=globals()
)

# 벤치마크 (50회)
result1 = t1.timeit(1000)
result2 = t2.timeit(1000)
result1 = result1.mean
result2 = result2.mean 
print(f"Heinsen scan: {result1:.4f} ms")
print(f"Triton scan: {result2:.4f} ms")
print(f"Speedup: {result1/result2:.2f}x")

Heinsen scan: 0.0013 ms
Triton scan: 0.0002 ms
Speedup: 6.20x
