In [1]:
## dataset.py

import json 
import pandas as pd

class Dataset(object):
    def __init__(
        self,
        dataset_filepath: str,
    ):
        self.dataset = []
        self.dataset = pd.read_csv(dataset_filepath).to_dict('records')
        for dp in self.dataset:
            if not dp['answer_choices'] or dp['answer_choices'] != dp['answer_choices']:
                del dp['answer_choices']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]
    



In [2]:
## model/decoder_functions.py

import torch
import torch.nn.functional as F

def decoder_predict_multiple_choice(
    transformer, input_tokenizer, target_tokenizer, batch
): 
    tokenized_batch = decoder_tokenize_batch(input_tokenizer, target_tokenizer, batch, transformer.device)
    output = transformer(
        input_ids=tokenized_batch["input_ids"],
        attention_mask=tokenized_batch["input_mask"],
        use_cache=True,
    )
    past_key_values = output.past_key_values

    num_answer_choices = (
        tokenized_batch["answer_choices_ids"].shape[0]
        // tokenized_batch["input_mask"].shape[0]
    )

    '''
    Expand the input_mask and past_key_values since these are the same and can be repeated for the different answer choices within an example 
    '''

    batch_size, max_input_len = tokenized_batch["input_mask"].shape
    expanded_input_mask = torch.repeat_interleave(tokenized_batch["input_mask"], num_answer_choices, dim=0)

    expanded_past_key_valyes = []
    for pastKeyValues_perLayer in past_key_values:
        list_broadcast_pastKeyValues_perLayer = []
        for key_or_value in pastKeyValues_perLayer:
            # This is for keys or values which have dimension [batch_size, max_input_len, num_heads, head_dim]
            # This is the standard for Hugging Face.
            if len(key_or_value.shape) == 4:
                list_broadcast_pastKeyValues_perLayer.append(
                    torch.repeat_interleave(key_or_value, num_answer_choices, dim=0)
                )
            # This is for keys or values which have dimension [batch_size x num_heads, head_dim, max_input_len].
            # This is what is used for BLOOM in transformers == 4.22.0
            elif len(key_or_value.shape) == 3:
                num_heads = key_or_value.shape[0] // batch_size
                flatten_keyOrValue = key_or_value.reshape(
                    ((batch_size, num_heads) + key_or_value.shape[1:])
                )
                broadcast_flatten_keyOrValue = torch.repeat_interleave(
                    flatten_keyOrValue, num_answer_choices, dim=0
                )
                list_broadcast_pastKeyValues_perLayer.append(
                    broadcast_flatten_keyOrValue.flatten(0, 1)
                )
            else:
                raise ValueError(
                    f"Invalid cached key or value shape: ", key_or_value.shape
                )

        expanded_past_key_valyes.append(
            tuple(list_broadcast_pastKeyValues_perLayer)
        )


    # Combine the input mask and choice mask so the model knows which cached input representations
    # are padded when conditioning on the cached input representations.
    # [batch_size x num_choices, max_input_len + max_choice_len]
    combined_mask = torch.cat(
        [expanded_input_mask, tokenized_batch["answer_choices_mask"]], dim=1
    )

    # WARNING: The loss at transformer_outputs[0] is not valid, since allChoices_ids uses a
    # pad token of 0 and so the loss will not be ignored for the pad tokens
    transformer_outputs = transformer(
        input_ids=tokenized_batch["answer_choices_ids"],
        attention_mask=combined_mask,
        past_key_values=expanded_past_key_valyes,
        use_cache=True,
    )

    # We used the logits for all choices to compute the log probs per example since
    # the loss returned in transformer_outputs will average the negative log probs across examples
    # [batch_size x num_choices, max_choice_len, vocab_size]
    answer_choice_ids_logits = transformer_outputs.logits.float()
    vocab_size = answer_choice_ids_logits.shape[-1]

    # Shift the ids, masks, and logits to handle predicting the next token for the decoder. Note that we need to pass in the input_ids and cannot rely on HuggingFace automatically constructing the ids from the labels, since we need to pass in an attention mask to handle the cached input representations.
    shifted_answer_choice_ids_logits = answer_choice_ids_logits[..., :-1, :].contiguous()
    shifted_answer_choice_ids = tokenized_batch["answer_choices_ids"][
        ..., 1:
    ].contiguous()
    shifted_answer_choice_masks = tokenized_batch["answer_choices_mask"][
        ..., 1:
    ].contiguous()

    shifted_answer_choices_max_len = shifted_answer_choice_ids_logits.shape[1]
    vocab_size = shifted_answer_choice_ids_logits.shape[-1]

    # Compute the log probability of the ids for all choices with respect to the logits [batch_size x num_choices x (max_choice_len-1)]
    shifted_answer_choice_ids_log_probs = -F.cross_entropy(
        shifted_answer_choice_ids_logits.view(-1, vocab_size),
        shifted_answer_choice_ids.view(-1),
        reduction="none",
    )


    # [batch_size, num_answer_choices, answer_choices_max_len]
    shifted_answer_choice_ids_log_probs = shifted_answer_choice_ids_log_probs.reshape(
        -1, num_answer_choices, shifted_answer_choices_max_len
    )

    shifted_answer_choices_mask = shifted_answer_choice_masks.reshape(
        -1, num_answer_choices, shifted_answer_choices_max_len
    )

    answer_choice_log_probs = torch.sum(shifted_answer_choice_ids_log_probs * shifted_answer_choices_mask, dim=2)

    _, predicted_choice = torch.max(answer_choice_log_probs, dim=1)

    return predicted_choice, answer_choice_log_probs



def decoder_generate(
    transformer, 
    input_tokenizer,
    target_tokenizer,
    batch,
    max_gen_len
):
    tokenized_batch = decoder_tokenize_batch(input_tokenizer, target_tokenizer, batch, transformer.device)

    generation_output = transformer.generate(
        input_ids=tokenized_batch["input_ids"],
        attention_mask=tokenized_batch["input_mask"],
        max_new_tokens=max_gen_len,
        eos_token_id=input_tokenizer.eos_token_id,
        pad_token_id=input_tokenizer.pad_token_id,
        bos_token_id=input_tokenizer.bos_token_id,
        do_sample=False,
        return_dict_in_generate=True,
    )

    # Remove the original input ids from the generated ids to get just the generated ids 
    input_len = tokenized_batch[f"input_ids"].shape[-1]

    generated_ids = generation_output["sequences"][:, input_len:]

    generated_txt = input_tokenizer.batch_decode(
        generated_ids, skip_special_tokens=True
    )

    return generation_output["sequences"].cpu().numpy().tolist(), generated_txt

def decoder_tokenize_batch(input_tokenizer, target_tokenizer, batch, device):

    tokenized_batch = {}

    keys_to_tokenize_with_tokenizer = [("input", input_tokenizer), ("answer_choices", target_tokenizer), ("target", target_tokenizer)]


    # Tokenize keys which should be tokenized
    for key, tokenizer in keys_to_tokenize_with_tokenizer:
        if key in batch:
            # The values of the batch are normally a list of text.The exception is that for answer_choices, the values  is a list of list. We flatten this to a single list to pass is into the tokenizer 
            if key == "answer_choices":
#                 print(batch[key])
                text = [item for list in batch[key] for item in list]
            else:
                text = batch[key]

        tokenized_dict = tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            truncation="longest_first",
        )

        input_ids = tokenized_dict["input_ids"]
        attention_mask = tokenized_dict["attention_mask"]

        if device is not None:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

        tokenized_batch[f"{key}_ids"] = input_ids
        tokenized_batch[f"{key}_mask"] = attention_mask
    return tokenized_batch

In [3]:
## model/encoder_decoder_functions.py

import torch
import torch.nn.functional as F


def compute_loss(transformer, tokenizer, batch):

    tokenized_batch = encoder_decoder_tokenize_batch(tokenizer, batch, transformer.device)

    transformer_outputs = transformer(
        input_ids=tokenized_batch["input_ids"],
        attention_mask=tokenized_batch["input_mask"],
        labels=tokenized_batch["target_ids"],
    )

    # [batch_size, max_target_len, vocab_size]
    target_logits = transformer_outputs[1].float()
    vocab_size = target_logits.shape[-1]

    # Compute the log probability of the ids for all choices with respect to the logits
    # [batch_size x max_target_len]
    negative_log_probs = F.cross_entropy(
        target_logits.reshape(-1, vocab_size),
        tokenized_batch["target_ids"].reshape(-1),
        reduction="none",
    )

    # Zero out log_probs for target_ids with no loss
    target_mask = tokenized_batch["target_mask"].reshape(-1)
    
    
    sum_negative_log_prob = torch.sum(
        negative_log_probs * target_mask
    )

    loss = sum_negative_log_prob / torch.sum(
            target_mask
        )

    return loss

def encoder_decoder_predict_multiple_choice(
    transformer, tokenizer, batch
):
    tokenized_batch = encoder_decoder_tokenize_batch(tokenizer, batch, transformer.device)
#     print(tokenized_batch)

    encoder_outputs = transformer.get_encoder()(
        tokenized_batch["input_ids"],
        tokenized_batch["input_mask"],
    )

    # The answer_choices is the flattened batch of answer choices. To get the number of answer choices per example, we divide the total number of answer choices in a batch by the batch size. 
    num_answer_choices = (
        tokenized_batch["answer_choices_ids"].shape[0] // tokenized_batch["input_mask"].shape[0]
    )

    '''Expand the input_mask and encoder_outputs since these are the same and can be repeated for the different answer choices within an example 
    '''
    # [batch_size x num_choices, max_input_len]
    expanded_input_mask = torch.repeat_interleave(tokenized_batch["input_mask"], num_answer_choices, dim=0)
    # BaseModelOutput object from HuggingFace where the first element is the hidden states of the encoder at the last layer 
    # [batch_size x num_choices, max_input_len, ff_dim]
    expanded_encoder_outputs = (
        torch.repeat_interleave(encoder_outputs[0], num_answer_choices, dim=0),
    )


    # WARNING: The loss at transformer_outputs[0] is not valid, since answer_choices_ids uses a pad token of 0 (while loss expects a pad token of -100) so the loss will not be ignored for the pad tokens. 
    # The input mask is passed in for the cross encoder-decoder attention.
    transformer_outputs = transformer(
        attention_mask=expanded_input_mask,
        encoder_outputs=expanded_encoder_outputs,
        labels=tokenized_batch["answer_choices_ids"],
    )

    # We used the logits for all choices to compute the log probs per example since the loss returned in transformer_outputs will average the negative log probs across examples
    # [batch_size x num_choices, max_choice_len, vocab_size]
    answer_choice_ids_logits = transformer_outputs[1].float()
    answer_choices_max_len = answer_choice_ids_logits.shape[1]
    vocab_size = answer_choice_ids_logits.shape[-1]

    # Compute the log probability of the ids for all choices with respect to the logits
    # [batch_size x num_choices x max_choice_len]
    answer_choices_ids_log_probs = -F.cross_entropy(
        answer_choice_ids_logits.view(-1, vocab_size),
        tokenized_batch["answer_choices_ids"].view(-1),
        reduction="none",
    )

    # [batch_size, num_answer_choices, answer_choices_max_len]
    answer_choices_ids_log_probs = answer_choices_ids_log_probs.reshape(
        -1, num_answer_choices, answer_choices_max_len
    )

    answer_choices_mask = tokenized_batch["answer_choices_mask"].reshape(
        -1, num_answer_choices, answer_choices_max_len
    )

    answer_choice_log_probs = torch.sum(answer_choices_ids_log_probs * answer_choices_mask, dim=2)

    _, predicted_choice = torch.max(answer_choice_log_probs, dim=1)

    return predicted_choice, answer_choice_log_probs


def encoder_decoder_generate(
    transformer, 
    tokenizer,
    batch,
    max_gen_len
):
    tokenized_batch = encoder_decoder_tokenize_batch(tokenizer, batch, transformer.device)

    generation_output = transformer.generate(
        input_ids=tokenized_batch["input_ids"],
        attention_mask=tokenized_batch["input_mask"],
        max_new_tokens=max_gen_len,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        do_sample=False,
        return_dict_in_generate=True,
    )
    generated_txt = tokenizer.batch_decode(
        generation_output["sequences"], skip_special_tokens=True
    )

    return generation_output["sequences"].cpu().numpy().tolist(), generated_txt

def encoder_decoder_tokenize_batch(tokenizer, batch, device):        

    tokenized_batch = {}

    # encoder decoder models pad to the right 
    tokenizer.padding_side = "right"

    keys_to_tokenize = ["input", "answer_choices", "target"]

    for key in keys_to_tokenize:
        if key in batch:
            # The values of the batch are normally a list of text.The exception is that for answer_choices, the values  is a list of list. We flatten this to a single list to pass is into the tokenizer 
            if key == "answer_choices":
#                 print(batch[key])
                text = [item for list in batch[key] for item in list]
            else:
                text = batch[key]

            tokenized_dict = tokenizer(
                text,
                return_tensors="pt",
                padding="longest",
                truncation="longest_first",
            )

            input_ids = tokenized_dict["input_ids"]
            attention_mask = tokenized_dict["attention_mask"]

            if device is not None:
                input_ids = input_ids.to(device)
                attention_mask = attention_mask.to(device)

            tokenized_batch[f"{key}_ids"] = input_ids
            tokenized_batch[f"{key}_mask"] = attention_mask

    return tokenized_batch


In [4]:
## merges/Merges.py
! pip install peft
import copy
import os

from peft import load_peft_weights, PeftConfig
from safetensors.torch import save_file


from transformers import (
    AutoModelForSeq2SeqLM,
    AutoModelForCausalLM,
    AutoTokenizer
)

class Merges(object):

    def __init__(self, name):
        self.name = name

        self.list_models = None
         
        self.loaded_models = None
        self.loaded_configs = None
        self.base_model = None
        self.tokenizer = None
        self.input_tokenizer = None
        self.target_tokenizer = None

        self.base_model_name = None
        self.base_model_revision_id = None

        self.max_seq_len = None
        self.max_gen_len = None

        self.device = None
        self.architecture = None

        self.merged_model = None

    def get_name(self):
        return self.name

    def get_model_config(self):
        raise NotImplementedError

    def _load_base_model(self):
        if self.architecture == "encoder_decoder":
            self.base_model =  AutoModelForSeq2SeqLM.from_pretrained(self.base_model_name, revision=self.base_model_revision_id, token=os.environ["HF_AUTH_TOKEN"]).to(self.device)
        elif self.architecture == "decoder":
            self.base_model =  AutoModelForCausalLM.from_pretrained(self.base_model_name, revision=self.base_model_revision_id, token=os.environ["HF_AUTH_TOKEN"]).to(self.device)
        else:
            raise NotImplementedError(f"Architecture not implemented {self.architecture}")
        

    def _load_tokenizer(self):

        if self.architecture == "encoder_decoder":
            if self.tokenizer is None:

                self.tokenizer = AutoTokenizer.from_pretrained(
                    self.base_model_name,
                    revision=self.base_model_revision_id,
                    model_max_length=self.max_seq_len,
                    legacy=False,
                    token=os.environ["HF_AUTH_TOKEN"]
                )

        elif self.architecture == "decoder":
            if self.input_tokenizer is None or self.target_tokenizer is None:
                    
                self.input_tokenizer = AutoTokenizer.from_pretrained(
                    self.base_model_name,
                    revision=self.base_model_revision_id,
                    model_max_length=self.max_seq_len,
                    legacy=False,
                    token=os.environ["HF_AUTH_TOKEN"]
                )
                self.target_tokenizer = copy.deepcopy(self.input_tokenizer)

                # Use eos_token for pad_token if it doesn't exist. This is ok since the
                # pad tokens will be ignored through the mask
                if self.input_tokenizer.pad_token_id is None:
                    self.input_tokenizer.pad_token_id = self.input_tokenizer.eos_token_id
                if self.target_tokenizer.pad_token_id is None:
                    self.target_tokenizer.pad_token_id = self.target_tokenizer.eos_token_id

                # Add BOS and not EOS token 
                self.input_tokenizer.padding_side = "left"

                # Add EOS and not BOS token 
                self.target_tokenizer.padding_side = "right"
                self.target_tokenizer.add_bos_token = False
                self.target_tokenizer.add_eos_token = True
        else:
            raise NotImplementedError(f"Architecture not implemented {self.architecture}")

    def predict_multiple_choice(self, batch):
        assert self.base_model is not None
        if self.architecture == "encoder_decoder":
            assert self.tokenizer is not None
            return encoder_decoder_predict_multiple_choice(self.base_model, self.tokenizer, batch)
        elif self.architecture == "decoder":
            return decoder_predict_multiple_choice(self.base_model, self.input_tokenizer, self.target_tokenizer, batch)
        else:
            raise NotImplementedError(f"Architecture not implemented {self.architecture}")
    
    def generate(self, batch):
        assert self.base_model is not None
        if self.architecture == "encoder_decoder":
            assert self.tokenizer is not None
            return encoder_decoder_generate(self.base_model, self.tokenizer, batch, self.max_gen_len)
        elif self.architecture == "decoder":
            return decoder_generate(self.base_model, self.input_tokenizer, self.target_tokenizer, batch, self.max_gen_len)
        else:
            raise NotImplementedError(f"Architecture not implemented {self.architecture}")

    def _load_huggingface_models_and_configs(self):
        assert len(self.list_models) > 0, f"List of models must include at leat 1 model"

        parameter_names = None
        for model_name, revision_id in self.list_models:

            peft_model_parameters = load_peft_weights(model_name, revision=revision_id, token=os.environ["HF_AUTH_TOKEN"])
            peft_config = PeftConfig.from_pretrained(model_name)

            if parameter_names is None:
                parameter_names = set(peft_model_parameters.keys())

            if parameter_names != set(peft_model_parameters.keys()):
                print(f"WARNING: parameters in {model_name} do not match {self.list_models[0]}")

            self.loaded_models[model_name] = peft_model_parameters 
            self.loaded_configs[model_name] = peft_config

    def merge(
        self,
    ):
        raise NotImplementedError
    
    def save_model(self, output_dir):
        assert self.merged_model is not None, "Merged model is empty"
        assert len(self.merged_model) > 0, "Merged model is empty"
        # Save merged model as safetensor 
        torch.save(self.merged_model, os.path.join(output_dir, "safetensor.pt"))



In [5]:
import torch
from typing import List, Tuple, Dict, Any
from peft import get_peft_model, set_peft_model_state_dict

class FlanT5GeoMed(Merges):
    def __init__(self, name: str):  
        super().__init__(name)

        # List of models to load for the merge
        self.list_models: List[Tuple[str, str]] = [
            # Text Classification
            ("lorahub/flan_t5_xl-dbpedia_14_given_list_what_category_does_the_paragraph_belong_to", "883db61b41a3a9e8716f5391d782f653fd9d693b"),
            ("lorahub/flan_t5_xl-wiki_qa_Topic_Prediction_Question_Only", "cc024699f37aee24e72cd28a596dbf3451a93484"),

            # Question Answering
            ("lorahub/flan_t5_xl-anli_r2", "ea7872e79fddc6e9df57b88c429bdb283b414bea"),
            ("lorahub/flan_t5_xl-web_questions_question_answer", "37701f6f673974308517151387182f42271a2eab"),
            ("lorahub/flan_t5_xl-duorc_SelfRC_question_answering", "b56b5b0b72a0a4b90b120833ff466aa7ef85dd84"),
            ("lorahub/flan_t5_xl-adversarial_qa_dbert_question_context_answer", "a935c63c0c7deaca77f437efd3425192a88dd90e"),

            # Text Generation
            ("lorahub/flan_t5_xl-gem_e2e_nlg", "04e25c5739d151e42916b262cb0ee900aa854816"),

            # Text2Text Generation
            ("lorahub/flan_t5_xl-wiki_hop_original_explain_relation", "d6bdec80c60d55db0b7125f8ca0d02871ab3ab34"),
            ("lorahub/flan_t5_xl-duorc_SelfRC_title_generation", "17653e0c744bb1453f93b816d1eb140d991be6a4"),

            # Sentence Similarity
            ("lorahub/flan_t5_xl-glue_mrpc", "292a6f0c2dec34a9faa143b37dc734eee14c860a"),
            ("lorahub/flan_t5_xl-glue_cola", "7fef5d273d145e26b07762b43abcbaa83874dc23"),

            # Comprehension and Understanding Tasks
            ("lorahub/flan_t5_xl-wiki_bio_comprehension", "9d06f885dbbbe69327203b299193873ea281522c"),
            ("lorahub/flan_t5_xl-wiki_bio_key_content", "f98ee1718a9ce23446671023a60fb05a57f5e9d3"),
            ("lorahub/flan_t5_xl-wiki_bio_guess_person", "e8998f9f0fad7aef94408c4741e7fbe2ff11f79d"),
            ("lorahub/flan_t5_xl-wiki_bio_who", "c081565f0d3e3aa251fa9d44fc6678d70cc9e20f"),

            # Search and Retrieval
            ("lorahub/flan_t5_xl-wiki_qa_found_on_google", "cb5c59ee688f22e0314968e2a0c1bee692e66c27"),

            # Natural Language Generation
            ("lorahub/flan_t5_xl-gem_web_nlg_en", "8043f44956456dffb6cc5e07bc59bffdf618ac97"),

            # Paraphrasing and Extraction
            ("lorahub/flan_t5_xl-duorc_ParaphraseRC_extract_answer", "c008dacf47c7836a0bcd2d4c47cd27923d2cda1e"),
            ("lorahub/flan_t5_xl-duorc_SelfRC_extract_answer", "377a71b7c71099688c836d7417eb9cfc0c33f6b5"),

            # Process Understanding
            ("lorahub/flan_t5_xl-wiqa_what_might_be_the_last_step_of_the_process", "fea37d25cf4eb8d81a85fc3296e7781fc8ea10db"),
        ]
        
        # Hyperparameters
        self.base_model_name: str = "google/flan-t5-xl" #base model
        self.base_model_revision_id: str = "7d6315df2c2fb742f0f5b556879d730926ca9001"
        self.is_peft: bool = True
        self.max_seq_len: int = 512
        self.device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
        self.architecture: str = "encoder_decoder"

        self.loaded_models: Dict[str, Dict[str, torch.Tensor]] = {}
        self.loaded_configs: Dict[str, Any] = {}
        self.merged_model: Dict[str, torch.Tensor] = {}

    def merge(self):
        ''' Load HuggingFace checkpoints and configs '''
        super()._load_huggingface_models_and_configs()

        ''' Load base model and tokenizer '''
        self._load_base_model()
        self._load_tokenizer()

        if not self.loaded_models:
            raise ValueError("No models loaded for merging.")

        base_model_params = self.base_model.state_dict()
        
        base_flattened_vector = torch.cat([param.to(self.device).view(-1) for param in base_model_params.values()])

        all_models = list(self.loaded_models.values())
        merged_model_dict = self.process_models(base_model_params, all_models)

        self.merged_model = merged_model_dict
            
        assert len(self.merged_model) > 0, "Merged model is empty"

        self.base_model.load_state_dict(self.merged_model, strict=False)  # Load parameters from the merged model dict
        self.base_model.eval()  # Set to evaluation mode
        return self.base_model

    def process_models(self, base_model_params: Dict[str, torch.Tensor], all_models: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        for param_name, param in base_model_params.items():  
            # Skip ".k." and ".o." parameters
            if any(substring in param_name for substring in [".q.", ".v."]):
                task_vector, original_shape = self.process_task_vectors(param_name, param, all_models)
                if task_vector:
                    result_for_block = self.compute_geometric_median(task_vector)
                    result_for_block_rescaled = result_for_block.view(original_shape)
                    base_model_params[param_name] += result_for_block_rescaled
            else:
                continue

        return base_model_params

    def process_task_vectors(self, param_name: str, param: torch.Tensor, all_models: List[Dict[str, torch.Tensor]]) -> Tuple[List[torch.Tensor], torch.Size]:
        task_vector = []
        original_shape = None

        for model in all_models:
            vector_A, vector_B = None, None
            stir_q, stir_v = "", ""
            
            for fine_tuned_param_name, fine_tuned_param in model.items():
                index_q = param_name.find(".q.")
                index_v = param_name.find(".v.")

                if index_q != -1:
                    stir_q = param_name[:index_q+3]
                    stir_v = ""
                if index_v != -1:
                    stir_v = param_name[:index_v+3]
                    stir_q = ""

                if ".q." in fine_tuned_param_name and stir_q in fine_tuned_param_name:
                    if "lora_A." in fine_tuned_param_name:
                        vector_A = fine_tuned_param
                    elif "lora_B." in fine_tuned_param_name:
                        vector_B = fine_tuned_param
                elif ".v." in fine_tuned_param_name and stir_v in fine_tuned_param_name:
                    if "lora_A." in fine_tuned_param_name:
                        vector_A = fine_tuned_param
                    elif "lora_B." in fine_tuned_param_name:
                        vector_B = fine_tuned_param

            if vector_A is not None and vector_B is not None:
                result = torch.matmul(vector_B, vector_A)
                original_shape = result.shape
                flattened_model_vector = result.view(-1)
                task_vector.append(flattened_model_vector)

        return task_vector, original_shape

    def compute_geometric_median(self, task_vectors: List[torch.Tensor], eps: float = 1e-8, max_iter: int = 300) -> torch.Tensor:
        if not task_vectors:
            raise ValueError("No task vectors provided for geometric median computation.")

        median = torch.mean(torch.stack(task_vectors), dim=0).to(self.device)

        for iteration in range(max_iter):
            distances = torch.stack([torch.norm(tv - median) for tv in task_vectors])
            distances[distances < 1e-10] = 1e-10  # Avoid division by zero
            weights = 1.0 / distances
            weights_sum = torch.sum(weights)

            if weights_sum == 0:
                break

            weighted_sum = torch.stack([w * tv for w, tv in zip(weights, task_vectors)]).sum(dim=0)
            new_median = weighted_sum / weights_sum
            shift = torch.norm(new_median - median)

            if shift < eps:
                return new_median

            median = new_median

        return median


In [6]:
## evaluate.py
! pip install evaluate

import evaluate 
import json 
import os 
import pandas as pd

from typing import List, Dict, Any

import torch
from tqdm import tqdm
from torch.utils import data


def convert_dict_of_lists_to_list_of_dicts(dict_of_lists: Dict[Any, List]) -> List[Dict]:
    """
    Args:
        dict_of_lists:

    Returns:
        list_ofDict
    """
    list_of_dicts = []
    for datapoint_values in zip(*dict_of_lists.values()):
        list_of_dicts.append(dict(zip(dict_of_lists, datapoint_values)))
    return list_of_dicts

def collate_fn(batch_of_datapoints: List[Dict]) -> Dict[Any, List]:
    """
    Convert a batch of datapoints into a datapoint that is batched. This is meant to override the default collate function in pytorch and specifically can handle when the value is a list 

    Args:
        batch_ofDatapoints:

    Returns:

    """
    datapoint_batched = {}
    for datapoint in batch_of_datapoints:
        # Gather together all the values per key
        for key, value in datapoint.items():
            if key in datapoint_batched:
                datapoint_batched[key].append(value)
            else:
                datapoint_batched[key] = [value]
    return datapoint_batched


def evaluate_dataset(
    merge_method,
    dataset_filepath: str,
) -> (Dict, List):

    data_loader = data.DataLoader(
        Dataset(dataset_filepath),
        batch_size=1,
        num_workers=0,
        shuffle=False,
        collate_fn=collate_fn
    )

    all_batches = []

    with torch.no_grad():
        for batch in tqdm(data_loader):
            # There are two types of evaluation models:
            # 1) multiple choice where the model scores each choice and predicts the choice with the highest score 
            # 2) generation where the model generate some output give some input 
            eval_type = batch["eval_type"][0]
            
            if eval_type == "multiple_choice":
                (
                    predicted_choice,
                    answer_choice_scores,
                ) = merge_method.predict_multiple_choice(batch)

                batch["prediction"] = str(predicted_choice.cpu().numpy().tolist()[0])
                all_batches.extend(convert_dict_of_lists_to_list_of_dicts(batch))
            
            else:
                assert eval_type == "generation"
                (
                    generated_ids, generated_txt
                ) = merge_method.generate(batch
                )
                batch["prediction"] = generated_txt 
                all_batches.extend(convert_dict_of_lists_to_list_of_dicts(batch))

    return all_batches

def evaluate_model(
    merge_method,
    all_dataset_filepaths: List[str],
) -> Dict:   
    output_dir = os.path.join("output", merge_method.get_name())
    prediction_dir = os.path.join(output_dir, "predictions")
    os.makedirs(prediction_dir, exist_ok=True)
    # Save merged model 
    merge_method.save_model(output_dir)

    all_scores = {}

    for dataset_filepath in all_dataset_filepaths:
        dataset_predictions = evaluate_dataset(merge_method, dataset_filepath)
        dp_df = pd.DataFrame(dataset_predictions)
        dp_df["dummy_field"] = 0
        dp_df.to_csv("../submission.csv", columns=["id", "prediction", "dummy_field"], index=False)



In [7]:
## main.py

def all_merge_handlers():
    """Enumerate and Load (import) all merge methods."""
    loaded_merges = {
        "flant5_geomed": FlanT5GeoMed,
    }
    
    
    return loaded_merges

# Load correct merging method 
merging_method = "flant5_geomed"
os.environ["HF_AUTH_TOKEN"] = ""#TO DO - Enter your HF auth token
loaded_merges = all_merge_handlers()
merge_method = loaded_merges[merging_method](merging_method)

# Call the merge function. The merged model is stored under merging_method object 
merge_method.merge()





Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

T5ForConditionalGeneration(
  (shared): Embedding(32128, 2048)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 2048)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=2048, out_features=2048, bias=False)
              (k): Linear(in_features=2048, out_features=2048, bias=False)
              (v): Linear(in_features=2048, out_features=2048, bias=False)
              (o): Linear(in_features=2048, out_features=2048, bias=False)
              (relative_attention_bias): Embedding(32, 32)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=2048, out_features=5120, bias=False)
              (wi_1): Linear(in_features=2048, out_features=5120, bias=False)
       

In [8]:

dataset_filepaths = ["..data/validation.csv"] #To Enter Path Of Validation Dataset
# Evaluate method on datsets passed in (used for testing)
evaluate_model(
    merge_method,
    dataset_filepaths,
)


100%|██████████| 200/200 [01:31<00:00,  2.18it/s]
