In [None]:
!huggingface-cli login --token $TOKEN$ --add-to-git-credential 

In [1]:
cd mamba-hf/src

/home/yashc/elephant/mamba-hf/src


In [21]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer,TrainingArguments
from huggingface_hub import HfApi, ModelFilter
import torch
from datasets import load_dataset, Dataset, DatasetDict
import json
from configuration_mamba import MambaConfig
from modeling_mamba import MambaModel, MambaForCausalLM
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, TaskType, get_peft_model, PeftMixedModel
from transformers import DataCollatorForLanguageModeling, AutoModelForCausalLM
from tqdm import tqdm


In [50]:
def compute_loss(model, inputs, return_outputs=False): 
    lm_logits = model(inputs)[0]
    labels = inputs.to(lm_logits.device)

    shift_logits = lm_logits[:, :-1, :].contiguous()
    labels = labels[:, 1:].contiguous()
    loss_fct = torch.nn.CrossEntropyLoss()
    lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
    return lm_loss


def evaluation(data, model, tokenizer, batch_size=32, max_length = 1024):
    num_samples = len(data)
    num_batches = (num_samples + batch_size - 1) // batch_size
    total_loss = 0
    
    with torch.no_grad():
        model.eval()
        for batch_idx in tqdm(range(num_batches), desc="Evaluating"):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, num_samples)
            batch_data = data['tgt'][start_idx:end_idx]
            inputs = [tokenizer.encode(datum, return_tensors="pt",truncation=True, padding='max_length', max_length=max_length) for datum in batch_data]
            input_ids = torch.cat(inputs, dim=0)
            
            loss = compute_loss(model, input_ids)
            total_loss += loss.item() * (end_idx - start_idx)
            
    avg_loss = total_loss / num_samples
    return avg_loss

def model_merge(adapters, model_path, data, tokenizer):
    base_model = MambaForCausalLM.from_pretrained(model_path)
    print("model loaded")
    ls_count = 0
    names = ["default"]
    peft_model = PeftMixedModel.from_pretrained(base_model, adapters[ls_count])
    ls_count += 1
    while ls_count < len(adapters):
        peft_model.load_adapter(adapters[ls_count], adapter_name=str(ls_count))
        names.append(str(ls_count))
        ls_count += 1
        
    peft_model.set_adapter(names)
    peft_model = peft_model.merge_and_unload()
    print("adapter merged")
    
    result = evaluation(data, peft_model, tokenizer)
    return result
    
def create_JSON(value):
    json_data = json.dumps(value, indent=4)
    with open(f"{value}", "w") as json_file:
        json_file.write(json_data)
        
def get_data(data_path, fraction = 0.01):
    data = load_dataset(data_path)['train'].shuffle()
    data = data.select(list(range(int(len(data) * fraction))))
    print("data fetched")
    return data

In [51]:
def load_json(json_path):
    with open(json_path, "r") as json_file:
        loaded_data = json.load(json_file)
    return loaded_data

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def load_data(data_path):
    data = load_dataset(data_path).shuffle()
    return DatasetDict({"train":  data["train"].select(list(range(int(len(data["train"]) * 0.5)))), 
                        "valid":  data["valid"].select(list(range(int(len(data["valid"]) * 0.5))))} )
def load_model(config):
    config = make_config(config)
    return MambaForCausalLM(config)

def load_model_pretrained(config):
    return MambaForCausalLM.from_pretrained(config)

def load_tokenizer(path):
    return AutoTokenizer.from_pretrained(path)

def make_config(json):
    config = MambaConfig(
    vocab_size = json["vocab_size"],
    d_model = json["d_model"],
    d_conv = json["d_conv"],
    expand = json["expand"],
    conv_bias = json["conv_bias"],
    bias = json["bias"],
    n_layer = json["n_layer"],
    dt_rank = json["dt_rank"],
    pad_vocab_size_multiple = json["pad_vocab_size_multiple"],
    initializer_range = json["initializer_range"],
    )
    return config

# def split_data(data):
#     train_size = int(len(data) * 0.8)
#     valid_size = len(data) - train_size 

#     ds_train = data.select(list(range(train_size)))
#     ds_valid = data.select(list(range(train_size, train_size + valid_size)))

#     return DatasetDict({"train": ds_train, "valid": ds_valid})



# def load_model_with_LoRA(model, target_modules):
#     config = LoraConfig(
#     target_modules = target_modules)
#     m1 = get_peft_model(model, config)
#     m1.print_trainable_parameters()
#     m1.save_pretrained("./wts/adapter")
#     return m1

def get_checkpoint_model(model_name):
    def get_models_by_organization(org_id, model_name):
        api = HfApi()
        new_filter = ModelFilter(tags="mamba")
        models = api.list_models(filter=new_filter)

        models_list = []
        for i in models:
            if (org_id in i.modelId):
                print(i)
                if((model_name in i.modelId)):
                    return i.modelId
        return False


    org_id = "mlsquare"
    return get_models_by_organization(org_id, model_name)

def make_config(json):
    config = MambaConfig(
    vocab_size = json["vocab_size"],
    d_model = json["d_model"],
    d_conv = json["d_conv"],
    expand = json["expand"],
    conv_bias = json["conv_bias"],
    bias = json["bias"],
    n_layer = json["n_layer"],
    dt_rank = json["dt_rank"],
    pad_vocab_size_multiple = json["pad_vocab_size_multiple"],
    initializer_range = json["initializer_range"],
    )
    return config

class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):        
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids)[0]
        labels = input_ids.to(lm_logits.device)
        
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
        return lm_loss

In [52]:
class Seshu:
    def __init__(self,adapters, config_file, train_args = False):
        self.adapters = load_json(adapters)
        self.config_data = load_json(config_file)
        if train_args:
            self.train_args = train_args
        else:
            self.train_args = TrainingArguments(
                                output_dir="mamba",
                                per_device_train_batch_size=1,
                                per_device_eval_batch_size=1,
                                num_train_epochs=4,
                                weight_decay=0.1,
                                lr_scheduler_type="cosine",
                                learning_rate=5e-4,
                                fp16=False,
                            )
        self.tokenizer = load_tokenizer(self.config_data["tokenizer_path"])
        
    def tokenize(self, data):
        outputs = self.tokenizer(
            data["tgt"],
            truncation=True,
            max_length=1024,
            return_overflowing_tokens=True,
            return_length=True,
        )
        input_batch = []
        for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
            if length != 0:
                input_batch.append(input_ids)
        return {"input_ids": input_batch}
        
        
    def pretrain(self):
#         model_config = make_config(self.config_data)
        if get_checkpoint_model(self.config_data["upload_path"]):
            model = load_model_pretrained(self.config_data["upload_path"])
        else:
            model = load_model(self.config_data)

        self.tokenizer.pad_token = self.tokenizer.eos_token
        data_collator = DataCollatorForLanguageModeling(self.tokenizer, mlm=False)
        data = load_data(self.config_data["data"])
        tokenized_data = data.map(self.tokenize, batched=True, remove_columns=data["train"].column_names)
        trainer = MambaTrainer( model=model, tokenizer=self.tokenizer, args=self.train_args, data_collator=data_collator,
                                train_dataset=tokenized_data["train"], eval_dataset=tokenized_data["valid"])
        trainer.train()

    def model_merge_eval(self, model_path, type_config = "small", data = "mlsquare/SERVER_samantar_mixed_val"):
        adapters = self.adapters[type_config]
        data = get_data(data)
        tokenizer = self.tokenizer
        result = model_merge(adapters, model_path, data, tokenizer)
        return result

In [53]:
model = Seshu("adapters.json", "model_parameters_lora.json")
model.model_merge_eval( "mlsquare/pico_seshu_test", type_config = "large", data = "mlsquare/SERVER_samantar_mixed_val")
# model.train_lora()

data fetched
model loaded
adapter merged


Evaluating: 100%|█████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  1.52it/s]


7.901081181779692

In [None]:
##CONFIG GENERATION

In [16]:
import json
data = {
    "model_path" : "mlsquare/pico_seshu_test",
    "tokenizer_path": "google/byt5-large",
    "adapter_path" : "mlsquare/mamba_pico_small_dt_proj",
    "data" : "mlsquare/CLIENT_samantar_mixed_train_val",
    "target_modules" : ["model.layers.3.dt_proj"]
}

# Convert the dictionary to JSON format
json_data = json.dumps(data, indent=4)

# Save the JSON data to a file
with open("model_parameters_lora.json", "w") as json_file:
    json_file.write(json_data)
    
with open("model_parameters_lora.json", "r") as json_file:
    loaded_data = json.load(json_file)

# Print the loaded data
print(loaded_data)

{'model_path': 'mlsquare/pico_seshu_test', 'tokenizer_path': 'google/byt5-large', 'adapter_path': 'mlsquare/mamba_pico_small_dt_proj', 'data': 'mlsquare/CLIENT_samantar_mixed_train_val', 'target_modules': ['model.layers.3.dt_proj']}


In [7]:
import json
data = {
    "vocab_size":20000,
    "d_state":16,
    "d_model":2560,
    "d_conv":4,
    "expand":2,
    "conv_bias":True,
    "bias":False,
    "n_layer":64,
    "pad_vocab_size_multiple":8,
    "dt_rank": "auto",
    "initializer_range":0.02,
    "tokenizer_path": "google/byt5-large",
    "upload_path" : "mlsquare/130M_Seshu",
    "data" : "mlsquare/CLIENT_samantar_mixed_train_val"
}

# Convert the dictionary to JSON format
json_data = json.dumps(data, indent=4)

# Save the JSON data to a file
with open("model_parameters.json", "w") as json_file:
    json_file.write(json_data)
    
with open("model_parameters.json", "r") as json_file:
    loaded_data = json.load(json_file)

# Print the loaded data
print(loaded_data)

{'vocab_size': 20000, 'd_state': 16, 'd_model': 2560, 'd_conv': 4, 'expand': 2, 'conv_bias': True, 'bias': False, 'n_layer': 64, 'pad_vocab_size_multiple': 8, 'dt_rank': 'auto', 'initializer_range': 0.02, 'tokenizer_path': 'google/byt5-large', 'upload_path': 'mlsquare/130M_Seshu', 'data': 'mlsquare/CLIENT_samantar_mixed_train_val'}


In [45]:
data = {  "small" : ["mlsquare/mamba_pico_small_out_proj", "mlsquare/mamba_pico_small_dt_proj", "mlsquare/mamba_pico_small_x_proj"],
           "large" : ["mlsquare/mamba_pico_large_x_dt_out_proj"]
                }

json_data = json.dumps(data, indent=4)

# Save the JSON data to a file
with open("adapters.json", "w") as json_file:
    json_file.write(json_data)
    
with open("adapters.json", "r") as json_file:
    loaded_data = json.load(json_file)

# Print the loaded data
print(loaded_data)

{'small': ['mlsquare/mamba_pico_small_out_proj', 'mlsquare/mamba_pico_small_dt_proj', 'mlsquare/mamba_pico_small_x_proj'], 'large': ['mlsquare/mamba_pico_large_x_dt_out_proj']}
