In [1]:
import numpy as np
from torch import nn
import torch
import sys
sys.path.append("../")
from libs.layers import linear_attention
from libs.layers import causal_linear_attn
from libs.layers import attention
from torch.nn.init import xavier_uniform_, constant_, xavier_normal_
import copy
import gc
torch.cuda.set_device(4)
from torch.nn.parameter import Parameter
import time
# import sys
# sys.executable

Package not found.


In [2]:
class original_SimpleAttention(nn.Module):
    def __init__(self, n_head, d_model,
                 pos_dim: int = 1,
                 attention_type='fourier',
                 dropout=0.1,
                 xavier_init=1e-4,
                 diagonal_weight=1e-2,
                 symmetric_init=False,
                 norm=False,
                 norm_type='layer',
                 eps=1e-5,
                 debug=False
                ):
        super(original_SimpleAttention, self).__init__()
        assert d_model % n_head == 0
        self.attention_type = attention_type
        self.d_k = d_model // n_head
        self.n_head = n_head
        self.pos_dim = pos_dim
        self.linears = nn.ModuleList(
            [copy.deepcopy(nn.Linear(d_model, d_model)) for _ in range(3)])
        self.xavier_init = xavier_init
        self.diagonal_weight = diagonal_weight
        self.symmetric_init = symmetric_init
        if self.xavier_init > 0:
            self._reset_parameters()
        self.add_norm = norm
        self.norm_type = norm_type
        if norm:
            self._get_norm(eps=eps)

        if pos_dim > 0:
            self.fc = nn.Linear(d_model + n_head*pos_dim, d_model)

        self.attn_weight = None
        self.dropout = nn.Dropout(dropout)
        self.debug = debug

    def forward(self, query, key, value, pos=None, mask=None, weight=None):
        
        if mask is not None:
            mask = mask.unsqueeze(1)

        bsz = query.size(0)
        if weight is not None:
            query, key = weight*query, weight*key

        query, key, value = \
            [layer(x).view(bsz, -1, self.n_head, self.d_k).transpose(1, 2)
             for layer, x in zip(self.linears, (query, key, value))]


        if self.add_norm:
            if self.attention_type in ['linear', 'galerkin', 'global']:
                if self.norm_type == 'instance':
                    key, value = key.transpose(-2, -1), value.transpose(-2, -1)

                key = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_K, (key[:, i, ...] for i in range(self.n_head)))], dim=1)
                value = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_V, (value[:, i, ...] for i in range(self.n_head)))], dim=1)

                if self.norm_type == 'instance':
                    key, value = key.transpose(-2, -1), value.transpose(-2, -1)
            else:
                if self.norm_type == 'instance':
                    key, query = key.transpose(-2, -1), query.transpose(-2, -1)

                key = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_K, (key[:, i, ...] for i in range(self.n_head)))], dim=1)
                query = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_Q, (query[:, i, ...] for i in range(self.n_head)))], dim=1)

                if self.norm_type == 'instance':
                    key, query = key.transpose(-2, -1), value.transpose(-2, -1)

        if pos is not None and self.pos_dim > 0:
            assert pos.size(-1) == self.pos_dim
            pos = pos.unsqueeze(1)
            pos = pos.repeat([1,self.n_head, 1, 1])
            query, key, value = [torch.cat([pos, x], dim=-1)
                                  for x in (query, key, value)]
    
        if self.attention_type in ['linear', 'galerkin', 'global']:
            x, self.attn_weight = linear_attention(query, key, value,
                                                   mask=mask,
                                                   attention_type=self.attention_type,
                                                   dropout=self.dropout)
        elif self.attention_type == 'causal':
            assert mask is not None
            x, self.attn_weight = causal_linear_attn(query, key, value,
                                                   mask=mask,
                                                   dropout=self.dropout)
        else:
            x, self.attn_weight = attention(query, key, value,
                                            mask=mask,
                                            attention_type=self.attention_type,
                                            dropout=self.dropout)

        out_dim = self.n_head * self.d_k if pos is None else self.n_head * \
            (self.d_k + self.pos_dim)
        att_output = x.transpose(1, 2).contiguous().view(bsz, -1, out_dim)

        if pos is not None and self.pos_dim > 0:
            att_output = self.fc(att_output)
            
       
        return att_output, self.attn_weight

    def _reset_parameters(self):
        for param in self.linears.parameters():
            if param.ndim > 1:
                xavier_uniform_(param, gain=self.xavier_init)
                if self.diagonal_weight > 0.0:
                    param.data += self.diagonal_weight * \
                        torch.diag(torch.ones(
                            param.size(-1), dtype=torch.float))
                if self.symmetric_init:
                    param.data += param.data.T
                    # param.data /= 2.0
            else:
                constant_(param, 0)
                
                
                

    def _get_norm(self, eps):
        if self.attention_type in ['linear', 'galerkin', 'global']:
            if self.norm_type == 'instance':
                self.norm_K = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
                self.norm_V = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
            elif self.norm_type == 'layer':
                self.norm_K = self._get_layernorm(self.d_k, self.n_head,
                                                  eps=eps)
                self.norm_V = self._get_layernorm(self.d_k, self.n_head,
                                                  eps=eps)
        else:
            if self.norm_type == 'instance':
                self.norm_K = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
                self.norm_Q = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
            elif self.norm_type == 'layer':
                self.norm_K = self._get_layernorm(self.d_k, self.n_head,
                                                  eps=eps)
                self.norm_Q = self._get_layernorm(self.d_k, self.n_head,
                                                  eps=eps)

    @staticmethod
    def _get_layernorm(normalized_dim, n_head, **kwargs):
        return nn.ModuleList(
            [copy.deepcopy(nn.LayerNorm(normalized_dim, **kwargs)) for _ in range(n_head)])

    @staticmethod
    def _get_instancenorm(normalized_dim, n_head, **kwargs):
        return nn.ModuleList(
            [copy.deepcopy(nn.InstanceNorm1d(normalized_dim, **kwargs)) for _ in range(n_head)])
    

In [3]:
class lin_SimpleAttention(nn.Module):
    def __init__(self, n_head, d_model,k,
                 pos_dim: int = 1,
                 attention_type='fourier',
                 dropout=0.1,
                 xavier_init=1e-4,
                 diagonal_weight=1e-2,
                 symmetric_init=False,
                 norm=False,
                 norm_type='layer',
                 eps=1e-5,
                 debug=False,
                 seq_len = 2048,
                ):
        super(lin_SimpleAttention, self).__init__()
        assert d_model % n_head == 0
        self.attention_type = attention_type
        self.d_k = d_model // n_head
        self.n_head = n_head
        self.pos_dim = pos_dim
        self.linears = nn.ModuleList(
            [copy.deepcopy(nn.Linear(d_model, d_model)) for _ in range(3)])
        self.xavier_init = xavier_init
        self.diagonal_weight = diagonal_weight
        self.symmetric_init = symmetric_init
        if self.xavier_init > 0:
            self._reset_parameters()
        self.add_norm = norm
        self.norm_type = norm_type
        if norm:
            self._get_norm(eps=eps)

        if pos_dim > 0:
            self.fc = nn.Linear(d_model + n_head*pos_dim, d_model)

        self.attn_weight = None
        self.dropout = nn.Dropout(dropout)
        self.debug = debug
        
        self.E_weight = Parameter(torch.Tensor(seq_len, k))
        self.F_weight = Parameter(torch.Tensor(seq_len, k))
        self.E_bias = Parameter(torch.Tensor(k, 1))
        self.F_bias = Parameter(torch.Tensor(k, 1))
        xavier_normal_(self.E_weight)
        xavier_normal_(self.F_weight)
        xavier_normal_(self.E_bias)
        xavier_normal_(self.F_bias) 
#         self.E = nn.Linear( seq_len, k)
#         self.F = nn.Linear(seq_len, k)

    def forward(self, query, key, value, pos=None, mask=None, weight=None):
        
        if mask is not None:
            mask = mask.unsqueeze(1)

        bsz = query.size(0)
        if weight is not None:
            query, key = weight*query, weight*key

        query, key, value = \
            [layer(x).view(bsz, -1, self.n_head, self.d_k).transpose(1, 2)      ### bsz*n_head*n*dim
             for layer, x in zip(self.linears, (query, key, value))]


        if self.add_norm:
            if self.attention_type in ['linear', 'galerkin', 'global']:
                if self.norm_type == 'instance':
                    key, value = key.transpose(-2, -1), value.transpose(-2, -1)

                key = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_K, (key[:, i, ...] for i in range(self.n_head)))], dim=1)
                value = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_V, (value[:, i, ...] for i in range(self.n_head)))], dim=1)

                if self.norm_type == 'instance':
                    key, value = key.transpose(-2, -1), value.transpose(-2, -1)
            else:
                if self.norm_type == 'instance':
                    key, query = key.transpose(-2, -1), query.transpose(-2, -1)

                key = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_K, (key[:, i, ...] for i in range(self.n_head)))], dim=1)
                query = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_Q, (query[:, i, ...] for i in range(self.n_head)))], dim=1)

                if self.norm_type == 'instance':
                    key, query = key.transpose(-2, -1), value.transpose(-2, -1)

        if pos is not None and self.pos_dim > 0:
            assert pos.size(-1) == self.pos_dim
            pos = pos.unsqueeze(1)
            pos = pos.repeat([1, self.n_head, 1, 1])
            query, key, value = [torch.cat([pos, x], dim=-1)     ### bsz*n_head*n*(dim+1)
                                  for x in (query, key, value)]
 
        key  = torch.einsum("ijkl,kz->ijzl",key, self.E_weight)+ self.E_bias
        value  = torch.einsum("ijkl,kz->ijzl",value, self.F_weight)+ self.F_bias
#         key = self.E(key.transpose(-1,-2)).transpose(-1,-2)
#         value = self.E(value.transpose(-1,-2)).transpose(-1,-2)
        
        
        if self.attention_type in ['linear', 'galerkin', 'global']:
            x, self.attn_weight = linear_attention(query, key, value,
                                                   mask=mask,
                                                   attention_type=self.attention_type,
                                                   dropout=self.dropout)
        elif self.attention_type == 'causal':
            assert mask is not None
            x, self.attn_weight = causal_linear_attn(query, key, value,
                                                   mask=mask,
                                                   dropout=self.dropout)
        else:
            x, self.attn_weight = attention(query, key, value,
                                            mask=mask,
                                            attention_type=self.attention_type,
                                            dropout=self.dropout)

        out_dim = self.n_head * self.d_k if pos is None else self.n_head * \
            (self.d_k + self.pos_dim)
        att_output = x.transpose(1, 2).contiguous().view(bsz, -1, out_dim)

        if pos is not None and self.pos_dim > 0:
            att_output = self.fc(att_output)
            
       
        return att_output, self.attn_weight

    def _reset_parameters(self):
        for param in self.linears.parameters():
            if param.ndim > 1:
                xavier_uniform_(param, gain=self.xavier_init)
                if self.diagonal_weight > 0.0:
                    param.data += self.diagonal_weight * \
                        torch.diag(torch.ones(
                            param.size(-1), dtype=torch.float))
                if self.symmetric_init:
                    param.data += param.data.T
                    # param.data /= 2.0
            else:
                constant_(param, 0)
                
                
                

    def _get_norm(self, eps):
        if self.attention_type in ['linear', 'galerkin', 'global']:
            if self.norm_type == 'instance':
                self.norm_K = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
                self.norm_V = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
            elif self.norm_type == 'layer':
                self.norm_K = self._get_layernorm(self.d_k, self.n_head,
                                                  eps=eps)
                self.norm_V = self._get_layernorm(self.d_k, self.n_head,
                                                  eps=eps)
        else:
            if self.norm_type == 'instance':
                self.norm_K = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
                self.norm_Q = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
            elif self.norm_type == 'layer':
                self.norm_K = self._get_layernorm(self.d_k, self.n_head,
                                                  eps=eps)
                self.norm_Q = self._get_layernorm(self.d_k, self.n_head,
                                                  eps=eps)

    @staticmethod
    def _get_layernorm(normalized_dim, n_head, **kwargs):
        return nn.ModuleList(
            [copy.deepcopy(nn.LayerNorm(normalized_dim, **kwargs)) for _ in range(n_head)])

    @staticmethod
    def _get_instancenorm(normalized_dim, n_head, **kwargs):
        return nn.ModuleList(
            [copy.deepcopy(nn.InstanceNorm1d(normalized_dim, **kwargs)) for _ in range(n_head)])

In [4]:
class svd_SimpleAttention(nn.Module):
    def __init__(self, n_head, d_model,k=2048,
                 pos_dim: int = 1,
                 attention_type='fourier',
                 dropout=0.1,
                 xavier_init=1e-4,
                 diagonal_weight=1e-2,
                 symmetric_init=False,
                 norm=False,
                 norm_type='layer',
                 eps=1e-5,
                 debug=False,
                 seq_len = 2048,
                ):
        super(svd_SimpleAttention, self).__init__()
        assert d_model % n_head == 0
        self.attention_type = attention_type
        self.d_k = d_model // n_head
        self.n_head = n_head
        self.pos_dim = pos_dim
        
    
        self.xavier_init = xavier_init
        self.diagonal_weight = diagonal_weight
        self.symmetric_init = symmetric_init

        self.add_norm = norm
        self.norm_type = norm_type
        if norm:
            self._get_norm(eps=eps)

        if pos_dim > 0:
            self.fc = nn.Linear(d_model + n_head*pos_dim, d_model)

        self.attn_weight = None
        self.dropout = nn.Dropout(dropout)
        self.debug = debug
        self.E_weight = Parameter(torch.Tensor(seq_len, k))
        self.F_weight = Parameter(torch.Tensor(seq_len, k))
        self.E_bias = Parameter(torch.Tensor(k, 1))
        self.F_bias = Parameter(torch.Tensor(k, 1))
        xavier_normal_(self.E_weight)
        xavier_normal_(self.F_weight)
        xavier_normal_(self.E_bias)
        xavier_normal_(self.F_bias)
        self.linear = nn.Linear(self.d_k + pos_dim, self.d_k + pos_dim)


    def forward(self, query, key, value, pos=None, mask=None, weight=None):
        
        if mask is not None:
            mask = mask.unsqueeze(1)

        bsz = query.size(0)
        if weight is not None:
            query, key = weight*query, weight*key 

        query = query.view(bsz, -1, self.n_head, self.d_k).transpose(1, 2)
        key = key.view(bsz, -1, self.n_head, self.d_k).transpose(1, 2)
        value = value.view(bsz, -1, self.n_head, self.d_k).transpose(1, 2)            
        if pos is not None and self.pos_dim > 0:
            assert pos.size(-1) == self.pos_dim
            pos = pos.unsqueeze(1)
            pos = pos.repeat([1, self.n_head, 1, 1])
            query, key, value = [torch.cat([pos, x], dim=-1)
                                for x in (query, key, value)]  
            
        query = self.linear(query)
        key  = torch.einsum("ijkl,kz->ijzl",key, self.E_weight)+ self.E_bias
        value  = torch.einsum("ijkl,kz->ijzl",value, self.F_weight)+ self.F_bias
        
        if self.add_norm:
            if self.attention_type in ['linear', 'galerkin', 'global']:
                if self.norm_type == 'instance':
                    key, value = key.transpose(-2, -1), value.transpose(-2, -1)

                key = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_K, (key[:, i, ...] for i in range(self.n_head)))], dim=1)
                value = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_V, (value[:, i, ...] for i in range(self.n_head)))], dim=1)

                if self.norm_type == 'instance':
                    key, value = key.transpose(-2, -1), value.transpose(-2, -1)
            else:
                if self.norm_type == 'instance':
                    key, query = key.transpose(-2, -1), query.transpose(-2, -1)

                key = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_K, (key[:, i, ...] for i in range(self.n_head)))], dim=1)
                query = torch.stack(
                    [norm(x) for norm, x in
                     zip(self.norm_Q, (query[:, i, ...] for i in range(self.n_head)))], dim=1)

                if self.norm_type == 'instance':
                    key, query = key.transpose(-2, -1), value.transpose(-2, -1)
        
        
        if self.attention_type in ['linear', 'galerkin', 'global']:
            x, self.attn_weight = linear_attention(query, key, value,
                                                   mask=mask,
                                                   attention_type=self.attention_type,
                                                   dropout=self.dropout)
        elif self.attention_type == 'causal':
            assert mask is not None
            x, self.attn_weight = causal_linear_attn(query, key, value,
                                                   mask=mask,
                                                   dropout=self.dropout)
        else:
            x, self.attn_weight = attention(query, key, value,
                                            mask=mask,
                                            attention_type=self.attention_type,
                                            dropout=self.dropout)

        out_dim = self.n_head * self.d_k if pos is None else self.n_head * \
            (self.d_k + self.pos_dim)
        att_output = x.transpose(1, 2).contiguous().view(bsz, -1, out_dim)

        if pos is not None and self.pos_dim > 0:
            att_output = self.fc(att_output)
            
       
        return att_output, self.attn_weight

    def _reset_parameters(self):
        for param in self.linears.parameters():
            if param.ndim > 1:
                xavier_uniform_(param, gain=self.xavier_init)
                if self.diagonal_weight > 0.0:
                    param.data += self.diagonal_weight * \
                        torch.diag(torch.ones(
                            param.size(-1), dtype=torch.float))
                if self.symmetric_init:
                    param.data += param.data.T
                    # param.data /= 2.0
            else:
                constant_(param, 0)
                         

    def _get_norm(self, eps):
        if self.attention_type in ['linear', 'galerkin', 'global']:
            if self.norm_type == 'instance':
                self.norm_K = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
                self.norm_V = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
            elif self.norm_type == 'layer':
                self.norm_K = self._get_layernorm(self.d_k + self.pos_dim, self.n_head,
                                                  eps=eps)
                self.norm_V = self._get_layernorm(self.d_k + self.pos_dim, self.n_head,
                                                  eps=eps)
        else:
            if self.norm_type == 'instance':
                self.norm_K = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
                self.norm_Q = self._get_instancenorm(self.d_k, self.n_head,
                                                     eps=eps,
                                                     affine=True)
            elif self.norm_type == 'layer':
                self.norm_K = self._get_layernorm(self.d_k + self.pos_dim, self.n_head,
                                                  eps=eps)
                self.norm_Q = self._get_layernorm(self.d_k + self.pos_dim, self.n_head,
                                                  eps=eps)

    @staticmethod
    def _get_layernorm(normalized_dim, n_head, **kwargs):
        return nn.ModuleList(
            [copy.deepcopy(nn.LayerNorm(normalized_dim, **kwargs)) for _ in range(n_head)])

    @staticmethod
    def _get_instancenorm(normalized_dim, n_head, **kwargs):
        return nn.ModuleList(
            [copy.deepcopy(nn.InstanceNorm1d(normalized_dim, **kwargs)) for _ in range(n_head)])

In [5]:
d=288
d_1=1
k=5
N = 4096
n_head = 4

In [14]:
from thop import profile
from thop import clever_format
from ptflops import get_model_complexity_info

In [15]:
Q = torch.randn(1, N, d)
K = torch.randn(1, N, d)
V = torch.randn(1, N, d)
pos = torch.randn(1,N,d_1)

In [30]:
def input_constructor(x):
    q = x[0]
    k = x[1]
    v = x[2]
    pos = x[-1]
    return {"query": torch.randn(q),
           "key": torch.randn(k),
           "value": torch.randn(v),
           "pos": torch.rand(pos)}

In [42]:
attention_type = "fourier"
model = original_SimpleAttention(n_head = n_head, d_model = d, pos_dim = d_1, attention_type=attention_type)
flops, params = get_model_complexity_info(model, ((1, N, d) ,(1, N, d),(1, N, d),(1, N, d_1)), 
                                          as_strings=True,input_constructor =input_constructor,
                                          print_per_layer_stat=True)

original_SimpleAttention(
  334.08 k, 100.000% Params, 1.36 GMac, 100.000% MACs, 
  (linears): ModuleList(
    249.7 k, 74.741% Params, 1.02 GMac, 74.740% MACs, 
    (0): Linear(83.23 k, 24.914% Params, 339.74 MMac, 24.913% MACs, in_features=288, out_features=288, bias=True)
    (1): Linear(83.23 k, 24.914% Params, 339.74 MMac, 24.913% MACs, in_features=288, out_features=288, bias=True)
    (2): Linear(83.23 k, 24.914% Params, 339.74 MMac, 24.913% MACs, in_features=288, out_features=288, bias=True)
  )
  (fc): Linear(84.38 k, 25.259% Params, 344.46 MMac, 25.260% MACs, in_features=292, out_features=288, bias=True)
  (dropout): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.1, inplace=False)
)


In [None]:
attention_type = "fourier"
model = original_SimpleAttention(n_head = n_head, d_model = d, pos_dim = d_1, attention_type=attention_type)
device = torch.device('cuda')
model.to(device)
dummy_input= torch.randn(1,N,d,dtype=torch.float).to(device)
dummy_pos = torch.randn(1,N,d_1,dtype=torch.float).to(device)
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 300
timings=np.zeros((repetitions,1))
#GPU-WARM-UP：开始跑dummy example
for _ in range(10):
   _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
# MEASURE PERFORMANCE
with torch.no_grad():
    for rep in range(repetitions):
        starter.record()
        _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
        ender.record()
     # WAIT FOR GPU SYNC
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        timings[rep] = curr_time
mean_syn = np.sum(timings) / repetitions
std_syn = np.std(timings)
print(f"{attention_type}:", mean_syn)
del model

In [None]:
attention_type = "galerkin"
model = original_SimpleAttention(n_head = n_head, d_model = d, pos_dim = d_1, attention_type=attention_type)
device = torch.device('cuda')
model.to(device)
dummy_input= torch.randn(1,N,d,dtype=torch.float).to(device)
dummy_pos = torch.randn(1,N,d_1,dtype=torch.float).to(device)
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 300
timings=np.zeros((repetitions,1))
#GPU-WARM-UP：开始跑dummy example
for _ in range(10):
   _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
# MEASURE PERFORMANCE
with torch.no_grad():
    for rep in range(repetitions):
        starter.record()
        _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
        ender.record()
     # WAIT FOR GPU SYNC
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        timings[rep] = curr_time
mean_syn = np.sum(timings) / repetitions
std_syn = np.std(timings)
print(f"{attention_type}:", mean_syn)
del model

In [None]:
attention_type = "fourier"
for k in range(3,10):
    model = lin_SimpleAttention(n_head = n_head, d_model = d,k = k, pos_dim = d_1, seq_len = N, attention_type='fourier')
    device = torch.device('cuda')
    model.to(device)
    dummy_input= torch.randn(1,N,d,dtype=torch.float).to(device)
    dummy_pos = torch.randn(1,N,d_1,dtype=torch.float).to(device)
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    repetitions = 300
    timings=np.zeros((repetitions,1))
    #GPU-WARM-UP：开始跑dummy example
    for _ in range(10):
       _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
    # MEASURE PERFORMANCE
    with torch.no_grad():
        for rep in range(repetitions):
            starter.record()
            _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
            ender.record()
         # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            timings[rep] = curr_time
    mean_syn = np.sum(timings) / repetitions
    print(f"lin {k} {attention_type}:", mean_syn)
    del model

In [None]:
attention_type = "galerkin"
for k in range(3,10):
    model = lin_SimpleAttention(n_head = n_head, d_model = d,k = k, pos_dim = d_1, seq_len = N, attention_type=attention_type)
    device = torch.device('cuda')
    model.to(device)
    dummy_input= torch.randn(1,N,d,dtype=torch.float).to(device)
    dummy_pos = torch.randn(1,N,d_1,dtype=torch.float).to(device)
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    repetitions = 300
    timings=np.zeros((repetitions,1))
    #GPU-WARM-UP：开始跑dummy example
    for _ in range(10):
       _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
    # MEASURE PERFORMANCE
    with torch.no_grad():
        for rep in range(repetitions):
            starter.record()
            _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
            ender.record()
         # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            timings[rep] = curr_time
    mean_syn = np.sum(timings) / repetitions
    print(f"lin {k} {attention_type}:", mean_syn)
    del model

In [None]:
attention_type = "fourier"
for k in range(3,10):
    model = svd_SimpleAttention(n_head = n_head, d_model = d,k = k, pos_dim = d_1, seq_len = N, attention_type=attention_type)
    device = torch.device('cuda')
    model.to(device)
    dummy_input= torch.randn(1,N,d,dtype=torch.float).to(device)
    dummy_pos = torch.randn(1,N,d_1,dtype=torch.float).to(device)
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    repetitions = 300
    timings=np.zeros((repetitions,1))
    #GPU-WARM-UP：开始跑dummy example
    for _ in range(10):
       _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
    # MEASURE PERFORMANCE
    with torch.no_grad():
        for rep in range(repetitions):
            starter.record()
            _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
            ender.record()
         # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            timings[rep] = curr_time
    mean_syn = np.sum(timings) / repetitions
    print(f"svd {k} {attention_type}:", mean_syn)
    del model

In [None]:
attention_type = "galerkin"
for k in range(3,10):
    model = svd_SimpleAttention(n_head = n_head, d_model = d,k = k, pos_dim = d_1, seq_len = N, attention_type=attention_type)
    device = torch.device('cuda')
    model.to(device)
    dummy_input= torch.randn(1,N,d,dtype=torch.float).to(device)
    dummy_pos = torch.randn(1,N,d_1,dtype=torch.float).to(device)
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    repetitions = 300
    timings=np.zeros((repetitions,1))
    #GPU-WARM-UP：开始跑dummy example
    for _ in range(10):
       _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
    # MEASURE PERFORMANCE
    with torch.no_grad():
        for rep in range(repetitions):
            starter.record()
            _ = model(dummy_input,dummy_input,dummy_input,dummy_pos)
            ender.record()
         # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            timings[rep] = curr_time
    mean_syn = np.sum(timings) / repetitions
    print(f"svd {k} {attention_type}:", mean_syn)
    del model