In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadConvNNAttentionEnhanced(nn.Module):
    """Enhanced version with head-specific sampling and coordinate encoding"""
    
    def __init__(self, d_hidden, num_heads, attention_dropout, K, sampling_type, 
                 num_samples, sample_padding, magnitude_type, seq_length=197, 
                 coordinate_encoding=False, diverse_sampling=True):
        super(MultiHeadConvNNAttentionEnhanced, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"
        
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.d_k = d_hidden // num_heads
        
        self.seq_length = seq_length
        self.K = K
        self.sampling_type = sampling_type
        self.num_samples = num_samples if num_samples != -1 else 'all'
        self.sample_padding = sample_padding if sampling_type == 'spatial' else 0    
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'similarity' else False
        self.diverse_sampling = diverse_sampling  # Enable head-specific sampling
        
        # Coordinate Encoding
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {}
        
        # Linear projections
        self.W_q = nn.Linear(d_hidden, d_hidden)
        self.W_k = nn.Linear(d_hidden, d_hidden)
        self.W_v = nn.Linear(d_hidden, d_hidden)
        self.W_o = nn.Linear(d_hidden, d_hidden)   
        self.dropout = nn.Dropout(attention_dropout)
        
        self.in_channels = (d_hidden // num_heads) + 1 if coordinate_encoding else d_hidden // num_heads
        self.out_channels = (d_hidden // num_heads) + 1 if coordinate_encoding else d_hidden // num_heads
        self.kernel_size = K
        self.stride = K
        
        # Shared Conv across heads
        self.conv = nn.Conv1d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=0,
        )
        
        # Pointwise conv for coordinate removal
        if coordinate_encoding:
            self.pointwise_conv = nn.Conv1d(
                in_channels=self.in_channels,
                out_channels=self.out_channels - 1,
                kernel_size=1
            )
        
        # Optional: Head mixing layer for cross-head interaction
        self.head_mixing = nn.Conv1d(d_hidden, d_hidden, 1, groups=1)
        self.mix_weight = nn.Parameter(torch.tensor([0.1]))  # Learnable mixing weight
        
    def get_head_specific_samples(self, seq_len, num_heads, device):
        """Generate different samples for each head"""
        if self.sampling_type == 'random':
            # Different random samples for each head
            all_indices = []
            for h in range(num_heads):
                # Add seed offset per head for reproducibility if needed
                rand_idx = torch.randperm(seq_len, device=device)[:self.num_samples]
                all_indices.append(rand_idx)
            return torch.stack(all_indices)  # (H, num_samples)
            
        elif self.sampling_type == 'spatial':
            # Different spatial patterns per head
            all_indices = []
            for h in range(num_heads):
                # Vary the spatial sampling pattern per head
                offset = h * (seq_len // (num_heads * 2))  # Stagger starting points
                start = (offset + self.sample_padding) % seq_len
                end = seq_len - self.sample_padding - 1
                
                # Ensure we don't go out of bounds
                if start >= end:
                    start = self.sample_padding
                    
                spat_idx = torch.linspace(start, end, self.num_samples, device=device).long()
                spat_idx = torch.clamp(spat_idx, 0, seq_len - 1)  # Safety clamp
                all_indices.append(spat_idx)
            return torch.stack(all_indices)  # (H, num_samples)
        
        else:
            raise ValueError(f"Invalid sampling_type: {self.sampling_type}")
    
    def split_head(self, x): 
        batch_size, seq_length, d_hidden = x.size()
        self.batch_size = batch_size
        return x.contiguous().view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x): 
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden) 
    
    def batch_split(self, x): 
        x = x.reshape(self.batch_size, -1, self.d_k, self.seq_length)
        return x.permute(0, 1, 3, 2).contiguous()
        
    def batch_combine(self, x): 
        batch_size, _, seq_length, d_k = x.size()
        x = x.permute(0, 1, 3, 2).contiguous() 
        return x.view(-1, self.d_k, seq_length)
    
    def _add_coordinate_encoding_multihead(self, x, head_offset=0):
        """Add coordinate encoding with head-specific patterns"""
        b, c, t = x.shape
        
        # Create head-specific coordinate patterns
        num_heads_in_batch = b // self.batch_size
        
        coords_list = []
        for i in range(b):
            head_idx = i % num_heads_in_batch
            
            # Vary coordinate encoding per head
            scale = 1.0 + 0.2 * (head_idx / num_heads_in_batch)  # Varying scales
            phase = 2 * torch.pi * head_idx / num_heads_in_batch  # Phase shift
            
            # Linear coordinates with head-specific transformation
            coords = torch.linspace(-scale, scale, t, device=x.device)
            
            # Optional: Add sinusoidal variation per head
            coords = coords + 0.1 * torch.sin(coords * torch.pi + phase)
            
            coords_list.append(coords)
        
        coords_tensor = torch.stack(coords_list).unsqueeze(1)  # (b, 1, t)
        x_with_coords = torch.cat((x, coords_tensor), dim=1)
        return x_with_coords
        
    def forward(self, x):
        # Apply linear projections
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)
        
        # Process based on sampling type
        if self.sampling_type == 'all':
            # All samples - standard processing
            q = self.batch_combine(self.split_head(q))
            k = self.batch_combine(self.split_head(k))
            v = self.batch_combine(self.split_head(v))
            
            # Add coordinate encoding
            if self.coordinate_encoding:
                q = self._add_coordinate_encoding_multihead(q)
                k = self._add_coordinate_encoding_multihead(k)
                v = self._add_coordinate_encoding_multihead(v)
            
            # ConvNN Algorithm
            if self.magnitude_type == 'distance':
                matrix_magnitude = self._calculate_distance_matrix(k, q, sqrt=True)
            else:
                matrix_magnitude = self._calculate_similarity_matrix(k, q)
                
            prime = self._prime(v, matrix_magnitude, self.K, self.maximum)
            
        elif self.sampling_type in ['random', 'spatial']:
            # Head-specific sampling
            seq_len = x.shape[1]
            
            if self.diverse_sampling:
                # Get different samples for each head
                sample_indices = self.get_head_specific_samples(
                    seq_len, self.num_heads, x.device
                )  # (H, num_samples)
            else:
                # Use same samples for all heads (original behavior)
                if self.sampling_type == 'random':
                    base_idx = torch.randperm(seq_len, device=x.device)[:self.num_samples]
                else:  # spatial
                    base_idx = torch.linspace(
                        self.sample_padding, 
                        seq_len - self.sample_padding - 1, 
                        self.num_samples, 
                        device=x.device
                    ).long()
                sample_indices = base_idx.unsqueeze(0).repeat(self.num_heads, 1)
            
            # Split heads first
            q_heads = self.split_head(q)  # (B, H, L, D/H)
            k_heads = self.split_head(k)  # (B, H, L, D/H)
            v_heads = self.split_head(v)  # (B, H, L, D/H)
            
            # Process each head with its specific samples
            prime_list = []
            for h in range(self.num_heads):
                idx = sample_indices[h]
                
                # Get head data
                q_h = q_heads[:, h, idx, :].transpose(1, 2)  # (B, D/H, num_samples)
                k_h = k_heads[:, h, :, :].transpose(1, 2)    # (B, D/H, L)
                v_h = v_heads[:, h, :, :].transpose(1, 2)    # (B, D/H, L)
                
                # Add coordinate encoding if needed
                if self.coordinate_encoding:
                    q_h = self._add_coordinate_encoding_multihead(q_h, head_offset=h)
                    k_h = self._add_coordinate_encoding_multihead(k_h, head_offset=h)
                    v_h = self._add_coordinate_encoding_multihead(v_h, head_offset=h)
                
                # Calculate magnitude matrix
                if self.magnitude_type == 'distance':
                    matrix_magnitude = self._calculate_distance_matrix_N(k_h, q_h, sqrt=True)
                else:
                    matrix_magnitude = self._calculate_similarity_matrix_N(k_h, q_h)
                
                # Set diagonal to inf/-inf
                range_idx = torch.arange(len(idx), device=x.device)
                matrix_magnitude[:, idx, range_idx] = float('inf') if self.magnitude_type == 'distance' else float('-inf')
                
                # Get prime
                prime_h = self._prime_N(v_h, matrix_magnitude, self.K, idx, self.maximum)
                prime_list.append(prime_h)
            
            # Stack all head primes
            prime = torch.cat(prime_list, dim=0)  # (B*H, D/H, L*K)
        
        # Apply convolution
        x = self.conv(prime)
        x = self.dropout(x)
        
        # Remove coordinate channel if needed
        if self.coordinate_encoding:
            x = self.pointwise_conv(x)
        
        # Reshape back to (B, L, D)
        x = self.combine_heads(self.batch_split(x.permute(0, 2, 1)))
        
        # Optional: Mix information across heads
        x_mixed = self.head_mixing(x.transpose(1, 2)).transpose(1, 2)
        x = x + self.mix_weight * x_mixed
        
        # Final output projection
        x = self.W_o(x)
        
        return x
    
    # Keep all the original calculation methods
    def _calculate_similarity_matrix(self, K, Q):
        k_norm = F.normalize(K, p=2, dim=1)
        q_norm = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.bmm(k_norm.transpose(2, 1), q_norm) 
        similarity_matrix = torch.clamp(similarity_matrix, min=0)  
        return similarity_matrix
    
    def _calculate_similarity_matrix_N(self, K, Q):
        k_norm = F.normalize(K, p=2, dim=1)
        q_norm = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.bmm(k_norm.transpose(2, 1), q_norm) 
        similarity_matrix = torch.clamp(similarity_matrix, min=0) 
        return similarity_matrix

    def _calculate_distance_matrix(self, K, Q, sqrt=False):
        norm_squared_K = torch.sum(K**2, dim=1, keepdim=True) 
        norm_squared_Q = torch.sum(Q**2, dim=1, keepdim=True) 
        dot_product = torch.bmm(K.transpose(2, 1), Q)  
        dist_matrix = norm_squared_K + norm_squared_Q.transpose(2, 1) - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        return dist_matrix

    def _calculate_distance_matrix_N(self, K, Q, sqrt=False):
        norm_squared_K = torch.sum(K**2, dim=1, keepdim=True).permute(0, 2, 1)
        norm_squared_Q = torch.sum(Q**2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
        dot_product = torch.bmm(K.transpose(2, 1), Q)  
        dist_matrix = norm_squared_K + norm_squared_Q - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        return dist_matrix

    def _prime(self, v, qk, K, maximum):
        b, c, t = v.shape 
        _, topk_indices = torch.topk(qk, k=K, dim=-1, largest=maximum)
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K)
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = prime.reshape(b, c, -1)
        return prime

    def _prime_N(self, v, qk, K, rand_idx, maximum):
        b, c, t = v.shape
        _, topk_indices = torch.topk(qk, k=K - 1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."
        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=v.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
        indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(v_expanded, dim=2, index=indices_expanded)
        prime = prime.reshape(b, c, -1)
        return prime

In [3]:
class MultiHeadConvNNAttentionSeparate(nn.Module):
    """Each head has its own ConvNN parameters"""
    
    def __init__(self, d_hidden, num_heads, attention_dropout, K, sampling_type, 
                 num_samples, sample_padding, magnitude_type, seq_length=197, 
                 coordinate_encoding=False):
        super(MultiHeadConvNNAttentionSeparate, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"
        
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.d_k = d_hidden // num_heads
        
        self.seq_length = seq_length
        self.K = K
        self.sampling_type = sampling_type
        self.num_samples = num_samples if num_samples != -1 else 'all'
        self.sample_padding = sample_padding if sampling_type == 'spatial' else 0    
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'similarity' else False
        
        # Coordinate Encoding
        self.coordinate_encoding = coordinate_encoding
        
        # Linear projections
        self.W_q = nn.Linear(d_hidden, d_hidden)
        self.W_k = nn.Linear(d_hidden, d_hidden)
        self.W_v = nn.Linear(d_hidden, d_hidden)
        self.W_o = nn.Linear(d_hidden, d_hidden)
        self.dropout = nn.Dropout(attention_dropout)
        
        self.in_channels = self.d_k + 1 if coordinate_encoding else self.d_k
        self.out_channels = self.d_k + 1 if coordinate_encoding else self.d_k
        
        # Separate ConvNN for each head
        self.conv_heads = nn.ModuleList([
            nn.Conv1d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=K,
                stride=K,
                padding=0
            ) for _ in range(num_heads)
        ])
        
        # Separate pointwise conv for each head (if using coordinate encoding)
        if coordinate_encoding:
            self.pointwise_heads = nn.ModuleList([
                nn.Conv1d(
                    in_channels=self.out_channels,
                    out_channels=self.d_k,
                    kernel_size=1
                ) for _ in range(num_heads)
            ])
        
        # Optional: Learnable head-specific sampling offsets
        if sampling_type in ['random', 'spatial']:
            self.sampling_offsets = nn.Parameter(torch.randn(num_heads))
    
    def split_head(self, x):
        batch_size, seq_length, d_hidden = x.size()
        return x.contiguous().view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
    
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden)
    
    def _add_coordinate_encoding_per_head(self, x, head_idx):
        """Add unique coordinate encoding per head"""
        b, c, t = x.shape
        
        # Head-specific coordinate pattern
        scale = 1.0 + 0.3 * torch.sigmoid(self.sampling_offsets[head_idx]).item() if hasattr(self, 'sampling_offsets') else 1.0
        
        # Create coordinates with head-specific transformations
        coords = torch.linspace(-scale, scale, t, device=x.device)
        
        # Add non-linear transformation based on head index
        if head_idx % 2 == 0:
            # Even heads: use standard linear
            coords = coords
        else:
            # Odd heads: use non-linear transformation
            coords = torch.tanh(coords * (1 + head_idx * 0.1))
        
        coords = coords.unsqueeze(0).unsqueeze(0).expand(b, 1, -1)
        x_with_coords = torch.cat((x, coords), dim=1)
        return x_with_coords
    
    def process_head(self, q_h, k_h, v_h, head_idx):
        """Process a single head with its own ConvNN"""
        
        # Add coordinate encoding if needed
        if self.coordinate_encoding:
            q_h = self._add_coordinate_encoding_per_head(q_h, head_idx)
            k_h = self._add_coordinate_encoding_per_head(k_h, head_idx)
            v_h = self._add_coordinate_encoding_per_head(v_h, head_idx)
        
        if self.sampling_type == 'all':
            # Calculate magnitude matrix
            if self.magnitude_type == 'distance':
                matrix_magnitude = self._calculate_distance_matrix(k_h, q_h, sqrt=True)
            else:
                matrix_magnitude = self._calculate_similarity_matrix(k_h, q_h)
            
            prime = self._prime(v_h, matrix_magnitude, self.K, self.maximum)
            
        elif self.sampling_type in ['random', 'spatial']:
            seq_len = k_h.shape[2]
            
            # Head-specific sampling
            if self.sampling_type == 'random':
                # Use head index as seed offset for reproducibility
                torch.manual_seed(head_idx * 1000)  # Optional: for reproducibility
                rand_idx = torch.randperm(seq_len, device=k_h.device)[:self.num_samples]
                q_sample = q_h[:, :, rand_idx]
            else:  # spatial
                # Head-specific spatial pattern
                offset = head_idx * (seq_len // (self.num_heads * 2))
                start = max(self.sample_padding, offset % seq_len)
                end = seq_len - self.sample_padding - 1
                spat_idx = torch.linspace(start, end, self.num_samples, device=k_h.device).long()
                q_sample = q_h[:, :, spat_idx]
                rand_idx = spat_idx
            
            # Calculate magnitude matrix
            if self.magnitude_type == 'distance':
                matrix_magnitude = self._calculate_distance_matrix_N(k_h, q_sample, sqrt=True)
            else:
                matrix_magnitude = self._calculate_similarity_matrix_N(k_h, q_sample)
            
            # Set diagonal
            range_idx = torch.arange(len(rand_idx), device=k_h.device)
            inf_val = float('inf') if self.magnitude_type == 'distance' else float('-inf')
            matrix_magnitude[:, rand_idx, range_idx] = inf_val
            
            prime = self._prime_N(v_h, matrix_magnitude, self.K, rand_idx, self.maximum)
        
        # Apply head-specific convolution
        x = self.conv_heads[head_idx](prime)
        
        # Remove coordinate channel if needed
        if self.coordinate_encoding:
            x = self.pointwise_heads[head_idx](x)
        
        return x
    
    def forward(self, x):
        batch_size = x.size(0)
        
        # Apply linear projections
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)
        
        # Split into heads
        q_heads = self.split_head(q)  # (B, H, L, D/H)
        k_heads = self.split_head(k)  # (B, H, L, D/H)
        v_heads = self.split_head(v)  # (B, H, L, D/H)
        
        # Process each head independently with its own ConvNN
        processed_heads = []
        for h in range(self.num_heads):
            # Get head data and transpose for conv1d
            q_h = q_heads[:, h, :, :].transpose(1, 2)  # (B, D/H, L)
            k_h = k_heads[:, h, :, :].transpose(1, 2)  # (B, D/H, L)
            v_h = v_heads[:, h, :, :].transpose(1, 2)  # (B, D/H, L)
            
            # Process with head-specific ConvNN
            head_output = self.process_head(q_h, k_h, v_h, h)  # (B, D/H, L')
            
            # Transpose back
            head_output = head_output.transpose(1, 2)  # (B, L', D/H)
            processed_heads.append(head_output)
        
        # Stack and combine heads
        x = torch.stack(processed_heads, dim=1)  # (B, H, L', D/H)
        x = self.combine_heads(x)  # (B, L', D)
        
        # Apply dropout and output projection
        x = self.dropout(x)
        x = self.W_o(x)
        
        return x
    
    # Calculation methods (same as original)
    def _calculate_similarity_matrix(self, K, Q):
        k_norm = F.normalize(K, p=2, dim=1)
        q_norm = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.bmm(k_norm.transpose(2, 1), q_norm)
        return torch.clamp(similarity_matrix, min=0)
    
    def _calculate_similarity_matrix_N(self, K, Q):
        k_norm = F.normalize(K, p=2, dim=1)
        q_norm = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.bmm(k_norm.transpose(2, 1), q_norm)
        return torch.clamp(similarity_matrix, min=0)
    
    def _calculate_distance_matrix(self, K, Q, sqrt=False):
        norm_squared_K = torch.sum(K**2, dim=1, keepdim=True)
        norm_squared_Q = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(2, 1), Q)
        dist_matrix = norm_squared_K + norm_squared_Q.transpose(2, 1) - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0)
        return torch.sqrt(dist_matrix) if sqrt else dist_matrix
    
    def _calculate_distance_matrix_N(self, K, Q, sqrt=False):
        norm_squared_K = torch.sum(K**2, dim=1, keepdim=True).permute(0, 2, 1)
        norm_squared_Q = torch.sum(Q**2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
        dot_product = torch.bmm(K.transpose(2, 1), Q)
        dist_matrix = norm_squared_K + norm_squared_Q - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0)
        return torch.sqrt(dist_matrix) if sqrt else dist_matrix
    
    def _prime(self, v, qk, K, maximum):
        b, c, t = v.shape
        _, topk_indices = torch.topk(qk, k=K, dim=-1, largest=maximum)
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K)
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        return prime.reshape(b, c, -1)
    
    def _prime_N(self, v, qk, K, rand_idx, maximum):
        b, c, t = v.shape
        _, topk_indices = torch.topk(qk, k=K - 1, dim=2, largest=maximum)
        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=v.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
        indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(v_expanded, dim=2, index=indices_expanded)
        return prime.reshape(b, c, -1)

In [4]:
# Test both implementations
d_hidden = 768
num_heads = 8
seq_length = 197
batch_size = 4

from layers import MultiHeadConvNNAttention
# Create input
x = torch.randn(batch_size, seq_length, d_hidden)

# Original implementation
original_attention = MultiHeadConvNNAttention(
    d_hidden=d_hidden,
    num_heads=num_heads,
    attention_dropout=0.1,
    K=4,
    sampling_type='spatial',
    num_samples=49,
    sample_padding=0,
    magnitude_type='similarity',
    seq_length=seq_length,
    coordinate_encoding=True
)

# Enhanced version with head diversity
enhanced_attention = MultiHeadConvNNAttentionEnhanced(
    d_hidden=d_hidden,
    num_heads=num_heads,
    attention_dropout=0.1,
    K=4,
    sampling_type='spatial',
    num_samples=49,
    sample_padding=0,
    magnitude_type='similarity',
    seq_length=seq_length,
    coordinate_encoding=True,
    diverse_sampling=True  # Enable head-specific sampling
)

# Separate ConvNN per head version
separate_attention = MultiHeadConvNNAttentionSeparate(
    d_hidden=d_hidden,
    num_heads=num_heads,
    attention_dropout=0.1,
    K=4,
    sampling_type='spatial',
    num_samples=49,
    sample_padding=0,
    magnitude_type='similarity',
    seq_length=seq_length,
    coordinate_encoding=True
)

# Test forward passes
out_original = original_attention(x)
out_enhanced = enhanced_attention(x)
out_separate = separate_attention(x)

print(f"Original output shape: {out_original.shape}")
print(f"Enhanced output shape: {out_enhanced.shape}")
print(f"Separate output shape: {out_separate.shape}")

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 3 for tensor number 1 in the list.

In [None]:
"""Hard nearest neighbor selection -> consider using soft selection with temperature"""

def _prime_soft(self, v, qk, K, maximum, temperature=1.0):
    b, c, t = v.shape
    
    # Get top-k values and indices
    topk_values, topk_indices = torch.topk(qk, k=K, dim=-1, largest=maximum)
    
    # Apply softmax to create soft weights
    if maximum:  # similarity
        weights = F.softmax(topk_values / temperature, dim=-1)
    else:  # distance
        weights = F.softmax(-topk_values / temperature, dim=-1)
    
    # Gather and weight
    topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
    v_gathered = torch.gather(v.unsqueeze(-1).expand(b, c, t, K), 
                             dim=2, index=topk_indices_exp)
    
    # Apply soft weights
    v_weighted = v_gathered * weights.unsqueeze(1)
    prime = v_weighted.reshape(b, c, -1)
    return prime

In [None]:
"""Lack of Regularization"""
class MultiHeadConvNNAttention(nn.Module):
    def __init__(self, ..., weight_decay=1e-4, use_layer_norm=True):
        # ...
        
        # Add layer normalization
        if use_layer_norm:
            self.layer_norm = nn.LayerNorm(d_hidden)
            
        # Add L2 regularization to conv weights
        self.weight_decay = weight_decay
        
    def forward(self, x):
        # Add residual connection and layer norm
        residual = x
        
        # ... your processing ...
        
        # Add residual and normalize
        x = self.layer_norm(x + residual) if hasattr(self, 'layer_norm') else x
        return x
        
    def get_regularization_loss(self):
        """Call this during training"""
        reg_loss = 0
        for conv in [self.conv]:
            reg_loss += self.weight_decay * torch.sum(conv.weight ** 2)
        return reg_loss

In [None]:
"""Add noise to the distance/similarity matrices to prevent overfitting to exact patterns"""
def forward(self, x, training=True):
    # ...
    
    if training and self.training:
        # Add noise to magnitude matrix
        noise = torch.randn_like(matrix_magnitude) * 0.1
        matrix_magnitude = matrix_magnitude + noise

In [None]:
class MultiHeadConvNNAttention(nn.Module):
    def __init__(self, ...):
        # ... existing init ...
        
        # ADD: Temperature for soft selection
        self.temperature = nn.Parameter(torch.ones(1))
        
        # ADD: Layer norm
        self.layer_norm = nn.LayerNorm(d_hidden)
        
    def forward(self, x):
        residual = x  # Save for residual connection
        
        # CRITICAL: Use the linear projections!
        q = self.W_q(x)
        k = self.W_k(x)
        v = self.W_v(x)
        
        # ... rest of your ConvNN processing ...
        
        # Use soft selection instead of hard topk
        # (implement _prime_soft as shown above)
        
        # CRITICAL: Use W_o projection
        x = self.W_o(self.combine_heads(self.batch_split(x.permute(0, 2, 1))))
        
        # Add residual and layer norm
        x = self.layer_norm(x + residual)
        
        return x