In [1]:
import torch.nn as nn
import torch

In [4]:
linear = nn.Linear(2,4)

In [10]:
linear.weight

Parameter containing:
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], requires_grad=True)

In [9]:
nn.init.zeros_(linear.weight)

Parameter containing:
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], requires_grad=True)

In [16]:
class MOELoraLayer(nn.Module):
    def __init__(self, dim, r, expert_num, hydra=False):
        super().__init__()
        self.expert_num = expert_num
        self.hydra = hydra # hydra lora

        self.router = nn.Linear(dim, expert_num, bias=False)

        if hydra:
            self.lora_A = nn.Linear(dim, r, bias=False)
        else:
            self.lora_A = nn.ModuleList()
            for i in range(expert_num):
                self.lora_A.append(nn.Linear(dim, r, bias=False))
            
        self.lora_B = nn.ModuleList()
        for i in range(expert_num):
            self.lora_B.append(nn.Linear(r, dim, bias=False))

        # initial lora B to zeros
        for linear in self.lora_B:
            nn.init.zeros_(linear.weight)

    def forward(self, x: torch.Tensor):
        route_weight = nn.functional.softmax(self.router(x), dim=-1, dtype=torch.float32).to(x.dtype)
        # try lora_alpha
        for i in range(self.expert_num):
            if self.hydra:
                x = x + torch.unsqueeze(route_weight[:,:,i], -1) * self.lora_B[i](self.lora_A(x))
            else:
                x = x + torch.unsqueeze(route_weight[:,:,i], -1) * self.lora_B[i](self.lora_A[i](x))
        return x

In [17]:
moelora = MOELoraLayer(10,2,5)

In [29]:
moelora.state_dict().keys()

odict_keys(['router.weight', 'lora_A.0.weight', 'lora_A.1.weight', 'lora_A.2.weight', 'lora_A.3.weight', 'lora_A.4.weight', 'lora_B.0.weight', 'lora_B.1.weight', 'lora_B.2.weight', 'lora_B.3.weight', 'lora_B.4.weight'])

In [24]:
x = torch.randn(3,4)
a = torch.randn(4,2)
a1= torch.randn(4,2)
a2= torch.randn(4,2)

b1 = torch.randn(2,4)
b2= torch.randn(2,4)

relu = nn.ReLU()

In [20]:
(relu(x @ a1) @ b1)*0.4 + (relu(x @ a2) @ b2)*0.6

tensor([[ 0.0131,  0.8114, -0.1674,  0.1921],
        [ 0.3173,  1.5902,  1.2141,  0.6184],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])

In [21]:
relu(x @ a1) @ (b1*0.4) + relu(x @ a2) @ (b2*0.6)

tensor([[ 0.0131,  0.8114, -0.1674,  0.1921],
        [ 0.3173,  1.5902,  1.2141,  0.6184],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])

In [25]:
relu(x @ a) @ ( b1*0.4) + relu(x @ a) @ (b2*0.6)

tensor([[-0.5762, -0.2077, -0.6032, -0.1846],
        [-0.3817, -0.1376, -0.3996, -0.1223],
        [ 0.0645, -0.0288,  0.0293, -0.0455]])

In [26]:
relu(x @ a) @ ( b1*0.4 + b2*0.6)

tensor([[-0.5762, -0.2077, -0.6032, -0.1846],
        [-0.3817, -0.1376, -0.3996, -0.1223],
        [ 0.0645, -0.0288,  0.0293, -0.0455]])

In [5]:
a = [1,2,3,4]
ta = torch.FloatTensor(a)

In [8]:
ta.dtype

torch.float32

In [12]:
nn.functional.softmax(ta)*4

  nn.functional.softmax(ta)*4


tensor([0.1282, 0.3486, 0.9475, 2.5757])

In [14]:
ta.unsqueeze(0).unsqueeze(0).shape

torch.Size([1, 1, 4])

# math length

In [2]:
from LLaMA3_lora_bias.llama import Tokenizer
model_path = '/home2/caojie/pretrain_models/Meta-Llama-3-8B/'
tokenizer = Tokenizer(model_path= f"{model_path}/tokenizer.model")

In [17]:
import json
with open(f'/home2/caojie/projects/LLM-Adapters/ft-training_set/math_14k.json', 'r') as f:
# with open(f'/home2/caojie/projects/LLM-Adapters/ft-training_set/commonsense_170k.json', 'r') as f:
    data = f.read()

data = json.loads(data)

In [14]:
data[1]

{'instruction': 'Please answer the following question with true or false, question: do good samaritan laws protect those who help at an accident?\n\nAnswer format: true/false',
 'input': '',
 'output': 'the correct answer is true',
 'answer': 'true'}

In [18]:
input_lens=[]
output_lens=[]
for x in data:
    input_lens.append(len(tokenizer.encode(x['instruction']+x['output'], bos=False, eos=False)))
    output_lens.append(len(tokenizer.encode(x['output'], bos=False, eos=False)))
print(f'average tokens:{sum(input_lens)/len(input_lens)}')
print(f'average tokens:{sum(output_lens)/len(output_lens)}')

average tokens:184.21255656921198
average tokens:127.19854895481646


In [3]:
import json
import re
import copy
import os
def extract_answer_number(dataset, sentence: str) -> float:
    dataset = dataset.lower()
    if dataset in ["multiarith", "addsub", "singleeq", "gsm8k", "svamp"]:
        sentence = sentence.replace(',', '')
        pred = [s for s in re.findall(r'-?\d+\.?\d*', sentence)]
        if not pred:
            return float('inf')
        pred_answer = float(pred[-1])
    else:
        raise NotImplementedError(' not support dataset: {}'.format(dataset))
    if isinstance(pred_answer, str):
        try:
            pred_answer = float(pred_answer)
        except ValueError as e:
            pred_answer = float('inf')
    return pred_answer


def extract_answer_letter(sentence: str) -> str:
    sentence_ = sentence.strip()
    pred_answers = re.findall(r'A|B|C|D|E', sentence_)
    if pred_answers:
        if not pred_answers:
            return ''
        return pred_answers[0]
    else:
        return ''

test_dataset_l="AddSub AQuA gsm8k MultiArith SingleEq SVAMP"
for dataset in test_dataset_l.split():
    save_path=f"/home2/caojie/outputs/LLaMA3-1_lora_bias/math_14k/b32_epoch3_warme1_lorar8_loraQ,K,V,O,FFN_UP_blr6e-3_maxseq300_flashatt2False_bf16True_/{dataset}_predict_mingen120.jsonl"
    with open(save_path, 'r') as f:
        data_l = f.readlines()
    data_l = [json.loads(one) for one in data_l]
    total = len(data_l)
    correct = 0
    miss = 0.001
    for data in data_l:
        label = data.get('answer')
        flag = False
        if dataset.lower() in ['aqua']:
            predict = extract_answer_letter(data.get('generate'))
            if label == predict:
                correct += 1
                flag = True
        else:
            if isinstance(label, str):
                label = float(label)
            predict = extract_answer_number(dataset, data.get('generate'))
            if abs(label - predict) <= miss:
                correct += 1
                flag = True
        new_data = copy.deepcopy(data)
        new_data['pred'] = predict
        new_data['flag'] = flag

        directory = os.path.dirname(save_path)
        with open(os.path.join(directory, f'{dataset}_predict_check.jsonl'), 'a', encoding='utf-8') as f:
            json_data = json.dumps(new_data, ensure_ascii=False)
            f.write(json_data+'\n')
 
    print(f'{dataset}: accuracy {correct}  {correct / total}')

            

AddSub: accuracy 356  0.9012658227848102
AQuA: accuracy 92  0.36220472440944884
gsm8k: accuracy 1024  0.7763457164518575
MultiArith: accuracy 591  0.985
SingleEq: accuracy 491  0.9665354330708661
SVAMP: accuracy 810  0.81
