Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Understanding how to define key, query and value for the cross attention calculation #119

Open
neuronphysics opened this issue Dec 18, 2022 · 0 comments

Comments

@neuronphysics
Copy link

Hello,

I have problem understanding how I can use this library to implement cross attention

for instance if tensor x=torch.rand(100,14,64) is key, tensor y=torch.rand(100,11,64) is value and tensorz=torch.rand(100,14,1) is query, how can I use TransformerDecoderBuilder to compute the cross attention for this example?

Here is how I built encoder and decoder class:

import math
import fast_transformers
from fast_transformers.builders import TransformerEncoderBuilder, TransformerDecoderBuilder
from collections import OrderedDict


class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model, dropout_prob=0.0, series_dimensions=1):
        global pe
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        self.d_model = d_model
        self.max_len = max_len
        self.series_dimensions = series_dimensions
        
        if self.series_dimensions == 1:
            if d_model % 2 != 0:
                raise ValueError("Cannot use sin/cos positional encoding with "
                                 "odd dim (got dim={:d})".format(d_model))
            pe = torch.zeros(self.max_len, d_model).float()
            pe.require_grad = False
            position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
        elif self.series_dimensions > 1:
            if d_model % 4 != 0:
                raise ValueError("Cannot use sin/cos positional encoding with "
                                 "odd dim (got dim={:d})".format(d_model))
            height = self.series_dimensions
            width = self.max_len
            pe = torch.zeros(d_model, height, width).float()
            pe.require_grad = False
            # Each dimension use half of d_model
            d_model = int(d_model / 2)
            div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
            pos_w = torch.arange(0., width).unsqueeze(1)
            pos_h = torch.arange(0., height).unsqueeze(1)
            pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
            pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
            pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
            pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
            pe = pe.view(2*d_model, height * width, -1).squeeze(-1) # Flattening it back to 1D series
            pe = pe.transpose(0, 1)
            
        pe = pe.unsqueeze(0) # Extending it by an extra leading dim for the batches
        self.register_buffer('pe', pe)

    # Expecting a flattened (1D) series
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class LinearTransformerCausalEncoder(torch.nn.Module):
    def __init__(self, input_features, output_features, hidden_dim, sequence_length, 
                 attention_type='causal-linear', n_layers=2, n_heads=4,
                 dropout=0.1, softmax_temp=None, activation_fn="gelu",
                 attention_dropout=0.1,
                ):
        super(LinearTransformerCausalEncoder, self).__init__()
        #
        self.d_model=hidden_dim*n_heads
        #
        self.pos_embedding = PositionalEncoding(
                                               max_len=sequence_length,
                                               d_model=self.d_model, #hidden_dim*n_heads      
                                               )
        self.value_embedding = nn.Linear(
            input_features,
            self.d_model
        )
        self.builder_dict = OrderedDict({
            "attention_type": attention_type,
            "n_layers": n_layers,
            "n_heads": n_heads,
            "feed_forward_dimensions": self.d_model*2,
            "query_dimensions": hidden_dim,
            "value_dimensions": hidden_dim,
            "dropout": dropout,
            "softmax_temp": softmax_temp,
            "activation" : activation_fn,
            "attention_dropout": attention_dropout,
        })
        self.transformer = TransformerEncoderBuilder.from_dictionary(
            self.builder_dict,
            strict=True
        ).get()
        hidden_size = n_heads*hidden_dim
        ##
        self.predictor = torch.nn.Linear(
            hidden_size,
            output_features
        )
    def forward(self, x):
        # x: [batch_size, input_dim, sequence_length]
        x = x.permute(0,2,1)
        x = self.value_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
        x = self.pos_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
        triangular_mask = fast_transformers.masking.TriangularCausalMask(x.size(1), device=x.device) # triangular_mask: [ sequence_length,  sequence_length]       
        y_hat = self.transformer(x, attn_mask=triangular_mask) # y_hat: [batch size, sequence_length, n_heads* hiden_size]     
        y_hat = self.predictor(y_hat) # y_hat: [batch size, sequence_length, output_size]
        return y_hat.permute(0,2,1)   # y_hat: [batch size, output_size, sequence_length]

class LinearTransformerCausalDecoder(torch.nn.Module):
    def __init__(self, output_features, hidden_dim, sequence_length, 
                 attention_type='causal-linear', n_layers=2, n_heads=4,
                 d_query=32, dropout=0.1, softmax_temp=None,activation_fn="gelu",
                 attention_dropout=0.1,):
        super(LinearTransformerCausalDecoder, self).__init__()
        self.d_model=hidden_dim*n_heads
        self.pos_embedding = PositionalEncoding(
             max_len=sequence_length,
            d_model=self.d_model, #hidden_dim*n_heads
           
        )
    
        self.value_embedding = torch.nn.Linear(
            output_features,
            self.d_model
        )
        self.builder_dict = OrderedDict({
            "cross_attention_type":attention_type,
            "self_attention_type":attention_type,
            "n_layers": n_layers,
            "n_heads": n_heads,
            "feed_forward_dimensions": self.d_model*2,
            "query_dimensions": hidden_dim,
            "value_dimensions": hidden_dim,
            "dropout": dropout,
            "softmax_temp": softmax_temp,
            "activation" : activation_fn,
            "attention_dropout": attention_dropout,
        })
        self.transformer = TransformerDecoderBuilder.from_dictionary(
            self.builder_dict,
            strict=True
        ).get()
        hidden_size = n_heads*hidden_dim
        
        self.predictor = torch.nn.Linear(
            hidden_size,
            output_features
        )
    def forward(self, target, memory, len_mask=None):
        
        x = target.permute(0,2,1) # x: [batch_size, sequence_length, input_dim]
        x = self.value_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
        x = self.pos_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
        triangular_mask = fast_transformers.masking.TriangularCausalMask(x.size(1), device=x.device) # triangular_mask: [ sequence_length,  sequence_length]       
        y_hat = self.transformer(x, memory, triangular_mask, len_mask=None) # y_hat: [batch size, sequence_length, n_heads* hiden_size]   
        y_hat = self.predictor(y_hat) # y_hat: [batch size, sequence_length, output_size]
        return y_hat.permute(0,2,1)   # y_hat: [batch size, output_size, sequence_length]x=torch.rand([100,14,64])

I have difficulty to comprehend how I can use LinearTransformerCausalDecoder for computing cross attention. I will appreciate if anyone can clarify it for this example key, query and value ? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant