In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# Test both implementations
d_hidden = 768
num_heads = 8
seq_length = 197
batch_size = 4

# Create input
x = torch.randn(batch_size, seq_length, d_hidden)
print("Input shape:", x.shape)


Input shape: torch.Size([4, 197, 768])


In [3]:

class MultiHeadConvNNAttention(nn.Module):
    def __init__(self, 
                 d_hidden, 
                 num_heads, 
                 attention_dropout,
                 K, 
                 sampling_type, 
                 num_samples, 
                 sample_padding, 
                 magnitude_type, 
                 seq_length=197, 
                 coordinate_encoding=False
                 ):
        
        super(MultiHeadConvNNAttention, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"

        # Core Parameters
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.d_k = d_hidden // num_heads

        # ConvNN Parameters
        self.K = K
        self.seq_length = seq_length

        # 3 types of sampling: all, random, spatial
        self.sampling_type = sampling_type
        self.num_samples = int(num_samples) 
        self.sample_padding = int(sample_padding) if sampling_type == 'spatial' else 0    

        # Similarity Metric 
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'cosine' else False

        # Coordinate Encoding (optional) 
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {}
        
        # Linear projections for query, key, value
        self.W_q = nn.Linear(d_hidden, d_hidden)
        self.W_k = nn.Linear(d_hidden, d_hidden)
        self.W_v = nn.Linear(d_hidden, d_hidden)
        self.W_o = nn.Linear(d_hidden, d_hidden)   
        self.dropout = nn.Dropout(attention_dropout)


        self.in_channels = (d_hidden // num_heads) + 1 if coordinate_encoding else d_hidden // num_heads
        self.out_channels = (d_hidden // num_heads) 
        
        self.conv = nn.Conv1d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.K,
            stride=self.K,
            padding=0,
        )

        # Utility Variables 
        self.INF = 1e5 
        self.NEG_INF = -1e5
        
    def split_head(self, x): 
        batch_size, seq_length, d_hidden = x.size()
        self.batch_size = batch_size
        # self.seq_length = seq_length
        return x.contiguous().view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_length, d_k)
        
    def combine_heads(self, x): 
        
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden) 
    
    def batch_split(self, x): 
        x = x.reshape(self.batch_size, -1, self.d_k, self.seq_length)
        return x.permute(0, 1, 3, 2).contiguous()
        
    def batch_combine(self, x): 
        batch_size, _, seq_length, d_k = x.size()
        x = x.permute(0, 1, 3, 2).contiguous() 
        return x.view(-1, self.d_k, seq_length)
        
    def forward(self, x):
        # Note: x shape: (B, seq_length, d_hidden)
        print("Input shape:", x.shape)
        # 1. Splithead & Batch Combine
        k = self.batch_combine(self.split_head(self.W_k(x)))
        v = self.batch_combine(self.split_head(self.W_v(x)))

        print("k shape after split and combine:", k.shape)

        # 3. Add Coordinate Encoding 
        k = self._add_coordinate_encoding(k) if self.coordinate_encoding else k
        v = self._add_coordinate_encoding(v) if self.coordinate_encoding else v

        print("k shape after encoding:", k.shape)

        
        if self.sampling_type == 'all': # All Samples
            q = self.batch_combine(self.split_head(self.W_q(x)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q
            print("q shape: ", q.shape)

            similarity_matrix = self._calculate_cosine_matrix(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix(k, q, sqrt=True)
            print("similarity_marix: ", similarity_matrix.shape)
            prime = self._prime(v, similarity_matrix, self.K, self.maximum)
            print("prime: ", prime.shape)

        elif self.sampling_type == 'random': # Random Samples
            rand_idx = torch.randperm(x.shape[1], device=x.device)[:self.num_samples]
            x_sample = x[:, rand_idx, :]
            print("x sample shape: ", x_sample.shape)
            
            q = self.batch_combine(self.split_head(self.W_q(x_sample)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q
            print("q shape: ", q.shape)

            similarity_matrix = self._calculate_cosine_matrix_N(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix_N(k, q, sqrt=True)

            print("similarity_marix: ", similarity_matrix.shape)
            range_idx = torch.arange(len(rand_idx), device=q.device)
            similarity_matrix[:, rand_idx, range_idx] = self.INF if self.magnitude_type == 'euclidean' else self.NEG_INF

            prime = self._prime_N(v, similarity_matrix, self.K, rand_idx, self.maximum)

            print("prime: ", prime.shape)
        elif self.sampling_type == 'spatial': # Spatial Samples
            spat_idx = torch.linspace(0 + self.sample_padding, x.shape[1] - self.sample_padding - 1, self.num_samples, device=x.device).long()
            x_sample = x[:, spat_idx, :]
            print("x sample shape: ", x_sample.shape)
            q = self.batch_combine(self.split_head(self.W_q(x_sample)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q
            print("q shape: ", q.shape)

            similarity_matrix = self._calculate_cosine_matrix_N(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix_N(k, q, sqrt=True)

            print("similarity_marix: ", similarity_matrix.shape)
            range_idx = torch.arange(len(spat_idx), device=q.device)
            similarity_matrix[:, spat_idx, range_idx] = self.INF if self.magnitude_type == 'euclidean' else self.NEG_INF

            prime = self._prime_N(v, similarity_matrix, self.K, spat_idx, self.maximum)

            print("prime: ", prime.shape)

        else: 
            raise ValueError("Invalid sampling_type. Must be one of ['all', 'random', 'spatial']")

        x = self.conv(prime)  

        x = self.dropout(x)
        x = x.permute(0, 2, 1) 
        x = self.W_o(self.combine_heads(self.batch_split(x)))
        return x       

    def _calculate_euclidean_matrix(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        torch.diagonal(dist_matrix, dim1=1, dim2=2).fill_(-0.1)  # Fill diagonal with -0.1 to avoid self-selection
        return dist_matrix 

    def _calculate_euclidean_matrix_N(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        return dist_matrix 

    def _calculate_cosine_matrix(self, K, Q):
        k_norm = F.normalize(K, p=2, dim=1)
        q_norm = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(k_norm.transpose(1, 2), q_norm)
        torch.diagonal(similarity_matrix, dim1=1, dim2=2).fill_(1.1)  # Fill diagonal with 1.1 to self-select
        return similarity_matrix

    def _calculate_cosine_matrix_N(self, K, Q):
        norm_k = F.normalize(K, p=2, dim=1)
        norm_q = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(norm_k.transpose(1, 2), norm_q)
        return similarity_matrix

    def _prime(self, v, qk, K, maximum):
        b, c, t = v.shape
        topk_values, topk_indices = torch.topk(qk, k=K, dim=2, largest=maximum)
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)

        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime 

        prime = prime.view(b, c, -1)

        return prime

    def _prime_N(self, v, qk, K, rand_idx, maximum):
        b, c, t = v.shape
        topk_values, topk_indices = torch.topk(qk, k=K-1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        # Map sample indicies back to original matrix positions 
        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=v.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=-1)
        topk_indices_exp = final_indices.unsqueeze(1).expand(b, c, t, K)

        # Expand topk values to match the shape of indices
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K-1)
        ones = torch.ones((b, c, t, 1), device=v.device)
        topk_values_exp = torch.cat((ones, topk_values_exp), dim=-1)

        # Gather matrix values and apply similarity weighting 
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()    
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime

        prime = prime.view(b, c, -1)
        return prime
    
    def _add_coordinate_encoding(self, x):
        b, c, t = x.shape 
        cache_key = f"{b}_{t}_{x.device}"
        if cache_key in self.coordinate_cache: 
            expanded_coords = self.coordinate_cache[cache_key]
        else: 
            coords_vec = torch.linspace(start=-1, end=1, steps=t, device=x.device).unsqueeze(0).expand(b, -1) 
            expanded_coords = coords_vec.unsqueeze(1).expand(b, -1, -1) 
            self.coordinate_cache[cache_key] = expanded_coords

        x_with_coords = torch.cat([x, expanded_coords], dim=1) 
        return x_with_coords

In [4]:
convnn= MultiHeadConvNNAttention(
    d_hidden=d_hidden,
    num_heads=8,
    attention_dropout=0.1,
    K=4,
    sampling_type='all',  # 'all', 'random', 'spatial'
    num_samples=-1, 
    sample_padding=0, 
    magnitude_type='euclidean',  # 'euclidean' or 'cosine'
    coordinate_encoding=True
)
ex = convnn(x)
print("Output shape:", ex.shape)  # Expected: (B, seq_length, d

Input shape: torch.Size([4, 197, 768])
k shape after split and combine: torch.Size([32, 96, 197])
k shape after encoding: torch.Size([32, 97, 197])
q shape:  torch.Size([32, 97, 197])
similarity_marix:  torch.Size([32, 197, 197])
prime:  torch.Size([32, 97, 788])
Output shape: torch.Size([4, 197, 768])


In [5]:
convnn= MultiHeadConvNNAttention(
    d_hidden=d_hidden,
    num_heads=8,
    attention_dropout=0.1,
    K=4,
    sampling_type='random',  # 'all', 'random', 'spatial'
    num_samples=30, 
    sample_padding=0, 
    magnitude_type='euclidean',  # 'euclidean' or 'cosine'
    coordinate_encoding=True
)
ex = convnn(x)
print("Output shape:", ex.shape)  # Expected: (B, seq_length, d

Input shape: torch.Size([4, 197, 768])
k shape after split and combine: torch.Size([32, 96, 197])
k shape after encoding: torch.Size([32, 97, 197])
x sample shape:  torch.Size([4, 30, 768])
q shape:  torch.Size([32, 97, 30])
similarity_marix:  torch.Size([32, 197, 30])
prime:  torch.Size([32, 97, 788])
Output shape: torch.Size([4, 197, 768])


In [6]:
convnn= MultiHeadConvNNAttention(
    d_hidden=d_hidden,
    num_heads=8,
    attention_dropout=0.1,
    K=4,
    sampling_type='spatial',  # 'all', 'random', 'spatial'
    num_samples=30, 
    sample_padding=0, 
    magnitude_type='euclidean',  # 'euclidean' or 'cosine'
    coordinate_encoding=True
)
ex = convnn(x)
print("Output shape:", ex.shape)  # Expected: (B, seq_length, d

Input shape: torch.Size([4, 197, 768])
k shape after split and combine: torch.Size([32, 96, 197])
k shape after encoding: torch.Size([32, 97, 197])
x sample shape:  torch.Size([4, 30, 768])
q shape:  torch.Size([32, 97, 30])
similarity_marix:  torch.Size([32, 197, 30])
prime:  torch.Size([32, 97, 788])
Output shape: torch.Size([4, 197, 768])


## Remove batch split, batch combine

In [7]:

class MultiHeadConvNNAttention_NoBatchSplit(nn.Module):
    def __init__(self, 
                 d_hidden, 
                 num_heads, 
                 attention_dropout,
                 K, 
                 sampling_type, 
                 num_samples, 
                 sample_padding, 
                 magnitude_type, 
                 seq_length=197, 
                 coordinate_encoding=False
                 ):
        
        super(MultiHeadConvNNAttention_NoBatchSplit, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"

        # Core Parameters
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.d_k = d_hidden // num_heads

        # ConvNN Parameters
        self.K = K
        self.seq_length = seq_length

        # 3 types of sampling: all, random, spatial
        self.sampling_type = sampling_type
        self.num_samples = int(num_samples) 
        self.sample_padding = int(sample_padding) if sampling_type == 'spatial' else 0    

        # Similarity Metric 
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'cosine' else False

        # Coordinate Encoding (optional) 
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {}
        
        # Linear projections for query, key, value
        self.W_q = nn.Linear(d_hidden, d_hidden)
        self.W_k = nn.Linear(d_hidden, d_hidden)
        self.W_v = nn.Linear(d_hidden, d_hidden)
        self.W_o = nn.Linear(d_hidden, d_hidden)   
        self.dropout = nn.Dropout(attention_dropout)

        # self.in_channels = (d_hidden // num_heads) + 1 if coordinate_encoding else d_hidden // num_heads
        # self.out_channels = (d_hidden // num_heads) 
        self.in_channels = d_hidden + 1 if coordinate_encoding else d_hidden
        self.out_channels = d_hidden
        
        self.conv = nn.Conv1d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.K,
            stride=self.K,
            padding=0,
        )

        # Utility Variables 
        self.INF = 1e5 
        self.NEG_INF = -1e5
        
    def split_head(self, x): 
        batch_size, seq_length, d_hidden = x.size()
        self.batch_size = batch_size
        # self.seq_length = seq_length
        return x.contiguous().view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_length, d_k)
        
    def combine_heads(self, x): 
        
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden) 
    
    def batch_split(self, x): 
        x = x.reshape(self.batch_size, -1, self.d_k, self.seq_length)
        return x.permute(0, 1, 3, 2).contiguous()
        
    def batch_combine(self, x): 
        batch_size, _, seq_length, d_k = x.size()
        x = x.permute(0, 1, 3, 2).contiguous() 
        return x.view(-1, self.d_k, seq_length)
        
    def forward(self, x):
        # Note: x shape: (B, seq_length, d_hidden)

        # 1. Splithead & Batch Combine
        k = self.W_k(x) 
        v = self.W_v(x) 

        print("k shape after linear:", k.shape)
        print("v shape after linear:", v.shape)
        k = k.transpose(1, 2)   
        v = v.transpose(1, 2)
        print("k shape after transpose:", k.shape)
        print("v shape after transpose:", v.shape)

        # k = self.batch_combine(self.split_head(k))
        # v = self.batch_combine(self.split_head(v))


        # 2. Add Coordinate Encoding 
        k = self._add_coordinate_encoding(k) if self.coordinate_encoding else k
        v = self._add_coordinate_encoding(v) if self.coordinate_encoding else v


        # 3. Sampling & Similarity Calculation
        if self.sampling_type == 'all': # All Samples
            # q = self.batch_combine(self.split_head(self.W_q(x)))
            q = self.W_q(x)
            q = q.transpose(1, 2)
            
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix(k, q, sqrt=True)
            prime = self._prime(v, similarity_matrix, self.K, self.maximum)

        elif self.sampling_type == 'random': # Random Samples
            rand_idx = torch.randperm(x.shape[1], device=x.device)[:self.num_samples]
            x_sample = x[:, rand_idx, :]            
            q = self.batch_combine(self.split_head(self.W_q(x_sample)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix_N(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix_N(k, q, sqrt=True)
            range_idx = torch.arange(len(rand_idx), device=q.device)
            similarity_matrix[:, rand_idx, range_idx] = self.INF if self.magnitude_type == 'euclidean' else self.NEG_INF
            prime = self._prime_N(v, similarity_matrix, self.K, rand_idx, self.maximum)

        elif self.sampling_type == 'spatial': # Spatial Samples
            spat_idx = torch.linspace(0 + self.sample_padding, x.shape[1] - self.sample_padding - 1, self.num_samples, device=x.device).long()
            x_sample = x[:, spat_idx, :]
            q = self.batch_combine(self.split_head(self.W_q(x_sample)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix_N(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix_N(k, q, sqrt=True)
            range_idx = torch.arange(len(spat_idx), device=q.device)
            similarity_matrix[:, spat_idx, range_idx] = self.INF if self.magnitude_type == 'euclidean' else self.NEG_INF
            prime = self._prime_N(v, similarity_matrix, self.K, spat_idx, self.maximum)
            
        else: 
            raise ValueError("Invalid sampling_type. Must be one of ['all', 'random', 'spatial']")

        # 4. Conv1d Layer
        x = self.conv(prime)  

        # 5. Dropout + Reshape (B, seq_length, d_hidden)
        x = self.dropout(x)
        x = x.permute(0, 2, 1) 

        # 6. Final Linear Projection
        x = self.W_o(x)
        return x       

    def _calculate_euclidean_matrix(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        torch.diagonal(dist_matrix, dim1=1, dim2=2).fill_(-0.1)  # Fill diagonal with -0.1 to avoid self-selection
        return dist_matrix 

    def _calculate_euclidean_matrix_N(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        return dist_matrix 

    def _calculate_cosine_matrix(self, K, Q):
        k_norm = F.normalize(K, p=2, dim=1)
        q_norm = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(k_norm.transpose(1, 2), q_norm)
        torch.diagonal(similarity_matrix, dim1=1, dim2=2).fill_(1.1)  # Fill diagonal with 1.1 to self-select
        return similarity_matrix

    def _calculate_cosine_matrix_N(self, K, Q):
        norm_k = F.normalize(K, p=2, dim=1)
        norm_q = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(norm_k.transpose(1, 2), norm_q)
        return similarity_matrix

    def _prime(self, v, qk, K, maximum):
        b, c, t = v.shape
        topk_values, topk_indices = torch.topk(qk, k=K, dim=2, largest=maximum)
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)

        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime 

        prime = prime.view(b, c, -1)

        return prime

    def _prime_N(self, v, qk, K, rand_idx, maximum):
        b, c, t = v.shape
        topk_values, topk_indices = torch.topk(qk, k=K-1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        # Map sample indicies back to original matrix positions 
        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=v.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=-1)
        topk_indices_exp = final_indices.unsqueeze(1).expand(b, c, t, K)

        # Expand topk values to match the shape of indices
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K-1)
        ones = torch.ones((b, c, t, 1), device=v.device)
        topk_values_exp = torch.cat((ones, topk_values_exp), dim=-1)

        # Gather matrix values and apply similarity weighting 
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()    
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime

        prime = prime.view(b, c, -1)
        return prime
    
    def _add_coordinate_encoding(self, x):
        b, c, t = x.shape 
        cache_key = f"{b}_{t}_{x.device}"
        if cache_key in self.coordinate_cache: 
            expanded_coords = self.coordinate_cache[cache_key]
        else: 
            coords_vec = torch.linspace(start=-1, end=1, steps=t, device=x.device).unsqueeze(0).expand(b, -1) 
            expanded_coords = coords_vec.unsqueeze(1).expand(b, -1, -1) 
            self.coordinate_cache[cache_key] = expanded_coords

        x_with_coords = torch.cat([x, expanded_coords], dim=1) 
        return x_with_coords 

In [8]:
# Test both implementations
d_hidden = 768
num_heads = 8
seq_length = 197
batch_size = 4

# Create input
x = torch.randn(batch_size, seq_length, d_hidden)
print("Input shape:", x.shape)

Input shape: torch.Size([4, 197, 768])


In [9]:
convnn= MultiHeadConvNNAttention_NoBatchSplit(
    d_hidden=d_hidden,
    num_heads=8,
    attention_dropout=0.1,
    K=4,
    sampling_type='all',  # 'all', 'random', 'spatial'
    num_samples=-1, 
    sample_padding=0, 
    magnitude_type='euclidean',  # 'euclidean' or 'cosine'
    coordinate_encoding=False
)
ex = convnn(x)
print("Output shape:", ex.shape)  # Expected: (B, seq_length, d

k shape after linear: torch.Size([4, 197, 768])
v shape after linear: torch.Size([4, 197, 768])
k shape after transpose: torch.Size([4, 768, 197])
v shape after transpose: torch.Size([4, 768, 197])
Output shape: torch.Size([4, 197, 768])


## Oct 7, 2025
### Figure out a way to add number of heads to the ConvNN Attention layer similar to the MultiHeadAttention layer 

In [10]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np

#### 1. MultiheadAttention

In [11]:

"""Multi-Head Layers for Transformer Encoder"""
class MultiHeadAttention(nn.Module): 
    def __init__(self, d_hidden, num_heads, attention_dropout):
        super(MultiHeadAttention, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"
        
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.d_k = d_hidden // num_heads # dimension of each head
        self.dropout = nn.Dropout(attention_dropout)
        
        self.W_q = nn.Linear(d_hidden, d_hidden)
        self.W_k = nn.Linear(d_hidden, d_hidden)
        self.W_v = nn.Linear(d_hidden, d_hidden)
        self.W_o = nn.Linear(d_hidden, d_hidden)        
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        print()
        print("[Inside scaled_dot_product_attention]")
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        print("attn_scores shape:", attn_scores.shape)
        
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        attn_probs = self.dropout(torch.softmax(attn_scores, dim=-1))
        print("attn_probs shape:", attn_probs.shape)
        print("V shape:", V.shape)
        output = torch.matmul(attn_probs, V)
        print("output shape:", output.shape)
        return output, attn_probs
    
    def split_head(self, x): 
        batch_size, seq_length, d_hidden = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_length, d_k)
        
    def combine_heads(self, x): 
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden) 
    
    def forward(self, x, mask=None):
        print("[Inside MultiHeadAttention forward]")
        q = self.split_head(self.W_q(x)) # (B, num_heads, seq_length, d_k)
        k = self.split_head(self.W_k(x))
        v = self.split_head(self.W_v(x))
        print("original x shape:", x.shape)
        print("kqv shape:", k.shape, q.shape, v.shape)
        
        attn_output, _ = self.scaled_dot_product_attention(q, k, v, mask) # (B, num_heads, seq_length, d_k)
        output = self.W_o(self.combine_heads(attn_output)) # (B, seq_length, d_hidden)
        return output


In [12]:
# Test both implementations
d_hidden = 768
num_heads = 3
seq_length = 197
batch_size = 4

# Create input
x = torch.randn(batch_size, seq_length, d_hidden)
print("Input shape:", x.shape)
print("-"*50)

attention = MultiHeadAttention(d_hidden=d_hidden, num_heads=num_heads, attention_dropout=0.1)
out = attention(x)
print("-"*50)
print("Output shape:", out.shape)  # Expected: (B, seq_length, d_hidden)


Input shape: torch.Size([4, 197, 768])
--------------------------------------------------
[Inside MultiHeadAttention forward]
original x shape: torch.Size([4, 197, 768])
kqv shape: torch.Size([4, 3, 197, 256]) torch.Size([4, 3, 197, 256]) torch.Size([4, 3, 197, 256])

[Inside scaled_dot_product_attention]
attn_scores shape: torch.Size([4, 3, 197, 197])
attn_probs shape: torch.Size([4, 3, 197, 197])
V shape: torch.Size([4, 3, 197, 256])
output shape: torch.Size([4, 3, 197, 256])
--------------------------------------------------
Output shape: torch.Size([4, 197, 768])


#### 2. Original ConvNN Attention

In [13]:
# Working implementation 
class MultiHeadConvNNAttention(nn.Module):
    def __init__(self, 
                 d_hidden, 
                 num_heads, 
                 attention_dropout,
                 K, 
                 sampling_type, 
                 num_samples, 
                 sample_padding, 
                 magnitude_type, 
                 seq_length=197, 
                 coordinate_encoding=False
                 ):
        
        super(MultiHeadConvNNAttention, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"

        # Core Parameters
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.d_k = d_hidden // num_heads

        # ConvNN Parameters
        self.K = K
        self.seq_length = seq_length

        # 3 types of sampling: all, random, spatial
        self.sampling_type = sampling_type
        self.num_samples = int(num_samples) 
        self.sample_padding = int(sample_padding) if sampling_type == 'spatial' else 0    

        # Similarity Metric 
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'cosine' else False

        # Coordinate Encoding (optional) 
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {}
        
        # Linear projections for query, key, value
        self.W_q = nn.Linear(d_hidden, d_hidden)
        self.W_k = nn.Linear(d_hidden, d_hidden)
        self.W_v = nn.Linear(d_hidden, d_hidden)
        self.W_o = nn.Linear(d_hidden, d_hidden)   
        self.dropout = nn.Dropout(attention_dropout)

        self.in_channels = (d_hidden // num_heads) + 1 if coordinate_encoding else (d_hidden // num_heads)
        self.out_channels = (d_hidden // num_heads) 
        
        self.conv = nn.Conv1d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.K,
            stride=self.K,
            padding=0,
        )

        # Utility Variables 
        self.INF = 1.1
        self.NEG_INF = -0.1 
        
    def split_head(self, x): 
        batch_size, seq_length, d_hidden = x.size()
        self.batch_size = batch_size
        # self.seq_length = seq_length
        return x.contiguous().view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_length, d_k)
        
    def combine_heads(self, x): 
        
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden) 
    
    def batch_split(self, x): 
        x = x.reshape(self.batch_size, -1, self.d_k, self.seq_length)
        return x.permute(0, 1, 3, 2).contiguous()
        
    def batch_combine(self, x): 
        batch_size, _, seq_length, d_k = x.size()
        x = x.permute(0, 1, 3, 2).contiguous() 
        return x.view(-1, self.d_k, seq_length)
        
    def forward(self, x):
        # Note: x shape: (B, seq_length, d_hidden)
        # 1. Splithead & Batch Combine
        k = self.batch_combine(self.split_head(self.W_k(x)))
        v = self.batch_combine(self.split_head(self.W_v(x)))
        
        # k = self.batch_combine(self.split_head(x))
        # v = self.batch_combine(self.split_head(x))

        # 2. Add Coordinate Encoding 
        k = self._add_coordinate_encoding(k) if self.coordinate_encoding else k
        v = self._add_coordinate_encoding(v) if self.coordinate_encoding else v


        # 3. Sampling & Similarity Calculation
        if self.sampling_type == 'all': # All Samples
            q = self.batch_combine(self.split_head(self.W_q(x)))
            # q = self.batch_combine(self.split_head(x))
            
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix(k, q, sqrt=True)

            # similarity_matrix = torch.softmax(similarity_matrix, dim=-1)
            
            prime = self._prime(v, similarity_matrix, self.K, self.maximum)
            # prime = self._prime_temperature(v, similarity_matrix, self.K, self.maximum, temperature=1) ## New Prime with Temperature Scaling

        elif self.sampling_type == 'random': # Random Samples
            rand_idx = torch.randperm(x.shape[1], device=x.device)[:self.num_samples]
            x_sample = x[:, rand_idx, :]            
            q = self.batch_combine(self.split_head(self.W_q(x_sample)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix_N(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix_N(k, q, sqrt=True)
            range_idx = torch.arange(len(rand_idx), device=q.device)
            similarity_matrix[:, rand_idx, range_idx] = self.INF if self.magnitude_type == 'euclidean' else self.NEG_INF

            # similarity_matrix = torch.softmax(similarity_matrix, dim=-1)

            prime = self._prime_N(v, similarity_matrix, self.K, rand_idx, self.maximum)

        elif self.sampling_type == 'spatial': # Spatial Samples
            spat_idx = torch.linspace(0 + self.sample_padding, x.shape[1] - self.sample_padding - 1, self.num_samples, device=x.device).long()
            x_sample = x[:, spat_idx, :]
            q = self.batch_combine(self.split_head(self.W_q(x_sample)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix_N(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix_N(k, q, sqrt=True)
            range_idx = torch.arange(len(spat_idx), device=q.device)
            similarity_matrix[:, spat_idx, range_idx] = self.INF if self.magnitude_type == 'euclidean' else self.NEG_INF

            # similarity_matrix = torch.softmax(similarity_matrix, dim=-1)

            prime = self._prime_N(v, similarity_matrix, self.K, spat_idx, self.maximum)
            
        else: 
            raise ValueError("Invalid sampling_type. Must be one of ['all', 'random', 'spatial']")

        # 4. Conv1d Layer
        x = self.conv(prime)  

        # 5. Dropout + Reshape (B, seq_length, d_hidden)
        x = self.dropout(x)
        x = x.permute(0, 2, 1) 

        # 6. Final Linear Projection
        x = self.W_o(self.combine_heads(self.batch_split(x)))
        return x       

    def _calculate_euclidean_matrix(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        torch.diagonal(dist_matrix, dim1=1, dim2=2).fill_(-0.1) 
        return dist_matrix 

    def _calculate_euclidean_matrix_N(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        return dist_matrix 

    def _calculate_cosine_matrix(self, K, Q):
        k_norm = F.normalize(K, p=2, dim=1)
        q_norm = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(k_norm.transpose(1, 2), q_norm)
        torch.diagonal(similarity_matrix, dim1=1, dim2=2).fill_(1.1)  # Fill diagonal with 1.1 to self-select
        return similarity_matrix

    def _calculate_cosine_matrix_N(self, K, Q):
        norm_k = F.normalize(K, p=2, dim=1)
        norm_q = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(norm_k.transpose(1, 2), norm_q)
        similarity_matrix = torch.softmax(similarity_matrix, dim=-1)
        return similarity_matrix

    def _prime(self, v, qk, K, maximum):
        b, c, t = v.shape
        topk_values, topk_indices = torch.topk(qk, k=K, dim=2, largest=maximum)
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)

        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime 

        prime = prime.view(b, c, -1)

        return prime

    def _prime_temperature(self, v, qk, K, maximum, temperature=1.0):
        b, c, t = v.shape

        # Get top-k values and indices
        topk_values, topk_indices = torch.topk(qk, k=K, dim=2, largest=maximum)

        # Normalize the top-k values to create attention weights
        if maximum:  # Cosine similarity
            topk_weights = F.softmax(topk_values / temperature, dim=-1)
        else:  # Euclidean distance
            topk_weights = F.softmax(-topk_values / temperature, dim=-1)

        # Expand for gathering
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        topk_weights_exp = topk_weights.unsqueeze(1).expand(b, c, t, K)

        # Gather and weight
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K)
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = prime * topk_weights_exp  # Now using normalized weights

        return prime.view(b, c, -1)

    def _prime_N(self, v, qk, K, rand_idx, maximum):
        b, c, t = v.shape
        topk_values, topk_indices = torch.topk(qk, k=K-1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        # Map sample indicies back to original matrix positions 
        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=v.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=-1)
        topk_indices_exp = final_indices.unsqueeze(1).expand(b, c, t, K)

        # Expand topk values to match the shape of indices
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K-1)
        ones = torch.ones((b, c, t, 1), device=v.device)
        topk_values_exp = torch.cat((ones, topk_values_exp), dim=-1)

        # Gather matrix values and apply similarity weighting 
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()    
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime

        prime = prime.view(b, c, -1)
        return prime
    
    def _add_coordinate_encoding(self, x):
        b, c, t = x.shape 
        cache_key = f"{b}_{t}_{x.device}"
        if cache_key in self.coordinate_cache: 
            expanded_coords = self.coordinate_cache[cache_key]
        else: 
            coords_vec = torch.linspace(start=-1, end=1, steps=t, device=x.device).unsqueeze(0).expand(b, -1) 
            expanded_coords = coords_vec.unsqueeze(1).expand(b, -1, -1) 
            self.coordinate_cache[cache_key] = expanded_coords

        x_with_coords = torch.cat([x, expanded_coords], dim=1) 
        return x_with_coords 


In [14]:
# Test both implementations
d_hidden = 768
num_heads = 3
seq_length = 197
batch_size = 4

# Create input
x = torch.randn(batch_size, seq_length, d_hidden)
print("Input shape:", x.shape)
print("-"*50)

convnn_params = {
    "K": 9,
    "sampling_type": 'all',  # 'all', 'random', 'spatial'
    "num_samples": -1, 
    "sample_padding": 0, 
    "magnitude_type": 'cosine',  # 'euclidean' or 'cosine'
    "coordinate_encoding": False
}

ConvNN = MultiHeadConvNNAttention(d_hidden=d_hidden, num_heads=num_heads, attention_dropout=0.1, **convnn_params)
out = ConvNN(x)
print("-"*50)
print("Output shape:", out.shape)  # Expected: (B, seq_length, d_hidden)


Input shape: torch.Size([4, 197, 768])
--------------------------------------------------
--------------------------------------------------
Output shape: torch.Size([4, 197, 768])


#### Modified ConvNN Attention with change to number of heads 

In [15]:
# Working implementation 
class MultiHeadConvNNAttention_Modified(nn.Module):
    def __init__(self, 
                 d_hidden, 
                 num_heads, 
                 attention_dropout,
                 K, 
                 sampling_type, 
                 num_samples, 
                 sample_padding, 
                 magnitude_type, 
                 seq_length=197, 
                 coordinate_encoding=False
                 ):
        
        super(MultiHeadConvNNAttention_Modified, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"

        # Core Parameters
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.d_k = d_hidden // num_heads

        # ConvNN Parameters
        self.K = K
        self.seq_length = seq_length

        # 3 types of sampling: all, random, spatial
        self.sampling_type = sampling_type
        self.num_samples = int(num_samples) 
        self.sample_padding = int(sample_padding) if sampling_type == 'spatial' else 0    

        # Similarity Metric 
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'cosine' else False

        # Coordinate Encoding (optional) 
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {}

        # Change out_features of V projection 
        self.v_out_features = self.d_k // self.K
        print(self.v_out_features)
        
        # Linear projections for query, key, value
        self.W_q = nn.Linear(d_hidden, d_hidden)
        self.W_k = nn.Linear(d_hidden, d_hidden)
        self.W_v = nn.Linear(d_hidden, self.v_out_features)
        self.W_o = nn.Linear(d_hidden, d_hidden)   
        self.dropout = nn.Dropout(attention_dropout)

        # self.in_channels = (d_hidden // num_heads) + 1 if coordinate_encoding else (d_hidden // num_heads)
        self.in_channels = (self.v_out_features) + 1 if coordinate_encoding else (self.v_out_features)
        self.out_channels = (d_hidden) 
        
        self.conv = nn.Conv1d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.K,
            stride=self.K,
            padding=0,
        )

        # Utility Variables 
        self.INF = 1.1
        self.NEG_INF = -0.1 
        
    def forward(self, x):
        # Note: x shape: (B, seq_length, d_hidden)
        # 1. Splithead & Batch Combine
        k = self.W_k(x)
        v = self.W_v(x)

        print("k shape after W_k:", k.shape)
        print("v shape after W_v:", v.shape)

        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        print("k shape after transpose:", k.shape)
        print("v shape after transpose:", v.shape)


        # 2. Add Coordinate Encoding 
        k = self._add_coordinate_encoding(k) if self.coordinate_encoding else k
        v = self._add_coordinate_encoding(v) if self.coordinate_encoding else v


        # 3. Sampling & Similarity Calculation
        if self.sampling_type == 'all': # All Samples
            q = self.W_q(x)
            print("q shape after W_q:", q.shape)
            q = q.transpose(1, 2)
            print("q shape after transpose:", q.shape)

            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix(k, q, sqrt=True)
            print()
            print("similarity_matrix shape:", similarity_matrix.shape)

            # similarity_matrix = torch.softmax(similarity_matrix, dim=-1)
            
            prime = self._prime(v, similarity_matrix, self.K, self.maximum)
            print("prime shape after _prime:", prime.shape)
            # prime = self._prime_temperature(v, similarity_matrix, self.K, self.maximum, temperature=1) ## New Prime with Temperature Scaling

        elif self.sampling_type == 'random': # Random Samples
            rand_idx = torch.randperm(x.shape[1], device=x.device)[:self.num_samples]
            x_sample = x[:, rand_idx, :]            
            q = self.batch_combine(self.split_head(self.W_q(x_sample)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix_N(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix_N(k, q, sqrt=True)
            range_idx = torch.arange(len(rand_idx), device=q.device)
            similarity_matrix[:, rand_idx, range_idx] = self.INF if self.magnitude_type == 'euclidean' else self.NEG_INF

            # similarity_matrix = torch.softmax(similarity_matrix, dim=-1)

            prime = self._prime_N(v, similarity_matrix, self.K, rand_idx, self.maximum)

        elif self.sampling_type == 'spatial': # Spatial Samples
            spat_idx = torch.linspace(0 + self.sample_padding, x.shape[1] - self.sample_padding - 1, self.num_samples, device=x.device).long()
            x_sample = x[:, spat_idx, :]
            q = self.batch_combine(self.split_head(self.W_q(x_sample)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix_N(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix_N(k, q, sqrt=True)
            range_idx = torch.arange(len(spat_idx), device=q.device)
            similarity_matrix[:, spat_idx, range_idx] = self.INF if self.magnitude_type == 'euclidean' else self.NEG_INF

            # similarity_matrix = torch.softmax(similarity_matrix, dim=-1)

            prime = self._prime_N(v, similarity_matrix, self.K, spat_idx, self.maximum)
            
        else: 
            raise ValueError("Invalid sampling_type. Must be one of ['all', 'random', 'spatial']")
        print()
        # 4. Conv1d Layer
        x = self.conv(prime)  
        print("x shape after conv:", x.shape)

        # 5. Dropout + Reshape (B, seq_length, d_hidden)
        x = self.dropout(x)
        x = x.permute(0, 2, 1) 
        print("x shape after permute:", x.shape)

        # 6. Final Linear Projection
        
        x = self.W_o(x)
        print("x shape after W_o:", x.shape)
        print()
        return x       

    def _calculate_euclidean_matrix(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        torch.diagonal(dist_matrix, dim1=1, dim2=2).fill_(-0.1) 
        return dist_matrix 

    def _calculate_euclidean_matrix_N(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        return dist_matrix 

    def _calculate_cosine_matrix(self, K, Q):
        k_norm = F.normalize(K, p=2, dim=1)
        q_norm = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(k_norm.transpose(1, 2), q_norm)
        torch.diagonal(similarity_matrix, dim1=1, dim2=2).fill_(1.1)  # Fill diagonal with 1.1 to self-select
        return similarity_matrix

    def _calculate_cosine_matrix_N(self, K, Q):
        norm_k = F.normalize(K, p=2, dim=1)
        norm_q = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(norm_k.transpose(1, 2), norm_q)
        similarity_matrix = torch.softmax(similarity_matrix, dim=-1)
        return similarity_matrix

    def _prime(self, v, qk, K, maximum):
        print("[Inside _prime]")
        b, c, t = v.shape
    
        print("v shape:", v.shape)

        topk_values, topk_indices = torch.topk(qk, k=K, dim=2, largest=maximum)
        print("topk_values shape:", topk_values.shape)
        print("topk_indices shape:", topk_indices.shape)
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)

        print("topk_indices_exp shape:", topk_indices_exp.shape)
        print("topk_values_exp shape:", topk_values_exp.shape)

        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()
        print("v_expanded shape:", v_expanded.shape)
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        print("prime shape after gather:", prime.shape)
        prime = topk_values_exp * prime 

        prime = prime.view(b, c, -1)

        return prime

    def _prime_temperature(self, v, qk, K, maximum, temperature=1.0):
        b, c, t = v.shape

        # Get top-k values and indices
        topk_values, topk_indices = torch.topk(qk, k=K, dim=2, largest=maximum)

        # Normalize the top-k values to create attention weights
        if maximum:  # Cosine similarity
            topk_weights = F.softmax(topk_values / temperature, dim=-1)
        else:  # Euclidean distance
            topk_weights = F.softmax(-topk_values / temperature, dim=-1)

        # Expand for gathering
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        topk_weights_exp = topk_weights.unsqueeze(1).expand(b, c, t, K)

        # Gather and weight
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K)
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = prime * topk_weights_exp  # Now using normalized weights

        return prime.view(b, c, -1)

    def _prime_N(self, v, qk, K, rand_idx, maximum):
        b, c, t = v.shape
        topk_values, topk_indices = torch.topk(qk, k=K-1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        # Map sample indicies back to original matrix positions 
        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=v.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=-1)
        topk_indices_exp = final_indices.unsqueeze(1).expand(b, c, t, K)

        # Expand topk values to match the shape of indices
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K-1)
        ones = torch.ones((b, c, t, 1), device=v.device)
        topk_values_exp = torch.cat((ones, topk_values_exp), dim=-1)

        # Gather matrix values and apply similarity weighting 
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()    
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime

        prime = prime.view(b, c, -1)
        return prime
    
    def _add_coordinate_encoding(self, x):
        b, c, t = x.shape 
        cache_key = f"{b}_{t}_{x.device}"
        if cache_key in self.coordinate_cache: 
            expanded_coords = self.coordinate_cache[cache_key]
        else: 
            coords_vec = torch.linspace(start=-1, end=1, steps=t, device=x.device).unsqueeze(0).expand(b, -1) 
            expanded_coords = coords_vec.unsqueeze(1).expand(b, -1, -1) 
            self.coordinate_cache[cache_key] = expanded_coords

        x_with_coords = torch.cat([x, expanded_coords], dim=1) 
        return x_with_coords 


In [16]:
# Test both implementations
d_hidden = 768
num_heads = 3
seq_length = 197
batch_size = 4

# Create input
x = torch.randn(batch_size, seq_length, d_hidden)
print("Input shape:", x.shape)
print("-"*50)

convnn_params = {
    "K": 9,
    "sampling_type": 'all',  # 'all', 'random', 'spatial'
    "num_samples": -1, 
    "sample_padding": 0, 
    "magnitude_type": 'cosine',  # 'euclidean' or 'cosine'
    "coordinate_encoding": False
}

ConvNN = MultiHeadConvNNAttention_Modified(d_hidden=d_hidden, num_heads=num_heads, attention_dropout=0.1, **convnn_params)
out = ConvNN(x)
print("-"*50)
print("Output shape:", out.shape)  # Expected: (B, seq_length, d_hidden)


Input shape: torch.Size([4, 197, 768])
--------------------------------------------------
28
k shape after W_k: torch.Size([4, 197, 768])
v shape after W_v: torch.Size([4, 197, 28])
k shape after transpose: torch.Size([4, 768, 197])
v shape after transpose: torch.Size([4, 28, 197])
q shape after W_q: torch.Size([4, 197, 768])
q shape after transpose: torch.Size([4, 768, 197])

similarity_matrix shape: torch.Size([4, 197, 197])
[Inside _prime]
v shape: torch.Size([4, 28, 197])
topk_values shape: torch.Size([4, 197, 9])
topk_indices shape: torch.Size([4, 197, 9])
topk_indices_exp shape: torch.Size([4, 28, 197, 9])
topk_values_exp shape: torch.Size([4, 28, 197, 9])
v_expanded shape: torch.Size([4, 28, 197, 9])
prime shape after gather: torch.Size([4, 28, 197, 9])
prime shape after _prime: torch.Size([4, 28, 1773])

x shape after conv: torch.Size([4, 768, 197])
x shape after permute: torch.Size([4, 197, 768])
x shape after W_o: torch.Size([4, 197, 768])

-----------------------------------

# Oct 11 Make ConvNN Attention same as MultiheadAttention 

In [17]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
from einops import rearrange

### i. MultiheadAttention

In [18]:

"""Multi-Head Layers for Transformer Encoder"""
class MultiHeadAttention(nn.Module): 
    def __init__(self, d_hidden, num_heads, attention_dropout):
        super(MultiHeadAttention, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"
        
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.d_k = d_hidden // num_heads # dimension of each head
        self.dropout = nn.Dropout(attention_dropout)
        
        self.W_q = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_k = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_v = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_o = nn.Linear(d_hidden, d_hidden, bias=False)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.d_k)
        print("\n[Scaled Dot Product Attention]")
        print("attn_scores shape:", attn_scores.shape)
        print("attn_scores:", attn_scores)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        
        attn_probs = self.dropout(torch.softmax(attn_scores, dim=-1))
        print("attn_probs shape:", attn_probs.shape)
        print("attn_probs:", attn_probs)
        print("v shape: ", V.shape)
        print("v: ", V)
        output = torch.matmul(attn_probs, V)
        print("output shape:", output.shape)
        print("output:", output)
        return output, attn_probs
    
    def split_head(self, x): 
        batch_size, seq_length, d_hidden = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_length, d_k)
        
    def combine_heads(self, x): 
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden) 
    
    def forward(self, x, mask=None):
        q = self.split_head(self.W_q(x)) # (B, num_heads, seq_length, d_k)
        k = self.split_head(self.W_k(x))
        v = self.split_head(self.W_v(x))
        
        attn_output, _ = self.scaled_dot_product_attention(q, k, v, mask) # (B, num_heads, seq_length, d_k)
        output = self.W_o(self.combine_heads(attn_output)) # (B, seq_length, d_hidden)
        return output


In [19]:
d_hidden = 8 
num_heads = 1
seq_length = 6
batch_size = 1
dropout = 0.0 

x = torch.Tensor(
    [
        [
            [ 1,  2,  3,  4,  5,  6,  7,  8],
            [ 9, 10, 11, 12, 13, 14, 15, 16],
            [17, 18, 19, 20, 21, 22, 23, 24],
            [25, 26, 27, 28, 29, 30, 31, 32],
            [33, 34, 35, 36, 37, 38, 39, 40],
            [41, 42, 43, 44, 45, 46, 47, 48],
        ]
    ]
).to(torch.float32)

# x = torch.randint(0, 10, (batch_size, seq_length, d_hidden)).to(torch.float32)
# x = torch.randn(batch_size, seq_length, d_hidden).to(torch.float32)
print(x)

print(x.shape)

attention = MultiHeadAttention(d_hidden=d_hidden, num_heads=num_heads, attention_dropout=dropout)
attention.W_k.weight.data.fill_(1.0)
attention.W_q.weight.data.fill_(1.0)
attention.W_v.weight.data.fill_(1.0)
attention.W_o.weight.data.fill_(1.0)
out = attention(x)
print()


tensor([[[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.],
         [ 9., 10., 11., 12., 13., 14., 15., 16.],
         [17., 18., 19., 20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29., 30., 31., 32.],
         [33., 34., 35., 36., 37., 38., 39., 40.],
         [41., 42., 43., 44., 45., 46., 47., 48.]]])
torch.Size([1, 6, 8])

[Scaled Dot Product Attention]
attn_scores shape: torch.Size([1, 1, 6, 6])
attn_scores: tensor([[[[  3665.6416,  10182.3379,  16699.0332,  23215.7305,  29732.4258,
            36249.1211],
          [ 10182.3379,  28284.2715,  46386.2070,  64488.1406,  82590.0703,
           100692.0078],
          [ 16699.0332,  46386.2070,  76073.3750, 105760.5469, 135447.7188,
           165134.8906],
          [ 23215.7305,  64488.1406, 105760.5469, 147032.9531, 188305.3750,
           229577.7812],
          [ 29732.4258,  82590.0703, 135447.7188, 188305.3750, 241163.0156,
           294020.6562],
          [ 36249.1211, 100692.0078, 165134.8906, 229577.7812, 294020.6562,


### ii. ConvNN Attention 

In [20]:

class MultiHeadConvNNAttention(nn.Module):
    def __init__(self, 
                 d_hidden, 
                 num_heads, 
                 attention_dropout,
                 K, 
                 sampling_type, 
                 num_samples, 
                 sample_padding, 
                 magnitude_type, 
                 seq_length=197, 
                 coordinate_encoding=False
                 ):
        
        super(MultiHeadConvNNAttention, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"

        # Core Parameters
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.d_k = d_hidden // num_heads

        # ConvNN Parameters
        self.K = K
        self.seq_length = seq_length

        # 3 types of sampling: all, random, spatial
        self.sampling_type = sampling_type
        self.num_samples = int(num_samples) 
        self.sample_padding = int(sample_padding) if sampling_type == 'spatial' else 0    

        # Similarity Metric 
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'cosine' else False

        # Coordinate Encoding (optional) 
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {}
        
        # Linear projections for query, key, value
        self.W_q = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_k = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_v = nn.Linear(d_hidden, d_hidden, bias=False)
        self.W_o = nn.Linear(d_hidden, d_hidden, bias=False)
        self.dropout = nn.Dropout(attention_dropout)

        self.in_channels = (d_hidden // num_heads) + 1 if coordinate_encoding else (d_hidden // num_heads)
        self.out_channels = (d_hidden // num_heads) 
        
        self.conv = nn.Conv1d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.K,
            stride=self.K,
            padding=0,
            bias = False, 
            groups=self.in_channels
        )

        # Utility Variables 
        self.INF = 1.1
        self.NEG_INF = -0.1 
        
    def split_head(self, x): 
        batch_size, seq_length, d_hidden = x.size()
        self.batch_size = batch_size
        # self.seq_length = seq_length
        return x.contiguous().view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_length, d_k)
        
    def combine_heads(self, x): 
        
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden) 
    
    def batch_split(self, x): 
        x = x.reshape(self.batch_size, -1, self.d_k, self.seq_length)
        return x.permute(0, 1, 3, 2).contiguous()
        
    def batch_combine(self, x): 
        batch_size, _, seq_length, d_k = x.size()
        x = x.permute(0, 1, 3, 2).contiguous() 
        return x.view(-1, self.d_k, seq_length)
        
    def forward(self, x):
        # Note: x shape: (B, seq_length, d_hidden)
        # 1. Splithead & Batch Combine
        k = self.batch_combine(self.split_head(self.W_k(x)))
        v = self.batch_combine(self.split_head(self.W_v(x)))
        # v = self.batch_combine(self.split_head(x))
        
        # 3. Sampling & Similarity Calculation
        if self.sampling_type == 'all': # All Samples
            q = self.batch_combine(self.split_head(self.W_q(x)))

            similarity_matrix = self._calculate_attention_matrix(k, q)

            similarity_matrix = torch.softmax(similarity_matrix, dim=-1)
            
            prime = self._prime(v, similarity_matrix, self.K, self.maximum)

        # 4. Conv1d Layer
        x = self.conv(prime)  

        print("\n[After Convolution]")
        print("output size: ", x.shape) 
        print(x)

        # 5. Dropout + Reshape (B, seq_length, d_hidden)
        x = self.dropout(x)
        x = x.permute(0, 2, 1) 

        # 6. Final Linear Projection
        x = self.W_o(self.combine_heads(self.batch_split(x)))
        return x       
    def _calculate_attention_matrix(self, K, Q):
        attn_score = torch.matmul(K.transpose(1, 2), Q) / np.sqrt(self.d_k)
        return attn_score

    def _calculate_euclidean_matrix(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        torch.diagonal(dist_matrix, dim1=1, dim2=2).fill_(-0.1) 
        return dist_matrix 

    def _calculate_euclidean_matrix_N(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        return dist_matrix 

    def _calculate_cosine_matrix(self, K, Q):
        k_norm = F.normalize(K, p=2, dim=1)
        q_norm = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(k_norm.transpose(1, 2), q_norm)
        torch.diagonal(similarity_matrix, dim1=1, dim2=2).fill_(1.1)  # Fill diagonal with 1.1 to self-select
        return similarity_matrix

    def _calculate_cosine_matrix_N(self, K, Q):
        norm_k = F.normalize(K, p=2, dim=1)
        norm_q = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(norm_k.transpose(1, 2), norm_q)
        similarity_matrix = torch.softmax(similarity_matrix, dim=-1)
        return similarity_matrix

    def _prime(self, v, qk, K, maximum):
        print("Before _prime:")
        print("v shape:", v.shape)
        print("similarity_matrix shape:", qk.shape)
        print("similarity_matrix:", qk)
        b, c, t = v.shape
        
        print("Inside Prime")   
        print("v shape:", v.shape)
        print("K:", K)
        topk_values, topk_indices = torch.topk(qk, k=K, dim=2, largest=maximum)
        print("topk_indices shape:", topk_indices.shape)
        print("topk_indices:", topk_indices, "\n")
        print("topk_values shape:", topk_values.shape)
        print("topk_values:", topk_values)

        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)

        topk_values_exp = torch.softmax(topk_values_exp, dim=-1)
        print("topk_values_exp shape:", topk_values_exp.shape)
        print("topk_values_exp:", topk_values_exp)

        # #### SOFTMAX ON TOP-K VALUES ####
        # topk_values_exp = torch.softmax(topk_values_exp, dim=-1)        
        # # print(topk_values_exp.shape, topk_indices_exp.shape)

        print("\n\n")
        print("v unsqueeze shape: ", v.unsqueeze(1).shape)
        print("v unsqueeze(1): ", v.unsqueeze(1))
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        print("prime shape after gather:", prime.shape)
        print("prime after gather:", prime)
        print("topk_values_exp shape:", topk_values_exp.shape)
        print("topk_values_exp:", topk_values_exp)
        prime = topk_values_exp * prime 
        
        print("prime shape after weighting:", prime.shape)

        prime = prime.view(b, c, -1)
        print("prime shape after view:", prime.shape)
        print("prime after view:", prime)

        return prime


In [21]:
d_hidden = 8 
num_heads = 1
seq_length = 6
batch_size = 1
dropout = 0.0 

# x = torch.Tensor(
#     [
#         [
#             [ 1,  2,  3,  4,  5,  6,  7,  8],
#             [ 9, 10, 11, 12, 13, 14, 15, 16],
#             [17, 18, 19, 20, 21, 22, 23, 24],
#             [25, 26, 27, 28, 29, 30, 31, 32],
#             [33, 34, 35, 36, 37, 38, 39, 40],
#             [41, 42, 43, 44, 45, 46, 47, 48],
#         ]
#     ]
# ).to(torch.float32)

print(x.shape)

convnn = MultiHeadConvNNAttention(d_hidden=d_hidden, 
                                  num_heads=num_heads, 
                                  attention_dropout=dropout, 
                                  K=6, 
                                  sampling_type='all',
                                  num_samples=-1,
                                  sample_padding=0,
                                  magnitude_type='cosine',
                                  coordinate_encoding=False,
                                  seq_length=6
                                  )

convnn.W_k.weight.data.fill_(1.0)
convnn.W_q.weight.data.fill_(1.0)
convnn.W_v.weight.data.fill_(1.0)
convnn.W_o.weight.data.fill_(1.0)
convnn.conv.weight.data.fill_(1.0)
out = convnn(x)




torch.Size([1, 6, 8])
Before _prime:
v shape: torch.Size([1, 8, 6])
similarity_matrix shape: torch.Size([1, 6, 6])
similarity_matrix: tensor([[[0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 1.]]], grad_fn=<SoftmaxBackward0>)
Inside Prime
v shape: torch.Size([1, 8, 6])
K: 6
topk_indices shape: torch.Size([1, 6, 6])
topk_indices: tensor([[[5, 1, 2, 3, 4, 0],
         [5, 1, 2, 3, 4, 0],
         [5, 1, 2, 3, 4, 0],
         [5, 1, 2, 3, 4, 0],
         [5, 1, 2, 3, 4, 0],
         [5, 1, 2, 3, 4, 0]]]) 

topk_values shape: torch.Size([1, 6, 6])
topk_values: tensor([[[1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.]]], grad_fn=<TopkBackward0>)
topk_values_exp shape: torch.Size([1, 8, 6, 6])
topk

In [22]:

# difference: convnn out - attention = 4.1027