In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.flex_attention import flex_attention, create_nested_block_mask

def generate_alibi_bias(H: int):
    """Returns an alibi bias score_mod given the number of heads H
    Args:
        H: number of heads
    Returns:
        alibi_bias: alibi bias score_mod
    """

    def alibi_mod(score, b, h, q_idx, kv_idx):
        scale = torch.exp2(-((h + 1) * 8.0 / H))
        bias = (q_idx - kv_idx) * scale
        return score + bias

    return alibi_mod


def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


class MultiHeadAttentionALIBI(nn.Module):
    """
    Computes multi-head attention with ALIBI. Supports nested tensors.

    Args:
        dim: Size of embedding dim for query, key and value
        num_head (int): Number of heads
        dropout (float, optional): Dropout probability. Default: 0.0
        bias (bool, optional): Whether to add bias to input projection. Default: True
    """

    def __init__(
        self,
        dim: int,
        num_head: int,
        dropout: float = 0.0,
        bias=True,
        device='cpu'
    ):
        super().__init__()
        self.dim = dim
        self.num_head = num_head
        self.dropout = dropout
        assert dim % num_head == 0, "Dim is not divisible by number of heads"
        self.dim_head = dim // num_head
        self.packed_proj = nn.Linear(dim, dim * 3, bias=bias, device=device)
        self.out_proj = nn.Linear(dim, dim, bias=bias, device=device)
        self.bias = bias
        self.score_mode = generate_alibi_bias(num_head)


    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        is_causal=False,
    ) -> torch.Tensor:
        """
        Forward pass; runs the following process:
            1. Apply input projection
            2. Split heads and prepare for flex attn
            3. Run flex attn
            4. Apply output projection
        Args:
            N = batch size
            query (torch.Tensor): query of shape (``N``, ``L*``, ``dim``)
            key (torch.Tensor): key of shape (``N``, ``L*``, ``dim``)
            value (torch.Tensor): value of shape (``N``, ``L*``, ``dim``)
            is_causal (bool, optional): Whether to apply causal mask. Default: False
        Returns:
            attn_output (torch.Tensor): output of shape (N, L*, dim)
        """
        N = query.size(0)
        # Step 1. Apply input projection
        if query is key and key is value:
            result = self.packed_proj(query)
            query, key, value = torch.chunk(result, 3, dim=-1)
        else:
            q_weight, k_weight, v_weight = torch.chunk(
                self.packed_proj.weight, 3, dim=0
            )
            if self.bias:
                q_bias, k_bias, v_bias = torch.chunk(
                    self.packed_proj.bias, 3, dim=0
                )
            else:
                q_bias, k_bias, v_bias = None, None, None
            query, key, value = (
                F.linear(query, q_weight, q_bias),
                F.linear(key, k_weight, k_bias),
                F.linear(value, v_weight, v_bias)
            )

        # Step 2. Split heads and prepare for flex attn
        # reshape query, key, value to separate by head
        # (N, L*, dim) -> (N, L*, num_head, dim_head) -> (N, num_head, L*, dim_head)
        query = query.unflatten(-1, [self.num_head, self.dim_head]).transpose(1, 2).detach().requires_grad_()
        key = key.unflatten(-1, [self.num_head, self.dim_head]).transpose(1, 2).detach().requires_grad_()
        value = value.unflatten(-1, [self.num_head, self.dim_head]).transpose(1, 2).detach().requires_grad_()


        block_mask = None
        if (is_causal):
          block_mask = create_nested_block_mask(causal_mask, N, self.num_head, query, key, _compile=True)
        # Step 3. Run flex attn
        # (N, num_head, L*, dim_head)
        attn_output = flex_attention(query, key, value, score_mod=self.score_mode, block_mask=block_mask)
        # (N, num_head, L*, dim_head) -> (N, L*, num_head, dim_head) -> (N, L*, dim)
        attn_output = attn_output.transpose(1, 2).flatten(-2)

        # Step 4. Apply output projection
        # (N, L*, dim) -> (N, L*, dim)
        attn_output = self.out_proj(attn_output)

        return attn_output

In [2]:
class FFN(nn.Module):
    def __init__(
        self,
        dim,
        hidden_dim, dropout=0.1, device='cpu'
    ):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.fc1 = nn.Linear(self.dim, self.hidden_dim, bias = True, device=device)
        self.dropout1 = nn.Dropout(p=dropout)
        self.fc2 = nn.Linear(self.hidden_dim, self.dim, bias = True, device=device)
        self.dropout2 = nn.Dropout(p=dropout)

        # for potential speed up
        # Pre-normalize the weights (can help with training stability)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, input):
        return self.dropout2(self.fc2(self.dropout1(F.relu(self.fc1(input)))))

In [3]:
class EncoderLayer(nn.Module):
    def __init__(
        self, dim, hidden_dim,
        num_head, dropout=0.1,
        bias=True, device='cpu'
    ):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim


        # attention sublayer
        self.self_attention = MultiHeadAttentionALIBI(
            dim = dim,
            num_head = num_head,
            dropout = dropout,
            bias = bias,
            device=device
        )

        # FFN sublayer
        self.ffn = FFN(
            dim = dim,
            hidden_dim = hidden_dim,
            dropout = dropout,
            device=device
        )

        # Dropout layer
        self.dropout = nn.Dropout(p=dropout)

        # layer-normalization layer
        self.LayerNorm_att = nn.LayerNorm(self.dim, device=device)
        self.LayerNorm_ffn = nn.LayerNorm(self.dim, device=device)


    def forward(self, x):
        att = self.self_attention(x, x, x)
        att = self.dropout(att)
        att_normalized = self.LayerNorm_att(x + att)

        ffn_sublayer = self.ffn(att_normalized)
        ffn_normalized = self.LayerNorm_ffn(att_normalized + ffn_sublayer)

        return ffn_normalized

class Encoder(nn.Module):
    def __init__(self, dim, hidden_dim, num_head, num_layers, dropout=0.1, device='cpu'):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(dim, hidden_dim, num_head, dropout, device=device)
            for _ in range(num_layers)
        ])

    def forward(self, src):
        for layer in self.layers:
            src = layer(src)
        return src

In [4]:
class DecoderLayer(nn.Module):

    def __init__(
        self, dim, hidden_dim,
        num_head, dropout=0.1,
        bias=True, device='cpu'
    ):
        super().__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim


        self.self_attention = MultiHeadAttentionALIBI(
            dim = dim,
            num_head = num_head,
            dropout = dropout,
            bias = bias,
            device=device
        )
        self.cross_attention = MultiHeadAttentionALIBI(
            dim = dim,
            num_head = num_head,
            dropout = dropout,
            bias = bias,
            device=device
        )

        # FFN sublayer
        self.ffn = FFN(
            dim = dim,
            hidden_dim = hidden_dim,
            dropout=dropout,
            device=device
        )


        # Dropout layer
        self.dropout_self_att = nn.Dropout(p=dropout)
        self.dropout_cross_att = nn.Dropout(p=dropout)

        # layer-normalization layer
        self.LayerNorm_self_att = nn.LayerNorm(self.dim, device=device)
        self.LayerNorm_cross_att = nn.LayerNorm(self.dim, device=device)
        self.LayerNorm_ffn = nn.LayerNorm(self.dim, device=device)

    def forward(self, tgt, memory):

        #calculate self attnention
        sa = self.self_attention(tgt, tgt, tgt, is_causal=True)
        sa = self.dropout_self_att(sa)
        sa_norm = self.LayerNorm_self_att(tgt + sa)

        #calculate cross attnention
        ca = self.cross_attention(sa_norm, memory, memory)
        ca = self.dropout_cross_att(ca)
        ca_norm = self.LayerNorm_cross_att(ca+sa_norm)

        #calculate ffn
        res = self.ffn(ca_norm)
        res_norm = self.LayerNorm_ffn(res+ca_norm)
        return res_norm


class Decoder(nn.Module):
    def __init__(self, dim, hidden_dim, num_head, num_layers, dropout=0.1, device='cpu'):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(dim, hidden_dim, num_head, dropout, device=device)
            for _ in range(num_layers)
        ])

    def forward(self, tgt, memory):
        for layer in self.layers:
            tgt = layer(tgt, memory)
        return tgt


In [5]:
class EncoderDecoder(nn.Module):
    def __init__(self, dim, hidden_dim, num_head, num_layers, dropout=0.1, device='cpu'):
        super().__init__()
        self.enc = Encoder(dim, hidden_dim, num_head, num_layers, dropout, device=device)
        self.dec = Decoder(dim, hidden_dim, num_head, num_layers, dropout, device=device)
    def forward(self, src, tgt):
        memory = self.enc(src)
        output = self.dec(tgt, memory)
        return output

In [6]:
class Transformer(nn.Module):
    def __init__(self, dim, hidden_dim, num_head, num_layers, src_voc, tgt_vocab, src_emb, tgt_emb, dropout=0.1, device='cpu'):
        super().__init__()
        self.src_embed=nn.Embedding(src_voc, src_emb, device=device)
        self.tgt_embed=nn.Embedding(tgt_vocab, tgt_emb, device=device)
        self.src_proj = nn.Linear(src_emb, dim, device=device)
        self.tgt_proj = nn.Linear(tgt_emb, dim, device=device)
        self.enc = Encoder(dim, hidden_dim, num_head, num_layers, dropout, device=device)
        self.dec = Decoder(dim, hidden_dim, num_head, num_layers, dropout, device=device)
    def forward(self, src, tgt):
        src = self.src_embed(src)
        src = self.src_proj(src)
        memory = self.enc(src)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_proj(tgt)
        output = self.dec(tgt, memory)
        return output

In [7]:
import numpy as np


def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return torch.tensor(sentence_lengths)


# Generate a batch of semi-realistic data using Zipf distribution for sentence lengths
# in the form of nested tensors with the jagged layout.
def gen_batch(N, vocab_size, device, dtype=torch.long):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    batch = torch.nested.nested_tensor(
        [torch.randint(0, vocab_size, (l,), dtype=dtype, device=device) for l in sentence_lengths],
        layout=torch.jagged, device=device
    )

    return batch, sentence_lengths

import math
import timeit


def benchmark(func, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    begin = timeit.default_timer()
    output = func(*args, **kwargs)
    torch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin), torch.cuda.max_memory_allocated()

In [8]:
torch.cuda.empty_cache()

In [11]:

class TransformerPad(nn.Module):
    def __init__(self, dim, hidden_dim, num_head, num_layers, src_voc, tgt_vocab, src_emb, tgt_emb, dropout=0.1, device='cpu'):
        super().__init__()
        self.src_embed=nn.Embedding(src_voc, src_emb, device=device)
        self.tgt_embed=nn.Embedding(tgt_vocab, tgt_emb, device=device)
        self.src_proj = nn.Linear(src_emb, dim, device=device)
        self.tgt_proj = nn.Linear(tgt_emb, dim, device=device)
        self.trans = nn.Transformer(d_model=dim, nhead=num_head, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True, device=device)
    def forward(self, src, tgt):
        src = self.src_embed(src)
        src = self.src_proj(src)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_proj(tgt)
        output = self.trans(src, tgt)
        return output

b_size = 64
src_vocab = 1754
src_embed = 1024
tgt_vocab = 1632
tgt_embed = 1024
dim = 512
hidden_dim = 2048
num_head = 8
num_layers = 6
dropout = 0.1
bias = True
device = "cuda"
torch.device(device)
torch.manual_seed(6)
src, src_sentence_lengths  = gen_batch(b_size, src_vocab, device)
tgt, tgt_sentence_lengths = gen_batch(b_size, tgt_vocab, device)
Ssrc = src_sentence_lengths.max().item()
Stgt = tgt_sentence_lengths.max().item()

print(
    f"Total sequence length in nested src {src_sentence_lengths.sum().item()}, max sequence length {Ssrc}"
)
print(
    f"Total sequence length in nested tg {tgt_sentence_lengths.sum().item()}, max sequence length {Stgt}"
)
padded_src, padded_tgt = (
    t.to_padded_tensor(0) for t in (src, tgt)
)

torch.manual_seed(6)
trans_njt = Transformer(
    dim, hidden_dim, num_head, num_layers, src_vocab, tgt_vocab, src_embed, tgt_embed, dropout=dropout, device=device
)
torch.manual_seed(6)
trans = TransformerPad(
    dim, hidden_dim, num_head, num_layers, src_vocab, tgt_vocab, src_embed, tgt_embed, dropout=dropout, device=device
)

# benchmark
nested_result, nested_time, nested_peak_memory = benchmark(
    trans_njt, src, tgt
)
padded_nested_result = nested_result.to_padded_tensor(0.0)

# benchmark
padded_result, padded_time, padded_peak_memory = benchmark(
    trans,
    padded_src,
    padded_tgt
)

print(f"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB")
print(f"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB")
print(
    f"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB"
)

Total sequence length in nested src 1579, max sequence length 107
Total sequence length in nested tg 1174, max sequence length 102
padded_time=0.18868, padded_peak_memory=7.01 GB
nested_time=1.31168, nested_peak_memory=4.05 GB
Nested peak memory reduction 2.96 GB
