In [None]:
!pip install transformers==4.31.0
!pip install bitsandbytes==0.41.0
!pip install accelerate==0.21.0
!pip install sentencepiece==0.1.99
!pip install peft==0.4.0

# Load base llama-2 13b model

*Note you will need to use a valid huggingface access token in access_token variable.*

In [None]:
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training

model_path = "meta-llama/Llama-2-13b-hf"
access_token = '***'

tokenizer = LlamaTokenizer.from_pretrained(model_path, token=access_token)
tokenizer.pad_token = 0
tokenizer.padding_side = "left"

model = LlamaForCausalLM.from_pretrained(
    model_path,
    token=access_token,
    device_map={"": 0},
    load_in_4bit=True,
    torch_dtype=torch.bfloat16,
    quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type='nf4',
        ),
)

# Patch peft

The following code provides patches for the peft Linear Lora layers and their quantized versions.

Additionally, it introduces two new methods to the main model:

*   **load_multiple_loras** : This method loads a list of Lora adapters all at once.
*   **set_lora_activations** : This method sets the weights of each Lora for inference (batch inference compatible).



In [None]:
import peft
import torch.nn.functional as F

def Linear_forward(self, x: torch.Tensor):
    previous_dtype = x.dtype
    if self.active_adapter not in self.lora_A.keys():
        return F.linear(x, peft.utils.transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
    if self.disable_adapters:
        if self.r[self.active_adapter] > 0 and self.merged:
            self.unmerge()
        result = F.linear(x, peft.utils.transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
    elif self.r[self.active_adapter] > 0 and not self.merged:
        result = F.linear(x, peft.utils.transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

        x = x.to(self.lora_A[self.active_adapter].weight.dtype)

        ones = torch.ones(list(x[0].shape[:-1]) + [1])
        mults = []
        for activ in list(map(list, zip(*self.lora_activation))):
            mults.append(torch.stack([ones*_x for _x in activ]).to(x.device))

        outputs = []
        for i, adpt in enumerate(list(self.r.keys())):
            _x = x.to(self.lora_A[adpt].weight.dtype)
            outputs.append(
                self.lora_B[adpt](
                    self.lora_A[adpt](self.lora_dropout[adpt](_x))
                )
                * self.scaling[adpt]
            )
            outputs[i] = outputs[i]*mults[i]

        result += torch.sum(torch.stack(outputs), 0)
    else:
        result = F.linear(x, peft.utils.transpose(self.weight, self.fan_in_fan_out), bias=self.bias)

    result = result.to(previous_dtype)

    return result

def Linear8bitLt_forward(self, x: torch.Tensor):
    result = super(peft.tuners.lora.Linear8bitLt, self).forward(x)

    ones = torch.ones(list(x[0].shape[:-1]) + [1])
    mults = []
    for activ in list(map(list, zip(*self.lora_activation))):
        mults.append(torch.stack([ones*_x for _x in activ]).to(x.device))

    if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
        return result
    elif self.r[self.active_adapter] > 0:
        if not torch.is_autocast_enabled():
            expected_dtype = result.dtype

            if x.dtype != torch.float32:
                x = x.float()
            outputs = []
            for i, adpt in enumerate(list(self.r.keys())):
                _x = x.to(self.lora_A[adpt].weight.dtype)
                outputs.append(
                    self.lora_B[adpt](
                        self.lora_A[adpt](self.lora_dropout[adpt](_x))
                    ).to(expected_dtype)
                    * self.scaling[adpt]
                )
                outputs[i] = outputs[i]*mults[i]
            output = torch.sum(torch.stack(outputs), 0)
        else:
            outputs = []
            for i, adpt in enumerate(list(self.r.keys())):
                _x = x.to(self.lora_A[adpt].weight.dtype)
                outputs.append(
                    self.lora_B[adpt](
                        self.lora_A[adpt](self.lora_dropout[adpt](_x))
                    ).to(expected_dtype)
                    * self.scaling[adpt]
                )
                outputs[i] = outputs[i]*mults[i]
            output = torch.sum(torch.stack(outputs), 0)
        result += output
    return result

def Linear4bit_forward(self, x: torch.Tensor):
    result = super(peft.tuners.lora.Linear4bit, self).forward(x)

    ones = torch.ones(list(x[0].shape[:-1]) + [1])
    mults = []
    for activ in list(map(list, zip(*self.lora_activation))):
        mults.append(torch.stack([ones*_x for _x in activ]).to(x.device))

    if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
        return result
    elif self.r[self.active_adapter] > 0:
        result = result.clone()
        if not torch.is_autocast_enabled():
            expected_dtype = result.dtype
            outputs = []
            for i, adpt in enumerate(list(self.r.keys())):
                _x = x.to(self.lora_A[adpt].weight.dtype)
                outputs.append(
                    self.lora_B[adpt](
                        self.lora_A[adpt](self.lora_dropout[adpt](_x))
                    ).to(expected_dtype)
                    * self.scaling[adpt]
                )
                outputs[i] = outputs[i]*mults[i]
            output = torch.sum(torch.stack(outputs), 0)
        else:
            outputs = []
            for i, adpt in enumerate(list(self.r.keys())):
                outputs.append(
                    self.lora_B[adpt](
                        self.lora_A[adpt](self.lora_dropout[adpt](x))
                    ).to(expected_dtype)
                    * self.scaling[adpt]
                )
                outputs[i] = outputs[i]*mults[i]
            output = torch.sum(torch.stack(outputs), 0)

        result += output
    return result

def load_multiple_loras(self, lora_dict):
    """Monkey patch method for LlamaForCausalLM : loads multiple lora adapters at the same time

    Args:
      lora_dict (dict) : dictionnary with adapter name as key and path as values.
          Example : {'guanaco':'Mikael110/llama-2-13b-guanaco-qlora', 'dolphin':'dfurman/llama-2-13b-dolphin-peft'}

    Returns: None
    """
    for l in lora_dict:
        _model = peft.PeftModel.from_pretrained(
            self,
            lora_dict[l],
            adapter_name=l,
            force_download=False
        )

def set_lora_activations(self, lora_activation):
    """Monkey patch method for LlamaForCausalLM : sets loras activation weights to Lora layers for a given batch

    Args:
      lora_activation (array): should be a 2 dimensional array
        - first dimension being the batch element
        - second dimension being the weight per lora adapter
        Example for 3 lora layers, batch of two : [[1st_lora_adapter_1st_element, 2nd_lora_adapter_1st_element, 3rd_lora_adapter_1st_element],
                                                   [1st_lora_adapter_2nd_element, 2nd_lora_adapter_2nd_element, 3rd_lora_adapter_2nd_element]]

    Returns: None
    """
    for module in self.model.modules():
        if isinstance(module, peft.tuners.lora.LoraLayer):
            module.lora_activation = lora_activation

peft.tuners.lora.Linear.forward = Linear_forward
peft.tuners.lora.Linear8bitLt.forward = Linear8bitLt_forward
peft.tuners.lora.Linear4bit.forward   = Linear4bit_forward
LlamaForCausalLM.set_lora_activations = set_lora_activations
LlamaForCausalLM.load_multiple_loras  = load_multiple_loras

In [None]:
model.load_multiple_loras({'guanaco':'Mikael110/llama-2-13b-guanaco-qlora',
                           'dolphin':'dfurman/llama-2-13b-dolphin-peft',
                           'codealpaca':'layoric/llama-2-13b-code-alpaca-qlora'})

# Inference

The code below shows how to use a modified model to run batch inference with a different load weighting for each input.

Llama-2 seems to struggle with padded inputs. You will get better results with a single input genration (either batch size 1, or same prompt used multiple times).

In [None]:
# Code prompt
code_prompt = '### Instruction:\nWrite python code to call a REST Webservice at path /user and POST a following json data name : Mick, lastname : jaeger, login : bigboy, password : password123.\n\
Start by explaining step by step how you will proceed, then write the code, commenting each part.\n\n\
### Response:\n'
# Orca prompt
orca_prompt = 'You are an AI assistant that helps people find information. User will you give you a question. Your task is to answer as faithfully as you can. While answering think step-bystep and justify your answer.\n\
### Instruction: John was a terrible writer. To practice, his teacher suggest that he consider people he knows and do what? Options: - write novels - advertising firm - write letter - write notes - write poems Let\'s think now! Step-by-step reasoning:\n\n\
### Response:'
# Guanaco prompt
guanaco_prompt = '### Instruction:\nWhat weights more, 1kg of feathers or 1kg of lead?\n\
Reason step by step and give your answer.\n\n\
### Response:\n'

input_ids = tokenizer([guanaco_prompt,orca_prompt,code_prompt], return_tensors="pt", padding=True).input_ids
lora_activation = [[1, 0, 0],
                   [0, 1, 0],
                   [0, 0, 1]]
model.set_lora_activations(lora_activation=lora_activation)

generation_output = model.generate(
    input_ids=input_ids,
    max_new_tokens=128,
    temperature=0.01,
    repetition_penalty=1.1
)

print(tokenizer.decode(generation_output[0]))
print('----')
print(tokenizer.decode(generation_output[1]))
print('----')
print(tokenizer.decode(generation_output[2]))

In [None]:
# Guanaco prompt
guanaco_prompt = '### Instruction:\nWhat weights more, 1kg of feathers or 1kg of lead?\n\
Reason step by step and give your answer.\n\n\
### Response:\n'

input_ids = tokenizer([guanaco_prompt,guanaco_prompt], return_tensors="pt").input_ids
lora_activation = [[1, 0, 0],
                   [1, 0.75, 0]]
model.set_lora_activations(lora_activation=lora_activation)

generation_output = model.generate(
    input_ids=input_ids,
    max_new_tokens=64,
    temperature=0.01,
    repetition_penalty=1.1
)
print(tokenizer.decode(generation_output[0]))
print('----')
print(tokenizer.decode(generation_output[1]))