## (1) Load model

In [None]:
from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = "state-spaces/mamba-370m"

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


print("plain", print_trainable_parameters(model))

## (2) Generate Text

In [None]:
import torch
import torch.nn.functional as F


def generate(
    model,
    tokenizer,
    prompt: str,
    n_tokens_to_gen: int = 50,
    sample: bool = True,
    top_k: int = 40,
):
    model.eval()

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids

    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits = model(indices_to_input)[:, -1]

        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape

        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)

        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]

        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]

    return output_completions

In [None]:
print([(n, type(m)) for n, m in model.named_modules()])

In [None]:
from peft import LoraConfig, TaskType

target_modules = ["layers.3.mixer.x_proj"]

config = LoraConfig(target_modules=target_modules, task_type="CAUSAL_LM")

In [None]:
from peft import inject_adapter_in_model

lora_model = inject_adapter_in_model(config, model)

In [None]:
print("plain", print_trainable_parameters(lora_model))

In [None]:
print(generate(lora_model, tokenizer, "Mamba is the"))

In [None]:
print(type(lora_model))

In [None]:
from peft import get_peft_model_state_dict

peft_state_dict = get_peft_model_state_dict(lora_model)
print(peft_state_dict)

In [None]:
lora_model.save_pretrained(lora_adapter, save_adapter=True, save_config=True)

model_to_merge = PeftModel.from_pretrained(AutoModelForCausalLM.from_pretrained(base_model).to(“cuda”), lora_adapter)

merged_model = model_to_merge.merge_and_unload()
merged_model.save_pretrained(merged_model)