# scaled_dot_product_attention-Taichi

## Taichi版 scaled_dot_product_attention

In [None]:
import taichi as ti
import math
import time
import numpy as np

np.set_printoptions(suppress=True)
np.set_printoptions(formatter={'float': '{:.2f}'.format})

# IS_DEBUG = True
IS_DEBUG = False
BACKEND = ti.gpu
ti.init(BACKEND, debug=IS_DEBUG)

# パラメータ設定
batch_size, head_size, sequence_size, embedding_size = 1, 2, 4, 8

# Taichi fields
Q = ti.field(dtype=ti.f32, shape=(batch_size, head_size, sequence_size, embedding_size))
K = ti.field(dtype=ti.f32, shape=(batch_size, head_size, sequence_size, embedding_size))
V = ti.field(dtype=ti.f32, shape=(batch_size, head_size, sequence_size, embedding_size))
out = ti.field(dtype=ti.f32, shape=(batch_size, head_size, sequence_size, embedding_size))

@ti.kernel
def init_attention(q:ti.template(), k:ti.template(), v:ti.template(), out:ti.template()):
    # ダミーデータの生成
    for I in ti.grouped(q):
        Q[I] = ti.random()
        K[I] = ti.random()
        V[I] = ti.random()
    out.fill(0)  # 出力の初期化    
init_attention(Q, K, V, out)

@ti.func
def max2d(matrix:ti.template()) -> ti.f32:
    "テンソル全体の最大値を計算"
    max_val = 1e-10
    for i,j in ti.ndrange(matrix.n, matrix.m):
        max_val = ti.max(max_val, matrix[i,j])
    return max_val

@ti.func
def softmax2d(mat:ti.template()):
    "softmaxを計算して、引数の行列を書き換える。"
    n,m = mat.n, mat.m
    mat_max = max2d(mat)
    for s,_s in ti.ndrange(n,m):
        mat[s, _s] = ti.exp(mat[s, _s] - mat_max)  # 指数関数を計算する前に最大値を引く
    for s in range(n):
        sum_exp = 0.0
        for _s in range(m):
            sum_exp += mat[s, _s]  # _s についての和を計算
        for _s in range(m):
            mat[s, _s] = mat[s, _s] / sum_exp  # _s についての和で割る（正規化）
                

@ti.kernel
def scaled_dotproduct_attention(q: ti.template(), k: ti.template(), v: ti.template(), out: ti.template()):
    """
    Parameters
    ---------------------
    q,k,v,out : tensor with shape [... s e]
          s : sequence
          e : embedding
    """
    assert q.shape == k.shape == v.shape == out.shape
    # 形状からsequence_sizeとembedding_sizeを取得
    upper_dims = ti.static(q.shape[:-2])  # 最初の2次元
    sequence_size, embedding_size = ti.static(q.shape[-2:])  # 最後の2次元

    for I in ti.static(ti.ndrange(*upper_dims)):
        mat = ti.Matrix([[-1e9] * sequence_size for _ in ti.ndrange(sequence_size)], ti.f32)
        
        # attention scoreを計算
        for s, _s in ti.ndrange(sequence_size, sequence_size):
            if s < _s: continue  # Causal mask : s <= _s だけを計算
            mat[s, _s] = 0.0
            for e in range(embedding_size): mat[s, _s] += q[I, s, e] * k[I, _s, e]
            mat[s, _s] *= (1.0 / ti.sqrt(embedding_size))

        if IS_DEBUG: print("att: ", I, mat)

        # Softmax
        softmax2d(mat)

        if IS_DEBUG: print("softmax(att): ", I, mat)

        # 出力の計算
        for s, _s, e in ti.ndrange(sequence_size, sequence_size, embedding_size):
            out[I, s, e] += mat[s, _s] * v[I, _s, e]

scaled_dotproduct_attention(Q, K, V, out)
ti.sync()
time.sleep(0.1)
print("out: ", out)

[Taichi] version 1.7.0, llvm 15.0.4, commit 2fd24490, linux, python 3.11.7


[I 03/20/24 22:09:31.365 193278] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout
[W 03/20/24 22:09:31.456 193278] [cuda_driver.cpp:load_lib@36] libcuda.so lib not found.
RHI Error: GLFW Error 65543: GLX: Failed to create context: GLXBadFBConfig
[W 03/20/24 22:09:31.823 193278] [opengl_api.cpp:initialize_opengl@205] Can not create OpenGL context
[W 03/20/24 22:09:31.825 193278] [misc.py:adaptive_arch_select@758] Arch=[<Arch.cuda: 3>, <Arch.metal: 4>, <Arch.vulkan: 10>, <Arch.opengl: 5>, <Arch.dx11: 6>, <Arch.dx12: 7>, <Arch.gles: 11>, <Arch.amdgpu: 9>] is not supported, falling back to CPU


[Taichi] Starting on arch=x64
out:  [[[[0.87 0.83 0.24 0.32 0.93 0.92 0.91 0.76]
   [0.65 0.90 0.14 0.28 0.63 0.56 0.53 0.57]
   [0.62 0.61 0.26 0.44 0.58 0.55 0.49 0.64]
   [0.63 0.64 0.19 0.37 0.46 0.54 0.50 0.64]]

  [[0.07 0.89 0.75 0.24 0.70 0.61 0.91 0.82]
   [0.20 0.74 0.67 0.33 0.74 0.67 0.73 0.79]
   [0.37 0.68 0.74 0.53 0.61 0.64 0.79 0.77]
   [0.50 0.61 0.58 0.54 0.55 0.50 0.76 0.61]]]]


### torch版　scaled_dot_product_attention

In [None]:
import torch
out_torch = torch.nn.functional.scaled_dot_product_attention(Q.to_torch(), K.to_torch(), V.to_torch(), attn_mask=None, dropout_p=0
                                                            # , is_causal=False
                                                            , is_causal = True
                                                           )
print("out: ", out_torch.numpy())
print("diff: ", out.to_numpy()-out_torch.numpy())

out:  [[[[0.87 0.83 0.24 0.32 0.93 0.92 0.91 0.76]
   [0.65 0.90 0.14 0.28 0.63 0.56 0.53 0.57]
   [0.62 0.61 0.26 0.44 0.58 0.55 0.49 0.64]
   [0.63 0.64 0.19 0.37 0.46 0.54 0.50 0.64]]

  [[0.07 0.89 0.75 0.24 0.70 0.61 0.91 0.82]
   [0.20 0.74 0.67 0.33 0.74 0.67 0.73 0.79]
   [0.37 0.68 0.74 0.53 0.61 0.64 0.79 0.77]
   [0.50 0.61 0.58 0.54 0.55 0.50 0.76 0.61]]]]
diff:  [[[[0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
   [0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
   [0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
   [0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]]

  [[0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
   [0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
   [0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]
   [0.00 0.00 0.00 0.00 -0.00 -0.00 -0.00 -0.00]]]]


### Python版scaled_dot_product_attention

参考　https://github.com/karpathy/nanoGPT/blob/master/model.py#L66

※Causal maskなしで実行。そのためoutの値は異なる。

In [None]:
import math
from torch.nn import functional as F

def scaled_dotproduct_attention_torch(q,k,v,seq_size):
    """torch版Attentio
    参考　https://github.com/karpathy/nanoGPT/blob/master/model.py#L66"""
    T = seq_size
    # manual implementation of attention
    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    if(IS_DEBUG): print("att: ", att)
    # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))  # マスクは除外
    att = F.softmax(att, dim=-1)
    if(IS_DEBUG): print("softmax(att): ", att)
    # att = self.attn_dropout(att)  # ドロップアウトは除外
    y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    return y

out_torch = scaled_dotproduct_attention_torch(Q.to_torch(), K.to_torch(), V.to_torch(), sequence_size)
print("out: ", out_torch.numpy())
print("diff: ", out.to_numpy()-out_torch.numpy())

out:  [[[[0.64 0.64 0.19 0.36 0.45 0.55 0.51 0.65]
   [0.65 0.64 0.20 0.37 0.48 0.56 0.52 0.66]
   [0.65 0.63 0.20 0.37 0.48 0.57 0.53 0.66]
   [0.63 0.64 0.19 0.37 0.46 0.54 0.50 0.64]]

  [[0.51 0.60 0.60 0.56 0.55 0.52 0.76 0.63]
   [0.51 0.61 0.62 0.57 0.54 0.52 0.77 0.64]
   [0.52 0.60 0.56 0.54 0.53 0.48 0.76 0.59]
   [0.50 0.61 0.58 0.54 0.55 0.50 0.76 0.61]]]]
diff:  [[[[0.23 0.18 0.05 -0.04 0.48 0.37 0.39 0.11]
   [0.01 0.26 -0.05 -0.09 0.15 -0.00 0.01 -0.08]
   [-0.04 -0.02 0.06 0.07 0.10 -0.02 -0.05 -0.02]
   [0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]]

  [[-0.44 0.28 0.16 -0.32 0.15 0.09 0.15 0.19]
   [-0.31 0.13 0.05 -0.24 0.20 0.14 -0.04 0.14]
   [-0.15 0.08 0.18 -0.02 0.08 0.16 0.03 0.18]
   [0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00]]]]
