In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
class LinearAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head
        self.block_size = config.block_size
        
        self.relu = nn.ReLU()
    
    def forward(self, x, kv_cache=None, use_cache=False, output_attentions=False):
        B, T, C = x.size()
        qkv = self.c_attn(x)
        q, k, v = qkv.split(C, dim=2)

        q = q.view(B, T, self.n_head, self.head_dim).permute(0, 2, 1, 3)  # (B, n_head, T, head_dim)
        k = k.view(B, T, self.n_head, self.head_dim).permute(0, 2, 3, 1)  # (B, n_head, head_dim, T)
        v = v.view(B, T, self.n_head, self.head_dim).permute(0, 2, 1, 3)  # (B, n_head, T, head_dim)
        
        q = self.relu(q)
        k = self.relu(k)
        
        # numerator
        kv = torch.matmul(k, v)  # (B, n_head, head_dim, head_dim)
        qkv_weighted_sum = torch.matmul(q, kv)  # (B, n_head, T, head_dim)
        
        # denominator
        k_sum = torch.sum(k, dim=-1, keepdim=True)  # (B, n_head, head_dim, 1)
        qk_sum = torch.matmul(q, k_sum)  # (B, n_head, T, 1)

        y = qkv_weighted_sum / (qk_sum + 1e-6)  # (B, n_head, T, head_dim)
        y = y.permute(0, 2, 1, 3).contiguous().view(B, T, C)  # (B, T, C)

        att_weights, updated_kv_cache = None, None
        return y, att_weights, updated_kv_cache
class Config:
    block_size = 1024
    n_embd = 768
    n_head = 12
config = Config()
sequence_length = 1024
hidden_dim = config.n_embd

x = torch.randn(1, sequence_length, hidden_dim) # Batch size of 1
attention_layer = LinearAttention(config)
output = attention_layer(x)

output

tensor([[[-0.0281,  0.0536,  0.0027,  ..., -0.0210, -0.0369, -0.0400],
         [-0.0391,  0.0666,  0.0153,  ..., -0.0124, -0.0393, -0.0374],
         [-0.0464,  0.0673, -0.0016,  ..., -0.0143, -0.0330, -0.0392],
         ...,
         [-0.0420,  0.0692,  0.0044,  ..., -0.0266, -0.0289, -0.0365],
         [-0.0330,  0.0754,  0.0259,  ..., -0.0192, -0.0349, -0.0476],
         [-0.0312,  0.0792,  0.0156,  ..., -0.0177, -0.0410, -0.0381]]],
       grad_fn=<ViewBackward0>)

In [22]:
class MixFFN(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        # self.hidden_dim = 2 * config.n_embd
        # self.inverted_conv = nn.Conv1d(config.n_embd, self.hidden_dim * 2, kernel_size=1)
        self.depthwise_conv = nn.Conv1d(
            config.n_embd,
            config.n_embd,
            kernel_size=3,
            padding=1,
            groups=config.n_embd,
        )
        self.act = nn.Sigmoid()
        self.pointwise_conv = nn.Conv1d(config.n_embd //2, config.n_embd, kernel_size=1)

    def forward(self, x):
        x = x.transpose(1, 2)
        # x = self.inverted_conv(x)
        x = self.depthwise_conv(x)

        # Gating 
        x, gate = torch.chunk(x, 2, dim=1)  # Split into (x, gate)
        gate = self.act(gate)
        x = x * gate
        
        x = self.pointwise_conv(x) 
        x = x.transpose(1, 2)
        return x

class Config:
    block_size = 1024
    n_embd = 768
    n_head = 12
config = Config()
sequence_length = 1024
hidden_dim = config.n_embd

x = torch.randn(1, sequence_length, hidden_dim) # Batch size of 1
attention_layer = MixFFN(config)
output = attention_layer(x)

output.shape

torch.Size([1, 1024, 768])

In [6]:
import torch
from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer, AutoModelForCausalLM
from bertviz import head_view

# Step 1: Configure model (Choose between "T5-small" or "gemma-2b-instruct")
model_name = "T5-small"  # Change this to "gemma-2b" if needed

def get_tokenizer_and_text_encoder(name="T5", device="cuda"):
    text_encoder_dict = {
        "T5": "DeepFloyd/t5-v1_1-xxl",
        "T5-small": "google/t5-v1_1-small",
        "T5-base": "google/t5-v1_1-base",
        "T5-large": "google/t5-v1_1-large",
        "T5-xl": "google/t5-v1_1-xl",
        "T5-xxl": "google/t5-v1_1-xxl",
        "gemma-2b": "google/gemma-2b",
        "gemma-2b-it": "google/gemma-2b-it",
        "gemma-2-2b": "google/gemma-2-2b",
        "gemma-2-2b-it": "google/gemma-2-2b-it",
        "gemma-2-9b": "google/gemma-2-9b",
        "gemma-2-9b-it": "google/gemma-2-9b-it",
        "Qwen2-0.5B-Instruct": "Qwen/Qwen2-0.5B-Instruct",
        "Qwen2-1.5B-Instruct": "Qwen/Qwen2-1.5B-Instruct",
    }
    assert name in list(text_encoder_dict.keys()), f"not support this text encoder: {name}"
    
    if "T5" in name:
        tokenizer = T5Tokenizer.from_pretrained(text_encoder_dict[name])
        text_encoder = T5EncoderModel.from_pretrained(text_encoder_dict[name], torch_dtype=torch.float16, output_attentions=True).to(device)
    elif "gemma" in name or "Qwen" in name:
        tokenizer = AutoTokenizer.from_pretrained(text_encoder_dict[name])
        tokenizer.padding_side = "right"
        text_encoder = (
            AutoModelForCausalLM.from_pretrained(text_encoder_dict[name], torch_dtype=torch.bfloat16, output_attentions=True)
            .get_decoder()
            .to(device)
        )
    else:
        print("error load text encoder")
        exit()

    return tokenizer, text_encoder

# Step 2: Get tokenizer and encoder
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=model_name, device=device)

print("start 01" , model_name)

# Step 3: Print model layers
print("\nModel layers:\n")
for name, param in text_encoder.named_parameters():
    print(name)

print("start 02")
# Step 4: Input prompt and encode text
input_text = "The quick brown fox jumps over the lazy dog."
inputs = tokenizer(input_text, return_tensors="pt").to(device)

print("start 0")

outputs = text_encoder(**inputs)

print("start 1")

# Step 5: Visualize the layers with BertViz
# You need a list of layer outputs, so we modify the encoder slightly to return all layers
class ModifiedTextEncoder(torch.nn.Module):
    def __init__(self, original_encoder):
        super(ModifiedTextEncoder, self).__init__()
        self.original_encoder = original_encoder

    def forward(self, input_ids, attention_mask):
        # Use the encoder and return hidden states from all layers
        output = self.original_encoder(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True, return_dict=True)
        return output.attentions

print("start 2")
modified_encoder = ModifiedTextEncoder(text_encoder)

print("start 3")
all_attentions = modified_encoder(**inputs)

print("start 4")
# Convert all_attentions to a list if it's not already
if isinstance(all_attentions, tuple):
    all_attentions = list(all_attentions)

print("start 5")
# Extract tokens for visualization
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
print("start 6")
# Visualize the first attention layer with BertViz
#head_view(all_attentions, tokens, tokenizer)

print("start 7")
# Note: To run this code, you must have BertViz installed and a Jupyter notebook or a supported browser environment
# You can install BertViz using the following command:
# !pip install bertviz


start 01 T5-small

Model layers:

shared.weight
encoder.block.0.layer.0.SelfAttention.q.weight
encoder.block.0.layer.0.SelfAttention.k.weight
encoder.block.0.layer.0.SelfAttention.v.weight
encoder.block.0.layer.0.SelfAttention.o.weight
encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight
encoder.block.0.layer.0.layer_norm.weight
encoder.block.0.layer.1.DenseReluDense.wi_0.weight
encoder.block.0.layer.1.DenseReluDense.wi_1.weight
encoder.block.0.layer.1.DenseReluDense.wo.weight
encoder.block.0.layer.1.layer_norm.weight
encoder.block.1.layer.0.SelfAttention.q.weight
encoder.block.1.layer.0.SelfAttention.k.weight
encoder.block.1.layer.0.SelfAttention.v.weight
encoder.block.1.layer.0.SelfAttention.o.weight
encoder.block.1.layer.0.layer_norm.weight
encoder.block.1.layer.1.DenseReluDense.wi_0.weight
encoder.block.1.layer.1.DenseReluDense.wi_1.weight
encoder.block.1.layer.1.DenseReluDense.wo.weight
encoder.block.1.layer.1.layer_norm.weight
encoder.block.2.layer.0.SelfAttention