In [139]:
import torch
import torch.nn as nn
import torch.nn.functional as F
#from utils.attention_toolkit import *

## require torch version >= 2.1.0

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)



## cusomized layer normalization
class MyLayerNorm(nn.Module):
    def __init__(self, features, dim = -2, eps=1e-5):
        super(MyLayerNorm, self).__init__()
        assert dim in [-2, -1], 'dim must be -1 or -2'
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps
        self.dim = dim

    def forward(self, x):
        mean = x.mean(dim = self.dim, keepdim=True)
        std = x.std( dim= self.dim, keepdim=True)
        x_normalized = (x - mean) / (std + self.eps)
        if self.dim == -2:
            res = self.gamma[:,None] * x_normalized + self.beta[:,None]
        else:
            res = self.gamma * x_normalized + self.beta
        return res




class CNN_Block_mynorm(nn.Module):
    def __init__(self, n_input, n_output, kernel_size=3, pool=2, activation = F.relu):
        super().__init__()
        self.skip = nn.Conv1d(n_input, n_output, kernel_size=1)
        self.norm = MyLayerNorm(n_output)
        self.c1 = nn.Conv1d(n_input, n_output, kernel_size=kernel_size, padding=kernel_size // 2,
        bias=False)
        self.c2 = nn.Conv1d(n_output, n_output, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
        self.c3 = nn.Conv1d(n_output, n_output, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
        self.max_pool = nn.MaxPool1d(kernel_size=pool, stride=pool)  
        self.drop = nn.Dropout(0.2)
        self.activation = activation


    def forward(self, x):
        skip = self.skip(x)
        x = self.activation(self.c1(x))
        x = self.activation(self.c2(x))
        x = self.activation(self.c3(x))
        x = self.norm(x + skip)
        x = self.max_pool(x)
        x = self.drop(x)
        return x



class RotaryEmbedding(nn.Module):
    # modified from : https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.outer(t, self.inv_freq.to(device))
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def _rotate_half(self, x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)


    def forward(self, q, k, seq_len_dim = -2):
        """Applies Rotary Position Embedding to the query and key tensors.
        # modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
        Args:
            q (`torch.Tensor`): The query tensor. #[..., seq_len, head_dim]
            k (`torch.Tensor`): The key tensor. # [..., seq_len, head_dim]
            seq_len_dim : demension of sequence length that should be -2 given # [..., seq_len, head_dim]
        Returns:
            `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
        """
        seq_len = q.shape[seq_len_dim]
        if q.device != self.cos_cached.device or seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len= max(seq_len, self.max_seq_len_cached)
                                    , device=q.device, dtype=q.dtype)

        cos = self.cos_cached[:seq_len]
        sin = self.sin_cached[:seq_len]
        # auto broadcasting
        #cos = cos[None, None, : , :]
        #sin = sin[None, None, :, :]
        q_embed = (q * cos) + (self._rotate_half(q) * sin)
        k_embed = (k * cos) + (self._rotate_half(k) * sin)
        
        return q_embed, k_embed




class MHA(nn.Module):
    def __init__(self, d_input, d_output, d_k, n_head, drop = 0.1, activation= F.relu, bias = False, keep_attention_score = False):
        super().__init__()
        d_model = d_k * n_head
        self.WQ = nn.Linear(d_input, d_model, bias = bias)
        self.WK = nn.Linear(d_input, d_model, bias = bias)
        self.WV = nn.Linear(d_input, d_model, bias = bias)
        self.WO = nn.Linear(d_model, d_output, bias = bias)
        self.drop = nn.Dropout(drop)
        self.RoEmb = RotaryEmbedding(d_k)
        
        self.d_model = d_model
        self.d_k = d_k
        self.n_head = n_head
        self.activation = activation
        ## torch.nn.functional.scaled_dot_product_attention notation
        self.scale = 1/torch.tensor(d_k ** 0.5)
        self.keep_attention_score = keep_attention_score
        self.keep_V = None
        self.keep_O = None 
        

    def forward(self, x):
        
        batch_size, d_seq, _ = x.shape
    
        Q, K, V = self.WQ(x), self.WK(x), self.WV(x)
        
        if self.activation is not None:
            Q = self.activation(Q)
            K = self.activation(K)
            V = self.activation(V)
            

        ## out Q, K, V shape is: (batch_size, n_head, d_seq, d_model)
        Q = Q.reshape(batch_size, d_seq, self.n_head, -1).permute(0, 2, 1, 3)
        K = K.reshape(batch_size, d_seq, self.n_head, -1).permute(0, 2, 1, 3)
        V = V.reshape(batch_size, d_seq, self.n_head, -1).permute(0, 2, 1, 3)

        # RoEmbeding
        Q, K = self.RoEmb(Q, K)
        
        ## replaceed with flash attention implemented in PyTorch
        ## the provided scaled_dot_product_attention is much more memory effient, and adaptive for hardware to speed up attention computation

        #scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        #attention = torch.softmax(scores, dim = -1)
        #O = torch.matmul(attention, V)

        O = torch.nn.functional.scaled_dot_product_attention(Q, K, V, scale = self.scale)

        ## keep K and O for reconstruct attention score matrixes
        if self.keep_attention_score:
            self.keep_V = V.detach()
            self.keep_O = O.detach()

        # O shape (batch_size, n_head, d_seq, d_model), transpose to get back the right dimension before reshape        
        O = O.transpose(1, 2).reshape(batch_size, d_seq, self.d_model)
        O = self.WO(O)
    
        if self.activation is not None:
            O = self.activation(O)
        
        O = self.drop(O)
    
        return O





class TransformerBlock(nn.Module):
    def __init__(self, ff_dim = 2048, d_embed_in = 128, d_embed_out = 64, d_k = 64, n_head = 12, drop = 0.1, keep_attention_score = False):
        super().__init__()
        self.atten = MHA(d_input = d_embed_in, d_output = d_embed_out, d_k = d_k, n_head = n_head, drop = drop, keep_attention_score = keep_attention_score)
        self.drop1 = nn.Dropout(drop)
        self.drop2 = nn.Dropout(drop)
        self.norm1 = nn.LayerNorm(d_embed_out)
        self.norm2 = nn.LayerNorm(d_embed_out)
        self.ffn = nn.Sequential(
          nn.Linear(d_embed_out, ff_dim),
          nn.ReLU(),
          nn.Linear(ff_dim,d_embed_out),
          nn.ReLU()
        )

    def forward(self, input):
        x1 = self.atten(input)
        x1 = self.drop1(x1)
        x1 = self.norm1(input + x1)
        x = self.ffn(x1)
        x = self.drop2(x)
        x = self.norm2(x + x1)
        return (x)


MIN_FLOAT32 = torch.finfo(torch.float32).min
MAX_FLOAT32 = torch.finfo(torch.float32).max

class Model(nn.Module):
    def __init__(self, n_output, ff_dim = 2048, d_k = 64, n_head = 12, trans_drop = 0.1, keep_attention_score = False):
        super().__init__()
        kernel_size = 5
        conv_layers = [32, 64, 128, 256]
        conv_pools = [2,2,5,5]
        L = []
        c = 4
        for (l, p) in zip(conv_layers, conv_pools):
            L.append(CNN_Block_mynorm(c, l, kernel_size, p))
            c = l
        self.conv_tower = nn.Sequential(*L)
        
        self.trans1 = TransformerBlock(ff_dim=ff_dim, d_embed_in=conv_layers[-1], \
                                    d_embed_out=conv_layers[-1], d_k=d_k, n_head=n_head, drop=trans_drop, \
                                    keep_attention_score = keep_attention_score)

        self.trans2 = TransformerBlock(ff_dim=ff_dim, d_embed_in=conv_layers[-1], \
                                    d_embed_out=conv_layers[-1], d_k=d_k, n_head=n_head, drop=trans_drop, \
                                    keep_attention_score = keep_attention_score)

        self.trans3 = TransformerBlock(ff_dim=ff_dim, d_embed_in=conv_layers[-1], \
                                    d_embed_out=conv_layers[-1], d_k=d_k, n_head=n_head, drop=trans_drop, \
                                    keep_attention_score = keep_attention_score)

        self.trans4 = TransformerBlock(ff_dim=ff_dim, d_embed_in=conv_layers[-1], \
                                    d_embed_out=conv_layers[-1], d_k=d_k, n_head=n_head, drop=trans_drop, \
                                    keep_attention_score = keep_attention_score)

        self.drop = nn.Dropout(0.5)

        self.fc = nn.Sequential(
            nn.Linear(10 * conv_layers[-1], 128),
            nn.ReLU(),
            )

        self.classifier = nn.Linear(128, n_output)

    def forward(self, x, training=None, mask=None, **kwargs):
        x = self.conv_tower(x)
        x = x.transpose(1, 2)
        x = self.trans1(x)
        x = self.trans2(x)
        x = self.trans3(x)
        x = self.trans4(x)
        x = x.reshape(x.shape[0], -1)
        x = self.drop(x)
        x = self.fc(x)
        x = self.classifier(x)
        #x = torch.clamp(x, min = MIN_FLOAT32, max = MAX_FLOAT32)
        return x


## test code: 
## Model(95)(torch.rand(2,4,1000)).shape
        # self.conv_pool_drop_1 = nn.Sequential(
        #     nn.Conv1d(in_channels=4, out_channels=n_kernel, kernel_size=25, stride=1),
        #     nn.ReLU(),
        #     nn.MaxPool1d(kernel_size=12, stride=12),
        #     nn.Dropout(0.2),
        #     )



In [37]:
Model(95)(torch.rand(2,4,1000)).shape

torch.Size([2, 95])

In [165]:
class MyLayerNorm(nn.Module):
    def __init__(self, features, dim = -2, eps=1e-5):
        super(MyLayerNorm, self).__init__()
        assert dim in [-2, -1], 'dim must be -1 or -2'
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps
        self.dim = dim

    def forward(self, x):
        mean = x.mean(dim = self.dim, keepdim=True)
        std = x.std( dim= self.dim, keepdim=True)
        x_normalized = (x - mean) / (std + self.eps)
        if self.dim == -2:
            res = self.gamma[:,None] * x_normalized + self.beta[:,None]
        else:
            res = self.gamma * x_normalized + self.beta
        return res


class CNN_Block(nn.Module):
    def __init__(self, n_input, n_output, kernel_size=25, pool=2, activation = F.gelu):
        super().__init__()
        self.c1 = nn.Conv1d(n_input, n_output, kernel_size=kernel_size, padding=kernel_size // 2)
        self.c2 = nn.Conv1d(n_output, n_output, kernel_size=kernel_size, padding=kernel_size // 2)
        self.norm1 = MyLayerNorm(n_output)
        self.norm2 = MyLayerNorm(n_output)
        self.max_pool = nn.MaxPool1d(kernel_size=pool, stride=pool)  
        self.drop = nn.Dropout(0.2)
        self.activation = activation


    def forward(self, x):
        x = self.activation(self.c1(x))
        x = self.norm1(x)
        x = self.activation(self.c2(x))
        x = self.norm2(x)
        x = self.max_pool(x)
        x = self.drop(x)
        return x


class Model(nn.Module):
    def __init__(self, n_output, ff_dim = 2048, d_k = 64, n_head = 12, trans_drop = 0.1, keep_attention_score = False):
        super().__init__()
        kernel_size = [25, 9, 9]
        conv_layers = [64, 320, 640]
        conv_pools = [5,5,1]
        L = []
        c = 4
        for (l, p, k) in zip(conv_layers, conv_pools, kernel_size):
            L.append(CNN_Block(c, l, k, p))
            c = l
        self.conv_tower = nn.Sequential(*L)
        #self.norm_cnn = nn.LayerNorm(conv_layers[-1])
        
        self.trans1 = TransformerBlock(ff_dim=ff_dim, d_embed_in=conv_layers[-1], \
                                    d_embed_out=conv_layers[-1], d_k=d_k, n_head=n_head, drop=trans_drop, \
                                    keep_attention_score = keep_attention_score)


        self.drop = nn.Dropout(0.5)

        self.fc = nn.Sequential(
            nn.Linear(408 * conv_layers[-1], 128),
            nn.GELU(),
            )

        self.classifier = nn.Linear(128, n_output)

    def forward(self, x, training=None, mask=None, **kwargs):
        x = self.conv_tower(x)
        x = x.transpose(1, 2)
        #x = self.norm_cnn(x)
        x = self.trans1(x)
        x = x.reshape(x.shape[0], -1)
        x = self.drop(x)
        x = self.fc(x)
        x = self.classifier(x)
        x = torch.clamp(x, min = MIN_FLOAT32, max = MAX_FLOAT32)
        return x



In [166]:
a = Model(919)(torch.rand(2,4,10200))

In [167]:
count_parameters(Model(919))

44885271

In [168]:
a.shape

torch.Size([2, 919])

In [76]:
2*97920

195840

In [66]:
96960-97920

-960

In [117]:
15/132* 10.5

1.1931818181818181

In [119]:
0.3* 5

1.5

In [120]:
0.5/19 *5

0.13157894736842105

In [122]:
1/15*5

0.3333333333333333

In [125]:
0.3333333333333333 + 0.13157894736842105 + 1.1931818181818181 + 1.5

3.1580940988835726

In [124]:
3.15/63.5


0.049606299212598425