In [1]:
# !sudo apt install libopenmpi-dev -y
# !pip3 install mpi4py --user
# !pip3 install deepspeed==0.12.3 --user

In [2]:
# !pip3 install accelerate transformers -U --user

In [3]:
!pip3 freeze

absl-py==2.0.0
accelerate==0.25.0
aiofiles==23.2.1
aiohttp==3.8.5
aiohttp-cors==0.7.0
aiorwlock==1.3.0
aiosignal==1.3.1
altair==5.1.2
anyio==3.7.1
appdirs==1.4.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
asttokens==2.2.1
async-timeout==4.0.3
attributedict==0.3.0
attrs==23.1.0
autoawq @ https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.6/autoawq-0.1.6+cu118-cp310-cp310-linux_x86_64.whl
azure-core==1.29.5
azure-identity==1.15.0
azure-storage-blob==12.18.3
azure-storage-file-datalake==12.13.2
backcall==0.2.0
bcrypt==4.0.1
beautifulsoup4==4.12.2
bitsandbytes==0.41.0
bleach==6.0.0
blessed==1.20.0
blessings==1.7
boto3==1.28.78
botocore==1.31.78
Brotli==1.1.0
cachetools==5.3.2
causal-conv1d==1.0.0
certifi==2022.12.7
cffi==1.15.1
chardet==5.2.0
charset-normalizer==2.1.1
circuitbreaker==1.4.0
click==8.1.7
cmake==3.27.7
codecov==2.1.13
colorama==0.4.6
coloredlogs==15.0.1
colorful==0.5.5
colour-runner==0.1.1
comm==0.1.4
contourpy=

In [4]:
!nvidia-smi

Tue Dec  5 16:04:56 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000001:00:00.0 Off |                    0 |
| N/A   39C    P0    63W / 300W |  13456MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [5]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [6]:
# Copyright (c) 2023, Albert Gu, Tri Dao.

import math
from functools import partial

from collections import namedtuple

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import PretrainedConfig, PreTrainedModel

from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None


def create_block(
    d_model,
    ssm_cfg=None,
    norm_epsilon=1e-5,
    rms_norm=False,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
):
    if ssm_cfg is None:
        ssm_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}
    mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )
    block = Block(
        d_model,
        mixer_cls,
        norm_cls=norm_cls,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32,
    )
    block.layer_idx = layer_idx
    return block


# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
    module,
    n_layer,
    initializer_range=0.02,  # Now only used for embedding layer.
    rescale_prenorm_residual=True,
    n_residuals_per_layer=1,  # Change to 2 if we have MLP
):
    if isinstance(module, nn.Linear):
        if module.bias is not None:
            if not getattr(module.bias, "_no_reinit", False):
                nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
                # We need to reinit p since this code could be called multiple times
                # Having just p *= scale would repeatedly scale it down
                nn.init.kaiming_uniform_(p, a=math.sqrt(5))
                with torch.no_grad():
                    p /= math.sqrt(n_residuals_per_layer * n_layer)


class MixerModel(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_layer: int,
        vocab_size: int,
        ssm_cfg=None,
        norm_epsilon: float = 1e-5,
        rms_norm: bool = False,
        initializer_cfg=None,
        fused_add_norm=False,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32

        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)

        # We change the order of residual and layer norm:
        # Instead of LN -> Attn / MLP -> Add, we do:
        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
        # the main branch (output of MLP / Mixer). The model definition is unchanged.
        # This is for performance reason: we can fuse add + layer_norm.
        self.fused_add_norm = fused_add_norm
        if self.fused_add_norm:
            if layer_norm_fn is None or rms_norm_fn is None:
                raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")

        self.layers = nn.ModuleList(
            [
                create_block(
                    d_model,
                    ssm_cfg=ssm_cfg,
                    norm_epsilon=norm_epsilon,
                    rms_norm=rms_norm,
                    residual_in_fp32=residual_in_fp32,
                    fused_add_norm=fused_add_norm,
                    layer_idx=i,
                    **factory_kwargs,
                )
                for i in range(n_layer)
            ]
        )

        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
            d_model, eps=norm_epsilon, **factory_kwargs
        )

        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return {
            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }

    def forward(self, input_ids, inference_params=None):
        hidden_states = self.embedding(input_ids)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(
                hidden_states, residual, inference_params=inference_params
            )
        if not self.fused_add_norm:
            residual = (hidden_states + residual) if residual is not None else hidden_states
            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
        else:
            # Set prenorm=False here since we don't need the residual
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
            hidden_states = fused_add_norm_fn(
                hidden_states,
                self.norm_f.weight,
                self.norm_f.bias,
                eps=self.norm_f.eps,
                residual=residual,
                prenorm=False,
                residual_in_fp32=self.residual_in_fp32,
            )
        return hidden_states

class MambaLMHeadModel(PreTrainedModel, GenerationMixin):

    def __init__(
        self,
        config,
        initializer_cfg=None,
        pad_vocab_size_multiple: int = 1,
        device=None,
        dtype=None,
        **backbone_kwargs,
    ) -> None:
        d_model = config.d_model
        n_layer = config.n_layer
        vocab_size = config.vocab_size
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__(config=config)
        if vocab_size % pad_vocab_size_multiple != 0:
            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
        self.backbone = MixerModel(
            d_model=d_model,
            n_layer=n_layer,
            vocab_size=vocab_size,
            initializer_cfg=initializer_cfg,
            **backbone_kwargs,
            **factory_kwargs,
        )
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)

        # Initialize weights and apply final processing
        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )
        self.tie_weights()
        # _tied_weights_keys = ['lm_head.weight']

    def tie_weights(self):
        self.lm_head.weight = self.backbone.embedding.weight

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, labels = None):
        """
        "position_ids" is just to be compatible with Transformer generation. We don't use it.
        num_last_tokens: if > 0, only return the logits for the last n tokens
        """
        hidden_states = self.backbone(input_ids, inference_params=inference_params)
        if num_last_tokens > 0:
            hidden_states = hidden_states[:, -num_last_tokens:]
        lm_logits = self.lm_head(hidden_states)
        
        loss = None
        if labels is not None:
            logits = lm_logits
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
            print(loss, shift_logits, shift_logits.dtype, shift_labels, shift_labels.dtype)
            return (loss,)
            
        else:
            CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
            return CausalLMOutput(logits=lm_logits)

    @classmethod
    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
        config = load_config_hf(pretrained_model_name)
        model = cls(**config, device=device, dtype=dtype, **kwargs)
        model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
        return model

In [7]:
# !wget https://huggingface.co/state-spaces/mamba-1.4b/raw/main/config.json -O config-1.4b.json

In [8]:
import json

with open('config-1.4b.json') as fopen:
    config = json.load(fopen)
    config['hidden_size'] = config['d_model']

In [9]:
config = PretrainedConfig(**{**config, 'vocab_size': 32000})
config

PretrainedConfig {
  "d_model": 2048,
  "fused_add_norm": true,
  "hidden_size": 2048,
  "n_layer": 48,
  "pad_vocab_size_multiple": 8,
  "residual_in_fp32": true,
  "rms_norm": true,
  "ssm_cfg": {},
  "transformers_version": "4.35.2",
  "vocab_size": 32000
}

In [10]:
model = MambaLMHeadModel(config)

In [11]:
model

MambaLMHeadModel(
  (backbone): MixerModel(
    (embedding): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-47): 48 x Block(
        (mixer): Mamba(
          (in_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
          (act): SiLU()
          (x_proj): Linear(in_features=4096, out_features=160, bias=False)
          (dt_proj): Linear(in_features=128, out_features=4096, bias=True)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
        (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      )
    )
    (norm_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)

In [12]:
from streaming import LocalDataset
import numpy as np
from streaming.base.format.mds.encodings import Encoding, _encodings

class UInt16(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint16)

_encodings['uint16'] = UInt16

In [13]:
# !git lfs clone https://huggingface.co/datasets/malaysia-ai/mosaic-instructions

In [14]:
class DatasetFixed(torch.utils.data.Dataset):
    def __init__(self, local):
        self.dataset = LocalDataset(local=local)

    def __getitem__(self, idx):
        print(idx)
        data = self.dataset[idx]
        data['labels'] = data['input_ids'].copy()

        data.pop('token_type_ids', None)
        for k in data.keys():
            data[k] = data[k].astype(np.int64)
        return data

    def __len__(self):
        return len(self.dataset)

train_dataset = DatasetFixed(local='mosaic-instructions')

In [15]:
from transformers import TrainingArguments, Trainer, default_data_collator

output_dir = 'test-1.4b'

deepspeed = {
    "comms_logger": {
        "enabled": True,
        "debug": True
    },
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "bf16": {
        "enabled": "auto"
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
            "total_num_steps": "auto",
        }
    },

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": True
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": True
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
        "sub_group_size": 1e8,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e8,
        "stage3_max_reuse_distance": 1e8,
        "stage3_gather_16bit_weights_on_model_save": True
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": False
}

training_args = TrainingArguments(
    output_dir,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    logging_steps=1,
    save_strategy='steps',
    save_steps=5,
    num_train_epochs=3,
    learning_rate=1e-4,
    weight_decay=0,
    warmup_steps=1000,
    bf16=True,
    fp16=False,
    gradient_checkpointing=False,
    deepspeed=deepspeed,
    save_total_limit=5,
    log_level='debug',
    max_steps=100,
    save_safetensors=True
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [16]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=default_data_collator,
)

max_steps is given, it will override any value given in num_train_epochs
Using auto half precision backend


In [17]:
model

MambaLMHeadModel(
  (backbone): MixerModel(
    (embedding): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-47): 48 x Block(
        (mixer): Mamba(
          (in_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
          (act): SiLU()
          (x_proj): Linear(in_features=4096, out_features=160, bias=False)
          (dt_proj): Linear(in_features=128, out_features=4096, bias=True)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
        (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      )
    )
    (norm_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)

In [18]:
# trainer.train()
# already saved 2 checkpoints, now want to test to load

In [None]:
trainer.train(resume_from_checkpoint = 'test-1.4b/checkpoint-5')

Currently training with a batch size of: 2


[2023-12-05 16:05:11,549] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-12-05 16:05:13,847] [INFO] [logging.py:96:log_dist] [Rank -1] DeepSpeed info: version=0.12.3, git-hash=unknown, git-branch=unknown
[2023-12-05 16:05:13,848] [INFO] [comm.py:637:init_distributed] cdb=None
[2023-12-05 16:05:13,848] [INFO] [comm.py:652:init_distributed] Not using the DeepSpeed or dist launchers, attempting to detect MPI environment...
[2023-12-05 16:05:14,243] [INFO] [comm.py:702:mpi_discovery] Discovered MPI settings of world_rank=0, local_rank=0, world_size=1, master_addr=10.208.0.238, master_port=29500
[2023-12-05 16:05:14,244] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2023-12-05 16:05:16,230] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False


Using /home/ubuntu/.cache/torch_extensions/py310_cu118 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py310_cu118/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module cpu_adam...


ninja: no work to do.
Time to load cpu_adam op: 2.2710254192352295 seconds
[2023-12-05 16:05:20,706] [INFO] [logging.py:96:log_dist] [Rank 0] Using DeepSpeed Optimizer param name adamw as basic optimizer
[2023-12-05 16:05:20,706] [INFO] [logging.py:96:log_dist] [Rank 0] Removing param_group that has no 'params' in the basic Optimizer
[2023-12-05 16:05:20,735] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Basic Optimizer = DeepSpeedCPUAdam
[2023-12-05 16:05:20,735] [INFO] [utils.py:56:is_zero_supported_optimizer] Checking ZeRO support for optimizer=DeepSpeedCPUAdam type=<class 'deepspeed.ops.adam.cpu_adam.DeepSpeedCPUAdam'>
[2023-12-05 16:05:20,735] [INFO] [logging.py:96:log_dist] [Rank 0] Creating fp16 ZeRO stage 3 optimizer, MiCS is enabled False, Hierarchical params gather False
[2023-12-05 16:05:20,736] [INFO] [logging.py:96:log_dist] [Rank 0] Creating torch.bfloat16 ZeRO stage 3 optimizer
[2023-12-05 16:05:20,856] [INFO] [utils.py:802:see_memory_usage] Stage 3 initialize begin

Attempting to resume from test-130m/checkpoint-5


[2023-12-05 16:05:29,590] [INFO] [torch_checkpoint_engine.py:27:load] [Torch] Loading checkpoint from test-130m/checkpoint-5/global_step5/zero_pp_rank_0_mp_rank_00_model_states.pt...
[2023-12-05 16:05:29,601] [INFO] [torch_checkpoint_engine.py:29:load] [Torch] Loaded checkpoint from test-130m/checkpoint-5/global_step5/zero_pp_rank_0_mp_rank_00_model_states.pt.
[2023-12-05 16:05:29,602] [INFO] [torch_checkpoint_engine.py:27:load] [Torch] Loading checkpoint from test-130m/checkpoint-5/global_step5/zero_pp_rank_0_mp_rank_00_model_states.pt...
[2023-12-05 16:05:29,611] [INFO] [torch_checkpoint_engine.py:29:load] [Torch] Loaded checkpoint from test-130m/checkpoint-5/global_step5/zero_pp_rank_0_mp_rank_00_model_states.pt.
[2023-12-05 16:05:29,621] [INFO] [torch_checkpoint_engine.py:27:load] [Torch] Loading checkpoint from test-130m/checkpoint-5/global_step5/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt...
[2023-12-05 16:05:32,035] [INFO] [torch_checkpoint_engine.py:29:load] [Torch] Loaded c

***** Running training *****
  Num examples = 385,224
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 1
  Total optimization steps = 100
  Number of trainable parameters = 1,334,841,344
  Continuing training from checkpoint, will skip to saved global_step
  Continuing training from epoch 0
  Continuing training from global step 5
  Will skip the first 0 epochs then the first 5 batches in the first epoch.


104242
261743
235058
178647
Adam Optimizer #0 is created with AVX2 arithmetic capability.
Config: alpha=0.000100, betas=(0.900000, 0.999000), weight_decay=0.000000, adam_w=1
tensor(7.7500, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<NllLossBackward0>) tensor([[-1.2109,  1.5234, -1.1094,  ..., -0.7695, -1.7109,  0.7188],
        [-1.0000,  1.5156, -0.6523,  ..., -0.8203, -1.6875,  0.7656],
        [-1.1797,  0.9805, -0.6641,  ..., -0.5586, -1.0234,  0.7148],
        ...,
        [-1.1016,  1.3828, -0.5547,  ..., -0.6758, -1.5547,  1.3438],
        [-0.3125,  2.0625, -0.6211,  ..., -1.0938, -2.0312,  0.5898],
        [-0.7656,  1.6641, -0.2109,  ..., -0.8594, -1.8359,  1.2578]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ViewBackward0>) torch.bfloat16 tensor([   77,   201,    66,  ...,  4833,   521, 23351], device='cuda:0') torch.int64


Step,Training Loss
6,7.75
7,7.5625
8,7.6875
9,8.5


225275
290107
tensor(7.5625, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<NllLossBackward0>) tensor([[-1.0391,  1.9844, -0.9297,  ..., -1.6094, -1.9062,  1.2188],
        [-1.1328,  1.6016, -0.9844,  ..., -0.8711, -1.8125,  0.9297],
        [-0.5273,  1.7031, -0.9492,  ..., -1.3047, -1.6328,  0.7617],
        ...,
        [-0.7695,  1.7656, -1.0859,  ..., -1.0234, -2.4375,  1.2656],
        [-0.4004,  1.9688, -1.2422,  ..., -0.9844, -2.1094,  1.3281],
        [-0.5586,  2.1719, -0.7969,  ..., -0.6133, -2.1094,  0.8359]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ViewBackward0>) torch.bfloat16 tensor([  267, 23724,  1206,  ...,   650, 29570,   628], device='cuda:0') torch.int64
105315
358768
tensor(7.6875, device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<NllLossBackward0>) tensor([[-0.6523,  2.2500, -0.9883,  ..., -0.6523, -1.5625,  0.8555],
        [-0.6836,  2.0781, -1.3047,  ..., -0.7031, -2.0938,  0.8594],
        [-0.5273,  1.7031, -0.8633,  ..., -1.250

Saving model checkpoint to test-130m/checkpoint-10
Configuration saved in test-130m/checkpoint-10/config.json
Model weights saved in test-130m/checkpoint-10/pytorch_model.bin


[2023-12-05 16:06:05,475] [INFO] [logging.py:96:log_dist] [Rank 0] [Torch] Checkpoint global_step10 is about to be saved!
[2023-12-05 16:06:05,484] [INFO] [logging.py:96:log_dist] [Rank 0] Saving model checkpoint: test-130m/checkpoint-10/global_step10/zero_pp_rank_0_mp_rank_00_model_states.pt
[2023-12-05 16:06:05,484] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving test-130m/checkpoint-10/global_step10/zero_pp_rank_0_mp_rank_00_model_states.pt...
[2023-12-05 16:06:05,504] [INFO] [torch_checkpoint_engine.py:23:save] [Torch] Saved test-130m/checkpoint-10/global_step10/zero_pp_rank_0_mp_rank_00_model_states.pt.
[2023-12-05 16:06:05,506] [INFO] [torch_checkpoint_engine.py:21:save] [Torch] Saving test-130m/checkpoint-10/global_step10/bf16_zero_pp_rank_0_mp_rank_00_optim_states.pt...




In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint_path = '/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/checkpoints/checkpoint-80925'

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, ignore_mismatched_sizes=True)

question = "Do elephants help forests?"
inputs = tokenizer(question, return_tensors="pt")
outputs = model.generate(**inputs, max_length=100, do_sample=True)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Response:", response)

The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.


RuntimeError: Error(s) in loading state_dict for MambaForCausalLM:
	size mismatch for backbone.embeddings.weight: copying a param with shape torch.Size([50280, 2048]) from checkpoint, the shape in current model is torch.Size([50280, 768]).
	You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

In [3]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(checkpoint_path)
print(config)

MambaConfig {
  "_name_or_path": "/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/checkpoints/checkpoint-80925",
  "bos_token_id": 0,
  "conv_kernel": 4,
  "eos_token_id": 0,
  "expand": 2,
  "hidden_act": "silu",
  "hidden_size": 768,
  "initializer_range": 0.1,
  "intermediate_size": 1536,
  "layer_norm_epsilon": 1e-05,
  "model_type": "mamba",
  "num_hidden_layers": 32,
  "pad_token_id": 0,
  "rescale_prenorm_residual": false,
  "residual_in_fp32": true,
  "state_size": 16,
  "time_step_floor": 0.0001,
  "time_step_init_scheme": "random",
  "time_step_max": 0.1,
  "time_step_min": 0.001,
  "time_step_rank": 48,
  "time_step_scale": 1.0,
  "transformers_version": "4.46.1",
  "use_bias": false,
  "use_cache": true,
  "use_conv_bias": true,
  "use_mambapy": false,
  "vocab_size": 50280
}



In [None]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
from transformers import logging
logging.set_verbosity_error()
# from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

checkpoint_path = '/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/checkpoints/checkpoint-3236'

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = MambaForCausalLM.from_pretrained(checkpoint_path)
input_ids = tokenizer("Do elephants help forests?", return_tensors="pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=40)
print(tokenizer.batch_decode(out))

In [None]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
from transformers import logging
logging.set_verbosity_error()
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

checkpoint_path = '/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/checkpoints/checkpoint-3236'

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = MambaLMHeadModel.from_pretrained(checkpoint_path)
input_ids = tokenizer("Do elephants help forests?", return_tensors="pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=40)
print(tokenizer.batch_decode(out))

['Do elephants help forests?\n\n? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?']


In [None]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
from transformers import logging
logging.set_verbosity_error()
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

checkpoint_path = '/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/checkpoints/checkpoint-19416'

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = MambaLMHeadModel.from_pretrained(checkpoint_path)
input_ids = tokenizer("Do elephants help forests?", return_tensors="pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=40)
print(tokenizer.batch_decode(out))

['Do elephants help forests?\n\nElephants help forests forests more more by by smaller destroying damage plants stomp , , , increases more more carbon , , ,']


In [None]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

checkpoint_path = '/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/checkpoints/checkpoint-80925'

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = MambaLMHeadModel.from_pretrained(checkpoint_path)

input_ids = tokenizer("Do elephants help forests?", return_tensors="pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=40)
print(tokenizer.batch_decode(out))

['Do elephants help forests?\n\nElephants help forests store more carbon by destroying smaller plants By Sam Wong Naturally boosting the forest’s carbon-carrying capacity IAN REDMOND / naturepl.com Elephants do a lot of damage to plants as they stomp around the jungle, but, counterintuitively, this activity increases the biomass of the forest, letting it store more carbon']


In [None]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

checkpoint_path = '/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/checkpoints/checkpoint-80925'

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = MambaLMHeadModel.from_pretrained(checkpoint_path)
input_ids = tokenizer("What is a group of crocodiles called?", return_tensors="pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=40)
print(tokenizer.batch_decode(out))

['What is a group of crocodiles called?\n\nA What eats What is a group of crocodiles called? November 3, 2020 October 30, 2020 by Mustapha Bunu A group of crocodiles is a bask or float. Bask is commonly used for groups found on land whereas float is the popular term for groups found in water. Crocodiles are very social animals, in fact, they are the most social of all reptilian species and can be found congregated in huge numbers when basking or feeding. Groups of crocodiles are called a bask because thatâs literally what they do, they bask in shorelines or on top of trees, under the sun. In case youâre wondering what basking means, itâs to lie exposed to warmth and light, typically from the sun, for relaxation and pleasure â According to Dictionary.com.']


In [None]:
model


MambaLMHeadModel(
  (backbone): MixerModel(
    (embedding): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-47): 48 x Block(
        (mixer): Mamba(
          (in_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
          (act): SiLU()
          (x_proj): Linear(in_features=4096, out_features=160, bias=False)
          (dt_proj): Linear(in_features=128, out_features=4096, bias=True)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
        (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      )
    )
    (norm_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=2048, out_features=32000, bias=False)
)



In [None]:
model = MambaForCausalLM.from_pretrained(checkpoint_path)

In [48]:
model

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(50280, 2048)
    (layers): ModuleList(
      (0-47): 48 x MambaBlock(
        (norm): MambaRMSNorm(2048, eps=1e-05)
        (mixer): MambaMixer(
          (conv1d): Conv1d(4096, 4096, kernel_size=(4,), stride=(1,), padding=(3,), groups=4096)
          (act): SiLU()
          (in_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (x_proj): Linear(in_features=4096, out_features=160, bias=False)
          (dt_proj): Linear(in_features=128, out_features=4096, bias=True)
          (out_proj): Linear(in_features=4096, out_features=2048, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm(2048, eps=1e-05)
  )
  (lm_head): Linear(in_features=2048, out_features=50280, bias=False)
)

In [34]:
vocab_size = 32000
embedding_dim = 2048
embedding_params = vocab_size * embedding_dim

in_proj_input_dim = 2048
in_proj_output_dim = 8192
in_proj_params = in_proj_input_dim * in_proj_output_dim

conv1d_out_channels = 4096
conv1d_kernel_size = 4
conv1d_params = conv1d_out_channels * conv1d_kernel_size

x_proj_input_dim = 4096
x_proj_output_dim = 160
x_proj_params = x_proj_input_dim * x_proj_output_dim

dt_proj_input_dim = 128
dt_proj_output_dim = 4096
dt_proj_params = dt_proj_input_dim * dt_proj_output_dim + dt_proj_output_dim  

out_proj_input_dim = 4096
out_proj_output_dim = 2048
out_proj_params = out_proj_input_dim * out_proj_output_dim

layernorm_dim = 2048
norm_params = layernorm_dim * 2  

block_params = in_proj_params + conv1d_params + x_proj_params + dt_proj_params + out_proj_params + norm_params

num_blocks = 48
total_block_params = block_params * num_blocks

final_layernorm_params = layernorm_dim * 2

lm_head_output_dim = vocab_size 
lm_head_params = embedding_dim * lm_head_output_dim

total_params = embedding_params + total_block_params + final_layernorm_params + lm_head_params

total_params

1396838400

In [5]:
from transformers import MambaConfig

config_path = "/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/checkpoints/checkpoint-3236/config.json"

config = MambaConfig.from_pretrained(config_path)
config

MambaConfig {
  "architectures": [
    "MambaForCausalLM"
  ],
  "bos_token_id": 0,
  "conv_kernel": 4,
  "d_model": 2048,
  "eos_token_id": 0,
  "expand": 2,
  "fused_add_norm": true,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.1,
  "intermediate_size": 4096,
  "layer_norm_epsilon": 1e-05,
  "model_type": "mamba",
  "n_layer": 48,
  "num_hidden_layers": 48,
  "pad_token_id": 0,
  "pad_vocab_size_multiple": 8,
  "rescale_prenorm_residual": false,
  "residual_in_fp32": true,
  "rms_norm": true,
  "ssm_cfg": {},
  "state_size": 16,
  "time_step_floor": 0.0001,
  "time_step_init_scheme": "random",
  "time_step_max": 0.1,
  "time_step_min": 0.001,
  "time_step_rank": 128,
  "time_step_scale": 1.0,
  "torch_dtype": "float32",
  "transformers_version": "4.46.1",
  "use_bias": false,
  "use_cache": true,
  "use_conv_bias": true,
  "use_mambapy": false,
  "vocab_size": 50280
}

In [None]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
from transformers import logging
logging.set_verbosity_error()
# from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

checkpoint_path = '/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/retrain_checkpoints/checkpoint-62046'

from transformers import MambaConfig

config_path = "/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/checkpoints/checkpoint-3236/config.json"

config = MambaConfig.from_pretrained(config_path)

tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = MambaForCausalLM.from_pretrained(checkpoint_path, config=config)
input_ids = tokenizer("Do elephants help forests?", return_tensors="pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=40)
print(tokenizer.batch_decode(out))