# How and Why does Self Attention Work? 

## Stacking Non-Linear Transformations

The prevailing paradigm in machine learning is to repeatedly perform the same non-linear transformation, initially on the input data, and then on the outputs successively. Each transformation is called a layer. "Deep" architectures are so called because they have many layers -- for example, the GPT-3 model trained at OpenAI had 96 layers in total. 

Many of the key breakthroughs in recent years have focused on resolving the problems that crop up while training deep architectures. I have covered a few of these in previous posts, such as [Batch Normalization](...) and [Dropout](...).

The earliest deep-learning architecture was what is now called a Feed-Forward Network, which in its current mature formulation, consists of a linear operation followed by a non-linear activation function $(\sigma)$ such as ReLU (Rectified Linear Unit).

$$
f(\bold{x}) = \sigma(\bold{Wx} + b)
$$


<img src="nn.svg" alt="Feed Forward Network" style="width:500px;"/>

Another popular architecture used primarily in computer vision is the Convolutional Neural Network (CNN), whose basic transformation is the convolution followed by an activation function. The idea here is that the model can learn to detect features in an image by performing non-linear transformations on small patches of the image. Then, just like the basic Feed-Forward network, the same operation can be performed on the outputs successively.


<img src="cnn.svg" alt="Feed Forward Network" style="width:800px;"/>

I won't go into to much detail on these here.


## A New Transformation

The idea of self-attention developed out of the sequence-to-sequence model architectures which were used primarily for machine translation. Here, the goal was to learn a single representation for an input sentence ("encode"), then use it to generate a translation ("decode").

It however proved difficult to compress the information needed to translate a long sentence into a single representation. 

Bahadanau et. al. (2014) introduced the concept of attention -- the model could learn to use parts of the input sentence ("attend") directly while decoding rather than rely solely on the learnt representation. Parikh et al. (2016) realized that the attention operation itself could be used for NLP tasks such as entailment. Lin et al. (2017) introduced the concept of self-attention to perform a variety of NLP tasks. 

In my earlier post on [Nadaraya-Watson Regression](...), we saw how this classic non-parametric technique can be interpreted as an early form of attention. 

Vaswani el al. (2017) realized that the self-attention mechanism could be used as the basic non-linear transformation for sequences of variable length. This also allowed the model to process tokens in parallel by incorporating positional information (check out my post on positional embeddings [here](...)). By avoiding having to explicitly process each token in sequence, it became possible to train much deeper networks. 




## So What Exactly is Self Attention?



## Let's See If it Works

Let's train a sentiment classifier. For our dataset, we will be using Yelp review dataset with GloVE embeddings. Our baseline model is a simple feedforward network that uses an average of the word embeddings. The candidate model will use an additional self-attention layer. 



### Let's start by loading the dataset

In [3]:
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state

# Load IMDb dataset from TensorFlow Datasets
def load_imdb_data(as_numpy=True):
    train_data, test_data = tfds.load('imdb_reviews', split=['train', 'test'], as_supervised=True)

    if as_numpy:
        train_data = tfds.as_numpy(train_data)
        test_data = tfds.as_numpy(test_data)

    return train_data, test_data


train_data, test_data = load_imdb_data(as_numpy=False)

Let's look at some reviews

In [6]:
from IPython.display import display, HTML
from itertools import islice

display(HTML("""
<style>
    .custom-paragraph {
        width: 600px;
        margin: auto;
        line-height: 1.6;
    }
</style>
"""))

In [7]:
n_samples = 5
#samples = list(islice(train_data, n_samples))
samples = list(train_data.take(n_samples).cache())
html = ""
for d in samples:
    text = d[0].numpy().decode('utf-8')
    label = "Positive" if d[1] == 1 else "Negative"
    html += f"""<div class="custom-paragraph">
              <p><b>[{label}]</b> {text}</p>
             </div>"""
display(HTML(html))

2024-10-24 15:46:30.137090: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


### Now, let's load the GloVE Embeddings



In [None]:
import numpy as np

def load_glove_embeddings(filepath):
    embeddings = {}
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            vector = np.asarray(values[1:], dtype='float32')
            embeddings[word] = vector
    return embeddings

# Load the 100-dimensional GloVe embeddings
glove_embeddings = load_glove_embeddings('datasets/glove.6B.100d.txt')

Let's build the embedding matrix

In [None]:
def build_embedding_matrix(vocab, glove_embeddings, embedding_dim=100):
    embedding_matrix = np.random.normal(size=(len(vocab), embedding_dim)).astype('float32')
    for word, idx in vocab.items():
        if word in glove_embeddings:
            embedding_matrix[idx] = glove_embeddings[word]
    return embedding_matrix

# Example vocabulary
vocab = {'the': 0, 'dog': 1, 'ran': 2, 'fast': 3, '<PAD>': 4, '<UNK>': 5}

# Create the embedding matrix
embedding_dim = 100
embedding_matrix = build_embedding_matrix(vocab, glove_embeddings, embedding_dim)


Let's define the baseline model

In [None]:
from flax import linen as nn
import jax.numpy as jnp
from flax.core.frozen_dict import freeze

class MyModel(nn.Module):
    vocab_size: int
    embedding_dim: int

    @nn.compact
    def __call__(self, x):
        embedding_layer = nn.Embed(
            num_embeddings=self.vocab_size,
            features=self.embedding_dim,
            embedding_init=lambda shape, dtype: jnp.array(embedding_matrix)
        )
        embedded = embedding_layer(x)
        return embedded

In [None]:
import jax

# Example input batch (batch_size=2, sequence_length=3)
input_data = jnp.array([[0, 1, 2], [1, 3, 4]])

# Initialize and apply the model
model = MyModel(vocab_size=len(vocab), embedding_dim=embedding_dim)

# Since Flax models are functional, we need to initialize parameters
params = model.init(jax.random.PRNGKey(42), input_data)
output = model.apply(freeze(params), input_data)

print(output)  # Output shape: (2, 3, 100)


TypeError: MyModel.__call__.<locals>.<lambda>() takes 2 positional arguments but 3 were given

In [None]:
list(glove_embeddings.keys())[1]

','

### References:

1. Bahdanau et al. 2015
2. Rocktashel et al. 2016 Reasoning about Entailment with Neural Attention ()
3. Parikh, A Decomposable Attention Model for Natural Language Inference
4. GloVe, Pennington et al. (2014)
5. Lin et al. A Structured Self-Attentive Sentence Embedding
6. GloVe Embeddings