In [1]:
# 定义简易EmbeddingModel

import os
import torch
import logging
from typing import Optional
from transformers import AutoModel, AutoTokenizer

logger = logging.getLogger("EmbeddingModel")


class EmbeddingModel(torch.nn.Module):
    MRL_NAME = "mrl_proj_{mrl_dim}"

    def __init__(self, model_path, model_kwargs, mrl_dims: list[int] = [64], mrl_2layer_proj: bool = True):
        super().__init__()
        model = AutoModel.from_pretrained(model_path, **model_kwargs)
        self.model = model

        for mrl_dim in mrl_dims:
            mrl_name = EmbeddingModel.MRL_NAME.format(mrl_dim=mrl_dim)
            if mrl_2layer_proj:
                projection = torch.nn.Sequential(
                    torch.nn.Linear(model.config.hidden_size, (model.config.hidden_size + mrl_dim) // 2),
                    torch.nn.SiLU(),
                    torch.nn.Linear((model.config.hidden_size + mrl_dim) // 2, mrl_dim),
                )
                setattr(self, mrl_name, projection)
            else:
                projection = torch.nn.Linear(model.config.hidden_size, mrl_dim)
                setattr(self, mrl_name, projection)
            self.get_mrl_proj(mrl_dim).to(device=model.device, dtype=model.dtype)

            mrl_weight_path = os.path.join(model_path, f"{mrl_name}.pt")

            if os.path.exists(mrl_weight_path):
                weight = torch.load(mrl_weight_path, map_location=model.device)
                self.get_mrl_proj(mrl_dim).load_state_dict(weight)
            else:
                logger.warning(f"MRL projection weight for {mrl_name} not found! Use random initialization instead.")        

    def get_mrl_proj(self, mrl_dim: int = None):
        if mrl_dim is None or mrl_dim == False:
            return lambda x: x
        else:
            return getattr(self, f"mrl_proj_{mrl_dim}")
        
    def _pool(
        self, 
        last_hidden_states: torch.Tensor, 
        pooling_method: str = "last",
        attention_mask: Optional[torch.Tensor] = None, 
        position_ids: Optional[torch.Tensor] = None, 
    ):
        """
        Pool the last_hidden_states along the sequence dimension. Handle packed inputs as well.
        """
        if position_ids is None:
            # NOTE: no packing
            # last_hidden_states: batch_size (* group_size), seq_len, d_embed
            if pooling_method == "first":
                embedding = last_hidden_states[:, 0]
            elif pooling_method == "last":
                embedding = last_hidden_states[:, -1]
            elif pooling_method == "mean":
                last_hidden_states = last_hidden_states.masked_fill(
                    ~attention_mask[..., None].bool(), 0.0)
                embedding = last_hidden_states.sum(
                    dim=1) / attention_mask.sum(dim=1, keepdim=True)
            elif pooling_method == "weighted-mean":
                attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
                s = torch.sum(last_hidden_states * attention_mask_.unsqueeze(-1).float(), dim=1)
                d = attention_mask_.sum(dim=1, keepdim=True).float()
                embedding = s / d
            else:
                raise NotImplementedError(f"Pooling_method {pooling_method} not implemented!")
        else:
            # NOTE: packing
            # last_hidden_states: 1, all_seq_len, d_embed
            # position_ids: 1, all_seq_len

            # all_seq_len, d_embed
            last_hidden_states = last_hidden_states[0]
            position_ids = position_ids[0]

            sequence_start_pos = position_ids == 0

            if pooling_method == "first":
                embedding = last_hidden_states[sequence_start_pos]
            elif pooling_method == "last":
                indices = torch.arange(len(position_ids), device=position_ids.device) - 1
                indices = indices[sequence_start_pos]
                # get the index of the last token in each sequence
                indices[:-1] = indices[1:].clone()
                indices[-1] = len(position_ids) - 1
                embedding = last_hidden_states[indices]
            elif pooling_method == "mean":
                embedding = torch.zeros_like(last_hidden_states[sequence_start_pos])
                indices = sequence_start_pos.cumsum(-1) - 1
                # accumulate hidden states of the same sequence
                embedding.index_add_(0, indices, last_hidden_states)
                # compute sequence lengths
                zero_indices = sequence_start_pos.nonzero(as_tuple=True)[0]
                indices = torch.cat([zero_indices, torch.tensor([len(position_ids)], dtype=zero_indices.dtype, device=zero_indices.device)])
                lengths = indices[1:] - indices[:-1]
                # mean over sequence
                embedding = embedding / lengths.unsqueeze(-1)
            else:
                raise NotImplementedError(f"Pooling method {pooling_method} is currently not supported for packed inputs!")

        return embedding
        
    def encode(self, inputs, mrl_dim: int = 64, normalize: bool = True, pooling_method: str = "last"):
        if "position_ids" in inputs:
            # NOTE: packing case
            # (1, all_sequence_length, d_embed)
            embeddings = self.model(**inputs).last_hidden_state
            # (batch_size, d_embed)
            embedding = self._pool(embeddings, pooling_method=pooling_method, position_ids=inputs["position_ids"])
            
        else:
            # NOTE: padding case
            # (batch_size, seq_len, d_embed)
            embeddings = self.model(**inputs).last_hidden_state
            # (batch_size, d_embed)
            embedding = self._pool(embeddings, pooling_method=pooling_method, attention_mask=inputs["attention_mask"])

        # NOTE: transform to a given dimension
        # (batch_size, seq_len, d_mrl)
        embedding = self.get_mrl_proj(mrl_dim)(embedding)

        if normalize:
            embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
        return embedding

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 按EmbeddingModel加载

model_path = "/mnt/bn/search-douyin-rank-yg/all_data_from_lf/peitian_data/data/outputs/qwen_0.5b-listwise-d64_2layer-v4-g2-fullparam-intra/checkpoint-26793"
model_kwargs = {
    # 用FA2
    "attn_implementation": "flash_attention_2",
    # 用GPU
    "device_map": {"": "cuda"},
    # 用fp16
    "torch_dtype": torch.float16
}
embedding_model = EmbeddingModel(
    model_path, 
    model_kwargs=model_kwargs,
    # 默认就是64
    # mrl_dims=[64]
    # 默认就是True
    # mrl_2layer_proj=True,
)

Some weights of the model checkpoint at /mnt/bn/search-douyin-rank-yg/all_data_from_lf/peitian_data/data/outputs/qwen_0.5b-listwise-d64_2layer-v4-g2-fullparam-intra/checkpoint-26793 were not used when initializing Qwen2Model: ['mrl_proj_64.0.bias', 'mrl_proj_64.0.weight', 'mrl_proj_64.2.bias', 'mrl_proj_64.2.weight']
- This IS expected if you are initializing Qwen2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Qwen2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  weight = torch.load(mrl_weight_path, map_location=model.device)


In [3]:
# 按peitian代码加载

import os
import sys
if ".." not in sys.path:
    sys.path.append("..")
    os.chdir("..")

from src import ModelArgs, get_model_and_tokenizer

args = ModelArgs()

model_ref, _ = get_model_and_tokenizer(
    args, 
    model_name_or_path=model_path,
    mrl_dims=[64],
    mrl_2layer_proj=True,
    device="cuda"

    # packing is enabled by default
)

[32m2024-12-03 17:17:07.018[0m | [1mINFO    [0m | [36msrc[0m:[36mget_model_and_tokenizer[0m:[36m36[0m - [1mLoading model and tokenizer from /mnt/bn/search-douyin-rank-yg/all_data_from_lf/peitian_data/data/outputs/qwen_0.5b-listwise-d64_2layer-v4-g2-fullparam-intra/checkpoint-26793...[0m
Some weights of the model checkpoint at /mnt/bn/search-douyin-rank-yg/all_data_from_lf/peitian_data/data/outputs/qwen_0.5b-listwise-d64_2layer-v4-g2-fullparam-intra/checkpoint-26793 were not used when initializing Qwen2Model: ['mrl_proj_64.0.bias', 'mrl_proj_64.0.weight', 'mrl_proj_64.2.bias', 'mrl_proj_64.2.weight']
- This IS expected if you are initializing Qwen2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Qwen2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertFor

In [13]:
# 测试一致性

def apply_template(text, tokenizer, max_length: int = 512, is_query: bool = True):
    # 25 tokens reserved for template
    max_length = max_length - 25
    input_ids = tokenizer.encode(text, max_length=max_length, truncation=True)
    text = tokenizer.decode(input_ids, skip_special_tokens=True)

    if is_query:
        text = "Query: {text}\nUse one word to summarize the query's relevant information. The word is: \"".format(text=text)
    else:
        text = "Text: {text}\nUse one word to summarize the text's content. The word is: \"".format(text=text)
    return text

# NOTE: 左padding从而-1能直接拿到最后一个token的embedding
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")

inputs1 = "I love bytedance!" * 1000
inputs1 = apply_template(inputs1, tokenizer, max_length=512)

inputs2 = "I love huawei" * 100
inputs2 = apply_template(inputs2, tokenizer, max_length=512)

device = embedding_model.model.device

# # padding
# inputs = tokenizer([inputs1, inputs2], padding=True, return_tensors="pt").to(device)

# # packing
input_ids = tokenizer([inputs1, inputs2]).input_ids
position_ids = [list(range(len(x))) for x in input_ids]
inputs = {
    "input_ids": torch.tensor([sum(input_ids, [])], device=device), 
    "position_ids": torch.tensor([sum(position_ids, [])], device=device)
}

print(inputs.keys())

a = embedding_model.encode(inputs)
b = model_ref._encode(inputs)
print(a == b)

dict_keys(['input_ids', 'attention_mask'])
tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True]], device='cuda:0')


In [None]:
# doc_info => text
import json

def parse_text_from_douyin_doc_info(doc_info):
    doc_text = ""
    doc = json.loads(doc_info)
    fields = [
        ("title", doc.get('title', '').strip()),
        ("username", doc.get('username', '').strip()),
        ("music", doc.get('music', '').strip()),
        ("poi", doc.get('poi', '').strip()),
        ("challenge", doc.get('challenge', '').strip()),
        ("ocr", doc.get('ocr', '').strip()),
        ("asr", doc.get('asr', '').strip())
    ]
    for field_name, field_value in fields:
        doc_text += f"<{field_name}>{field_value}\n\n"
    return doc_text

In [None]:
# Merge LoRA and save to destination

lora_path = "/mnt/bn/search-douyin-rank-yg/all_data_from_lf/peitian_data/data/outputs/qwen_0.5b-listwise-d64_2layer-v4_0.1-g2-bs64/checkpoint-6304"
model, tokenizer = get_model_and_tokenizer(
    args, 
    mrl_dims=[64],
    mrl_2layer_proj=True,
    lora=lora_path,

    # 在CPU上加载模型必须用sdpa/eager, 同时禁用packing
    attn_impl="sdpa",
    packing=False,

    # NOTE: lora_unload默认为True, 即get_model_and_tokenizer会自动合并lora weight到模型上
    # lora_unload=True,
)


# 可能会有一些cuda报错, 不重要, 忽视即可
dest = lora_path + "-merge_lora"
model.save_pretrained(dest)
tokenizer.save_pretrained(dest)