<a href="https://colab.research.google.com/github/cmtorresjimenez/curso_nlp/blob/main/multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Módulo de Multi-Head Attention

In [29]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import Tensor
import math
from typing import Optional

## Attention

Una función de atención puede ser decrita como un mapeo de una consulta (query) y un conjunto de parejas llave-valor (key-value) a una salida, donde consultas, llaves, valores y salidas son todos vectores. La salida se calcula como una suma ponderada de los valores, donde el peso asignado a cada uno de los valores es calculado por una función de compatibilidad entre cada consulta y la correspodiente llave.

En Transformers, dicha función atención se denomina "Scaled Dot-Product Attention". La entrada consiste en consultas y llaves de dimensión $d_k$, y valores de dimensión $d_v$. Calculamos el producto punto de la consulta con todas las llave, divimos cada producto por $\sqrt{d_k}$, y aplicamos una función softmax para obtener los pesos sobre los valores.

En la práctica, calculamos la función de atención sobre un conjunto de consultas de manera simultanea, acopladas en una matriz $Q$. Las llaves y valores también se acoplan en matricez $K$ y $V$ respectivamente. Calculamos la matriz de salidas de la siguiente manera:

$Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$

In [30]:
class MultiHeadedAttention(nn.Module):
    '''
    BLoque de MultiHeadedAttention que permita al modelo atender de manera
    conjunta a información de diferentes subespacios de representación.

    Args:
        num_heads (int): número de cabezas por capa
        d_model (int): dimensión total del modelo
        dropout (float): Una capa de dropout sobre attention_probs. Default: 0.0. 
    '''
    def __init__(self, num_heads: int, d_model: int, dropout: float = 0.0):
        super(MultiHeadedAttention, self).__init__()
        if d_model % num_heads != 0:
            raise ValueError(
                f"The hidden size ({d_model}) is not a multiple of the number of attention "
                f"heads ({num_heads})"
            )
        # Número de features por cabeza, se asume que d_v = d_k
        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.num_heads = num_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(p=dropout)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_heads, self.d_k)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        mask: Optional[Tensor] = None,
        output_attentions: Optional[bool] = False
    ):
        '''
        Args:
            query, key, value: Se mapea el query y un conjunto de parejas key-value a una salida output.
            mask: máscara que previene la atención en ciertas posiciones.
            output_attentions: Indica si se quiere regresar la matriz de pesos de atención
        '''
        if mask is not None:
            # Se aplica la misma máscara para todas las cabezas
            mask = mask.unsqueeze(1)
        
        query_layer = self.transpose_for_scores(self.query(query)) # (batch, num_heads, seq_len, d_k)
        key_layer = self.transpose_for_scores(self.key(key))
        value_layer = self.transpose_for_scores(self.value(value))

        # Se realiza el producto punto entre "query" y "key" para obtener los scores de atención crudos/sin procesar
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.d_k)

        # Se aplica máscara
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))

        # Se normalizan los scores de atención a probabilidades
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # (batch, seq_len, num_heads, d_k)
        new_context_layer_shape = context_layer.size()[:-2] + (self.d_model,) # (batch, seq_len, d_model)
        context_layer = context_layer.view(*new_context_layer_shape)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs

In [31]:
class MultiHeadedAttention(nn.Module):
    '''
    BLoque de MultiHeadedAttention que permita al modelo atender de manera
    conjunta a información de diferentes subespacios de representación.

    Args:
        num_heads (int): número de cabezas por capa
        d_model (int): dimensión total del modelo
        dropout (float): Una capa de dropout sobre attention_probs. Default: 0.0. 
    '''
    def __init__(self, num_heads: int, d_model: int, dropout: float = 0.0):
        super(MultiHeadedAttention, self).__init__()
        if d_model % num_heads != 0:
            raise ValueError(
                f"The hidden size ({d_model}) is not a multiple of the number of attention "
                f"heads ({num_heads})"
            )
        # Número de features por cabeza, se asume que d_v = d_k
        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.num_heads = num_heads

        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(p=dropout)
        self.norm = nn.BatchNorm1d(d_model)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        mask: Optional[Tensor] = None,
        output_attentions: Optional[bool] = False
    ):
        '''
        Args:
            query, key, value: Se mapea el query y un conjunto de parejas key-value a una salida output.
            mask: máscara que previene la atención en ciertas posiciones.
            output_attentions: Indica si se quiere regresar la matriz de pesos de atención
        '''
        key_layer = self.key(key)
        value_layer = self.value(value)

        # Se realiza el producto punto entre "query" y "key" para obtener los scores de atención crudos/sin procesar
        attention_scores = torch.matmul(query, key_layer.transpose(-2, -1))
        attention_scores = attention_scores / math.sqrt(self.d_k)

        # Se normalizan los scores de atención a probabilidades
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        #attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer) + query

        context_layer = self.norm(context_layer.transpose(-1, -2)).transpose(-1, -2)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs

## Ejemplo uso

Instanciamos el módulo de Multi-Head Attention

In [32]:
att = MultiHeadedAttention(8, 768, 0.1)

Creamos un entrada con tamaño de lote de 8 y secuencias de 300 elementos aleatorios

In [33]:
# Se prepara la entrada
x = torch.rand(8, 300, 768)
mask = torch.ones((8, 300))

# Ejecutar módulo, se regresa la matriz de atención
output = att(query=x, key=x, value=x, mask=mask, output_attentions=True)

Embeddings de salida

In [34]:
print(output[0].shape)

torch.Size([8, 300, 768])


Matriz de atención

In [35]:
print(output[1].shape)

torch.Size([8, 300, 300])


## Visualización de self-attention

In [36]:
%%capture
!pip install transformers
!pip install bertviz

In [37]:
from transformers import BertModel, BertTokenizer
from bertviz import head_view

In [41]:
bert = BertModel.from_pretrained('dccuchile/bert-base-spanish-wwm-cased')
bert_embeddings_layer = bert.embeddings
tokenizer = BertTokenizer.from_pretrained('dccuchile/bert-base-spanish-wwm-cased')

att = MultiHeadedAttention(8, 768, 0.1)

att.query.load_state_dict(bert.encoder.layer[0].attention.self.query.state_dict())
att.key.load_state_dict(bert.encoder.layer[0].attention.self.key.state_dict())
att.value.load_state_dict(bert.encoder.layer[0].attention.self.value.state_dict())

Some weights of the model checkpoint at dccuchile/bert-base-spanish-wwm-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at dccuchile/bert-base-spanish-wwm-cased and are newly initialized: ['bert.pooler.dense.bi

<All keys matched successfully>

In [93]:
text_input = tokenizer(["El perro va caminando sobre el pasto"], return_tensors='pt')

In [91]:
input_ids = text_input['input_ids']
x = bert_embeddings_layer(input_ids)
mask = text_input['attention_mask']

output, attention_scores = att(query=x, key=x, value=x, mask=mask, output_attentions=True)

In [92]:
attention_scores.shape

torch.Size([1, 10, 10])

In [45]:
input_id_list = input_ids.tolist()[0]
tokens = tokenizer.convert_ids_to_tokens(input_id_list)

In [94]:
head_view((attention_scores,)*12, tokens)

ValueError: ignored

## Counting Letters

In [49]:
import numpy as np
import string

In [87]:
class MultiHeadedAttention(nn.Module):
    '''
    BLoque de MultiHeadedAttention que permita al modelo atender de manera
    conjunta a información de diferentes subespacios de representación.

    Args:
        num_heads (int): número de cabezas por capa
        d_model (int): dimensión total del modelo
        dropout (float): Una capa de dropout sobre attention_probs. Default: 0.0. 
    '''
    def __init__(self, num_heads: int, d_model: int, dropout: float = 0.0):
        super(MultiHeadedAttention, self).__init__()
        if d_model % num_heads != 0:
            raise ValueError(
                f"The hidden size ({d_model}) is not a multiple of the number of attention "
                f"heads ({num_heads})"
            )
        # Número de features por cabeza, se asume que d_v = d_k
        self.d_k = d_model // num_heads
        self.d_model = d_model
        self.num_heads = num_heads

        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(p=dropout)
        self.norm = nn.BatchNorm1d(d_model)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        mask: Optional[Tensor] = None,
        output_attentions: Optional[bool] = False
    ):
        '''
        Args:
            query, key, value: Se mapea el query y un conjunto de parejas key-value a una salida output.
            mask: máscara que previene la atención en ciertas posiciones.
            output_attentions: Indica si se quiere regresar la matriz de pesos de atención
        '''
        query_layer = self.transpose_for_scores(self.query(query)) # (batch, num_heads, seq_len, d_k)
        key_layer = self.key(key)
        value_layer = self.value(value)

        # Se realiza el producto punto entre "query" y "key" para obtener los scores de atención crudos/sin procesar
        attention_scores = torch.matmul(query, key_layer.transpose(-2, -1))
        attention_scores = attention_scores / math.sqrt(self.d_k)

        # Se normalizan los scores de atención a probabilidades
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        #attention_probs = self.dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer) + query

        context_layer = self.norm(context_layer.transpose(-1, -2)).transpose(-1, -2)

        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs

In [88]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    attn_logits = attn_logits / math.sqrt(d_k)
    if mask is not None:
        attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    values = torch.matmul(attention, v) + q
    return values, attention

class MultiheadAttention(nn.Module):

    def __init__(self, input_dim, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.d_k = embed_dim // num_heads
        self.d_model = embed_dim
        self.num_heads = num_heads

        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

        #self.dropout = nn.Dropout(p=dropout)
        self.norm = nn.BatchNorm1d(embed_dim)
        self.o_proj = nn.Linear(embed_dim, embed_dim)

        self._reset_parameters()
    
    def _reset_parameters(self):
        # Original Transformer initialization, see PyTorch documentation
        nn.init.xavier_uniform_(self.o_proj.weight)
        self.o_proj.bias.data.fill_(0)

    def forward(self, query, key, value, mask=None, return_attention=False):
        key_layer = self.key(key)
        value_layer = self.value(value)

        # Determine value outputs
        o, attention = scaled_dot_product(query, key_layer, value_layer, mask=mask)
        #o = self.norm(values.transpose(-1, -2)).transpose(-1, -2)
        #values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
        #values = values.reshape(batch_size, seq_length, embed_dim)
        #o = self.o_proj(values)

        if return_attention:
            return o, attention
        else:
            return o

In [52]:
class CounterModel(nn.Module):
    def __init__(self, num_heads: int = 1, d_model: int = 64, dropout: float = 0.1, emb_dim: int = 64, vocab_size: int = 3, max_len: int = 10):
        super(CounterModel, self).__init__()
        self.vocab_size = vocab_size
        self.embeddings = nn.Embedding(num_embeddings=vocab_size+1, embedding_dim=emb_dim)
        self.query = nn.Linear(in_features=emb_dim, out_features=1).weight
        self.mh_att = MultiheadAttention(d_model, emb_dim, num_heads)
        self.classification_head = nn.Linear(in_features=emb_dim, out_features=max_len+1)
    
    def forward(self, letter_sequence):
        batch_size = input_batch.shape[0]
        letter_sequence_ids = (input_batch == 1).nonzero(as_tuple=True)[-1].reshape((batch_size,-1))
        x = self.embeddings(letter_sequence_ids)
        q_r = self.query.repeat((batch_size,self.vocab_size,1))
        output, attention_scores = self.mh_att(query=q_r, key=x, value=x, mask=None, return_attention=True)
        logits = self.classification_head(output)
        return logits

In [53]:
class CounterModel(nn.Module):
    def __init__(self, num_heads: int = 1, d_model: int = 64, dropout: float = 0.1, emb_dim: int = 64, vocab_size: int = 3, max_len: int = 10):
        super(CounterModel, self).__init__()
        self.vocab_size = vocab_size
        self.embeddings = nn.Embedding(num_embeddings=vocab_size+1, embedding_dim=emb_dim)
        self.query = nn.Linear(in_features=emb_dim, out_features=1).weight
        self.mh_att = MultiHeadedAttention(num_heads, d_model, dropout)
        self.classification_head = nn.Linear(in_features=emb_dim, out_features=max_len+1)
    
    def forward(self, letter_sequence):
        #print(letter_sequence.shape)
        batch_size = input_batch.shape[0]
        letter_sequence_ids = (input_batch == 1).nonzero(as_tuple=True)[-1].reshape((batch_size,-1))
        #cls_tokens = torch.ones((batch_size,1), dtype=torch.long)*(self.vocab_size+1)
        #letter_sequence_ids = torch.concat((cls_tokens, letter_sequence_ids), 1)
        x = self.embeddings(letter_sequence_ids)
        #print(x.shape)
        q_r = self.query.repeat((batch_size,self.vocab_size,1))
        output, attention_scores = self.mh_att(query=q_r, key=x, value=x, mask=None, output_attentions=True)
        #print(output.shape)
        logits = self.classification_head(output)
        #print(logits.shape)
        return logits

In [54]:
class Task(object):

	def __init__(self, max_len=10, vocab_size=3):
		super(Task, self).__init__()
		self.max_len = max_len
		self.vocab_size = vocab_size
		assert self.vocab_size <= 26, "vocab_size needs to be <= 26 since we are using letters to prettify LOL"

	def next_batch(self, batchsize=100):
		x = np.eye(self.vocab_size + 1)[np.random.choice(np.arange(self.vocab_size + 1), [batchsize, self.max_len])]
		y = np.eye(self.max_len + 1)[np.sum(x, axis=1)[:, 1:].astype(np.int32)]
		return x, y

	def prettify(self, samples):
		samples = samples.reshape(-1, self.max_len, self.vocab_size + 1)
		idx = np.expand_dims(np.argmax(samples, axis=2), axis=2)
		dictionary = np.array(list(' ' + string.ascii_uppercase))
		return dictionary[idx]

In [55]:
task = Task(max_len=10, vocab_size=3)

In [56]:
minibatch_x, minibatch_y = task.next_batch(batchsize=2)

In [None]:
print(minibatch_x)

In [58]:
print(minibatch_y)

[[[0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
  [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]]]


In [59]:
minibatch_y.shape

(2, 3, 11)

In [60]:
task.prettify(minibatch_x)

array([[['C'],
        ['A'],
        ['C'],
        ['C'],
        [' '],
        [' '],
        ['B'],
        ['C'],
        [' '],
        ['B']],

       [['A'],
        ['C'],
        ['A'],
        ['A'],
        ['C'],
        [' '],
        ['A'],
        [' '],
        ['A'],
        [' ']]], dtype='<U1')

In [61]:
input_batch = torch.tensor(minibatch_x, dtype=torch.long)

In [62]:
input_batch_ids = (input_batch == 1).nonzero(as_tuple=True)[-1].reshape((input_batch.shape[0],-1))

In [63]:
cls_token = torch.ones((2,1), dtype=torch.long)*4

In [64]:
torch.concat((cls_token, input_batch_ids), 1)

tensor([[4, 3, 1, 3, 3, 0, 0, 2, 3, 0, 2],
        [4, 1, 3, 1, 1, 3, 0, 1, 0, 1, 0]])

In [65]:
model = CounterModel()

In [66]:
model(input_batch).shape

torch.Size([2, 3, 11])

In [67]:
q = nn.Linear(in_features=10, out_features=1, bias=False).weight

In [68]:
q.shape

torch.Size([1, 10])

In [69]:
q_r = q.repeat((2,3,1))

In [70]:
q_r.shape

torch.Size([2, 3, 10])

In [71]:
k = torch.rand((2,5,10))

In [72]:
v = torch.rand((2,5,10))

In [73]:
probs = torch.matmul(q_r, k.transpose(-1, -2))

In [74]:
probs.shape

torch.Size([2, 3, 5])

In [75]:
context_layer = torch.matmul(probs, v)

In [76]:
context_layer.shape

torch.Size([2, 3, 10])

## Training

In [78]:
model = CounterModel(num_heads=1, d_model=64, dropout=0.1, emb_dim=64)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

In [79]:
model.train()

CounterModel(
  (embeddings): Embedding(4, 64)
  (mh_att): MultiHeadedAttention(
    (query): Linear(in_features=64, out_features=64, bias=True)
    (key): Linear(in_features=64, out_features=64, bias=True)
    (value): Linear(in_features=64, out_features=64, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (classification_head): Linear(in_features=64, out_features=11, bias=True)
)

In [80]:
model.embeddings.weight.shape

torch.Size([4, 64])

In [81]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

13643

In [82]:
task = Task(max_len=10, vocab_size=3)
dataset = [task.next_batch(batchsize=128) for i in range(100)]

In [83]:
for step in range(5000):
    # get the inputs; data is a list of [inputs, labels]
    i = step%100
    #input_batch, label_batch = task.next_batch(batchsize=128)
    input_batch, label_batch = dataset[0]
    input_batch = torch.tensor(input_batch, dtype=torch.long)
    label_batch = torch.tensor(label_batch, dtype=torch.float)

    # zero the parameter gradients
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs = model(input_batch)
    loss = F.cross_entropy(outputs, label_batch)
    loss.backward()
    for name, param in model.named_parameters():
        #if "classification_head" in name:
        #print(param.grad.shape)
        if param.grad is None:
            print(f"NONE: {name}")
            break
        else:
            print(name, param.grad)
    optimizer.step()

    # print statistics
    if step % 100 == 0:
        print(f'step: {step + 1} | loss: {loss.item():.3f}')

print('Finished Training')

[1;30;43mSe han truncado las últimas 5000 líneas del flujo de salida.[0m
          6.4391, -6.3941,  6.4384, -6.2920, -6.2426, -4.8633,  6.3537, -6.3902]])
embeddings.weight tensor([[ 5.5306e+00,  5.4847e+00, -1.8986e+00, -5.5764e+00,  5.5325e+00,
          5.3799e+00, -5.5277e+00, -5.2650e+00, -4.5007e+00, -5.4769e+00,
          5.4214e+00, -5.6273e+00, -5.4552e+00,  5.4389e+00, -5.5516e+00,
         -5.5137e+00,  5.6512e+00, -5.5829e+00, -5.3547e+00, -5.6213e+00,
          5.5263e+00,  5.4450e+00,  5.6122e+00, -5.6170e+00,  5.6727e+00,
         -5.4142e+00, -5.5298e+00,  5.3143e+00,  5.6563e+00,  5.4769e+00,
          4.8286e+00,  5.4954e+00,  3.8536e+00,  5.5774e+00,  5.5541e+00,
          5.5677e+00, -5.5754e+00, -5.6234e+00, -5.5088e+00, -5.5648e+00,
          5.2446e+00, -5.4572e+00,  5.5696e+00,  5.6108e+00, -5.4152e+00,
         -5.4816e+00, -5.6807e+00, -5.5029e+00, -5.5999e+00,  5.4523e+00,
         -5.6224e+00, -5.6042e+00,  5.4376e+00,  5.4965e+00, -5.4986e+00,
          

## TF training

In [None]:
!pip install tensorflow-gpu==1.10

In [None]:
!git clone https://github.com/greentfrapp/attention-primer.git

In [None]:
!python attention-primer/1_counting-letters/main.py --train