In [1]:
import torch
import torch.nn as nn
import math
# basic configs
batch = 1
hidden = 768
sequence_length = 2048
new_token_length = 1
num_head = 12
head_dim = hidden // num_head  # 64
num_layer = 12
intermediate_size = 3072

device = "cpu"
dtype = torch.float32

# Q/K/V/out proj
W_q = torch.randn(hidden, hidden, dtype=dtype, device=device)
b_q = torch.randn(hidden, 1,      dtype=dtype, device=device)
W_k = torch.randn(hidden, hidden, dtype=dtype, device=device)
b_k = torch.randn(hidden, 1,      dtype=dtype, device=device)
W_v = torch.randn(hidden, hidden, dtype=dtype, device=device)
b_v = torch.randn(hidden, 1,      dtype=dtype, device=device)
W_o = torch.randn(hidden, hidden, dtype=dtype, device=device)
b_o = torch.randn(hidden, 1,      dtype=dtype, device=device)

# MLP (fc1, fc2) + biases
W_fc1 = torch.randn(intermediate_size, hidden, dtype=dtype, device=device)
b_fc1 = torch.randn(intermediate_size, 1,      dtype=dtype, device=device)
W_fc2 = torch.randn(hidden, intermediate_size, dtype=dtype, device=device)
b_fc2 = torch.randn(hidden, 1,                 dtype=dtype, device=device)

# position embedding
max_pos = sequence_length
pos_emb = torch.randn(max_pos, hidden, dtype=dtype, device=device)  # [T, hidden]

# layer norm
ln1 = nn.LayerNorm(hidden, eps=1e-5, elementwise_affine=True).to(device)
ln2 = nn.LayerNorm(hidden, eps=1e-5, elementwise_affine=True).to(device)

# scale
scale = 1.0 / math.sqrt(head_dim)

In [2]:
common = torch.randn(1, sequence_length, hidden)

In [3]:
import torch
import torch.nn.functional as F
import math


# KV cache
K_cache = torch.zeros(sequence_length, hidden, dtype=dtype, device=device)  # time-major
V_cache = torch.zeros(sequence_length, hidden, dtype=dtype, device=device)

final_output = []

def right_mm_colvec(W: torch.Tensor, x_col: torch.Tensor) -> torch.Tensor:
    # W: [C, C], x_col: [C, 1] -> (x_col^T @ W^T)^T = [C,1]
    return (x_col.T @ W.T).T


for t in range(sequence_length):
    # input = input token + position embedding
    x_input = common[:, t, :].T + pos_emb[t].view(-1, 1)  # [hidden,1]

    #  Self-Attention (pre-norm)
    x1 = ln1(x_input.view(1, 1, hidden)).view(hidden, 1) # CHANGED: Attention 전 RMSNorm

    # Q/K/V projection + bias
    q = right_mm_colvec(W_q, x1) + b_q      # [hidden,1]
    k = right_mm_colvec(W_k, x1) + b_k
    v = right_mm_colvec(W_v, x1) + b_v

    # store K & V values in KV cache
    K_cache[t, :] = k.view(-1)
    V_cache[t, :] = v.view(-1)

    # reshape to per-head
    # q: [H, 1, D], k_all: [H, D, t+1], v_all: [H, t+1, D]
    q_h = q.view(num_head, head_dim, 1).transpose(1, 2)                    # [H,1,D]
    k_all = K_cache[:t+1, :].view(t+1, num_head, head_dim).transpose(0,1)  # [H,t+1,D]
    v_all = V_cache[:t+1, :].view(t+1, num_head, head_dim).transpose(0,1)  # [H,t+1,D]

    attn_scores = torch.matmul(q_h, k_all.transpose(-2, -1)) * scale  # [H,1,t+1]
    attn_probs  = F.softmax(attn_scores, dim=-1)                      # causal: limit the length to t+1

    context_h = torch.matmul(attn_probs, v_all)        # [H, 1, t+1] @ [H, t+1, D] = [H,1,D]
    context = context_h.transpose(1, 2).contiguous().view(hidden, 1)  # [hidden,1]

    # out projection + bias
    attn_out = right_mm_colvec(W_o, context) + b_o  # [hidden,1]


    # residual connection
    x_attn = x_input + attn_out

    #  MLP (pre-norm)
    x2 = ln2(x_attn.view(1, 1, hidden)).view(hidden, 1)

    hidden_mid = right_mm_colvec(W_fc1, x2) + b_fc1      # [3072,1]
    hidden_mid = F.gelu(hidden_mid)
    mlp_out    = right_mm_colvec(W_fc2, hidden_mid) + b_fc2   # [768,1]
    # residual connection
    x_out = x_attn + mlp_out
    final_output.append(x_out)

print(len(final_output))

2048


In [4]:
# stack: [2048, 768, 1]
stacked = torch.stack(final_output, dim=0)
stacked = stacked.squeeze(-1)
verilog_out = stacked.unsqueeze(0)

print(verilog_out.shape)  # torch.Size([1, 2048, 768])


torch.Size([1, 2048, 768])


In [5]:
class OPTAttention(nn.Module):
    def __init__(self, ):
        super().__init__()
        num_attention_heads = 12
        head_dim = 64
        hidden_size = 768
        dropout_prob=0.1
        self.num_heads = num_attention_heads
        self.head_dim = head_dim
        self.embed_dim = hidden_size
        self.scale = 1 / math.sqrt(self.head_dim)

        self.qkv_proj = nn.Linear(self.embed_dim, 3 * self.embed_dim)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor = None,
        past_key: torch.Tensor = None,
        past_value: torch.Tensor = None,
    ):
        B, T, C = x.size()
        qkv = self.qkv_proj(x)
        qkv = qkv.view(B, T, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)

        q = q.transpose(1, 2)  # [B, heads, T, head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Concatenate cached K/V
        if past_key is not None:
            k = torch.cat([past_key, k], dim=-2)
            v = torch.cat([past_value, v], dim=-2)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask
        attn_probs = F.softmax(attn_scores, dim=-1)
        #attn_probs = self.dropout(attn_probs)

        context = torch.matmul(attn_probs, v)
        context = context.transpose(1, 2).contiguous().view(B, T, C)
        out = self.out_proj(context)
        #out = self.dropout(out)

        return out, k, v  # return updated cache


class OPTMLP(nn.Module):
    def __init__(self,):
        super().__init__()
        hidden_size = 768
        intermediate_size = 3072
        dropout_prob=0.1
        self.fc1 = nn.Linear(hidden_size, intermediate_size)
        self.fc2 = nn.Linear(intermediate_size, hidden_size)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = self.act(x)
        #x = self.dropout(x)
        x = self.fc2(x)
        #x = self.dropout(x)
        return x

self_attn = OPTAttention()
self_mlp = OPTMLP()

In [6]:
# copy the weights & biases
self_attn.qkv_proj.weight.data = torch.cat([W_q, W_k, W_v], dim=0)
self_attn.qkv_proj.bias.data = torch.cat([b_q, b_k, b_v], dim=0).squeeze()

self_attn.out_proj.weight.data = W_o
self_attn.out_proj.bias.data = b_o.squeeze()

self_mlp.fc1.weight.data=W_fc1
self_mlp.fc1.bias.data=b_fc1.squeeze()
self_mlp.fc2.weight.data=W_fc2
self_mlp.fc2.bias.data=b_fc2.squeeze()

In [7]:
T=2048
x=common + pos_emb.unsqueeze(0)
causal_bool = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
attention_mask = torch.zeros(1, 1, T, T, device=x.device, dtype=x.dtype)
attention_mask = attention_mask.masked_fill(causal_bool, float("-inf"))

residual = x
x = ln1(x)
x, new_key, new_value = self_attn(x, attention_mask, None, None)
x = x + residual

residual = x
x =ln2(x)
x = self_mlp(x)
x = x + residual
pytorch_out = x

In [8]:
diff = (verilog_out - pytorch_out).abs()

print("max abs diff:", diff.max().item())
print("mean abs diff:", diff.mean().item())

max abs diff: 0.978271484375
mean abs diff: 0.005575356539338827


In [9]:
import torch

# absolute / relative difference
abs_diff = (verilog_out - pytorch_out).abs()
rel_diff = abs_diff / (pytorch_out.abs() + 1e-12)

# find the coordinate where maximum relative difference occurs
max_rel_val, max_rel_idx = torch.max(rel_diff.view(-1), dim=0)
max_rel_coords = torch.unravel_index(max_rel_idx, rel_diff.shape)

# get the values
v_val = verilog_out[max_rel_coords].item()
p_val = pytorch_out[max_rel_coords].item()
abs_val = abs_diff[max_rel_coords].item()

print("Max relative diff:", max_rel_val.item())
print("Index (B, T, C):", tuple(c.item() for c in max_rel_coords))
print("verilog_out value:", v_val)
print("pytorch_out value:", p_val)
print("abs diff:", abs_val)


Max relative diff: 18.45945930480957
Index (B, T, C): (0, 1928, 470)
verilog_out value: 0.0394287109375
pytorch_out value: -0.00225830078125
abs diff: 0.04168701171875
