### Load Pretrained Model
Load a pretrained Mamba Model that is compatible with Transformers Library

In [None]:
from configuration_mamba import MambaConfig
from modeling_mamba import MambaForCausalLM
from transformers import AutoTokenizer

config = MambaConfig(
    vocab_size=10,
    d_state=4,
    d_model=6,
    d_conv=4,
    expand=2,
    conv_bias=True,
    bias=False,
    n_layer=1,
)
model = MambaForCausalLM(config)
print(model.config)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("Q-bert/Mamba-130M")
text = "Hi"
input_ids = tokenizer.encode(text, return_tensors="pt")
input_ids[[0]] = 0
output = model.generate(input_ids, max_length=20, num_beams=5, no_repeat_ngram_size=2)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)

In [None]:
print(output)

In [None]:
print([(n, type(m)) for n, m in model.named_modules()])

In [None]:
print(type(model))


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


print("plain", print_trainable_parameters(model))

In [None]:
plist = model.state_dict().keys()
for p in plist:
    print(p)

In [None]:
# set one tensor to zero
import torch


def zero_init(model):
    state_dict_before = model.state_dict()
    state_dict_after = state_dict_before
    for p in state_dict_before:
        wt = state_dict_before[p]
        state_dict_after[p] = torch.zeros_like(wt)
    model.load_state_dict(state_dict_after)
    return model


s = "model.layers.0.in_proj.weight"
print("before", model.state_dict()[s])
model = zero_init(model)
plist = model.state_dict().keys()
for p in plist:
    print(p)
print("after", model.state_dict()[s])

### Add LoRA adapters
1. Identify a particular layer in the Mamba and add an LoRA layer there
2. At this time, is only layer to verify if the code works


In [None]:
from peft import LoraConfig, TaskType

target_modules = ["model.layers.0.x_proj"]

config = LoraConfig(target_modules=target_modules, task_type="CAUSAL_LM")

In [None]:
from peft import get_peft_model

model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
model.save_pretrained("wts")

### Merge the adpater into the Model
merge the adapter back to the model, so the merged model will have exactly the same architecture
except with the weights modified

In [None]:
from peft import PeftConfig, PeftModel

adapter_path = "./wts/"
adapter_config = PeftConfig.from_pretrained(adapter_path)


config = MambaConfig(
    vocab_size=10,
    d_state=4,
    d_model=6,
    d_conv=4,
    expand=2,
    conv_bias=True,
    bias=False,
    n_layer=1,
)

model = MambaForCausalLM(config)
base_model = MambaForCausalLM(config)
# base_model = zero_init(base_model)

adapted_model = PeftModel.from_pretrained(base_model, adapter_path)

In [None]:
m = adapted_model.merge_and_unload()

In [None]:
s = "model.layers.0.x_proj.weight"
print("before LoRA", base_model.state_dict()[s])

plist = m.state_dict().keys()
for p in plist:
    print(p)
print("after LoRA", m.state_dict()[s])

In [None]:
text = "Hi"

input_ids = tokenizer.encode(text, return_tensors="pt")
input_ids[[0]] = 0

output = m.generate(input_ids, max_length=20, num_beams=5, no_repeat_ngram_size=2)

generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(generated_text)

In [None]:
print("base mamba", print_trainable_parameters(base_model))
print("lora mamba", print_trainable_parameters(model))
print("merged mamba", print_trainable_parameters(m))

In [None]:
m.save_pretrained("./mbins", from_pt=True)

In [None]:
import torch

torch.save(m, "./mbins/merged_mamba.pt")

In [None]:
torch.save(base_model, "./mbins/base_mamba.pt")

In [None]:
adapted_model.state_dict().keys()

In [None]:
adapter_config

In [None]:
for p in adapted_model.named_parameters():
    print(p)