In [1]:
import torch
import torch.nn as nn
import time

from collections import defaultdict

In [2]:
class FocusedLinearAttention(nn.Module):
    def __init__(self, dim, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., focusing_factor=3, kernel_size=5):
        super(FocusedLinearAttention, self).__init__()

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.focusing_factor = focusing_factor
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size, groups=head_dim, padding=kernel_size // 2)
        self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None, measure_time=False):
        start_time = time.time()

        B, N, C = x.shape
        print(f"Input shape: {x.shape}")
        H = W = int(N ** 0.5)
        assert H * W == N, f"Input does not correspond to a square grid. Got N={N}, H={H}, W={W}"
        time_records = {}


        # 1. QKV 연산
        t1 = time.time()
        qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)
        q, k, v = qkv.unbind(0)
        t2 = time.time()
        time_records["QKV computation"] = (t2 - t1) * 1000

        # Dynamically generate positional encoding
        t3 = time.time()
        positional_encoding = self.generate_positional_encoding(N, C, H, W, x.device)
        t4 = time.time()
        time_records["Positional encoding generation"] = (t4 - t3) * 1000

        # Add positional encoding to k
        t5 = time.time()
        k = k + positional_encoding
        t6 = time.time()
        time_records["Positional encoding addition"] = (t6 - t5) * 1000

        # Apply kernel function and focusing mechanism
        t7 = time.time()
        kernel_function = nn.ReLU()
        q = kernel_function(q) + 1e-6
        k = kernel_function(k) + 1e-6
        t8 = time.time()
        time_records["Kernel function application"] = (t8 - t7) * 1000

        # Apply scaling using Softplus
        t9 = time.time()
        scale = nn.Softplus()(self.scale)
        q = (q / scale) ** self.focusing_factor
        k = (k / scale) ** self.focusing_factor
        t10 = time.time()
        time_records["QK Scaling and focusing"] = (t10 - t9) * 1000

        # Normalize q and k
        t11 = time.time()
        q = q / q.norm(dim=-1, keepdim=True)
        k = k / k.norm(dim=-1, keepdim=True)
        t12 = time.time()
        time_records["QK Normalization"] = (t12 - t11) * 1000

        # Multi-head attention reshaping
        t13 = time.time()
        q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
        t14 = time.time()
        time_records["QK Reshaping"] = (t14 - t13) * 1000

        # Compute attention
        t15 = time.time()
        attn = (q @ k.transpose(-2, -1)) * (N ** -0.5)
        attn = self.softmax(attn)
        t16 = time.time()
        time_records["QK Attention computation"] = (t16 - t15) * 1000

        # Apply attention to v
        t17 = time.time()
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        t18 = time.time()
        time_records["Apply Attention to V"] = (t18 - t17) * 1000

        # Post-attention processing with depthwise convolution
        t19 = time.time()
        v = v.reshape(B * self.num_heads, H, W, -1).permute(0, 3, 1, 2)
        x = x + self.dwc(v).reshape(B, C, N).permute(0, 2, 1)
        t20 = time.time()
        time_records["Reshape V and apply DWC"] = (t20 - t19) * 1000

        # Final projection and dropout
        t21 = time.time()
        x = self.proj(x)
        final_output = self.proj_drop(x)
        t22 = time.time()
        time_records["Final projection and dropout"] = (t22 - t21) * 1000

        total_time = time.time() - start_time
        total_time_ms = total_time * 1000

        return final_output, time_records if measure_time else final_output

    def generate_positional_encoding(self, N, C, H, W, device):
        # Directly create a normalized positional encoding grid
        grid = torch.linspace(0, 1, H, device=device).view(H, 1).expand(H, W)
        positional_encoding = torch.stack((grid, grid.T), dim=-1).reshape(1, N, 2)  # Shape: (1, N, 2)

        # Expand positional encoding to match the embedding dimension
        positional_encoding = positional_encoding.repeat(1, 1, C // 2)  # Shape: (1, N, C)

        return positional_encoding


In [3]:
B, N, C = 32, 144, 256  # Batch, Sequence Length, Embedding Dim
num_heads = 8
num_runs = 5  # 실행 횟수

def measure_execution_time(model, x, device):
    stage_times = defaultdict(list)

    for i in range(num_runs):
        if device == "cuda":
            torch.cuda.synchronize()

        start_time = time.time()
        _, time_records = model(x, measure_time=True)  # 실행 및 시간 측정
        if device == "cuda":
            torch.cuda.synchronize()

        total_time = (time.time() - start_time) * 1000  # ms 변환
        stage_times["Total forward pass"].append(total_time)

        if i > 0:  # 첫 실행 제외
            for stage, t in time_records.items():
                stage_times[stage].append(t)

    # 평균 계산
    avg_stage_times = {stage: sum(times) / len(times) for stage, times in stage_times.items()}

    print(f"\nAverage Execution Time on {device.upper()} (excluding first run)")
    print("=" * 65)
    print(f"{'Stage':<35}{'Time (ms)':<15}")
    print("=" * 65)
    total_avg_time = avg_stage_times["Total forward pass"]
    for stage, avg_time in avg_stage_times.items():
        percentage = (avg_time / total_avg_time) * 100
        print(f"{stage:<35}{avg_time:<15.3f}")

    print("||" * 65)
    print(f"{'Total':<35}{total_avg_time:<15.3f}")
    qk_attn = avg_stage_times["QK Attention computation"]
    print(f"{'QK attn':<35}{qk_attn:<15.3f}")
    attn_percentage = (avg_stage_times["QK Attention computation"] / total_avg_time) * 100
    print(f"{'QK attn percent':<35}{attn_percentage:<15.3f}")
    v_attn = avg_stage_times["Apply Attention to V"]
    print(f"{'V attn':<35}{v_attn:<15.3f}")
    v_percentage = (avg_stage_times["Apply Attention to V"] / total_avg_time) * 100
    print(f"{'V attn percent':<35}{v_percentage:<15.3f}")
    print("||" * 65)

    return avg_stage_times

In [5]:
if torch.cuda.is_available():
    print("\nRunning on CUDA...")
    x_cuda = torch.randn(B, N, C).cuda()
    attn_cuda = FocusedLinearAttention(dim=C, num_heads=num_heads).cuda()
    measure_execution_time(attn_cuda, x_cuda, "cuda")

# CPU 실행
print("\nRunning on CPU...")
x_cpu = torch.randn(B, N, C)
attn_cpu = FocusedLinearAttention(dim=C, num_heads=num_heads)
measure_execution_time(attn_cpu, x_cpu, "cpu")


Running on CUDA...
Input shape: torch.Size([32, 144, 256])
Input shape: torch.Size([32, 144, 256])
Input shape: torch.Size([32, 144, 256])
Input shape: torch.Size([32, 144, 256])
Input shape: torch.Size([32, 144, 256])

Average Execution Time on CUDA (excluding first run)
Stage                              Time (ms)      
Total forward pass                 1.851          
QKV computation                    0.334          
Positional encoding generation     0.141          
Positional encoding addition       0.023          
Kernel function application        0.139          
QK Scaling and focusing            0.132          
QK Normalization                   0.109          
QK Reshaping                       0.056          
QK Attention computation           0.197          
Apply Attention to V               0.113          
Reshape V and apply DWC            0.215          
Final projection and dropout       0.096          
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||

{'Total forward pass': 27.66885757446289,
 'QKV computation': 5.03617525100708,
 'Positional encoding generation': 0.28526782989501953,
 'Positional encoding addition': 0.27567148208618164,
 'Kernel function application': 1.3124346733093262,
 'QK Scaling and focusing': 1.4196038246154785,
 'QK Normalization': 1.01393461227417,
 'QK Reshaping': 0.09119510650634766,
 'QK Attention computation': 10.399460792541504,
 'Apply Attention to V': 1.7408132553100586,
 'Reshape V and apply DWC': 2.0082592964172363,
 'Final projection and dropout': 1.624763011932373}