In [1]:
## dataset.py

import json 
import pandas as pd

class Dataset(object):
    def __init__(
        self,
        dataset_filepath: str,
    ):
        self.dataset = []
        self.dataset = pd.read_csv(dataset_filepath).to_dict('records')
        for dp in self.dataset:
            if not dp['answer_choices'] or dp['answer_choices'] != dp['answer_choices']:
                del dp['answer_choices']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]
    



In [2]:
## model/decoder_functions.py

import torch
import torch.nn.functional as F

def decoder_predict_multiple_choice(
    transformer, input_tokenizer, target_tokenizer, batch
): 
    tokenized_batch = decoder_tokenize_batch(input_tokenizer, target_tokenizer, batch, transformer.device)
    output = transformer(
        input_ids=tokenized_batch["input_ids"],
        attention_mask=tokenized_batch["input_mask"],
        use_cache=True,
    )
    past_key_values = output.past_key_values

    num_answer_choices = (
        tokenized_batch["answer_choices_ids"].shape[0]
        // tokenized_batch["input_mask"].shape[0]
    )

    '''
    Expand the input_mask and past_key_values since these are the same and can be repeated for the different answer choices within an example 
    '''

    batch_size, max_input_len = tokenized_batch["input_mask"].shape
    expanded_input_mask = torch.repeat_interleave(tokenized_batch["input_mask"], num_answer_choices, dim=0)

    expanded_past_key_valyes = []
    for pastKeyValues_perLayer in past_key_values:
        list_broadcast_pastKeyValues_perLayer = []
        for key_or_value in pastKeyValues_perLayer:
            # This is for keys or values which have dimension [batch_size, max_input_len, num_heads, head_dim]
            # This is the standard for Hugging Face.
            if len(key_or_value.shape) == 4:
                list_broadcast_pastKeyValues_perLayer.append(
                    torch.repeat_interleave(key_or_value, num_answer_choices, dim=0)
                )
            # This is for keys or values which have dimension [batch_size x num_heads, head_dim, max_input_len].
            # This is what is used for BLOOM in transformers == 4.22.0
            elif len(key_or_value.shape) == 3:
                num_heads = key_or_value.shape[0] // batch_size
                flatten_keyOrValue = key_or_value.reshape(
                    ((batch_size, num_heads) + key_or_value.shape[1:])
                )
                broadcast_flatten_keyOrValue = torch.repeat_interleave(
                    flatten_keyOrValue, num_answer_choices, dim=0
                )
                list_broadcast_pastKeyValues_perLayer.append(
                    broadcast_flatten_keyOrValue.flatten(0, 1)
                )
            else:
                raise ValueError(
                    f"Invalid cached key or value shape: ", key_or_value.shape
                )

        expanded_past_key_valyes.append(
            tuple(list_broadcast_pastKeyValues_perLayer)
        )


    # Combine the input mask and choice mask so the model knows which cached input representations
    # are padded when conditioning on the cached input representations.
    # [batch_size x num_choices, max_input_len + max_choice_len]
    combined_mask = torch.cat(
        [expanded_input_mask, tokenized_batch["answer_choices_mask"]], dim=1
    )

    # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a
    # pad token of 0 and so the loss will not be ignored for the pad tokens
    transformer_outputs = transformer(
        input_ids=tokenized_batch["answer_choices_ids"],
        attention_mask=combined_mask,
        past_key_values=expanded_past_key_valyes,
        use_cache=True,
    )

    # We used the logits for all choices to compute the log probs per example since
    # the loss returned in transformer_outputs will average the negative log probs across examples
    # [batch_size x num_choices, max_choice_len, vocab_size]
    answer_choice_ids_logits = transformer_outputs.logits.float()
    vocab_size = answer_choice_ids_logits.shape[-1]

    # Shift the ids, masks, and logits to handle predicting the next token for the decoder. Note that we need to pass in the input_ids and cannot rely on HuggingFace automatically constructing the ids from the labels, since we need to pass in an attention mask to handle the cached input representations.
    shifted_answer_choice_ids_logits = answer_choice_ids_logits[..., :-1, :].contiguous()
    shifted_answer_choice_ids = tokenized_batch["answer_choices_ids"][
        ..., 1:
    ].contiguous()
    shifted_answer_choice_masks = tokenized_batch["answer_choices_mask"][
        ..., 1:
    ].contiguous()

    shifted_answer_choices_max_len = shifted_answer_choice_ids_logits.shape[1]
    vocab_size = shifted_answer_choice_ids_logits.shape[-1]

    # Compute the log probability of the ids for all choices with respect to the logits [batch_size x num_choices x (max_choice_len-1)]
    shifted_answer_choice_ids_log_probs = -F.cross_entropy(
        shifted_answer_choice_ids_logits.view(-1, vocab_size),
        shifted_answer_choice_ids.view(-1),
        reduction="none",
    )


    # [batch_size, num_answer_choices, answer_choices_max_len]
    shifted_answer_choice_ids_log_probs = shifted_answer_choice_ids_log_probs.reshape(
        -1, num_answer_choices, shifted_answer_choices_max_len
    )

    shifted_answer_choices_mask = shifted_answer_choice_masks.reshape(
        -1, num_answer_choices, shifted_answer_choices_max_len
    )

    answer_choice_log_probs = torch.sum(shifted_answer_choice_ids_log_probs * shifted_answer_choices_mask, dim=2)

    _, predicted_choice = torch.max(answer_choice_log_probs, dim=1)

    return predicted_choice, answer_choice_log_probs



def decoder_generate(
    transformer, 
    input_tokenizer,
    target_tokenizer,
    batch,
    max_gen_len
):
    tokenized_batch = decoder_tokenize_batch(input_tokenizer, target_tokenizer, batch, transformer.device)

    generation_output = transformer.generate(
        input_ids=tokenized_batch["input_ids"],
        attention_mask=tokenized_batch["input_mask"],
        max_new_tokens=max_gen_len,
        eos_token_id=input_tokenizer.eos_token_id,
        pad_token_id=input_tokenizer.pad_token_id,
        bos_token_id=input_tokenizer.bos_token_id,
        do_sample=False,
        return_dict_in_generate=True,
    )

    # Remove the original input ids from the generated ids to get just the generated ids 
    input_len = tokenized_batch[f"input_ids"].shape[-1]

    generated_ids = generation_output["sequences"][:, input_len:]

    generated_txt = input_tokenizer.batch_decode(
        generated_ids, skip_special_tokens=True
    )

    return generation_output["sequences"].cpu().numpy().tolist(), generated_txt

def decoder_tokenize_batch(input_tokenizer, target_tokenizer, batch, device):

    tokenized_batch = {}

    keys_to_tokenize_with_tokenizer = [("input", input_tokenizer), ("answer_choices", target_tokenizer), ("target", target_tokenizer)]


    # Tokenize keys which should be tokenized
    for key, tokenizer in keys_to_tokenize_with_tokenizer:
        if key in batch:
            # The values of the batch are normally a list of text.The exception is that for answer_choices, the values  is a list of list. We flatten this to a single list to pass is into the tokenizer 
            if key == "answer_choices":
#                 print(batch[key])
                text = [item for list in batch[key] for item in list]
            else:
                text = batch[key]

        tokenized_dict = tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            truncation="longest_first",
        )

        input_ids = tokenized_dict["input_ids"]
        attention_mask = tokenized_dict["attention_mask"]

        if device is not None:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

        tokenized_batch[f"{key}_ids"] = input_ids
        tokenized_batch[f"{key}_mask"] = attention_mask
    return tokenized_batch

In [3]:
## model/encoder_decoder_functions.py

import torch
import torch.nn.functional as F


def compute_loss(transformer, tokenizer, batch):

    tokenized_batch = encoder_decoder_tokenize_batch(tokenizer, batch, transformer.device)

    transformer_outputs = transformer(
        input_ids=tokenized_batch["input_ids"],
        attention_mask=tokenized_batch["input_mask"],
        labels=tokenized_batch["target_ids"],
    )

    # [batch_size, max_target_len, vocab_size]
    target_logits = transformer_outputs[1].float()
    vocab_size = target_logits.shape[-1]

    # Compute the log probability of the ids for all choices with respect to the logits
    # [batch_size x max_target_len]
    negative_log_probs = F.cross_entropy(
        target_logits.reshape(-1, vocab_size),
        tokenized_batch["target_ids"].reshape(-1),
        reduction="none",
    )

    # Zero out log_probs for target_ids with no loss
    target_mask = tokenized_batch["target_mask"].reshape(-1)
    
    
    sum_negative_log_prob = torch.sum(
        negative_log_probs * target_mask
    )

    loss = sum_negative_log_prob / torch.sum(
            target_mask
        )

    return loss

def encoder_decoder_predict_multiple_choice(
    transformer, tokenizer, batch
):
    tokenized_batch = encoder_decoder_tokenize_batch(tokenizer, batch, transformer.device)
#     print(tokenized_batch)

    encoder_outputs = transformer.get_encoder()(
        tokenized_batch["input_ids"],
        tokenized_batch["input_mask"],
    )

    # The answer_choices is the flattened batch of answer choices. To get the number of answer choices per example, we divide the total number of answer choices in a batch by the batch size. 
    num_answer_choices = (
        tokenized_batch["answer_choices_ids"].shape[0] // tokenized_batch["input_mask"].shape[0]
    )

    '''Expand the input_mask and encoder_outputs since these are the same and can be repeated for the different answer choices within an example 
    '''
    # [batch_size x num_choices, max_input_len]
    expanded_input_mask = torch.repeat_interleave(tokenized_batch["input_mask"], num_answer_choices, dim=0)
    # BaseModelOutput object from HuggingFace where the first element is the hidden states of the encoder at the last layer 
    # [batch_size x num_choices, max_input_len, ff_dim]
    expanded_encoder_outputs = (
        torch.repeat_interleave(encoder_outputs[0], num_answer_choices, dim=0),
    )


    # WARNING: The loss at transformer_outputs[0] is not valid, since answer_choices_ids uses a pad token of 0 (while loss expects a pad token of -100) so the loss will not be ignored for the pad tokens. 
    # The input mask is passed in for the cross encoder-decoder attention.
    transformer_outputs = transformer(
        attention_mask=expanded_input_mask,
        encoder_outputs=expanded_encoder_outputs,
        labels=tokenized_batch["answer_choices_ids"],
    )

    # We used the logits for all choices to compute the log probs per example since the loss returned in transformer_outputs will average the negative log probs across examples
    # [batch_size x num_choices, max_choice_len, vocab_size]
    answer_choice_ids_logits = transformer_outputs[1].float()
    answer_choices_max_len = answer_choice_ids_logits.shape[1]
    vocab_size = answer_choice_ids_logits.shape[-1]

    # Compute the log probability of the ids for all choices with respect to the logits
    # [batch_size x num_choices x max_choice_len]
    answer_choices_ids_log_probs = -F.cross_entropy(
        answer_choice_ids_logits.view(-1, vocab_size),
        tokenized_batch["answer_choices_ids"].view(-1),
        reduction="none",
    )

    # [batch_size, num_answer_choices, answer_choices_max_len]
    answer_choices_ids_log_probs = answer_choices_ids_log_probs.reshape(
        -1, num_answer_choices, answer_choices_max_len
    )

    answer_choices_mask = tokenized_batch["answer_choices_mask"].reshape(
        -1, num_answer_choices, answer_choices_max_len
    )

    answer_choice_log_probs = torch.sum(answer_choices_ids_log_probs * answer_choices_mask, dim=2)

    _, predicted_choice = torch.max(answer_choice_log_probs, dim=1)

    return predicted_choice, answer_choice_log_probs


def encoder_decoder_generate(
    transformer, 
    tokenizer,
    batch,
    max_gen_len
):
    tokenized_batch = encoder_decoder_tokenize_batch(tokenizer, batch, transformer.device)

    generation_output = transformer.generate(
        input_ids=tokenized_batch["input_ids"],
        attention_mask=tokenized_batch["input_mask"],
        max_new_tokens=max_gen_len,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        do_sample=False,
        return_dict_in_generate=True,
    )
    generated_txt = tokenizer.batch_decode(
        generation_output["sequences"], skip_special_tokens=True
    )

    return generation_output["sequences"].cpu().numpy().tolist(), generated_txt

def encoder_decoder_tokenize_batch(tokenizer, batch, device):        

    tokenized_batch = {}

    # encoder decoder models pad to the right 
    tokenizer.padding_side = "right"

    keys_to_tokenize = ["input", "answer_choices", "target"]

    for key in keys_to_tokenize:
        if key in batch:
            # The values of the batch are normally a list of text.The exception is that for answer_choices, the values  is a list of list. We flatten this to a single list to pass is into the tokenizer 
            if key == "answer_choices":
#                 print(batch[key])
                text = [item for list in batch[key] for item in list]
            else:
                text = batch[key]

            tokenized_dict = tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
                truncation="longest_first",
            )

            input_ids = tokenized_dict["input_ids"]
            attention_mask = tokenized_dict["attention_mask"]

            if device is not None:
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)

            tokenized_batch[f"{key}_ids"] = input_ids
            tokenized_batch[f"{key}_mask"] = attention_mask

    return tokenized_batch


In [4]:
## merges/Merges.py
# !pip install peft
import copy
import os

from peft import load_peft_weights, PeftConfig
from safetensors.torch import save_file


from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    AutoTokenizer
)

class Merges(object):

    def __init__(self, name):
        self.name = name

        self.list_models = None
         
        self.loaded_models = None
        self.loaded_configs = None
        self.base_model = None
        self.tokenizer = None
        self.input_tokenizer = None
        self.target_tokenizer = None

        self.base_model_name = None
        self.base_model_revision_id = None

        self.max_seq_len = None
        self.max_gen_len = None

        self.device = None
        self.architecture = None

        self.merged_model = None

    def get_name(self):
        return self.name

    def get_model_config(self):
        raise NotImplementedError

    def _load_base_model(self):
        if self.architecture == "encoder_decoder":
            self.base_model =  AutoModelForSeq2SeqLM.from_pretrained(self.base_model_name, revision=self.base_model_revision_id, token=os.environ["HF_AUTH_TOKEN"]).to(self.device)
        elif self.architecture == "decoder":
            self.base_model =  AutoModelForCausalLM.from_pretrained(self.base_model_name, revision=self.base_model_revision_id, token=os.environ["HF_AUTH_TOKEN"]).to(self.device)
        else:
            raise NotImplementedError(f"Architecture not implemented {self.architecture}")
        

    def _load_tokenizer(self):

        if self.architecture == "encoder_decoder":
            if self.tokenizer is None:

                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.base_model_name,
                    revision=self.base_model_revision_id,
                    model_max_length=self.max_seq_len,
                    legacy=False,
                    token=os.environ["HF_AUTH_TOKEN"]
                )

        elif self.architecture == "decoder":
            if self.input_tokenizer is None or self.target_tokenizer is None:
                    
                self.input_tokenizer = AutoTokenizer.from_pretrained(
                    self.base_model_name,
                    revision=self.base_model_revision_id,
                    model_max_length=self.max_seq_len,
                    legacy=False,
                    token=os.environ["HF_AUTH_TOKEN"]
                )
                self.target_tokenizer = copy.deepcopy(self.input_tokenizer)

                # Use eos_token for pad_token if it doesn't exist. This is ok since the
                # pad tokens will be ignored through the mask
                if self.input_tokenizer.pad_token_id is None:
                    self.input_tokenizer.pad_token_id = self.input_tokenizer.eos_token_id
                if self.target_tokenizer.pad_token_id is None:
                    self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id

                # Add BOS and not EOS token 
                self.input_tokenizer.padding_side = "left"

                # Add EOS and not BOS token 
                self.target_tokenizer.padding_side = "right"
                self.target_tokenizer.add_bos_token = False
                self.target_tokenizer.add_eos_token = True
        else:
            raise NotImplementedError(f"Architecture not implemented {self.architecture}")

    def predict_multiple_choice(self, batch):
        assert self.base_model is not None
        if self.architecture == "encoder_decoder":
            assert self.tokenizer is not None
            return encoder_decoder_predict_multiple_choice(self.base_model, self.tokenizer, batch)
        elif self.architecture == "decoder":
            return decoder_predict_multiple_choice(self.base_model, self.input_tokenizer, self.target_tokenizer, batch)
        else:
            raise NotImplementedError(f"Architecture not implemented {self.architecture}")
    
    def generate(self, batch):
        assert self.base_model is not None
        if self.architecture == "encoder_decoder":
            assert self.tokenizer is not None
            return encoder_decoder_generate(self.base_model, self.tokenizer, batch, self.max_gen_len)
        elif self.architecture == "decoder":
            return decoder_generate(self.base_model, self.input_tokenizer, self.target_tokenizer, batch, self.max_gen_len)
        else:
            raise NotImplementedError(f"Architecture not implemented {self.architecture}")

    def _load_huggingface_models_and_configs(self):
        assert len(self.list_models) > 0, f"List of models must include at leat 1 model"

        parameter_names = None
        for model_name, revision_id in self.list_models:

            peft_model_parameters = load_peft_weights(model_name, revision=revision_id, token=os.environ["HF_AUTH_TOKEN"])
            peft_config = PeftConfig.from_pretrained(model_name)

            if parameter_names is None:
                parameter_names = set(peft_model_parameters.keys())

            if parameter_names != set(peft_model_parameters.keys()):
                print(f"WARNING: parameters in {model_name} do not match {self.list_models[0]}")

            self.loaded_models[model_name] = peft_model_parameters 
            self.loaded_configs[model_name] = peft_config

    def merge(
        self,
    ):
        raise NotImplementedError
    
    def save_model(self, output_dir):
        assert self.merged_model is not None, "Merged model is empty"
        assert len(self.merged_model) > 0, "Merged model is empty"
        # Save merged model as safetensor 
        save_file(self.merged_model, os.path.join(output_dir, "safetensor.pt"))

Collecting peft
  Downloading peft-0.12.0-py3-none-any.whl (296 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m296.4/296.4 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: peft
Successfully installed peft-0.12.0
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
## merges/LlamaAvg.py

import torch 

from peft import get_peft_model, set_peft_model_state_dict

class LlamaAvg(Merges):
    def __init__(self, name):
        super().__init__(name)

        '''
        These values are meant to be modified by the user.
        '''
        # Give a list of models to load for the merge. Each element is the list a is a tuple of (model, revision_id). We recommend specifying a revision id to ensure the model was not modified after May 31 
        self.list_models = [("abcdabcd987/gsm8k-llama2-7b-lora-16", "636b5eb8da724edae406ba69ef90fd06478e6df7"), 
                            ("FinGPT/fingpt-forecaster_dow30_llama2-7b_lora", "69f77190315afdb03a889d89bf2a0f932b311617")]

        # Hyperparameters 
        self.base_model_name = "meta-llama/Llama-2-7b-hf"
        # We recommend specifying a revision id to ensure the model was not modified after May 31 
        self.base_model_revision_id = "01c7f73d771dfac7d292323805ebc428287df4f9"


        self.max_seq_len = None
        self.max_gen_len = 64
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Architecture must match base model. 
        self.architecture = "decoder"
        '''
        These are variables used later in the code and not intended to be set, but feel free to adapt to your use case.  
        '''
        # Loaded models and configs 
        self.loaded_models = {}
        self.loaded_configs = {}

        # Merged model parameters
        self.merged_model = {}



    # Implement merge function 
    def merge(
        self,
    ):

        '''
        1) Load HuggingFace checkpoints and configs 
        '''
        super()._load_huggingface_models_and_configs()
        '''
        2) Merge checkpoints  
        '''
        parameter_lambdas = [0.8, 0.2]

        # Get individual models 
        all_models = list(self.loaded_models.values())

        # Get all the parameters names (uses the first model and assume all the models have the same parameter)
        all_parameter_names = all_models[0].keys()

        for parameter_name in all_parameter_names:
            merged_parameter = None
            for parameter_lambda, model in zip(parameter_lambdas, all_models):
                parameter = model[parameter_name]
                if merged_parameter is None:
                    merged_parameter = torch.clone(parameter) * parameter_lambda
                else:
                    # first model has rank 16 and second model has rank 8, so we expand the second model to rank 16 by adding zeros
                    if "A" in parameter_name:
                        parameter = torch.cat([torch.zeros_like(parameter), parameter], dim=0)
                    else:
                        assert "B" in parameter_name
                        parameter = torch.cat([torch.zeros_like(parameter), parameter], dim=1)
                    merged_parameter += parameter * parameter_lambda
            self.merged_model[parameter_name] = merged_parameter

        '''
        3) Load base model and tokenizer
        '''
        self._load_base_model()
        self._load_tokenizer()

        '''
        4) Load merged model into base model 
        '''
        # Modify the base model. This is needed for Peft, which wraps the base_model in a Peft wrapper. 
        huggingface_config = list(self.loaded_configs.values())[0]
        if huggingface_config is not None:
            self.base_model = get_peft_model(self.base_model, huggingface_config)
            set_peft_model_state_dict(self.base_model, self.merged_model)
        
        else:
            self.base_model.load(self.merged_model)

        # Requires to make results deterministic. If not set, we will just run once and use the results from the first pass. 
        self.base_model.eval()

        return self.base_model

In [6]:
## merges/FlanT5Avg.py

import torch 

from peft import get_peft_model, set_peft_model_state_dict

class FlanT5Avg(Merges):
    def __init__(self, name):
        super().__init__(name)


        '''
        These values are meant to be modified by the user.
        '''
        # Give a list of models to load for the merge 
        self.list_models = [("lorahub/flan_t5_xl-wiki_qa_Is_This_True_", "30a1ee2f857196c1eb996d854548cc19f45ac642"), 
                            ("lorahub/flan_t5_xl-kilt_tasks_hotpotqa_complex_question", "27d014366bec1c5333ba2e2fae966b7de3c02df1")]
        
        # Hyperparameters 
        self.base_model_name = "google/flan-t5-xl"
        self.base_model_revision_id = "7d6315df2c2fb742f0f5b556879d730926ca9001"
        self.max_seq_len = 512
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Architecture must match base model. 
        self.architecture = "encoder_decoder"

        '''
        These are variables used later in the code and not intended to be set, but feel free to adapt to your use case.  
        '''
        # Loaded models and configs 
        self.loaded_models = {}
        self.loaded_configs = {}

        # Merged model parameters
        self.merged_model = {}

    # Implement merge function 
    def merge(
        self,
    ):

        '''
        1) Load HuggingFace checkpoints and configs 
        '''
        super()._load_huggingface_models_and_configs()

        '''
        2) Merge checkpoints  
        '''
        parameter_lambdas = [0.5, 0.5]

        # Get individual models 
        all_models = list(self.loaded_models.values())

        # Get all the parameters names (uses the first model and assume all the models have the same parameter)
        all_parameter_names = all_models[0].keys()

        for parameter_name in all_parameter_names:
            merged_parameter = None
            for parameter_lambda, model in zip(parameter_lambdas, all_models):
                parameter = model[parameter_name]
                if merged_parameter is None:
                    merged_parameter = torch.clone(parameter) * parameter_lambda
                else:
                    merged_parameter += parameter * parameter_lambda
            self.merged_model[parameter_name] = merged_parameter
        '''
        3) Load base model and tokenizer 
        '''
        self._load_base_model()
        self._load_tokenizer()

        '''
        4) Load merged model into base model 
        '''
        # Modify the base model. This is needed for Peft, which wraps the base_model in a Peft wrapper. 
        huggingface_config = list(self.loaded_configs.values())[0]
        if huggingface_config is not None:
            self.base_model = get_peft_model(self.base_model, huggingface_config)
            set_peft_model_state_dict(self.base_model, self.merged_model)
        
        else:
            self.base_model.load(self.merged_model)

        # Requires to make results deterministic. If not set, we will just run once and use the results from the first pass. 
        self.base_model.eval()

        return self.base_model

In [None]:
## merges/TinyLlamaAvg.py

import torch 

# from llm_merging.merging.Merges import Merges
from peft import get_peft_model, set_peft_model_state_dict

class TinyLlamaAvg(Merges):
    def __init__(self, name):
        super().__init__(name)

        '''
        These values are meant to be modified by the user.
        '''
            # Give a list of models to load for the merge. Each element is the list a is a tuple of (model, revision_id). We recommend specifying a revision id to ensure the model was not modified after May 31
        self.list_models = [("TinyLlama/TinyLlama_v1.1", "f67f7cf6a907e567552b946699a9b9b45394fc46"),
                            ("TinyLlama/TinyLlama_v1.1_math_code", "36978c95f61ba8078250f04d71b5404fa9733614")]

        # Hyperparameters 
        self.base_model_name = "TinyLlama/TinyLlama_v1.1"
        # We recommend specifying a revision id to ensure the model was not modified after May 31 
        self.base_model_revision_id = "f67f7cf6a907e567552b946699a9b9b45394fc46"
        self.is_peft = False

        self.max_seq_len = None
        self.max_gen_len = 64
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Architecture must match base model. 
        self.architecture = "decoder"
        '''
        These are variables used later in the code and not intended to be set, but feel free to adapt to your use case.  
        '''
        # Loaded models and configs 
        self.loaded_models = {}
        self.loaded_configs = {}

        # Merged model parameters
        self.merged_model = {}



    # Implement merge function 
    def merge(
        self,
    ):

        '''
        1) Load HuggingFace checkpoints and configs 
        '''
        super()._load_huggingface_models_and_configs()
        '''
        2) Merge checkpoints  
        '''
        parameter_lambdas = [0.8, 0.2]

        # Get individual models 
        all_models = list(self.loaded_models.values())

        # Get all the parameters names (uses the first model and assume all the models have the same parameter)
        all_parameter_names = all_models[0].keys()

        for parameter_name in all_parameter_names:
            merged_parameter = None
            for parameter_lambda, model in zip(parameter_lambdas, all_models):
                parameter = model[parameter_name]
                if merged_parameter is None:
                    merged_parameter = torch.clone(parameter) * parameter_lambda
                else:
                    merged_parameter += parameter * parameter_lambda
            self.merged_model[parameter_name] = merged_parameter

        '''
        3) Load base model and tokenizer
        '''
        self._load_base_model()
        self._load_tokenizer()

        '''
        4) Load merged model into base model 
        '''
        self.base_model.load_state_dict(self.merged_model)

        # Requires to make results deterministic. If not set, we will just run once and use the results from the first pass. 
        self.base_model.eval()

        return self.base_model

In [8]:
## evaluate.py
# !pip install evaluate

import evaluate 
import json 
import os 
import pandas as pd

from typing import List, Dict, Any

import torch
from tqdm import tqdm
from torch.utils import data


def convert_dict_of_lists_to_list_of_dicts(dict_of_lists: Dict[Any, List]) -> List[Dict]:
    """
    Args:
        dict_of_lists:

    Returns:
        list_ofDict
    """
    list_of_dicts = []
    for datapoint_values in zip(*dict_of_lists.values()):
        list_of_dicts.append(dict(zip(dict_of_lists, datapoint_values)))
    return list_of_dicts

def collate_fn(batch_of_datapoints: List[Dict]) -> Dict[Any, List]:
    """
    Convert a batch of datapoints into a datapoint that is batched. This is meant to override the default collate function in pytorch and specifically can handle when the value is a list 

    Args:
        batch_ofDatapoints:

    Returns:

    """
    datapoint_batched = {}
    for datapoint in batch_of_datapoints:
        # Gather together all the values per key
        for key, value in datapoint.items():
            if key in datapoint_batched:
                datapoint_batched[key].append(value)
            else:
                datapoint_batched[key] = [value]
    return datapoint_batched


def evaluate_dataset(
    merge_method,
    dataset_filepath: str,
) -> (Dict, List):

    data_loader = data.DataLoader(
        Dataset(dataset_filepath),
        batch_size=1,
        num_workers=0,
        shuffle=False,
        collate_fn=collate_fn
    )

    all_batches = []

    with torch.no_grad():
        for batch in tqdm(data_loader):
            # There are two types of evaluation models:
            # 1) multiple choice where the model scores each choice and predicts the choice with the highest score 
            # 2) generation where the model generate some output give some input 
            eval_type = batch["eval_type"][0]
            
            if eval_type == "multiple_choice":
                (
                    predicted_choice,
                    answer_choice_scores,
                ) = merge_method.predict_multiple_choice(batch)

                batch["prediction"] = str(predicted_choice.cpu().numpy().tolist()[0])
                all_batches.extend(convert_dict_of_lists_to_list_of_dicts(batch))
            
            else:
                assert eval_type == "generation"
                (
                    generated_ids, generated_txt
                ) = merge_method.generate(batch
                )
                batch["prediction"] = generated_txt 
                all_batches.extend(convert_dict_of_lists_to_list_of_dicts(batch))

    return all_batches

def evaluate_model(
    merge_method,
    all_dataset_filepaths: List[str],
) -> Dict:   
    output_dir = os.path.join("output", merge_method.get_name())
    prediction_dir = os.path.join(output_dir, "predictions")
    os.makedirs(prediction_dir, exist_ok=True)
    # Save merged model 
    merge_method.save_model(output_dir)

    all_scores = {}

    for dataset_filepath in all_dataset_filepaths:
        dataset_predictions = evaluate_dataset(merge_method, dataset_filepath)
        dp_df = pd.DataFrame(dataset_predictions)
        dp_df["dummy_field"] = 0
        dp_df.to_csv("../submission.csv", columns=["id", "prediction", "dummy_field"], index=False)

Collecting evaluate
  Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: evaluate
Successfully installed evaluate-0.4.2
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


E0000 00:00:1725599853.674774      13 common_lib.cc:798] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:479
E0906 05:17:33.706609309     256 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:"2024-09-06T05:17:33.706593012+00:00"}


In [None]:
## main.py

def all_merge_handlers():
    """Enumerate and Load (import) all merge methods."""
    loaded_merges = {
        "llama_avg": LlamaAvg,
        "tiny_llama_avg": TinyLlamaAvg,
        "flant5_avg": FlanT5Avg,
    }
    
    
    return loaded_merges

# Load correct merging method 
merging_method = "llama_avg"
dataset_filepaths = ["../data/validation.csv"]
os.environ["HF_AUTH_TOKEN"] = "" # TODO
loaded_merges = all_merge_handlers()
merge_method = loaded_merges[merging_method](merging_method)

# Call the merge function. The merged model is stored under merging_method object 
merge_method.merge()

# Evaluate method on datsets passed in (used for testing)
evaluate_model(
    merge_method,
    dataset_filepaths,
)



Downloading shards: 100%|██████████| 2/2 [01:03<00:00, 31.62s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.82s/it]
  0%|          | 0/807 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
  0%|          | 2/807 [02:00<13:33:32, 60.64s/it]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
  3%|▎         | 26/807 [14:28<8:01:30, 36.99s/it] 