In [1]:
# autoreload

%load_ext autoreload
%autoreload 2

In [1]:
2

2

In [2]:
from mamba_ssm import Mamba2
import torch


x = torch.randn(1, 4096, 512).cuda()

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=512,  # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,  # Local convolution width
    expand=2,  # Block expansion factor
    layer_idx=0,
).to("cuda")
y = model(x)
assert y.shape == x.shape



In [2]:
import torch
import torch.nn.functional as F

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.utils.generation import InferenceParams



from einops import rearrange, repeat

seqlens = [170, 65, 100]
genlen = 20
total_seqlen = sum(seqlens)
device = "cuda"
dtype = torch.float16

config = MambaConfig(
    d_model=1024,
    n_layer=4,
    vocab_size=32000,
    ssm_cfg=dict(layer="Mamba2"),
    rms_norm=True,
    residual_in_fp32=True,
    fused_add_norm=True,
    pad_vocab_size_multiple=16,
)
torch.manual_seed(2357)
model = MambaLMHeadModel(config, device=device, dtype=dtype)
xs = [torch.randint(0, 1000, (1, seqlen), device=device, dtype=torch.long) for seqlen in seqlens]

# Reference 1: Forward pass with seq_idx
x = torch.cat(xs, dim=1)
seq_idx = torch.cat([torch.full((ids.shape[1],), i, dtype=torch.int32, device=device)
                        for i, ids in enumerate(xs)], dim=0).unsqueeze(0)
cu_seqlens = F.pad(torch.tensor(seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
out_ref = model(x, seq_idx=seq_idx).logits
# Only take the last @genlen logits of each sequence
out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1]
                        for i in range(len(seqlens))], dim=0)
"""
# Reference 2: Generate the last @genlen tokens of each sequence in a for loop
out_loop = []
for input_ids in xs:
    out = model.generate(
        input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True,
        return_dict_in_generate=True, cg=True, teacher_outputs=input_ids,
    ).scores
    out_loop.append(torch.stack(out, dim=1))
out_loop = torch.cat(out_loop, dim=0)
print(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}")"""

'\n# Reference 2: Generate the last @genlen tokens of each sequence in a for loop\nout_loop = []\nfor input_ids in xs:\n    out = model.generate(\n        input_ids=input_ids[:, :-genlen], max_length=input_ids.shape[1], output_scores=True,\n        return_dict_in_generate=True, cg=True, teacher_outputs=input_ids,\n    ).scores\n    out_loop.append(torch.stack(out, dim=1))\nout_loop = torch.cat(out_loop, dim=0)\nprint(f"Max diff between ref1 and ref2: {(out_loop - out_ref).abs().max()}")'

In [4]:
out_ref

tensor([[[-0.2952,  1.2607, -0.2241,  ...,  0.5239, -1.0312,  0.4465],
         [-1.6006,  0.3271, -0.1626,  ...,  0.7397,  0.6523, -1.1016],
         [-0.3184, -0.0527,  0.2690,  ..., -0.1461, -0.0345, -0.3540],
         ...,
         [-0.9595, -0.3396,  0.9458,  ...,  0.0075, -0.5146, -0.8008],
         [ 0.6274,  0.6328, -0.2942,  ...,  0.2147, -1.2061,  0.4158],
         [ 0.2052, -0.8013, -0.1104,  ...,  0.2040, -0.0773, -0.1528]],

        [[ 0.5171,  0.8818,  0.0328,  ...,  0.3943, -0.6523, -0.5820],
         [ 0.2727, -0.5337, -0.0599,  ..., -0.1683,  1.1680,  0.4231],
         [ 0.9858,  1.4727,  0.0798,  ...,  0.3040, -0.3718,  0.2317],
         ...,
         [ 1.2686, -1.0332, -1.0049,  ..., -1.0918,  0.3596, -0.0473],
         [-0.6675, -0.6196,  0.0687,  ..., -1.1299, -0.7573, -0.2954],
         [-0.3403,  0.0311,  2.2285,  ...,  0.3616,  0.4629, -0.3442]],

        [[-0.9526,  0.0261,  0.4924,  ...,  0.8110, -0.7744, -0.0202],
         [-0.5474,  0.5635,  0.6396,  ..., -0

In [50]:
x.shape, seq_idx.shape

(torch.Size([1, 335]), torch.Size([1, 335]))

In [45]:
x.shape, seq_idx.shape

(torch.Size([1, 335]), torch.Size([1, 335]))

In [20]:
input_ids = torch.cat([ids[:, :-genlen] for ids in xs], dim=1)
prompt_seqlens = [seqlen - genlen for seqlen in seqlens]
cu_seqlens = F.pad(torch.tensor(prompt_seqlens, device=device, dtype=torch.int32).cumsum(dim=0), (1, 0))
seq_idx = torch.cat([torch.full((seqlen,), i, dtype=torch.int32, device=device)
                      for i, seqlen in enumerate(prompt_seqlens)], dim=0).unsqueeze(0)
inference_params = InferenceParams(max_seqlen=2048, max_batch_size=len(seqlens))

scores, sequences = [], []
# Both seq_idx and cu_seqlens must be passed in for varlen generation
logits = model(input_ids, inference_params=inference_params, seq_idx=seq_idx, cu_seqlens=cu_seqlens).logits

print(logits.shape)


torch.Size([1, 275, 32000])


In [21]:
rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d").shape


torch.Size([3, 1, 32000])

In [None]:
logits = rearrange(logits[0, cu_seqlens[1:] - 1], "b d -> b 1 d")
scores.append(logits)
# In practice we should sample. In this case we take from the teacher_output for testing
sampled_tokens = rearrange(
    torch.stack([ids[0, -genlen] for ids in xs], dim=0), "b -> b 1"
)
# sequences.append(sampled_tokens)
# for i in range(1, genlen):
#     inference_params.seqlen_offset += 1
#     logits = model(
#         sampled_tokens, inference_params=inference_params, num_last_tokens=1
#     ).logits
#     scores.append(logits)
#     # In practice we should sample. In this case we take from the teacher_output for testing

#     print(logits.shape)
#     sampled_tokens = rearrange(
#         torch.stack([ids[0, -genlen + i] for ids in xs], dim=0), "b -> b 1"
#     )
#     sequences.append(sampled_tokens)
# out_varlen = torch.cat(scores, dim=1)

# out_varlen.shape

In [1]:
from mt.ds import build_dataset

ds = build_dataset("iwslt17", "de", "en", is_encoder_decoder=False)
tokenizer = ds.get_tokenizer()
tokenizer.padding_side = "right"
tdl, vdl,_ = ds.get_dataloaders(train_batch_size=2, val_batch_size=2, tokenizer=tokenizer)
tdl

  warn(


<torch.utils.data.dataloader.DataLoader at 0x7fbb4578b8e0>

In [27]:
def pack_2d(tokens, cu_seqlens):
    """
    pack function: convert tokens to packed_tokens (batch_size=1)

    Args:
    tokens (torch.Tensor): Input tensor of shape (batch_size, max_seq_len)
    cu_seqlens (torch.Tensor): Cumulative sequence lengths tensor

    Returns:
    torch.Tensor: Packed tokens of shape (total_tokens,)
    """
    batch_size, max_seq_len = tokens.shape
    seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1]

    # Create a mask for valid tokens
    indices_2d = (
        torch.arange(max_seq_len, device=tokens.device)
        .unsqueeze(0)
        .expand(batch_size, -1)
    )
    mask_2d = indices_2d < seq_len_list.unsqueeze(1)

    print(mask_2d)
    # Apply the mask and flatten the result
    packed_tokens = tokens[mask_2d]

    return packed_tokens

In [3]:
def unpack(packed_hidden_states, cu_seqlens):
    batch_size = cu_seqlens.shape[0] - 1
    seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

    packed_hidden_states = packed_hidden_states.squeeze(0)

    ori_indices = (
        torch.arange(seq_len, device=cu_seqlens.device)
        .unsqueeze(0)
        .expand((batch_size, seq_len))
    )

    ori_indices = (ori_indices + cu_seqlens[:-1].unsqueeze(1)) % (
        len(packed_hidden_states)
    )

    return packed_hidden_states[ori_indices]


In [80]:
for i,b in enumerate(tdl):
    batch = {k: v.to(device) for k,v in b.items()}

    ids, labels = (
        batch["input_ids"][:, :-1].contiguous(),
        batch["input_ids"][:, 1:].contiguous(),
    )
    attention_mask = batch["attention_mask"][:, :-1]
    batch_size = attention_mask.shape[0]

    seqlens = attention_mask.sum(dim=1, dtype=torch.int32)
    seq_idx = torch.cat(
        [
            torch.full((seqlen,), i, dtype=torch.int32, device=ids.device)
            for i, seqlen in enumerate(seqlens)
        ],
        dim=0,
    ).unsqueeze(0)
    cu_seqlens = torch.zeros(
        batch_size + 1, dtype=torch.int32, device=attention_mask.device
    )
    cu_seqlens[1:] = seqlens.cumsum(0)

    packed_ids = ids[attention_mask.bool()].unsqueeze(0)
    print(cu_seqlens)
    print(packed_ids.shape)
    print(seq_idx)
    # break
    lm_logits = model.forward(
        input_ids=packed_ids,
        seq_idx=seq_idx,
    ).logits


    print(lm_logits.shape)
    unpacked = unpack(lm_logits, cu_seqlens)
    print(unpacked.shape)

    sep_mask = (ids == tokenizer.sep_token_id).cumsum(dim=1) > 0
    labels[~sep_mask] = tokenizer.pad_token_id

    loss = F.cross_entropy(
        unpacked.view(-1, lm_logits.size(-1)),
        labels.view(-1),
        ignore_index=tokenizer.pad_token_id,
    )
    print(loss)

    break

tensor([  0,  88, 145], device='cuda:0', dtype=torch.int32)
torch.Size([1, 145])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]], device='cuda:0', dtype=torch.int32)
torch.Size([1, 145, 32000])
torch.Size([2, 88, 32000])
tensor(10.6094, device='cuda:0', dtype=torch.float16,
       grad_fn=<NllLossBackward0>)


In [50]:
ids.view(-1).shape

torch.Size([84])

In [39]:
batch_size, max_seq_len = ids.shape
seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1]

# Create a mask for valid tokens
indices_2d = (
    torch.arange(max_seq_len, device=ids.device)
    .unsqueeze(0)
    .expand(batch_size, -1)
)
mask_2d = indices_2d < seq_len_list.unsqueeze(1)


tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True]], device='cuda:0')

In [66]:
# mask_2d.sum(dim=1, dtype=torch.int32), attention_mask.sum(dim=1, dtype=torch.int32)


attention_mask = attention_mask.contiguous()


torch.Size([2, 42, 42])

In [74]:
ids[attention_mask.bool()]

torch.Size([57])

In [69]:
ids.view(-1)

tensor([  565,   360,   440,   457,   741,  1130, 21312,    15,   304,  1894,
          558,  2373,   411,  6463, 22246,    18,    27,    22,  9801,   360,
           17,     2,   820,   354,   265,  2636, 26607,  9128,    15,  9128,
          336,  5039,  3693,   280, 22246, 29471,    27,    22,  2373,   411,
         6463,    17,  2981,   348,    16, 15060,  1540,  4914, 14023,    17,
            2,  1979,  7783,   979, 30988,    17,     1,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0], device='cuda:0')

In [71]:
ids.view(-1)[attention_mask.bool()]

IndexError: too many indices for tensor of dimension 1

In [56]:
cu_seqlens

tensor([ 0, 42, 57], device='cuda:0', dtype=torch.int32)

In [4]:
for i,b in enumerate(vdl):
    batch = {k: v.to(device) for k,v in b.items()}
    input_ids, labels = batch["input_ids"], batch["labels"]

    batch_size, seq_len = input_ids.shape
    max_length = 20
    attention_mask = batch["attention_mask"]

    done = torch.tensor([False] * batch_size).to(input_ids.device)

    inference_params = InferenceParams(
            max_seqlen=max_length + seq_len,
            max_batch_size=batch_size,)
    seqlens = attention_mask.sum(dim=1, dtype=torch.int32)
    seq_idx = torch.cat(
        [
            torch.full((seqlen,), i, dtype=torch.int32, device=input_ids.device)
            for i, seqlen in enumerate(seqlens)
        ],
        dim=0,
    ).unsqueeze(0)
    cu_seqlens = torch.zeros(
        batch_size + 1, dtype=torch.int32, device=attention_mask.device
    )
    cu_seqlens[1:] = seqlens.cumsum(0)

    packed_ids = input_ids[attention_mask.bool()].unsqueeze(0)
    lm_logits = model.forward(
        input_ids=packed_ids,
        seq_idx=seq_idx,
        cu_seqlens=cu_seqlens,
        inference_params=inference_params,
    ).logits
    unpacked = unpack(lm_logits, cu_seqlens)


    print(unpacked.shape, cu_seqlens)
    # next_tokens = torch.argmax(unpacked[:, cu_seqlens[1:]-1, :], dim=-1, keepdim=True)

    break

    for i in range(1, max_length):
        out = model.forward(
            input_ids=next_tokens,
            inference_params=inference_params,
        )
        next_tokens = torch.argmax(out.logits[:, -1, :], dim=-1, keepdim=True)
        input_ids = torch.cat((input_ids, next_tokens), dim=-1)
        is_eos = next_tokens == tokenizer.eos_token_id
        done = done | is_eos.squeeze(-1)
        if done.all():
            break

    print(input_ids.shape)
    break

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


torch.Size([2, 47, 32000]) tensor([ 0, 47, 73], device='cuda:0', dtype=torch.int32)


In [12]:
lm_logits[:,cu_seqlens[1:]-1].view(batch_size,1, -1).shape

torch.Size([2, 1, 32000])

In [14]:
out_ref = torch.cat([out_ref[:, cu_seqlens[i + 1] - genlen - 1:cu_seqlens[i + 1] - 1]
                      for i in range(len(seqlens))], dim=0)

In [15]:
out_ref.shape

torch.Size([3, 20, 1024])

In [53]:
def unpack(packed_hidden_states, cu_seqlens):
    batch_size = cu_seqlens.shape[0] - 1
    seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()

    packed_hidden_states = packed_hidden_states.squeeze(0)

    ori_indices = (
        torch.arange(seq_len, device=cu_seqlens.device)
        .unsqueeze(0)
        .expand((batch_size, seq_len))
    )

    ori_indices = (ori_indices + cu_seqlens[:-1].unsqueeze(1)) % (
        len(packed_hidden_states)
    )

    return packed_hidden_states[ori_indices]


"""
pack function: convert hidden_states to packed_hidden_states (batch_size=1)
"""


def pack(hidden_states, cu_seqlens):
    batch_size, seq_len, hidden_dim = hidden_states.shape
    seq_len_list = cu_seqlens[1:] - cu_seqlens[:-1]
    seq_len_list_3d = seq_len_list.unsqueeze(1).unsqueeze(2)
    indices_3d = (
        torch.arange(seq_len, device=hidden_states.device)
        .unsqueeze(0)
        .unsqueeze(2)
        .repeat(batch_size, 1, hidden_dim)
    )
    mask_3d = indices_3d < seq_len_list_3d
    packed_hidden_states = hidden_states[mask_3d].view(-1, hidden_dim)
    return packed_hidden_states

In [54]:
lens = [5, 8, 13]
max_seq_len = max(lens)
seqs = [torch.randn(lens[i], 512) for i in range(3)]
seqs = torch.cat(seqs, dim=0).cuda()
lens.insert(0, 0)
cu_seqlens = torch.cumsum(torch.tensor(lens), dim=0).cuda()

In [55]:
cu_seqlens, lens

(tensor([ 0,  5, 13, 26], device='cuda:0'), [0, 5, 8, 13])

In [56]:
seqs.shape

torch.Size([26, 512])

In [59]:
inf_params.key_value_memory_dict[0][0].shape

torch.Size([3, 1152, 4])

In [58]:

from mamba_ssm.utils.generation import InferenceParams

inf_params = InferenceParams(max_seqlen=512, max_batch_size=1)
out = model.forward(
    seqs.unsqueeze(0),
    # seqlen=max_seq_len,
    cu_seqlens=cu_seqlens,
    inference_params=inf_params,
)
out.shape

torch.Size([1, 26, 512])

In [50]:
inf_params.key_value_memory_dict[0][0].shape

torch.Size([4, 1152, 4])

In [104]:
unpack(out, cu_seqlens).shape

torch.Size([4, 8, 512])

In [65]:
out = model.forward(u=seqs, seqlen=13, cu_seqlens=cu_seqlens) 
out.shape

torch.Size([26, 512])

In [3]:
# autoreload

%load_ext autoreload
%autoreload 2

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.utils.generation import InferenceParams

import torch

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:


vocab_size = 32000
d_model = 512
n_layer = 12
rms_norm = True
fused_add_norm = True
use_fast_path = False
dropout = 0.1
device = None


cfg = MambaConfig(
    vocab_size=vocab_size,
    d_model=d_model,
    n_layer=n_layer,
    rms_norm=rms_norm,
    fused_add_norm=fused_add_norm,
    use_fast_path=use_fast_path,
    ssm_cfg={"layer": "Mamba2", 
             
             "dropout": dropout
             },
)

model = MambaLMHeadModel(
    device=device,
    config=cfg,
).cuda()

print(f"nr of params; {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

nr of params; 37025856


In [9]:
model.backbone.embedding.weight.device

device(type='cpu')

In [18]:
lens = [0, 16, 16, 16, 20]
torch.cumsum(torch.tensor(lens), dim=0)

tensor([ 0, 16, 32, 48, 68])

In [4]:
# batch = torch.randn(4, 32, 512).cuda()
batch = torch.randint(0, 32000, (4, 20)).cuda()
lens = [0, 16, 16, 16, 20]
cu_seqlens = torch.cumsum(torch.tensor(lens), dim=0).cuda()
max_seqlen = max(lens)

from mamba_ssm.utils.generation import InferenceParams
inf_params = InferenceParams(max_seqlen=512, max_batch_size=1)
out = model.forward(
    input_ids=batch,
    # seqlen=max_seqlen,
    cu_seqlens=cu_seqlens,
    inference_params=inf_params,
)
out

CausalLMOutput(logits=tensor([[[-0.4535, -0.1903,  0.1940,  ..., -0.2765,  0.3405,  0.3049],
         [-0.1870,  0.6263, -0.1692,  ...,  0.1917, -0.3314, -0.7155],
         [ 0.2659,  0.2048, -0.0832,  ..., -0.2611, -0.6467, -0.5289],
         ...,
         [-0.7333,  0.0273,  0.1092,  ..., -0.2086,  0.0383,  0.5336],
         [ 0.3112, -0.3334,  0.1502,  ..., -0.3022,  0.1446,  0.5136],
         [ 0.2224,  0.5957,  0.2837,  ..., -1.2423,  0.3974,  0.1422]],

        [[-0.4915,  0.5788, -0.0796,  ..., -0.1596,  0.0389,  0.2409],
         [-0.7328,  0.0601,  0.0737,  ..., -0.2258,  0.0514,  0.4765],
         [ 0.3146, -0.3397,  0.1267,  ..., -0.3161,  0.1515,  0.5114],
         ...,
         [-0.2585, -0.1724, -0.2640,  ..., -0.3327, -0.0336,  0.0461],
         [-0.7257,  0.0241,  0.7068,  ..., -0.5366,  1.1969, -0.5112],
         [ 0.0372,  0.0156, -0.3576,  ..., -0.4546, -0.2252,  0.7202]],

        [[-0.5095,  0.2670,  0.2471,  ..., -0.9624,  0.4414,  0.2995],
         [-0.2678, -0.1

In [63]:
out.logits.shape

torch.Size([4, 20, 32000])

In [15]:
seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
batch_size = len(seq_lengths)
batch_indices = torch.arange(batch_size, device=cu_seqlens.device)
last_token_logits = out.logits[batch_indices, seq_lengths - 1]
# last_token_logits.shape
next_tokens = torch.argmax(last_token_logits, dim=-1, keepdim=True)

In [39]:
last_token_logits.shape

torch.Size([4, 32000])

In [9]:
batch.shape

torch.Size([4, 20])

In [10]:
batch[batch_indices, seq_lengths-1]

tensor([ 5948, 29903,  4018, 29128], device='cuda:0')

In [20]:
torch.cat((batch, next_tokens), dim=1)    

tensor([[ 1443, 27456, 20415,  3087,  6465,  8025, 25972,  2599,   551, 27948,
         17773, 20349, 27143,   428,  9126,  5948,  3558,  8070, 18483, 17954,
         24459],
        [20300, 10746, 28850, 23706, 12792, 10966,  1608, 19539,  1391, 30376,
          2586,  4297, 31050, 17172, 16162, 29903, 11419, 16579,  9919, 13000,
          1777],
        [16963, 27322, 29404, 31576,  3638, 14424,  4787, 10924,   956,  4131,
          6685, 14701,  9244, 22546, 22311,  4018,  8396,  2757, 28358, 28719,
         28254],
        [23937, 29644,  3785, 29132,  9224, 24593, 17456,   906, 21355, 14527,
          8706, 16524,  7036,    20, 11992,  7826, 23393, 17528,  9246, 29128,
         18180]], device='cuda:0')

In [31]:
inf_params.seqlen_offset +=1
reout = model.forward(input_ids=next_tokens, inference_params=inf_params) 
reout

CausalLMOutput(logits=tensor([[[-0.1875, -1.3006,  0.3607,  ...,  0.4449, -0.9811,  1.2702]],

        [[ 0.3344,  0.2296, -0.4904,  ...,  0.1901,  0.4278, -0.2156]],

        [[ 0.7879, -0.0104,  0.2826,  ..., -0.3183, -0.3512,  0.2976]],

        [[-1.1600,  0.0934, -0.2397,  ..., -0.6869, -0.2997,  0.2996]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>))

In [32]:
reout.logits.shape

torch.Size([4, 1, 32000])

In [1]:
# autoreload

%load_ext autoreload
%autoreload 2
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.utils.generation import InferenceParams

import torch



In [2]:
from mt.ds import build_dataset

ds = build_dataset(
    name="iwslt17",
    source="de",
    target="en",
    is_encoder_decoder=False,
)
tokenizer = ds.get_tokenizer()
tokenizer.padding_side = "right"
tdl, vdl, _ = ds.get_dataloaders(
    tokenizer=tokenizer, train_batch_size=2, val_batch_size=2
)

  warn(


In [3]:
from models.factory import build_model
import torch
model_cls = build_model(
    task="mt",
    name="mamba2",
)
model_cls

Please install tensorboardX: pip install tensorboardX


models.mamba2.mt_wrapper.Mamba2MT

In [4]:
mt = model_cls(dropout=0, tokenizer=tokenizer, vocab_size=32000, precision='bf16-mixed', **model_cls.configs["default"]) 
mt = mt.to("cuda").to(torch.bfloat16)
print("nr of params: ", sum(p.numel() for p in mt.parameters() if p.requires_grad))
device = "cuda"

nr of params:  78308544


In [15]:
for i,b in enumerate(tdl):
    b = {k: v.to(device) for k,v in b.items()}
    out = mt.training_step(b, i)  
    break

/home/hugo/.pyenv/versions/3.10.4/envs/ctxeff/lib/python3.10/site-packages/pytorch_lightning/core/module.py:436: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


In [12]:
b['input_ids']

tensor([[  569,   770,   440,  2427, 25064,    15,   675,   683,   621,  9037,
          7452,    15,   304,   356,   427,  4119,   528,    17,     2,  1150,
           372,   980,   342,   480,   295,  3570,    15,   618,   570,   372,
           265,  1877,  1017,   336,   342,   447,   538,   480,   295,  3570,
            17,     1],
        [  500, 14460,   706, 19765,    17,     2,   296,  1523,   307, 19765,
            17,     1,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0]], device='cuda:0')

In [7]:
for i,b in enumerate(vdl):
    b = {k: v.to(device) for k, v in b.items()}
    print(b["input_ids"].shape)
    out = mt.validation_step(b, i)
    break

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


torch.Size([2, 47])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272


/home/hugo/.pyenv/versions/3.10.4/envs/ctxeff/lib/python3.10/site-packages/pytorch_lightning/core/module.py:436: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


In [24]:
b['input_ids']

tensor([[15548,   777,   830,   392,   594,  3407, 22629,  4699,    15,   548,
           355, 13407, 10964,   873,    15,   434,   304,   612,   488,   863,
           348,  2358, 20699,    15,   304,   531, 23442,  1510,  1908,  1261,
           304, 24216,   349, 11637, 13846,  4035,  1033,    15,   548,  2369,
          1904,  1283,  2789,  7134,   360,    17,     2],
        [  696,   525, 19376,   427,  4031,  2725,   304, 14380, 22632,  1050,
         12002, 13093,   541,    15,   587,   396,   427,   304,   337,  9032,
           558,  4501,   272,  2746,    17,     2,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]], device='cuda:0')

In [12]:
for i,b in enumerate(vdl):

    b = {k: v.to(device) for k, v in b.items()}

    input_ids, labels = b["input_ids"], b["labels"]
    batch_size, seq_len = input_ids.shape
    max_length = 512
    attention_mask = b["attention_mask"] if mt.use_padding else None

    cu_seqlens = torch.zeros(
        batch_size + 1, dtype=torch.int32, device=attention_mask.device
    )
    cu_seqlens[1:] = attention_mask.sum(dim=1, dtype=torch.int32).cumsum(0)

    cache = mt.model.allocate_inference_cache(
        batch_size=batch_size,
        max_seqlen=max_length + seq_len,
        dtype=mt.precision,
    )
    inference_params = InferenceParams(
        max_seqlen=max_length + seq_len,
        max_batch_size=batch_size,
        key_value_memory_dict=cache,
    )
    done = torch.tensor([False] * batch_size).to(input_ids.device)

    for idx in range(max_length):
        print(idx)
        if idx > 0:
            last_tokens = input_ids[:, -1:]  # (B, 1)

        outputs = mt.model.forward(
            input_ids=input_ids if idx == 0 else last_tokens,
            cu_seqlens=cu_seqlens if idx == 0 else None,
            inference_params=inference_params,
        ).logits


        if idx == 0:
            seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
            batch_indices = torch.arange(outputs.shape[0], device=outputs.device)
            next_token_logits = outputs[batch_indices, seq_lengths - 1]
        else:
            next_token_logits = outputs.squeeze(1)

        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        input_ids = torch.cat((input_ids, next_token), dim=-1)
        inference_params.seqlen_offset += 1

        is_eos = next_token == mt.tokenizer.eos_token_id
        done = done | is_eos.squeeze(-1)
        if done.all():
            break

    # Create a cumulative sum mask where positions after EOS become True
    eos_token_id = mt.tokenizer.eos_token_id
    eos_mask = (input_ids == eos_token_id).cumsum(dim=1) > 0
    input_ids[eos_mask] = mt.tokenizer.pad_token_id

    # mask source sentence
    source_mask = (input_ids == mt.tokenizer.sep_token_id).cumsum(dim=1) == 0
    input_ids[source_mask] = mt.tokenizer.pad_token_id

    tpreds = mt.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
    tlabels = mt.tokenizer.batch_decode(labels, skip_special_tokens=True)
    bleu_score = mt.bleu.compute(predictions=tpreds, references=tlabels)["score"]

    break

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
27

In [11]:
next_token_logits

tensor([[ 0.3711,  0.2197, -0.4512,  ..., -0.6328, -0.5195,  0.1167],
        [ 0.1260, -0.2188, -0.3457,  ...,  0.1221, -0.2266, -0.2451]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SqueezeBackward1>)

In [11]:
seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
outputs.shape, seq_lengths

"""(torch.Size([2, 47, 32000]),
 tensor([47, 26], device='cuda:0', dtype=torch.int32))"""

(torch.Size([2, 47, 32000]),
 tensor([47, 26], device='cuda:0', dtype=torch.int32))

In [20]:
outputs[:, seq_lengths - 1].shape

"""torch.Size([2, 2, 32000])
"""

'torch.Size([2, 2, 32000])\n'

In [21]:
batch_indices = torch.arange(outputs.shape[0], device=outputs.device)
last_token_indices = seq_lengths - 1
result = outputs[batch_indices, last_token_indices]

In [25]:
result.shape

torch.Size([2, 32000])

In [28]:
result.argmax(dim=-1, keepdim=True).shape

torch.Size([2, 1])