In [1]:
import torch
import torch.nn as nn
import time
from thop import profile
import tracemalloc
import gc


In [71]:
class SingleHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, window_size):
        super(SingleHeadSelfAttention, self).__init__()
        self.embed_size = embed_size
        self.window_size = window_size

        # Linear projection for queries, keys, and values
        self.proj = nn.Linear(embed_size, embed_size, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)
     
        self.scale = torch.sqrt(torch.FloatTensor([embed_size]))
        if torch.cuda.is_available():
            self.scale = self.scale.cuda()

    def forward(self, x, mask=None):
        N = x.shape[0]
        seq_len = x.shape[1]

        # Project input to queries, keys, and values
        queries = self.proj(x)
        keys = self.proj(x)
        values = self.proj(x)

        out = torch.zeros_like(x)
        for start in range(0, seq_len, self.window_size):
            end = min(start + self.window_size, seq_len)
            q_window = queries[:, start:end]
            k_window = keys[:, start:end]
            v_window = values[:, start:end]

            # Scaled dot-product attention
            # Dimensions of q_window, k_window, v_window: (N, window_size, embed_size)
            # Compute attention scores
            energy = torch.bmm(q_window, k_window.transpose(1, 2)) / self.scale
            if mask is not None:
                mask_window = mask[:, start:end]
                energy = energy.masked_fill(mask_window == 0, float("-1e20"))

            attention = torch.softmax(energy, dim=2)
            out_window = torch.bmm(attention, v_window)
            out[:, start:end] = out_window

        out = self.fc_out(out)
        return out
    
def profile_self_attention(seq_length, device="cpu"):
    # Set up parameters
    batch_size = 1
    embed_size = 128
    window_size = 64
    
    # Create an instance of SelfAttention
    self_attention = SingleHeadSelfAttention(embed_size, window_size)

    # Create a random input tensor
    x = torch.randn(batch_size, seq_length, embed_size)

    # Move model and input to GPU if available
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self_attention = self_attention.to(device)
    x = x.to(device)
    
    def compute_flops(seq_len, window_size):
        return 2 * seq_len**2 * embed_size + 2 * seq_len * embed_size**2

    flops = compute_flops(seq_length, window_size)

    # Measure memory usage
    if device == "cpu":
        process = psutil.Process()
        initial_memory = process.memory_info().rss / 1024**2  # Initial memory in MB
        _ = self_attention(x)
        final_memory = process.memory_info().rss / 1024**2  # Final memory in MB
        memory_usage = final_memory - initial_memory  # Memory usage in MB
    else:
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        _ = self_attention(x)
        memory_usage = torch.cuda.max_memory_allocated() / 1024**2  # Convert to MB
        
    # Measure time
    start_time = time.time()
    _ = self_attention(x)
    end_time = time.time()
    wall_clock_time = end_time - start_time

    return flops, memory_usage, wall_clock_time



In [None]:
input_lengths = [10, 100, 1000, 10000, 100000, 1000000]
flops_list, memory_list, time_list = [], [], []
N = 50
for i in range(N):
    for length in input_lengths:
        flops, memory, wall_time = profile_self_attention(length, device = "cuda")
        flops_list.append(flops)
        memory_list.append(memory)
        time_list.append(wall_time)

In [74]:
import pandas as pd
profiling = pd.DataFrame({"input_length":input_lengths*N,
              "flops": flops_list,
              "mem": memory_list,
              "time": time_list})

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
# Plotting
plt.figure(figsize=(15, 5))

plt.subplot(131)
sns.lineplot(data = profiling, x = "input_length", y = "flops")
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Input Length')
plt.ylabel('FLOPS')
plt.title('Computational Complexity')

plt.subplot(132)
sns.lineplot(data = profiling, x = "input_length", y = "mem")
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Input Length')
plt.ylabel('Memory Usage (MB)')
plt.title('Memory Usage')

plt.subplot(133)
sns.lineplot(data = profiling, x = "input_length", y = "time")
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Input Length')
plt.ylabel('Wall Clock Time (s)')
plt.title('Wall Clock Time')

plt.tight_layout()
plt.show()
