# How and Why does Self Attention Work? 

## Stacking Non-Linear Transformations

The prevailing paradigm in machine learning is to repeatedly perform the same non-linear transformation, initially on the input data, and then on the outputs successively. Each transformation is called a layer. "Deep" architectures are so called because they have many layers -- for example, the GPT-3 model trained at OpenAI had 96 layers in total. 

Many of the key breakthroughs in recent years have focused on resolving the problems that crop up while training deep architectures. I have covered a few of these in previous posts, such as [Batch Normalization](...) and [Dropout](...).

The earliest deep-learning architecture was what is now called a Feed-Forward Network, which in its current mature formulation, consists of a linear operation followed by a non-linear activation function $(\sigma)$ such as ReLU (Rectified Linear Unit).

$$
f(\bold{x}) = \sigma(\bold{Wx} + b)
$$


<img src="nn.svg" alt="Feed Forward Network" style="width:500px;"/>

Another popular architecture used primarily in computer vision is the Convolutional Neural Network (CNN), whose basic transformation is the convolution followed by an activation function. The idea here is that the model can learn to detect features in an image by performing non-linear transformations on small patches of the image. Then, just like the basic Feed-Forward network, the same operation can be performed on the outputs successively.


<img src="cnn.svg" alt="Feed Forward Network" style="width:800px;"/>

I won't go into to much detail on these here.


## A New Transformation

The idea of self-attention developed out of the sequence-to-sequence model architectures which were used primarily for machine translation. Here, the goal was to learn a single representation for an input sentence ("encode"), then use it to generate a translation ("decode").

It however proved difficult to compress the information needed to translate a long sentence into a single representation. 

Bahadanau et. al. (2014) introduced the concept of attention -- the model could learn to use parts of the input sentence ("attend") directly while decoding rather than rely solely on the learnt representation. Parikh et al. (2016) realized that the attention operation itself could be used for NLP tasks such as entailment. Lin et al. (2017) introduced the concept of self-attention to perform a variety of NLP tasks. 

In my earlier post on [Nadaraya-Watson Regression](...), we saw how this classic non-parametric technique can be interpreted as an early form of attention. 

Vaswani el al. (2017) realized that the self-attention mechanism could be used as the basic non-linear transformation for sequences of variable length. This also allowed the model to process tokens in parallel by incorporating positional information (check out my post on positional embeddings [here](...)). By avoiding having to explicitly process each token in sequence, it became possible to train much deeper networks. 




## So What Exactly is Self Attention?



## Let's See If it Works

Let's train a sentiment classifier. For our dataset, we will be using Yelp review dataset with GloVE embeddings. Our baseline model is a simple feedforward network that uses an average of the word embeddings. The candidate model will use an additional self-attention layer. 



### Let's start by loading the dataset

In [1]:
from datasets import load_dataset

# Load IMDb dataset from Hugging Face
raw_train_data, raw_test_data = load_dataset("imdb", split=["train", "test"])


  from .autonotebook import tqdm as notebook_tqdm


Let's look at some reviews

In [2]:
from IPython.display import display, HTML

display(HTML("""
<style>
    .custom-paragraph {
        width: 600px;
        margin: auto;
        line-height: 1.6;
    }
</style>
"""))

In [4]:
from itertools import islice
n_samples = 5
samples = list(islice(raw_train_data, n_samples))
#samples = list(raw_train_data.take(n_samples).cache())
html = ""
for d in samples:
    text = d['text']
    label = "Positive" if d['label'] == 1 else "Negative"
    html += f"""<div class="custom-paragraph">
              <p><b>[{label}]</b> {text}</p>
             </div>"""
display(HTML(html))

Preprocess the data

In [5]:
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter

nltk.download('punkt')

[nltk_data] Downloading package punkt to /Users/vikram/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [6]:
import re
from collections import Counter

def clean_paragraph(text):
    # Keep only alphanumeric characters, punctuation, and spaces.
    text = text.lower()
    text = re.sub(r'<br\s*/?>', ' ', text, flags=re.IGNORECASE)
    text = re.sub(r'[-]', ' ', text, flags=re.IGNORECASE)
    text = re.sub(r'[.]', '. ', text)
    cleaned_text = re.sub(r'[^a-z0-9\s\.,!?;:-]', '', text)
    return cleaned_text

def process_raw_data(raw_data):
    word_counts = Counter()
    processed_data = []
    for i, item in enumerate(raw_data):
        print(f"Processing item {i}", end="\r")
        # Split into sentences
        text = item['text'] #.numpy().decode('utf-8')
        text = clean_paragraph(text)
        label = item['label']
        # Tokenize
        tokens = word_tokenize(text)
        word_counts.update(tokens)
        processed_data.append((tokens, label))
        # Update indices
    sorted_counts = word_counts.most_common()
    return sorted_counts, processed_data


word_counts, processed_data = process_raw_data(raw_train_data)

Processing item 24999

In [7]:
VOCAB_SIZE = 9998
vocab = [ t[0] for t in word_counts[:VOCAB_SIZE] ]
vocab += [ '<UNK>', '<PAD>' ]
len(vocab)


10000

In [14]:
import numpy as np

EMBED_DIM = 100

def load_glove_embeddings(filepath):
    embeddings = {}
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype='float32')
            embeddings[word] = vector
    return embeddings

# Load the 100-dimensional GloVe embeddings
glove_embeddings = load_glove_embeddings('../datasets/glove/glove.6B.100d.txt')

Let's build the embedding matrix

In [15]:
def build_embedding_matrix(vocab, glove_embeddings):
    embedding_matrix = np.random.normal(size=(len(vocab), EMBED_DIM)).astype('float32')
    for idx, word in enumerate(vocab):
        if word in glove_embeddings:
            embedding_matrix[idx] = glove_embeddings[word]
    return embedding_matrix

# Create the embedding matrix
embedding_matrix = build_embedding_matrix(vocab, glove_embeddings)
embedding_matrix.shape

(10000, 100)

Let's define the baseline model

In [22]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
import optax

class EmbeddingModel(nn.Module):
    vocab_size: int
    embed_dim: int
    num_classes: int

    @nn.compact
    def __call__(self, x):
        # Embed the input tokens
        embeddings = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim)(x)

        # Compute the mean of embeddings along the sequence dimension (global average pooling)
        pooled = jnp.mean(embeddings, axis=1)

        # A single dense layer for classification
        logits = nn.Dense(self.num_classes)(pooled)
        return logits

# Initialize model and input data
vocab_size = 5000   # Example: 5000 unique tokens
embed_dim = 128      # Size of embedding vectors
num_classes = 2      # Example: Binary classification (e.g., positive/negative sentiment)
sequence_length = 32  # Input sequence length (e.g., 32 words)

model = EmbeddingModel(vocab_size, embed_dim, num_classes)

# Random input: Batch of 8 sequences with length 32 (word indices in range [0, vocab_size-1])
rng = jax.random.PRNGKey(0)
dummy_input = jax.random.randint(rng, (8, sequence_length), 0, vocab_size)

# Initialize model parameters
variables = model.init(rng, dummy_input)

# Print the initialized parameters
print(variables)


{'params': {'Embed_0': {'embedding': Array([[-0.1462095 ,  0.06522508, -0.11902562, ..., -0.00575609,
         0.04394421,  0.00789563],
       [-0.1023385 , -0.07691558,  0.00242564, ...,  0.01657767,
        -0.04966426,  0.05988198],
       [ 0.18051139,  0.157578  , -0.02782032, ..., -0.06300503,
        -0.20322858,  0.05086327],
       ...,
       [ 0.06365619,  0.08014771, -0.04117989, ...,  0.06356586,
         0.0695641 ,  0.01247433],
       [-0.13622792, -0.14932428, -0.04634263, ...,  0.05660709,
        -0.05892929,  0.10624246],
       [ 0.03298381, -0.10392094,  0.1450322 , ...,  0.11992919,
        -0.08287467,  0.14703886]], dtype=float32)}, 'Dense_0': {'kernel': Array([[-0.09500722,  0.01477639],
       [ 0.01441283,  0.19017147],
       [-0.14262271,  0.02427947],
       [-0.12324871, -0.03615518],
       [ 0.10451844, -0.00641061],
       [-0.09632497, -0.01472902],
       [ 0.09709357, -0.0105184 ],
       [-0.07130978, -0.01305214],
       [-0.0410353 , -0.0316590

In [21]:
import jax

# Example input batch (batch_size=2, sequence_length=3)
input_data = jnp.array([[0, 1, 2], [1, 3, 4]])

# Initialize and apply the model
model = MyModel(vocab_size=len(vocab), embedding_dim=embedding_dim)

# Since Flax models are functional, we need to initialize parameters
params = model.init(jax.random.PRNGKey(42), input_data)
output = model.apply(freeze(params), input_data)

print(output)  # Output shape: (2, 3, 100)


TypeError: MyModel.__call__.<locals>.<lambda>() takes 2 positional arguments but 3 were given

In [None]:
list(glove_embeddings.keys())[1]

','

: 

: 

: 

: 

### References:

1. Bahdanau et al. 2015
2. Rocktashel et al. 2016 Reasoning about Entailment with Neural Attention ()
3. Parikh, A Decomposable Attention Model for Natural Language Inference
4. GloVe, Pennington et al. (2014)
5. Lin et al. A Structured Self-Attentive Sentence Embedding
6. GloVe Embeddings