In [None]:
!git clone https://github.com/mlsquare/mergekit-mamba.git

In [None]:
pip install .

In [None]:
from huggingface_hub import HfApi, ModelFilter
from peft import PeftMixedModel


def get_models_by_organization(org_id):
    api = HfApi()
    new_filter = ModelFilter(tags="mamba")
    models = api.list_models(filter=new_filter)
    models_list = []
    for i in models:
        print(i.modelId)
        if org_id in i.modelId:
            models_list.append(i.modelId)
    return models_list


org_id = "mlsquare"
models = get_models_by_organization(org_id)
models

In [None]:
models = {
    "small": [
        "mlsquare/mamba_130M_small_out_proj",
        "mlsquare/mamba_130M_small_d_proj",
        "mlsquare/mamba_130M_small_x_proj",
    ],
    "large": ["mlsquare/mamba_130M_large_x_d_out_proj"],
}


def compute_loss(model, inputs, return_outputs=False):
    input_ids = inputs.pop("input_ids")
    lm_logits = model(input_ids)[0]
    labels = input_ids.to(lm_logits.device)

    shift_logits = lm_logits[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous()
    loss_fct = torch.nn.CrossEntropyLoss()
    lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
    return lm_loss


def evaluation(data_path, model):
    data = load_dataset(data_path).shuffle()
    tokenized_data = data.map(tokenize, batched=True, remove_columns=data.column_names)
    val = 0
    for i in data["tgt"]:
        val += compute_loss(model, i)
    print(val / len(data["tgt"]))
    return val / len(data["tgt"])


def model_merge_large(adapters, model_path, data_path):

    model = AutoModelForCausalLM.from_pretrained(model_path)
    model.load_adapter(adapters["large"][0])
    result = evaluation(data_path, model)


def model_merge_small(adapters, model_path, data_path):

    base_model = AutoModelForCausalLM.from_pretrained(model_path)
    peft_model = PeftMixedModel.from_pretrained(base_model, adapters["small"][0])
    peft_model.load_adapter(adapters["small"][1], adapter_name="1")
    peft_model.load_adapter(adapters["small"][2], adapter_name="2")
    peft_model.set_adapter(["default", "1", "2", "3"])
    result = evaluation(data_path, model)


def create_JSON(value):
    json_data = json.dumps(value, indent=4)
    with open(f"{value}", "w") as json_file:
        json_file.write(json_data)

In [None]:
<model>-<PARAMS>-<AdapterComputation>-<target_modules>

In [None]:
mamba_130M_small_out_proj = {
    "model_path": "Q-bert/Mamba-130M",
    "tokenizer_path": "Q-bert/Mamba-130M",
    "target_modules": ["out_proj"],
    "adapter_path": "mlsquare/mamba-130M-small-out_proj",
    "data": "mlsquare/samantar1per_cent_merged_with_train_val",
}

In [None]:
mamba_130M_small_d_proj = {
    "model_path": "Q-bert/Mamba-130M",
    "tokenizer_path": "Q-bert/Mamba-130M",
    "target_modules": ["d_proj"],
    "adapter_path": "mlsquare/mamba-130M-small-out_proj",
    "data": "mlsquare/samantar1per_cent_merged_with_train_val",
}

In [None]:
mamba_130M_small_x_proj = {
    "model_path": "Q-bert/Mamba-130M",
    "tokenizer_path": "Q-bert/Mamba-130M",
    "target_modules": ["x_proj"],
    "adapter_path": "mlsquare/mamba-130M-small-out_proj",
    "data": "mlsquare/samantar1per_cent_merged_with_train_val",
}

In [None]:
mamba_130M_large_x_d_out_proj = {
    "model_path": "Q-bert/Mamba-130M",
    "tokenizer_path": "Q-bert/Mamba-130M",
    "target_modules": ["x_proj", "d_proj", "out_proj"],
    "adapter_path": "mlsquare/mamba-130M-small-out_proj",
    "data": "mlsquare/samantar1per_cent_merged_with_train_val",
}