In [None]:
# type: ignore
# Load imports

import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from configuration_mamba import MambaConfig
from datasets import Dataset, DatasetDict, load_dataset
from huggingface_hub import HfApi, ModelFilter
from modeling_mamba import MambaForCausalLM, MambaModel
from peft import LoraConfig, PeftMixedModel, TaskType, get_peft_model
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

In [None]:
def load_json(json_path):
    with open(json_path, "r") as json_file:
        loaded_data = json.load(json_file)
    return loaded_data


def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )


def load_data(data_path):
    data = load_dataset(data_path).shuffle()
    return DatasetDict(
        {
            "train": data["train"].select(list(range(int(len(data["train"]) * 0.5)))),
            "valid": data["valid"].select(list(range(int(len(data["valid"]) * 0.5)))),
        }
    )


def load_model_pretrained(config):
    return MambaForCausalLM.from_pretrained(config)


def load_tokenizer(path):
    return AutoTokenizer.from_pretrained(path)


def make_config(json):
    config = MambaConfig(
        vocab_size=json["vocab_size"],
        d_model=json["d_model"],
        d_conv=json["d_conv"],
        expand=json["expand"],
        conv_bias=json["conv_bias"],
        bias=json["bias"],
        n_layer=json["n_layer"],
        dt_rank=json["dt_rank"],
        pad_vocab_size_multiple=json["pad_vocab_size_multiple"],
        initializer_range=json["initializer_range"],
    )
    return config


def split_data(data):
    train_size = int(len(data) * 0.8)
    valid_size = len(data) - train_size

    ds_train = data.select(list(range(train_size)))
    ds_valid = data.select(list(range(train_size, train_size + valid_size)))

    return DatasetDict({"train": ds_train, "valid": ds_valid})


def load_model(config):
    config = make_config(config)
    return MambaForCausalLM(config)


def load_model_with_LoRA(model, target_modules):
    config = LoraConfig(target_modules=target_modules)
    m1 = get_peft_model(model, config)
    m1.print_trainable_parameters()
    m1.save_pretrained("./wts/adapter")
    return m1


def get_checkpoint_model(model_name):
    def get_models_by_organization(org_id, model_name):
        api = HfApi()
        new_filter = ModelFilter(tags="mamba")
        models = api.list_models(filter=new_filter)

        models_list = []
        for i in models:
            if org_id in i.modelId:
                print(i)
                if model_name in i.modelId:
                    return i.modelId
        return False

    org_id = "mlsquare"
    return get_models_by_organization(org_id, model_name)


def make_config(json):
    config = MambaConfig(
        vocab_size=json["vocab_size"],
        d_model=json["d_model"],
        d_conv=json["d_conv"],
        expand=json["expand"],
        conv_bias=json["conv_bias"],
        bias=json["bias"],
        n_layer=json["n_layer"],
        dt_rank=json["dt_rank"],
        pad_vocab_size_multiple=json["pad_vocab_size_multiple"],
        initializer_range=json["initializer_range"],
    )
    return config


class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids)[0]
        labels = input_ids.to(lm_logits.device)

        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
        return lm_loss

In [None]:
class Seshu:
    def __init__(self, config_file, train_args=False):
        self.config_data = load_json(config_file)
        if train_args:
            self.train_args = train_args
        else:
            self.train_args = TrainingArguments(
                output_dir="mamba",
                per_device_train_batch_size=1,
                per_device_eval_batch_size=1,
                num_train_epochs=4,
                weight_decay=0.1,
                lr_scheduler_type="cosine",
                learning_rate=5e-4,
                fp16=False,
            )
        self.tokenizer = load_tokenizer(self.config_data["tokenizer_path"])

    def tokenize(self, data):
        outputs = self.tokenizer(
            data["tgt"],
            truncation=True,
            max_length=1024,
            return_overflowing_tokens=True,
            return_length=True,
        )
        input_batch = []
        for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
            if length != 0:
                input_batch.append(input_ids)
        return {"input_ids": input_batch}

    def train_lora(self):
        avail = False
        try:
            model = AutoModelForCausalLM.from_pretrained(self.config_data["model_path"])
            model.enable_input_require_grads()
            model.load_adapter(self.config_data["adapter_path"])
            avail = True
        except Exception as e:
            print("Adapter not valid!! creating new.")
        if not avail:
            model = load_model_pretrained(self.config_data["model_path"])
            model = load_model_with_LoRA(model, self.config_data["target_modules"])
        self.tokenizer.pad_token = self.tokenizer.eos_token
        data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
        data = load_data(self.config_data["data"])
        tokenized_data = data.map(
            self.tokenize, batched=True, remove_columns=data["train"].column_names
        )
        trainer = MambaTrainer(
            model=model,
            tokenizer=self.tokenizer,
            args=self.train_args,
            data_collator=data_collator,
            train_dataset=tokenized_data["train"],
            eval_dataset=tokenized_data["valid"],
        )
        trainer.train()

    #         model.push_to_hub(self.config_data["upload_path"])

    def pretrain(self):
        #         model_config = make_config(self.config_data)
        if get_checkpoint_model(self.config_data["model_path"]):
            model = load_model_pretrained(config)
        else:
            model = load_model(self.config_data)

        self.tokenizer.pad_token = self.tokenizer.eos_token
        data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
        data = load_data(self.config_data["data"])
        tokenized_data = data.map(
            tokenize, batched=True, remove_columns=data["train"].column_names
        )
        trainer = MambaTrainer(
            model=model,
            tokenizer=self.tokenizer,
            args=self.train_args,
            data_collator=data_collator,
            train_dataset=tokenized_data["train"],
            eval_dataset=tokenized_data["valid"],
        )
        trainer.train()


#         model.push_to_hub(self.config_data["upload_path"])

In [None]:
##CONFIG GENERATION

In [None]:
import json

data = {
    "model_path": "Q-bert/Mamba-130M",
    "tokenizer_path": "Q-bert/Mamba-130M",
    "target_modules": ["out_proj"],
    "adapter_path": "mlsquare/exp-lora-ada-1",
    "data": "mlsquare/samantar1per_cent_merged_with_train_val",
}

# Convert the dictionary to JSON format
json_data = json.dumps(data, indent=4)

# Save the JSON data to a file
with open("model_parameters_lora.json", "w") as json_file:
    json_file.write(json_data)

with open("model_parameters_lora.json", "r") as json_file:
    loaded_data = json.load(json_file)

# Print the loaded data
print(loaded_data)

In [None]:
import json

data = {
    "vocab_size": 7000,
    "d_model": 1,
    "d_conv": 4,
    "expand": 2,
    "conv_bias": True,
    "bias": False,
    "n_layer": 1,
    "dt_rank": "auto",
    "pad_vocab_size_multiple": 8,
    "initializer_range": 0.02,
    "tokenizer_path": "google/byt5-large",
    "upload_path": "mlsquare/pico_mamba",
    "data": "mlsquare/samantar1per_cent_merged_with_train_val",
}

# Convert the dictionary to JSON format
json_data = json.dumps(data, indent=4)

# Save the JSON data to a file
with open("model_parameters.json", "w") as json_file:
    json_file.write(json_data)

with open("model_parameters.json", "r") as json_file:
    loaded_data = json.load(json_file)

# Print the loaded data
print(loaded_data)

In [None]:
model = Seshu("model_parameters_lora.json")
model.train_lora()