In [None]:
import json
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from configuration_mamba import MambaConfig
from datasets import Dataset, DatasetDict, load_dataset
from huggingface_hub import HfApi, ModelFilter
from modeling_mamba import MambaForCausalLM, MambaModel
from peft import LoraConfig, PeftMixedModel, TaskType, get_peft_model
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

In [None]:
from huggingface_hub import HfApi, ModelFilter
from peft import PeftMixedModel


def get_models_by_organization(org_id):
    api = HfApi()
    new_filter = ModelFilter(tags="mamba")
    models = api.list_models(filter=new_filter)
    models_list = []
    for i in models:
        print(i.modelId)
        if org_id in i.modelId:
            models_list.append(i.modelId)
    return models_list


org_id = "mlsquare"
models = get_models_by_organization(org_id)
models

In [None]:
adapters = {
    "small": [
        "mlsquare/mamba_130M_small_out_proj",
        "mlsquare/mamba_130M_small_dt_proj",
        "mlsquare/mamba_130M_small_x_proj",
    ],
    "large": ["mlsquare/mamba_130M_large_x_dt_out_proj"],
}


def compute_loss(model, inputs, return_outputs=False):
    lm_logits = model(inputs)[0]
    labels = inputs.to(lm_logits.device)

    shift_logits = lm_logits[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous()
    loss_fct = torch.nn.CrossEntropyLoss()
    lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
    return lm_loss


def evaluation(data, model, tokenizer):
    val = 0
    for i in tqdm(data, desc="Evaluating"):
        value = tokenizer.encode(i['tgt'], return_tensors="pt")
        val += compute_loss(model, value)

    avg_loss = val / len(data)
    print("LOSS: ", avg_loss)
    return avg_loss


def model_merge_large(adapters, model_path, data, tokenizer):

    model = MambaForCausalLM.from_pretrained(model_path)
    print("model loaded")

    model.load_adapter(adapters["large"][0])
    print("adapter merged")

    result = evaluation(data, model, tokenizer)
    return result


def model_merge_small(adapters, model_path, data, tokenizer):

    base_model = MambaForCausalLM.from_pretrained(
        model_path, token="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"
    )
    print("model loaded")

    peft_model = PeftMixedModel.from_pretrained(
        base_model, adapters["small"][0], token="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"
    )
    peft_model.load_adapter(adapters["small"][1], adapter_name="1")
    peft_model.load_adapter(adapters["small"][2], adapter_name="2")
    peft_model.set_adapter(["default", "1", "2"])
    print("adapter merged")

    result = evaluation(data, peft_model, tokenizer)
    return result


def create_JSON(value):
    json_data = json.dumps(value, indent=4)
    with open(f"{value}", "w") as json_file:
        json_file.write(json_data)


def get_data(data_path, fraction=0.01):
    data = load_dataset(data_path)['train'].shuffle()
    data = data.select(list(range(int(len(data) * fraction))))
    print("data fetched")
    return data


def load_tokenizer(path):
    return AutoTokenizer.from_pretrained(path)

In [None]:
data = get_data(mamba_130M_small["data"])
tokenizer = load_tokenizer(mamba_130M_small["tokenizer_path"])
result = model_merge_small(adapters, mamba_130M_small["model_path"], data, tokenizer)

In [None]:
data = get_data(mamba_130M_large["data"])
tokenizer = load_tokenizer(mamba_130M_large["tokenizer_path"])
result = model_merge_small(adapters, mamba_130M_large["model_path"], data, tokenizer)

In [None]:
<model>-<PARAMS>-<AdapterComputation>-<target_modules>

In [None]:
mamba_130M_small = {
    "model_path": "mlsquare/pico_seshu",
    "tokenizer_path": "google/byt5-large",
    "adapter_path": "mlsquare/mamba_130M_large_x_dt",
    "data": "mlsquare/samantar1per_cent_merged_with_train_val",
}

In [None]:
mamba_130M_large = {
    "model_path": "mlsquare/pico_seshu",
    "tokenizer_path": "google/byt5-large",
    "adapter_path": "mlsquare/mamba_130M_large_x_dt",
    "data": "mlsquare/samantar1per_cent_merged_with_train_val",
}

In [None]:
from fedem import server

mamba_130M_small = {
    "model_path": "mlsquare/pico_seshu",
    "tokenizer_path": "google/byt5-large",
    "adapter_path": "mlsquare/mamba_130M_large_x_dt",
    "data": "mlsquare/samantar1per_cent_merged_with_train_val",
}

adapters = {
    "small": [
        "mlsquare/mamba_130M_small_out_proj",
        "mlsquare/mamba_130M_small_dt_proj",
        "mlsquare/mamba_130M_small_x_proj",
    ],
    "large": ["mlsquare/mamba_130M_large_x_dt_out_proj"],
}

data = server.get_data(mamba_130M_small["data"])
tokenizer = server.load_tokenizer(mamba_130M_small["tokenizer_path"])
result = server.model_merge_small(
    adapters, mamba_130M_small["model_path"], data, tokenizer
)