In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from quant import quant, dequant

model_name = 'facebook/opt-6.7b'
model = AutoModelForCausalLM.from_pretrained(model_name)

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.30it/s]


In [22]:
import torch
import torch.nn as nn
a = torch.randn(4096, 4096)
head_dim = 128
num_heads = 32
scaling = head_dim**-0.5

In [12]:
collected_act = torch.load('/raid/jwjeong/results/attn_activation.pt')

In [24]:
def _shape(tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, num_heads, head_dim).transpose(1, 2).contiguous()

soft_max_results = {}

for layer, key in enumerate(collected_act):
    cur_act = collected_act[key]
    soft_max_results[key] = []
    for i in range(len(cur_act)):
        q_proj_test = model.model.decoder.layers[layer].self_attn.q_proj
        k_proj_test = model.model.decoder.layers[layer].self_attn.k_proj
        bsz, tgt_len, _ = cur_act[i].size()
        with torch.no_grad():
            query_states = q_proj_test(cur_act[i]) * scaling
            key_states = _shape(k_proj_test(cur_act[i]), -1, bsz)
            proj_shape = (bsz * num_heads, -1, head_dim)
            query_states = _shape(query_states, tgt_len, bsz).view(*proj_shape)
            key_states = key_states.view(*proj_shape)
            src_len = key_states.size(1)


            attention_mask = torch.zeros(bsz, 1, tgt_len, src_len).to(model.device)
            min_val = torch.finfo(model.dtype).min
            for i in range(tgt_len):
                for j in range(src_len):
                    if i < j:
                        attention_mask[0][0][i][j] = min_val

            attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
            
            attn_weights = attn_weights.view(bsz, num_heads, tgt_len, src_len) + attention_mask
            attn_weights = torch.max(
                attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
            )
            attn_weights = attn_weights.view(bsz * num_heads, tgt_len, src_len)
            if attn_weights.dtype == torch.float16:
                attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
            else:
                attn_weights = nn.functional.softmax(attn_weights, dim=-1)
            
            soft_max_results[key].append(attn_weights)

In [26]:
torch.save(soft_max_results, 'softmax_results.pt')