# How and Why does Self Attention Work? 

## Stacking Non-Linear Transformations

The prevailing paradigm in machine learning is to stack sets of non-linear transformation, initially on the input data, and then on the outputs successively. Each transformation is called a layer. "Deep" architectures are so called because they have many layers -- for example, the GPT-3 model trained at OpenAI had 96 layers in total. 

Many of the key breakthroughs in recent years have focused on resolving the problems that crop up while training deep architectures. I have covered a few of these in previous posts, such as [Batch Normalization](...) and [Dropout](...).

The earliest deep-learning architecture was what is now called a Feed-Forward Network, which in its current mature formulation, consists of a linear operation followed by a non-linear activation function $(\sigma)$ such as ReLU (Rectified Linear Unit).

$$
f(\bold{x}) = \sigma(\bold{Wx} + b)
$$


<img src="nn.svg" alt="Feed Forward Network" style="width:500px;"/>

Another popular architecture used primarily in computer vision is the Convolutional Neural Network (CNN), whose basic transformation is the convolution followed by an activation function. The idea here is that the model can learn to detect features in an image by performing non-linear transformations on small patches of the image. Then, just like the basic Feed-Forward network, the same operation can be performed on the outputs successively.


<img src="cnn.svg" alt="Feed Forward Network" style="width:800px;"/>

I won't go into to much detail on these here.


## A New Transformation

The idea of self-attention developed out of the sequence-to-sequence model architectures which were used primarily for machine translation. Here, the goal was to learn a single representation for an input sentence ("encode"), then use it to generate a translation ("decode").

It however proved difficult to compress the information needed to translate a long sentence into a single representation. 

Bahadanau et. al. (2014) introduced the concept of attention -- the model could learn to use parts of the input sentence ("attend") directly while decoding rather than rely solely on the learnt representation. Parikh et al. (2016) realized that the attention operation itself could be used for NLP tasks such as entailment. Lin et al. (2017) introduced the concept of self-attention to perform a variety of NLP tasks. 

In my earlier post on [Nadaraya-Watson Regression](...), we saw how this classic non-parametric technique can be interpreted as an early form of attention. 

Vaswani el al. (2017) realized that the self-attention mechanism could be used as the basic non-linear transformation for sequences of variable length. This also allowed the model to process tokens in parallel by incorporating positional information (check out my post on positional embeddings [here](...)). By avoiding having to explicitly process each token in sequence, it became possible to train much deeper networks. 




## So What Exactly is Self Attention?



## Let's See If it Works

Let's train a sentiment classifier. For our dataset, we will be using Yelp review dataset with GloVE embeddings. Our baseline model is a simple feedforward network that uses an average of the word embeddings. The candidate model will use an additional self-attention layer. 



### Let's start by loading the dataset

In [60]:
from datasets import load_dataset

# Load IMDb dataset from Hugging Face
raw_train_data, raw_test_data = load_dataset("imdb", split=["train", "test"])
print(len(raw_train_data), len(raw_test_data))

25000 25000


Let's look at some reviews

In [61]:
from IPython.display import display, HTML

display(HTML("""
<style>
    .custom-paragraph {
        width: 600px;
        margin: auto;
        line-height: 1.6;
    }
</style>
"""))

In [62]:
from itertools import islice
n_samples = 5
samples = list(islice(raw_train_data, n_samples))
#samples = list(raw_train_data.take(n_samples).cache())
html = ""
for d in samples:
    text = d['text']
    label = "Positive" if d['label'] == 1 else "Negative"
    html += f"""<div class="custom-paragraph">
              <p><b>[{label}]</b> {text}</p>
             </div>"""
display(HTML(html))

Preprocess the data

In [63]:
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter

nltk.download('punkt')

[nltk_data] Downloading package punkt to /Users/vikram/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [64]:
import re
from collections import Counter

def clean_paragraph(text):
    # Keep only alphanumeric characters, punctuation, and spaces.
    text = text.lower()
    text = re.sub(r'<br\s*/?>', ' ', text, flags=re.IGNORECASE)
    text = re.sub(r'[-]', ' ', text, flags=re.IGNORECASE)
    text = re.sub(r'[.]', '. ', text)
    cleaned_text = re.sub(r'[^a-z0-9\s\.,!?;:-]', '', text)
    return cleaned_text

def process_raw_data(raw_data):
    word_counts = Counter()
    processed_data = []
    for i, item in enumerate(raw_data):
        print(f"Processing item {i}", end="\r")
        # Split into sentences
        text = item['text'] #.numpy().decode('utf-8')
        text = clean_paragraph(text)
        label = item['label']
        # Tokenize
        tokens = word_tokenize(text)
        word_counts.update(tokens)
        processed_data.append((tokens, label))
        # Update indices
    sorted_counts = word_counts.most_common()
    return sorted_counts, processed_data


train_word_counts, clean_train_data = process_raw_data(raw_train_data)
test_word_counts, clean_test_data = process_raw_data(raw_test_data)

Processing item 24999

In [65]:
VOCAB_SIZE = 10000
vocab = [ t[0] for t in train_word_counts[:VOCAB_SIZE] ]
vocab += [ '<UNK>', '<PAD>' ]

In [66]:
tok_to_idx = {}
for i, tok in enumerate(vocab): 
    tok_to_idx[tok] = i

In [67]:
import numpy as np

EMBED_DIM = 100

def load_glove_embeddings(filepath):
    embeddings = {}
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype='float32')
            embeddings[word] = vector
    return embeddings

# Load the 100-dimensional GloVe embeddings
glove_embeddings = load_glove_embeddings('../datasets/glove/glove.6B.100d.txt')

In [136]:
import numpy as np

def cosine_similarity(vec_a, vec_b):
    """
    Compute the cosine similarity between two vectors.
    
    Parameters:
    vec_a (np.ndarray): First input vector.
    vec_b (np.ndarray): Second input vector.
    
    Returns:
    float: Cosine similarity between the two vectors.
    """
    # Ensure the input vectors are 1-dimensional
    vec_a = vec_a.flatten()
    vec_b = vec_b.flatten()
    
    # Compute the dot product of the two vectors
    dot_product = np.dot(vec_a, vec_b)
    
    # Compute the L2 norms (magnitude) of the vectors
    norm_a = np.linalg.norm(vec_a)
    norm_b = np.linalg.norm(vec_b)
    
    # Compute cosine similarity
    if norm_a == 0 or norm_b == 0:
        return 0.0  # Return 0 if either vector is zero-length
    
    return dot_product / (norm_a * norm_b)

# Example usage
vector1 =  glove_embeddings["excellent"]
vector2 =  glove_embeddings["terrible"]
similarity = cosine_similarity(vector1, vector2)
print(f"Cosine Similarity: {similarity}")


Cosine Similarity: 0.31886589527130127


Let's build the embedding matrix

In [68]:
def build_embedding_matrix(vocab, glove_embeddings):
    embedding_matrix = np.random.normal(size=(len(vocab), EMBED_DIM)).astype('float32')
    for idx, word in enumerate(vocab):
        if word in glove_embeddings:
            embedding_matrix[idx] = glove_embeddings[word]
        else:
            embedding_matrix[idx] = np.zeros(EMBED_DIM)
    return embedding_matrix

# Create the embedding matrix
embedding_matrix = build_embedding_matrix(vocab, glove_embeddings)
embedding_matrix.shape

(10002, 100)

Let's build the dataset

In [69]:
import jax
import jax.numpy as jnp

SEQUENCE_LENGTH = 32

def get_idxs(tokens):
    # Substitute tokens with indices
    idxs = []
    for tok in tokens:
        if tok in tok_to_idx:
            idxs.append(tok_to_idx[tok])
        else:
            idxs.append(tok_to_idx['<UNK>'])

    return idxs


def split_into_chunks(idxs, chunk_size):
    chunks =  [idxs[i:i+chunk_size] for i in range (0, len(idxs), chunk_size)]
    for c in chunks:
        while len(c) < chunk_size:
            c.append(tok_to_idx['<PAD>'])
    return chunks


def generate_dataset(data):

    X = []
    Y = []
    for i, x in enumerate(data):
        print(f"Processing {i}", end="\r")
        tokens = x[0]
        label = x[1]
        idxs = get_idxs(tokens)
        chunks = split_into_chunks(idxs, SEQUENCE_LENGTH)
        X += chunks 
        Y += [label]*len(chunks)
    X = jnp.array(X)
    Y = jnp.array(Y)
    return X, Y


X_train, Y_train = generate_dataset(clean_train_data)
print(X_train.shape, Y_train.shape)
X_test, Y_test = generate_dataset(clean_test_data)
print(X_test.shape, Y_test.shape)

(214588, 32) (214588,)
(210023, 32) (210023,)


Let's define the baseline model

In [142]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax

class EmbeddingModel(nn.Module):
    embed_matrix: jnp.ndarray  # Preloaded pretrained embeddings
    num_classes: int

    @nn.compact
    def __call__(self, x):
        # Embed the input tokens
        embedding_layer = nn.Embed(
            num_embeddings=self.embed_matrix.shape[0],
            features=self.embed_matrix.shape[1],
            embedding_init=nn.initializers.constant(self.embed_matrix)
        )
        x = embedding_layer(x)
        x = jnp.mean(x, axis=1)
        logits = nn.Dense(self.num_classes)(x)

        return logits


class TransformerModel(nn.Module):
    embed_matrix: jnp.ndarray  # Preloaded pretrained embeddings
    num_classes: int

    @nn.compact
    def __call__(self, x):
        # Embed the input tokens
        embedding_layer = nn.Embed(
            num_embeddings=self.embed_matrix.shape[0],
            features=self.embed_matrix.shape[1],
            embedding_init=nn.initializers.constant(self.embed_matrix)
        )
        x = embedding_layer(x)
        q = nn.Dense(32)(x)
        k = nn.Dense(32)(x)
        v = nn.Dense(32)(x)

        k = jnp.transpose(k, (0, 2, 1))
        r = q @ k / jnp.sqrt(32)
        x = r @ v
        x = jnp.mean(x, axis=1)
        logits = nn.Dense(self.num_classes)(x)

        return logits



In [141]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax


# Define loss function and metrics
def binary_cross_entropy_loss(logits, labels):
    """Binary cross-entropy loss."""
    loss = optax.sigmoid_binary_cross_entropy(logits, labels)
    return jnp.mean(loss)

#@jax.jit
def compute_metrics(logits, labels):
    """Compute accuracy and loss metrics."""
    loss = binary_cross_entropy_loss(logits, labels)
    predictions = (logits > 0).astype(jnp.float32)  # Sigmoid threshold at 0.5
    print(predictions)
    accuracy = jnp.mean(predictions == labels)
    return {'loss': loss, 'accuracy': accuracy}

# Create the training state
def create_train_state(rng, model, dummy_input, learning_rate=5e-2):
    """Initialize the model and optimizer."""
    params = model.init(rng, dummy_input)  # Model parameter initialization
    tx = optax.adam(learning_rate)         # Adam optimizer
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

#@jax.jit
def train_step(state, batch):
    """Perform a single training step."""
    inputs, labels = batch

    def loss_fn(params):
        logits = state.apply_fn(params, inputs)
        loss = binary_cross_entropy_loss(logits, labels)
        return loss

    # Compute gradients
    loss, grads = jax.value_and_grad(loss_fn)(state.params)

    # Update parameters
    state = state.apply_gradients(grads=grads)

    # Compute metrics
    metrics = compute_metrics(state.apply_fn(state.params, inputs), labels)
    return state, metrics

@jax.jit
def eval_step(state, batch):
    """Evaluate the model on a batch."""
    inputs, labels = batch
    logits = state.apply_fn(state.params, inputs)
    return compute_metrics(logits, labels)


NUM_CLASSES = 1     # Example: Binary classification (e.g., positive/negative sentiment)
SEQUENCE_LENGTH = 32  # Input sequence length (e.g., 32 words)
NUM_EPOCHS = 10
BATCH_SIZE = 64

# Initialize model and state
model = TransformerModel(embedding_matrix, NUM_CLASSES)
rng = jax.random.PRNGKey(42)
rng, init_rng = jax.random.split(rng)
dummy_input = jax.random.randint(rng, (BATCH_SIZE, SEQUENCE_LENGTH), 0, vocab_size)
state = create_train_state(init_rng, model, dummy_input)

# Training loop
for epoch in range(NUM_EPOCHS):
    # Training
    TRAIN_SIZE = X_train.shape[0]
    train_metrics = {'loss': 0, 'accuracy': 0}
    rng, sub_rng = jax.random.split(rng)
    perm = jax.random.permutation(rng, TRAIN_SIZE)
    X_train_perm = X_train[perm]
    Y_train_perm = Y_train[perm]

    for i in range(TRAIN_SIZE // BATCH_SIZE):

        X_batch = X_train_perm[i*BATCH_SIZE:(i+1)*BATCH_SIZE]
        Y_batch = Y_train_perm[i*BATCH_SIZE:(i+1)*BATCH_SIZE]

        state, metrics = train_step(state, (X_batch, Y_batch))
        train_metrics = {k: train_metrics[k] + metrics[k] for k in metrics}

    # Average metrics over the entire epoch
    train_metrics = {k: v / (TRAIN_SIZE // BATCH_SIZE) for k, v in train_metrics.items()}
    print(f"Epoch {epoch + 1}, Train Loss: {train_metrics['loss']:.4f}, Train Accuracy: {train_metrics['accuracy']:.4f}")


[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]]
[[0.]
 [1.]
 [0.]
 [1.]
 [0.]
 [0.]
 [1.]
 [1.]
 [1.]
 [1.]
 [0.]
 [1.]
 [0.]
 [0.]
 [0.]
 [1.]
 [0.]
 [1.]
 [1.]
 [1.]
 [0.]
 [1.]
 [1.]
 [1.]
 [0.]
 [1.]
 [0.]
 [0.]
 [1.]
 [0.]
 [1.]
 [0.]
 [0.]
 [0.]
 [1.]
 [0.]
 [0.]
 [1.]
 [0.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [1.]
 [0.]
 [0.]
 [1.]
 [0.]
 [1.]
 [0.]
 [0.]
 [0.]
 [0.]
 [1.]]
[[1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [1.]
 [

KeyboardInterrupt: 

### References:

1. Bahdanau et al. 2015
2. Rocktashel et al. 2016 Reasoning about Entailment with Neural Attention ()
3. Parikh, A Decomposable Attention Model for Natural Language Inference
4. GloVe, Pennington et al. (2014)
5. Lin et al. A Structured Self-Attentive Sentence Embedding
6. GloVe Embeddings