## Basic Setup

Run the cells below for the basic setup of this notebook.

In [None]:
try:
    from google.colab import drive
    IN_COLAB = True
except:
    IN_COLAB = False
    print('No colab environment, assuming local setup.')

if IN_COLAB:
    drive.mount('/content/drive')

    # TODO: Enter the foldername in your Drive where you have saved the unzipped
    # turorials folder, e.g. 'alphafold-decoded/tutorials'
    FOLDERNAME = None
    assert FOLDERNAME is not None, "[!] Enter the foldername."

    # Now that we've mounted your Drive, this ensures that
    # the Python interpreter of the Colab VM can load
    # python files from within it.
    import sys
    sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))
    %cd /content/drive/My\ Drive/$FOLDERNAME

    print('Connected COLAB to Google Drive.')

import os
    
base_folder = 'tensor_introduction'
control_folder = f'{base_folder}/control_values'

assert os.path.isdir(control_folder), 'Folder "control_values" not found, make sure that FOLDERNAME is set correctly.' if IN_COLAB else 'Folder "control_values" not found, make sure that your root folder is set correctly.'

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import math
import torch
import os

In [3]:
# You might need to change this path, depending on your directory structure and current working directory
control_folder = 'attention/control_values'
assert os.path.isdir(control_folder), 'Folder for control values not found. Please check your path.'

# Attention

Attention is the underlying mechanism for most of the biggest breakthroughs in Machine Learning in the last years. Google published the original transformer paper under the name 'Attention Is All You Need' and so far, it lived up to its expectation.

In this Notebook, we will implement the following attention mechanisms:

- MultiHeadAttention
- Gated MultiHeadAttention
- Global Gated MultiHeadAttention

These modules will do the heavy lifting for the Evoformer, the first part of AlphaFold's architecture. The rest of the Evoformer will mostly be about stacking the layers correctly. All of them will be implemented in the class `MultiHeadAttention`.

To get started, head over to `mha.py` and implement the `__init__` method and `prepare_qkv`. Don't worry about the global parameter for now, treat as if it were set to False. `prepare_qkv` will rearrange the tensors, so that the different heads are split up and  the attention dimension is moved to a fixed position.

Run the following code cell to check your implementation.

In [4]:
from attention.mha import MultiHeadAttention

c_in = 8
c = 10
N_head = 4
attn_dim = -3

mha = MultiHeadAttention(c_in, c, N_head, attn_dim=attn_dim, gated=True)

param_shapes = { name: param.shape for name, param in mha.named_parameters()}

expected_shapes = {
    'linear_q.weight': (40, 8),
    'linear_k.weight': (40, 8), 
    'linear_v.weight': (40, 8), 
    'linear_o.weight': (8, 40), 
    'linear_o.bias': (8,),
    'linear_g.weight': (40, 8), 
    'linear_g.bias': (40,),
}

assert param_shapes.keys() == expected_shapes.keys()
assert param_shapes.items() == expected_shapes.items()
    

In [5]:
qkv_shape = (3, 5, 4, 8) + (N_head*c,)
q = torch.linspace(-4, 4, steps=math.prod(qkv_shape)).reshape(qkv_shape)
k = torch.linspace(-3, 3, steps=math.prod(qkv_shape)).reshape(qkv_shape)
v = torch.linspace(-2, 2, steps=math.prod(qkv_shape)).reshape(qkv_shape)

q_prep, k_prep, v_prep = mha.prepare_qkv(q, k, v)

expected_q = torch.load(f'{control_folder}/prepped_q_local.pt')
expected_k = torch.load(f'{control_folder}/prepped_k_local.pt')
expected_v = torch.load(f'{control_folder}/prepped_v_local.pt')

assert torch.allclose(q_prep, expected_q)
assert torch.allclose(k_prep, expected_k)
assert torch.allclose(v_prep, expected_v)

Next, implement the forward pass through the MultiHeadAttention module. Again, don't worry about global attention for now. The method contains step-by-step instructions for the implementation.

After you're done, check your implementation with the following cell:

In [6]:
c_in = 8
c = 10
N_head = 4
attn_dim = -3
query_dim = 4

inp_shape = (3, 5, query_dim, 8, c_in)
inp = torch.linspace(-4, 4, math.prod(inp_shape)).reshape(inp_shape)

# Check for ungated attention

mha_ungated = MultiHeadAttention(c_in, c, N_head, attn_dim=attn_dim, gated=False)

with torch.no_grad():
    for key, param in mha_ungated.named_parameters():
        param.copy_(torch.linspace(-4, 4, param.numel(), device=param.device).reshape(param.shape))

    out_ungated = mha_ungated(inp)

expected_result = torch.load(f'{control_folder}/forward_nogate_nonglobal.pt')

assert torch.allclose(out_ungated, expected_result)

# Check for gated attention

mha_gated = MultiHeadAttention(c_in, c, N_head, attn_dim=attn_dim, gated=True)

with torch.no_grad():
    for key, param in mha_gated.named_parameters():
        param.copy_(torch.linspace(-4, 4, param.numel(), device=param.device).reshape(param.shape))

    out_gated = mha_gated(inp)

expected_result = torch.load(f'{control_folder}/forward_gated_nonglobal.pt')

assert torch.allclose(out_gated, expected_result)

# Check for gated attention with bias

bias = torch.linspace(-1, 1, N_head*query_dim**2).reshape(N_head, query_dim, query_dim)
with torch.no_grad():
    out_with_bias = mha_gated(inp, bias=bias)

expected_result = torch.load(f'{control_folder}/forward_gated_bias_nonglobal.pt')

assert torch.allclose(out_with_bias, expected_result)




Last, we will implement the global self-attention mechanism. It will be used in the ExtraMSA stack in AlphaFold to account for the large number of sequences. 

Global self-attention has two major differences:
- For the key and value embeddings, only one head is used
- The query vectors will be averaged over the query dimension, so that only one query vector will be used for the attention mechanism

Thinking back to the attention mechanism, the number of query vectors determines the number of outputs of the layer, so the global attention variant would reduce the number of outputs. However, AlphaFold only uses gated global attention, and the number of outputs is restored when broadcasting the weighted value vectors against the gate embedding.

Implement the method `prepare_qkv_global`. Also, modify the `__init__` method so that key and value embeddings use only one head when is_global is set, and modify the `forward` method so that `prepare_qkv_global` is called instead of `prepare_qkv` if is_global is set. You won't have to do any other modifications to forward, but it might be helpful to carefully look through the function and see why that's the case.

Test your code with the following cells.

In [7]:
c_in = 8
c = 10
N_head = 4
attn_dim = -3
query_dim = 4

inp_shape = (3, 5, query_dim, 8, c_in)
inp = torch.linspace(-4, 4, math.prod(inp_shape)).reshape(inp_shape)

mha_global = MultiHeadAttention(c_in, c, N_head, attn_dim=attn_dim, gated=True, is_global=True)

# Test for prepare_qkv_global

qkv_shape = (3, 5, 4, 8) + (N_head*c,)
q = torch.linspace(-4, 4, steps=math.prod(qkv_shape)).reshape(qkv_shape)
k = torch.linspace(-3, 3, steps=math.prod(qkv_shape)).reshape(qkv_shape)
v = torch.linspace(-2, 2, steps=math.prod(qkv_shape)).reshape(qkv_shape)

with torch.no_grad():
    q_prep, k_prep, v_prep = mha_global.prepare_qkv_global(q, k, v)


expected_q = torch.load(f'{control_folder}/prepped_q_global.pt')
expected_k = torch.load(f'{control_folder}/prepped_k_global.pt')
expected_v = torch.load(f'{control_folder}/prepped_v_global.pt')

assert torch.allclose(q_prep, expected_q)
assert torch.allclose(k_prep, expected_k)
assert torch.allclose(v_prep, expected_v)

In [8]:
with torch.no_grad():
    for key, param in mha_global.named_parameters():
        param.copy_(torch.linspace(-4, 4, param.numel(), device=param.device).reshape(param.shape))

    out = mha_global(inp)

expected_result = torch.load(f'{control_folder}/forward_global.pt')

assert torch.allclose(expected_result, out)

## Conclusion

With this chapter, we are done with the introductory material. In the next chapter, we will implement the input feature extractor, the module that builds the numeric input tensors for the model from the raw MSA text file.

If you want to learn more about attention, you can check out the later assignments from CS231n (the Computer Vision course from Stanford we suggested in the last chapter) or the [Annotated Transformer](http://nlp.seas.harvard.edu/annotated-transformer/), an online Jupyter Notebook that explains the Transformer Architecture, which powers modern LLMs like ChatGPT.