In [10]:
import sys
import os
from functools import partial
import pandas as pd
import nevergrad as ng
from string import Template
from datasets import Dataset
sys.path.append("/home/v-oostapenko/dev/mttl")
sys.path.append("/home/v-oostapenko/dev/mttl/projects/wiki_experts/")
from lora_hub.algorithm import lorahub_learning, get_score
from mttl.datamodule.platypus_module import (
    PlatypusModule,
    PlatypusConfig,
    PlatypusQAModule,
)
from src.graph.module_graph import ModuleGraph

In [11]:

def get_examples_for_learning():
    """
    Get a few examples to learn to compose given LoRA modules
    """
    return [
        {"input":
            "Infer the date from context.\n\nQ: Jane is celebrating the last day of Jan 2012. What is the date tomorrow in MM/DD/YYYY?\nOptions:\n(A) 02/02/2012\n(B) 02/15/2012\n(C) 01/25/2012\n(D) 04/22/2012\n(E) 02/01/2012\n(F) 02/11/2012\nA:", "output": "(E)"}
    ]

In [12]:
def load_dataset(example_inputs, example_outputs, tokenizer):
    # add empty string if example_outputs is None
    if example_outputs is None:
        example_outputs = [""] * len(example_inputs)
    df = [
        {"input": example_inputs[i], "output": example_outputs[i]}
        for i in range(len(example_inputs))
    ]
    dataset = Dataset.from_pandas(pd.DataFrame(df))
    preprocess_func_with_tokenizer = partial(preprocess_function, tokenizer=tokenizer)
    processed_datasets = dataset.map(
        preprocess_func_with_tokenizer,
        batched=True,
        num_proc=1,
        desc="Running tokenizer on dataset",
    )
    return processed_datasets

def preprocess_function(examples, tokenizer):
    """
    standard preprocess function for dataset
    """
    inputs = examples["input"]
    targets = examples["output"]
    model_inputs = tokenizer(
        inputs,
        max_length=2048,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    labels = tokenizer(
        targets,
        max_length=2048,
        padding=True,
        truncation=True,
        return_tensors="pt",
    )
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs

In [40]:

from src.config import ExpertConfig
from huggingface_hub import login
from src.expert_model import MultiExpertModel

# if "HF_TOKEN" in os.environ:
#     login(token=os.environ["HF_TOKEN"])

config = ExpertConfig()
config.model = "EleutherAI/gpt-neo-125m"
config.load_in_8bit = True
config.model_family = "gpt"
config.data_dir = os.environ["MMLU_DATA_DIR"]
config.predict_batch_size = 2
config.max_input_length = 4096
config.max_output_length = 5
config.modify_layers= "q_proj|v_proj|k_proj"
config.trainable_param_names=".*lora_[ab].*"
config.model_modifier= "lora"
config.modify_modules = ".*"

config_dm = PlatypusConfig(
    model=config.model,
    train_batch_size=config.train_batch_size,
    predict_batch_size=config.predict_batch_size,
    max_input_length=config.max_input_length,
    max_output_length=config.max_output_length,
    validation_portion=config.validation_portion,
    model_family=config.model_family,
    train_on_inputs=False,
    train_on_reverse="platypus"
)
dm = PlatypusModule(config_dm)
model_class = MultiExpertModel
module = model_class(**vars(config), tokenizer=dm.tokenizer)

Padding side is right
Setting pad_token_id to 0, given that pad_token_id was not detected.
FlashAttention not found, skipping replacing attn with flash attn.


In [49]:
from string import Template
s = Template("security_studies -> linear(sordonia/expert_llama2_13b_security_studies:$weght1);\
            abstract_algebra -> linear(sordonia/expert_llama2_13b_security_studies:$weght2);\
            ")

In [52]:
s = s.safe_substitute(weght1=0.5, weght2=0.5)
graph = ModuleGraph.from_string(s)

In [45]:
from string import Template
s = Template("security_studies -> linear(sordonia/expert_llama2_13b_security_studies:$weght1);\
            abstract_algebra -> linear(sordonia/expert_llama2_13b_security_studies:$weght2);\
            ")
graph = ModuleGraph.from_string(s)
cache = graph.create_modules()

In [11]:
s = Template("security_studies -> linear(sordonia/expert_llama2_13b_security_studies:$weght1);\
            abstract_algebra -> linear(sordonia/expert_llama2_13b_security_studies:$weght2);\
            ")
def get_score(weights, template: Template):
    print(template.safe_substitute(weght1=weights[0], weght2=weights[1]))
    return sum((weights - 0.5) ** 2)

_get_score = partial(get_score, template=s)

instrum=ng.p.Array(
                init=[0]* 2,
                upper=[1.5] * 2,
                lower=[-1.5] * 2,
            )

optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=2)
recommendation = optimizer.minimize(_get_score)
print(recommendation.value)
# >>> [0.49999998 0.50000004]

security_studies -> linear(sordonia/expert_llama2_13b_security_studies:0.0);            abstract_algebra -> linear(sordonia/expert_llama2_13b_security_studies:0.0);            
security_studies -> linear(sordonia/expert_llama2_13b_security_studies:0.5);            abstract_algebra -> linear(sordonia/expert_llama2_13b_security_studies:0.0);            
[0.5 0. ]


In [47]:

instrum = ng.p.Array(
    init=[0] * number_of_loras,
    upper=[1.5] * number_of_loras,
    lower=[-1.5] * number_of_loras,
)
optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=max_inference_step)
print("> Begin to perform gradient-free optimization ...")
recommendation = optimizer.minimize(get_score_partial, verbosity=1)

dict_keys(['security_studies', 'abstract_algebra'])

In [32]:

# construct input list and output list
example_inputs, examples_outputs = [], []
for example in get_examples_for_learning():
    example_inputs.append(example["input"])
    examples_outputs.append(example["output"])

dataset = load_dataset(example_inputs, examples_outputs, dm.tokenizer) 

Running tokenizer on dataset: 100%|██████████| 1/1 [00:00<00:00, 207.84 examples/s]


In [None]:

def get_score(weights, model, cache, example_dataset, batch_size, get_loss, get_regular):
    # the composed lora state dict
    final_state_dict = {}
    # module list is the list
    lora_module_list = list(cache.keys())
    # all keys are the same
    keys = cache[lora_module_list[0]].keys()
    for i, peft_model_id in enumerate(lora_module_list):
        lora_state_dict = cache[peft_model_id]
        if i == 0:
            for key in keys:
                final_state_dict[key] = weights[i] * lora_state_dict[key]
        else:
            for key in keys:
                final_state_dict[key] = (
                    final_state_dict[key] + weights[i] * lora_state_dict[key]
                )
    # reload the model with the new adapter config
    set_peft_model_state_dict(model, final_state_dict)
        
    # minimize the metric
    loss = get_loss(example_dataset, model, batch_size)
    # L1 regularization term
    metric_val = loss + get_regular(weights)
    
    return metric_val

In [None]:

get_score_partial = partial(get_score, 
                            model=model, 
                            cache=cache,
                            example_dataset=dataset,
                            batch_size=batch_size,
                            get_loss=get_loss, 
                            get_regular=get_regular)
# set up the limit of the weights
instrum = ng.p.Array(
    init=[0] * number_of_loras,
    upper=[1.5] * number_of_loras,
    lower=[-1.5] * number_of_loras,
)
optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=max_inference_step)
print("> Begin to perform gradient-free optimization ...")
recommendation = optimizer.minimize(get_score_partial, verbosity=1)
final_lora = get_final_weights(recommendation.value, lora_module_list, cache)
# set the final weights
set_peft_model_state_dict(model, final_lora)
model = model.merge_and_unload()