In [1]:
import torch
import torch.nn as nn
import time

from collections import defaultdict

In [2]:
class LinearNorm(nn.Linear):
    def __init__(self, in_features, out_features):
        super().__init__(in_features, out_features)
        nn.init.xavier_uniform_(self.weight)

In [3]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads, attn_ratio=2):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        inner_dim = head_dim * num_heads * 3
        self.qkv = LinearNorm(dim, inner_dim)

        self.proj = nn.Sequential(
            nn.Hardswish(),
            LinearNorm(dim, dim)
        )

    def forward(self, x, measure_time=False):
        start_time = time.time()

        B, N, C = x.shape
        print(f"Input shape: {x.shape}")

        time_records = {}

        # QKV 연산
        t1 = time.time()
        qkv = self.qkv(x)
        t2 = time.time()
        time_records["QKV computation"] = (t2 - t1) * 1000

        # Reshaping 및 분할
        t3 = time.time()
        qkv = qkv.view(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        t4 = time.time()
        time_records["Reshape & split"] = (t4 - t3) * 1000

        # 어텐션 스코어 계산
        t5 = time.time()
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        t6 = time.time()
        time_records["Attention computation"] = (t6 - t5) * 1000

        # 어텐션 적용 후 값 계산
        t9 = time.time()
        x = (attn @ v)
        t10 = time.time()
        time_records["Attention output computation"] = (t10 - t9) * 1000

        # 차원 변환 및 최종 투영
        t11 = time.time()
        x = x.transpose(1, 2).reshape(B, N, C)
        final_output = self.proj(x)
        t12 = time.time()
        time_records["Final projection"] = (t12 - t11) * 1000

        total_time = time.time() - start_time
        total_time_ms = total_time * 1000
        # print(f"Total forward pass time: {total_time_ms:.3f} ms\n")

        # 테이블 출력
        # print(f"{'Stage':<35}{'Time (ms)':<15}{'Percentage (%)'}")
        # print("=" * 65)
        # for stage, time_ms in time_records.items():
        #     percentage = (time_ms / total_time_ms) * 100
        #     print(f"{stage:<35}{time_ms:<15.3f}{percentage:.2f}%")

        return final_output, time_records if measure_time else final_output

In [4]:
class SampledAttention(nn.Module):
    def __init__(self, dim, num_heads, attn_ratio=2, sample_ratio=0.5):
        super(SampledAttention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        inner_dim = head_dim * num_heads * 3
        self.qkv = LinearNorm(dim, inner_dim)

        self.proj = nn.Sequential(
            nn.Hardswish(),
            LinearNorm(dim, dim)
        )

        self.sample_ratio = sample_ratio

    def forward(self, x, measure_time=False):
        start_time = time.time()

        B, N, C = x.shape
        print(f"Input shape: {x.shape}")

        time_records = {}

        # QKV 연산
        t1 = time.time()
        qkv = self.qkv(x)
        t2 = time.time()
        time_records["QKV computation"] = (t2 - t1) * 1000

        # Reshaping 및 분할
        t3 = time.time()
        qkv = qkv.view(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        t4 = time.time()
        time_records["Reshape & split"] = (t4 - t3) * 1000

        # 샘플링 인덱스 선택
        t5 = time.time()
        num_samples = int(N * self.sample_ratio)
        prob = torch.ones(N, device=x.device) / N
        sampled_indices = torch.multinomial(prob, num_samples, replacement=False)
        t6 = time.time()
        time_records["Sampling"] = (t6 - t5) * 1000


        # 샘플링된 Q, K, V 추출
        t7 = time.time()
        sampled_q = q[:, :, sampled_indices, :]
        sampled_k = k[:, :, sampled_indices, :]
        sampled_v = v[:, :, sampled_indices, :]
        t8 = time.time()
        time_records["Sampled QKV extraction"] = (t8 - t7) * 1000

        # 어텐션 연산
        t9 = time.time()
        attn_sampled = (sampled_q @ sampled_k.transpose(-2, -1)) * self.scale
        attn_sampled = attn_sampled.softmax(dim=-1)
        t10 = time.time()
        time_records["Attention computation"] = (t10 - t9) * 1000

        # 어텐션을 적용한 출력 생성
        t11 = time.time()
        output_sampled = (attn_sampled @ sampled_v)
        t12 = time.time()
        time_records["Output sampled computation"] = (t12 - t11) * 1000

        # 전체 출력 생성
        t13 = time.time()
        output = torch.zeros(B, self.num_heads, N, v.size(-1), device=x.device)
        output[:, :, sampled_indices, :] = output_sampled
        t14 = time.time()
        time_records["Output reconstruction"] = (t14 - t13) * 1000

        # 차원 변환 및 최종 투영
        t15 = time.time()
        output = output.transpose(1, 2).reshape(B, N, C)
        final_output = self.proj(output)
        t16 = time.time()
        time_records["Final projection"] = (t16 - t15) * 1000

        total_time = time.time() - start_time
        total_time_ms = total_time * 1000
        # print(f"Total forward pass time: {total_time_ms:.3f} ms\n")

        # # 테이블 출력
        # print(f"{'Stage':<30}{'Time (ms)':<15}{'Percentage (%)'}")
        # print("=" * 55)
        # for stage, time_ms in time_records.items():
        #     percentage = (time_ms / total_time_ms) * 100
        #     print(f"{stage:<30}{time_ms:<15.3f}{percentage:.2f}%")

        return final_output, time_records if measure_time else final_output

In [5]:
B, N, C = 32, 128, 256  # Batch, Sequence Length, Embedding Dim
num_heads = 8
num_runs = 5  # 실행 횟수

def measure_execution_time(model, x, device):
    stage_times = defaultdict(list)

    # stage_times = {  # 각 단계별 시간 저장용
    #     "QKV computation": [],
    #     "Reshape & split": [],
    #     "Attention score computation": [],
    #     "Softmax application": [],
    #     "Attention output computation": [],
    #     "Final projection": [],
    #     "Total forward pass": []
    # }

    for i in range(num_runs):
        if device == "cuda":
            torch.cuda.synchronize()

        start_time = time.time()
        _, time_records = model(x, measure_time=True)  # 실행 및 시간 측정
        if device == "cuda":
            torch.cuda.synchronize()

        total_time = (time.time() - start_time) * 1000  # ms 변환
        stage_times["Total forward pass"].append(total_time)

        if i > 0:  # 첫 실행 제외
            for stage, t in time_records.items():
                stage_times[stage].append(t)

    # 평균 계산
    avg_stage_times = {stage: sum(times) / len(times) for stage, times in stage_times.items()}

    print(f"\nAverage Execution Time on {device.upper()} (excluding first run)")
    print("=" * 65)
    print(f"{'Stage':<35}{'Time (ms)':<15}")
    print("=" * 65)
    total_avg_time = avg_stage_times["Total forward pass"]
    for stage, avg_time in avg_stage_times.items():
        percentage = (avg_time / total_avg_time) * 100
        print(f"{stage:<35}{avg_time:<15.3f}")

    print("||" * 65)
    print(f"{'Total':<35}{total_avg_time:<15.3f}")
    attn = avg_stage_times["Attention computation"]
    print(f"{'attn':<35}{attn:<15.3f}")
    attn_percentage = (avg_stage_times["Attention computation"] / total_avg_time) * 100
    print(f"{'attn percent':<35}{attn_percentage:<15.3f}")
    print("||" * 65)

    return avg_stage_times

In [7]:
if torch.cuda.is_available():
    print("\nRunning on CUDA...")
    x_cuda = torch.randn(B, N, C).cuda()
    attn_cuda = Attention(dim=C, num_heads=num_heads).cuda()
    measure_execution_time(attn_cuda, x_cuda, "cuda")

# CPU 실행
print("\nRunning on CPU...")
x_cpu = torch.randn(B, N, C)
attn_cpu = Attention(dim=C, num_heads=num_heads)
measure_execution_time(attn_cpu, x_cpu, "cpu")


Running on CUDA...
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])

Average Execution Time on CUDA (excluding first run)
Stage                              Time (ms)      
Total forward pass                 0.855          
QKV computation                    0.263          
Reshape & split                    0.037          
Attention computation              0.163          
Attention output computation       0.057          
Final projection                   0.130          
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Total                              0.855          
attn                               0.163          
attn percent                       19.077         
|||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||

{'Total forward pass': 10.85062026977539,
 'QKV computation': 2.878546714782715,
 'Reshape & split': 0.06395578384399414,
 'Attention computation': 5.649149417877197,
 'Attention output computation': 0.8988380432128906,
 'Final projection': 1.431286334991455}

In [10]:
if torch.cuda.is_available():
    print("\nRunning on CUDA...")
    x_cuda = torch.randn(B, N, C).cuda()
    attn_cuda = SampledAttention(dim=C, num_heads=num_heads).cuda()
    measure_execution_time(attn_cuda, x_cuda, "cuda")

# CPU 실행
print("\nRunning on CPU...")
x_cpu = torch.randn(B, N, C)
attn_cpu = SampledAttention(dim=C, num_heads=num_heads)
measure_execution_time(attn_cpu, x_cpu, "cpu")


Running on CUDA...
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])

Average Execution Time on CUDA (excluding first run)
Stage                              Time (ms)      
Total forward pass                 1.530          
QKV computation                    0.286          
Reshape & split                    0.051          
Sampling                           0.304          
Sampled QKV extraction             0.143          
Attention computation              0.148          
Output sampled computation         0.058          
Output reconstruction              0.090          
Final projection                   0.172          
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Total                              1.530          
attn                              

{'Total forward pass': 10.745048522949219,
 'QKV computation': 4.091382026672363,
 'Reshape & split': 0.052988529205322266,
 'Sampling': 0.20056962966918945,
 'Sampled QKV extraction': 0.9617805480957031,
 'Attention computation': 1.6039609909057617,
 'Output sampled computation': 0.2637505531311035,
 'Output reconstruction': 0.44864416122436523,
 'Final projection': 2.724885940551758}

In [11]:
class SampledAttention_v2(nn.Module):
    def __init__(self, dim, num_heads, attn_ratio=2, sample_ratio=0.5):
        super(SampledAttention_v2, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        inner_dim = head_dim * num_heads * 3
        self.qkv = LinearNorm(dim, inner_dim)

        self.proj = nn.Sequential(
            nn.Hardswish(),
            LinearNorm(dim, dim)
        )

        self.sample_ratio = sample_ratio

    def forward(self, x, measure_time=False):
        start_time = time.time()

        B, N, C = x.shape
        print(f"Input shape: {x.shape}")

        time_records = {}

        # QKV 연산
        t1 = time.time()
        qkv = self.qkv(x)
        t2 = time.time()
        time_records["QKV computation"] = (t2 - t1) * 1000

        # Reshaping 및 분할
        t3 = time.time()
        qkv = qkv.view(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        t4 = time.time()
        time_records["Reshape & split"] = (t4 - t3) * 1000

        # 마스크 생성
        t5 = time.time()
        mask = torch.arange(N, device=x.device) % 2
        mask = mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1)  # (1, 1, N, 1)
        mask = mask.expand(B, self.num_heads, N, q.shape[-1])
        t6 = time.time()
        time_records["Mask generation"] = (t6 - t5) * 1000

        # 마스크 적용
        t7 = time.time()
        masked_q = q * mask
        masked_k = k * mask
        t8 = time.time()
        time_records["Mask application"] = (t8 - t7) * 1000

        # 홀수 인덱스 선택
        t9 = time.time()
        selected_indices = torch.arange(N, device=x.device) % 2 == 1
        t10 = time.time()
        time_records["Index selection"] = (t10 - t9) * 1000

        # 샘플링된 Q, K 추출
        t11 = time.time()
        masked_q = masked_q[:, :, selected_indices, :]
        masked_k = masked_k[:, :, selected_indices, :]
        t12 = time.time()
        time_records["Sampled QK extraction"] = (t12 - t11) * 1000

        # 어텐션 연산
        t13 = time.time()
        masked_attn = (masked_q @ masked_k.transpose(-2, -1)) * self.scale
        masked_attn = masked_attn.softmax(dim=-1)
        t14 = time.time()
        time_records["Attention computation"] = (t14 - t13) * 1000

        # 어텐션을 적용한 출력 생성
        t15 = time.time()
        restored_attn = torch.zeros(B, self.num_heads, N, N, device=x.device)
        restored_attn[:, :, selected_indices, :][:, :, :, selected_indices] = masked_attn
        t16 = time.time()
        time_records["Attention restoration"] = (t16 - t15) * 1000

        # 최종 출력 생성
        t17 = time.time()
        x = (restored_attn @ v).transpose(1, 2).reshape(B, N, C)
        final_output = self.proj(x)
        t18 = time.time()
        time_records["Final projection"] = (t18 - t17) * 1000

        total_time = time.time() - start_time
        total_time_ms = total_time * 1000
        # print(f"Total forward pass time: {total_time_ms:.3f} ms\n")

        # # 테이블 출력
        # print(f"{'Stage':<30}{'Time (ms)':<15}{'Percentage (%)'}")
        # print("=" * 55)
        # for stage, time_ms in time_records.items():
        #     percentage = (time_ms / total_time_ms) * 100
        #     print(f"{stage:<30}{time_ms:<15.3f}{percentage:.2f}%")

        return final_output, time_records if measure_time else final_output

In [15]:
if torch.cuda.is_available():
    print("\nRunning on CUDA...")
    x_cuda = torch.randn(B, N, C).cuda()
    attn_cuda = SampledAttention_v2(dim=C, num_heads=num_heads).cuda()
    measure_execution_time(attn_cuda, x_cuda, "cuda")

# CPU 실행
print("\nRunning on CPU...")
x_cpu = torch.randn(B, N, C)
attn_cpu = SampledAttention_v2(dim=C, num_heads=num_heads)
measure_execution_time(attn_cpu, x_cpu, "cpu")


Running on CUDA...
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])

Average Execution Time on CUDA (excluding first run)
Stage                              Time (ms)      
Total forward pass                 1.561          
QKV computation                    0.275          
Reshape & split                    0.042          
Mask generation                    0.079          
Mask application                   0.035          
Index selection                    0.049          
Sampled QK extraction              0.190          
Attention computation              0.118          
Attention restoration              0.238          
Final projection                   0.229          
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Total                             

{'Total forward pass': 15.823698043823242,
 'QKV computation': 4.1435956954956055,
 'Reshape & split': 0.07158517837524414,
 'Mask generation': 0.10752677917480469,
 'Mask application': 1.328110694885254,
 'Index selection': 0.06943941116333008,
 'Sampled QK extraction': 0.6388425827026367,
 'Attention computation': 1.239478588104248,
 'Attention restoration': 3.7297606468200684,
 'Final projection': 3.448605537414551}

In [18]:
class SampledAttention_v3(nn.Module):
    def __init__(self, dim, num_heads, attn_ratio=2, sample_ratio=0.5):
        super(SampledAttention_v3, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        inner_dim = head_dim * num_heads * 3
        self.qkv = LinearNorm(dim, inner_dim)

        self.proj = nn.Sequential(
            nn.Hardswish(),
            LinearNorm(dim, dim)
        )

        self.sample_ratio = sample_ratio

    def forward(self, x, measure_time=False):
        start_time = time.time()

        B, N, C = x.shape
        print(f"Input shape: {x.shape}")

        time_records = {}

        # QKV 연산
        t1 = time.time()
        qkv = self.qkv(x)
        t2 = time.time()
        time_records["QKV computation"] = (t2 - t1) * 1000

        # Reshaping 및 분할
        t3 = time.time()
        qkv = qkv.view(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        t4 = time.time()
        time_records["Reshape & split"] = (t4 - t3) * 1000

        # 마스크 생성
        t5 = time.time()
        mask = torch.arange(N, device=x.device) % 2
        mask = (torch.arange(N, device=x.device) % 2).bool()
        mask = mask.view(1, 1, N, 1)
        t6 = time.time()
        time_records["Mask generation"] = (t6 - t5) * 1000

        # 마스크 적용
        t7 = time.time()
        #masked_q = q * mask
        #masked_k = k * mask
        q *= mask
        k *= mask
        t8 = time.time()
        time_records["Mask application"] = (t8 - t7) * 1000

        # 홀수 인덱스 선택
        t9 = time.time()
        odd_indices = torch.arange(N, device=x.device)[mask.view(-1)]
        t10 = time.time()
        time_records["Index selection"] = (t10 - t9) * 1000

        # 샘플링된 Q, K 추출
        t11 = time.time()
        masked_q = q[:, :, odd_indices, :]
        masked_k = k[:, :, odd_indices, :]
        t12 = time.time()
        time_records["Sampled QK extraction"] = (t12 - t11) * 1000

        # 어텐션 연산
        t13 = time.time()
        masked_attn = (masked_q @ masked_k.transpose(-2, -1)) * self.scale
        masked_attn = masked_attn.softmax(dim=-1)
        t14 = time.time()
        time_records["Attention computation"] = (t14 - t13) * 1000

        # 어텐션을 적용한 출력 생성
        t15 = time.time()
        restored_attn = torch.zeros(B, self.num_heads, N, N, device=x.device)
        restored_attn[:, :, odd_indices, :][:, :, :, odd_indices] = masked_attn
        t16 = time.time()
        time_records["Attention restoration"] = (t16 - t15) * 1000

        # 최종 출력 생성
        t17 = time.time()
        x = (restored_attn @ v).transpose(1, 2).reshape(B, N, C)
        final_output = self.proj(x)
        t18 = time.time()
        time_records["Final projection"] = (t18 - t17) * 1000

        total_time = time.time() - start_time
        total_time_ms = total_time * 1000
        # print(f"Total forward pass time: {total_time_ms:.3f} ms\n")

        # # 테이블 출력
        # print(f"{'Stage':<30}{'Time (ms)':<15}{'Percentage (%)'}")
        # print("=" * 55)
        # for stage, time_ms in time_records.items():
        #     percentage = (time_ms / total_time_ms) * 100
        #     print(f"{stage:<30}{time_ms:<15.3f}{percentage:.2f}%")

        return final_output, time_records if measure_time else final_output

In [19]:
if torch.cuda.is_available():
    print("\nRunning on CUDA...")
    x_cuda = torch.randn(B, N, C).cuda()
    attn_cuda = SampledAttention_v3(dim=C, num_heads=num_heads).cuda()
    measure_execution_time(attn_cuda, x_cuda, "cuda")

# CPU 실행
print("\nRunning on CPU...")
x_cpu = torch.randn(B, N, C)
attn_cpu = SampledAttention_v3(dim=C, num_heads=num_heads)
measure_execution_time(attn_cpu, x_cpu, "cpu")


Running on CUDA...
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])
Input shape: torch.Size([32, 128, 256])

Average Execution Time on CUDA (excluding first run)
Stage                              Time (ms)      
Total forward pass                 1.514          
QKV computation                    0.245          
Reshape & split                    0.091          
Mask generation                    0.114          
Mask application                   0.045          
Index selection                    0.107          
Sampled QK extraction              0.073          
Attention computation              0.115          
Attention restoration              0.100          
Final projection                   0.210          
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Total                             

{'Total forward pass': 15.934228897094727,
 'QKV computation': 4.9803853034973145,
 'Reshape & split': 0.06449222564697266,
 'Mask generation': 0.12040138244628906,
 'Mask application': 0.36388635635375977,
 'Index selection': 0.1131296157836914,
 'Sampled QK extraction': 0.7218122482299805,
 'Attention computation': 1.4966130256652832,
 'Attention restoration': 3.7601590156555176,
 'Final projection': 4.047036170959473}