<a href="https://colab.research.google.com/github/isabelarvelo/LongLoRA/blob/main/s2_attention_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Shifted Sparse Attention Code Demo

In [1]:
import torch
import torch.nn.functional as F

In [None]:
# B: batch size; S: sequence length or number of tokens; G: group size;
# H: number of attention heads; D: dimension of each attention head

In [4]:
batch_size = 1
seq_len = 8
num_heads = 4
head_dim = 16
group_size = 4

In [6]:
x = torch.randn(batch_size, seq_len, num_heads, head_dim)
assert seq_len % group_size == 0, "Sequence length must be divisible by group size"

In [7]:
# Split heads into two groups
x1, x2 = x.chunk(2, dim=2)

In [9]:
x1.shape, x2.shape

(torch.Size([1, 8, 2, 16]), torch.Size([1, 8, 2, 16]))

In [19]:
# Example usage
B, N, H, D = 1, 16, 4, 8  # Batch size, Sequence length, Number of heads, Head dimension
G = 4  # Group size

# Create dummy input tensors
query_states = torch.randn(B, N, H, D)
key_states = torch.randn(B, N, H, D)
value_states = torch.randn(B, N, H, D)

In [20]:
query_states.shape, key_states.shape, value_states.shape

(torch.Size([1, 16, 4, 8]),
 torch.Size([1, 16, 4, 8]),
 torch.Size([1, 16, 4, 8]))

The `shift` function is the core of the S2-Attn mechanism. It does two main things:

1. Shifting: It takes the second half of the attention heads and shifts them by half the group size. This allows for information to flow between different parts of the sequence.
2. Reshaping: It reorganizes the data into a shape that's suitable for grouped attention. This step divides the sequence into groups and prepares them for parallel processing.

After defining this function, it's applied to the query, key, and value states. This prepares all three components for the S2-Attn mechanism, ensuring they're all shifted and reshaped in the same way.

In [22]:
def shift(qkv, bsz, q_len, group_size, num_heads, head_dim):
    # Shift the second half of the heads by half the group size
    # This creates an offset that allows for information flow between groups
    qkv[:, num_heads // 2:] = qkv[:, num_heads // 2:].roll(-group_size // 2, dims=1)

    # Reshape the tensor for grouped attention:
    # 1. Transpose to move heads dimension before sequence dimension
    # 2. Reshape to group the sequence
    # 3. Transpose again to put group dimension before heads
    qkv = qkv.transpose(1, 2).reshape(bsz * (q_len // group_size), group_size, num_heads, head_dim).transpose(1, 2)
    return qkv

# Apply the shift operation to query, key, and value states
# This prepares all three components for the S2-Attn mechanism
query_states = shift(query_states, B, N, G, H, D)
key_states = shift(key_states, B, N, G, H, D)
value_states = shift(value_states, B, N, G, H, D)

In [23]:
query_states.shape, key_states.shape, value_states.shape

(torch.Size([4, 4, 4, 8]), torch.Size([4, 4, 4, 8]), torch.Size([4, 4, 4, 8]))

Here we see the batch size is multiplied by a factor of 4 and for each batch, the sequence length is 1/4 the original.

In [25]:
# Simulate attention (in reality, you would perform actual attention here)
# For demonstration, we'll just return the query states
attn_output = query_states

In [26]:
# Reshape back to original shape
attn_output = attn_output.transpose(1, 2).reshape(B, N, H, D)

Let's break it down:

`attn_output.transpose(1, 2)`:

* This swaps the second and third dimensions.
* The shape changes from (B * (N // G), H, G, D) to (B * (N // G), G, H, D).


`.reshape(B, N, H, D)`

* This reshapes the tensor back to its original dimensions.
* It combines the first two dimensions (B * (N // G) and G) back into B and N.

The reason this works is that the total number of elements hasn't changed; we're just reorganizing them. The grouped structure we created earlier (B * (N // G), G) is now being "unfolded" back into (B, N). This reshaping operation is essentially reversing the grouping we did earlier in the shift function. It's taking our grouped, shifted attention output and reorganizing it back into the original sequence order and shape, which is necessary for further processing or for the output of the layer.

In [27]:
attn_output.shape

torch.Size([1, 16, 4, 8])

In [28]:
# Unshift the second half of heads
attn_output[:, :, H // 2:] = attn_output[:, :, H // 2:].roll(G // 2, dims=1)





*   `attn_output[:, :, H // 2:]`: selects the second half of the attention heads
*   `.roll(G // 2, dims=1)`: rolls (circularly shifts) these heads along the sequence dimension (dim=1) by half the group size (G // 2).
This rolling operation is the inverse of the initial shift we applied in the shift function.

In [29]:
attn_output.shape

torch.Size([1, 16, 4, 8])