Skip to content

Commit

Permalink
[Engine] Apply the STS task to bge models (#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenzhong1 committed Nov 29, 2023
1 parent 74e92aa commit 0c4c5ed
Show file tree
Hide file tree
Showing 13 changed files with 461 additions and 115 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Step-by-Step
=======
This document describes the end-to-end workflow for Huggingface model [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5) and [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) with Neural Engine backend.
This document describes the end-to-end workflow for Huggingface model [BAAI/bge-large-en-v1.5](https://huggingface.co/BAAI/bge-large-en-v1.5), [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) and [BAAI/bge-small-en-v1.5](https://huggingface.co/BAAI/bge-small-en-v1.5) with Neural Engine backend.

Here we take the [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5) as an example.

Expand Down Expand Up @@ -53,15 +53,15 @@ Neural Engine can parse ONNX model and Neural Engine IR.
We provide with three `modes`: `accuracy`, `throughput` or `latency`. For throughput mode, we will use multi-instance with 4cores/instance occupying one socket.
You can run fp32 model inference by setting `precision=fp32`, command as follows:
```shell
bash run_bge.sh --model=BAAI/bge-base-en-v1.5 --dataset=mrpc --precision=fp32 --mode=throughput
bash run_bge.sh --model=BAAI/bge-base-en-v1.5 --precision=fp32 --mode=throughput
```
By setting `precision=int8` you could get PTQ int8 model and setting `precision=bf16` to get bf16 model.
```shell
bash run_bge.sh --model=BAAI/bge-base-en-v1.5 --dataset=mrpc --precision=int8 --mode=throughput
bash run_bge.sh --model=BAAI/bge-base-en-v1.5 --precision=int8 --mode=throughput
```
By setting `precision=dynamic_int8`, you could benchmark dynamic quantized int8 model.
```shell
bash run_bge.sh --model=BAAI/bge-base-en-v1.5 --dataset=mrpc --precision=dynamic_int8 --mode=throughput
bash run_bge.sh --model=BAAI/bge-base-en-v1.5 --precision=dynamic_int8 --mode=throughput
```


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from typing import cast, List, Dict, Union

import numpy as np
import torch
from mteb import DRESModel
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
from optimum.onnxruntime import ORTModelForFeatureExtraction
import copy


class EngineBGEModel(DRESModel):
def __init__(self,
model_name_or_path: str = None,
pooling_method: str = 'cls',
normalize_embeddings: bool = True,
query_instruction_for_retrieval: str = None,
batch_size: int = 256,
backend: str = 'Engine',
**kwargs) -> None:

ort_model_path = kwargs.get("ort_model_path", None)
engine_model = kwargs.get("engine_model", None)
self.backend = kwargs.get("backend", 'Engine')

if backend == 'Engine':
self.engine_model = engine_model.graph
file_name = kwargs.get("file_name", None)
print('The backend is Neural Engine, evaluate on: ', ort_model_path, file_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.pytorch_model = AutoModel.from_pretrained(model_name_or_path)
self.hidden_size = self.pytorch_model.config.hidden_size
elif backend == 'Pytorch':
print('The backend is Pytorch, evaluate on: ', ort_model_path, file_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.pytorch_model = AutoModel.from_pretrained(model_name_or_path)
elif backend == 'Onnxruntime':
print('The backend is Onnxruntime.')
file_name = kwargs.get("file_name", None)
print('The backend is Onnxruntime, evaluate on: ', ort_model_path, file_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.ort_model = ORTModelForFeatureExtraction.from_pretrained(ort_model_path, file_name=file_name)

self.query_instruction_for_retrieval = query_instruction_for_retrieval
self.normalize_embeddings = normalize_embeddings
self.pooling_method = pooling_method
self.batch_size = batch_size
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

if self.pytorch_model is not None:
self.pytorch_model = self.pytorch_model.to(self.device)

num_gpus = torch.cuda.device_count()
if num_gpus > 1:
self.pytorch_model = torch.nn.DataParallel(self.pytorch_model)
self.batch_size = self.batch_size * num_gpus

def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
'''
This function will be used for retrieval task
if there is a instruction for queries, we will add it to the query text
'''
if self.query_instruction_for_retrieval is not None:
input_texts = ['{}{}'.format(self.query_instruction_for_retrieval, q) for q in queries]
else:
input_texts = queries
return self.encode(input_texts)

def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray:
'''
This function will be used for retrieval task
encode corpus for retrieval task
'''
if isinstance(corpus[0], dict):
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
else:
input_texts = corpus
return self.encode(input_texts)

@torch.no_grad()
def encode(self, sentences: List[str], **kwargs) -> np.ndarray:
if self.backend == 'Engine':
ort_all_embeddings = []
for start_index in tqdm(range(0, len(sentences), self.batch_size),
desc="Batches",
disable=len(sentences) < 256):
sentences_batch = sentences[start_index:start_index + self.batch_size]
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512,
).to(self.device)

ort_inputs = self.tokenizer(sentences_batch,
padding=True,
truncation=True,
max_length=512,
return_tensors="np")

input_ids = np.ascontiguousarray(ort_inputs['input_ids'])
token_type_ids = np.ascontiguousarray(ort_inputs['token_type_ids'])
attention_mask = np.ascontiguousarray(ort_inputs['attention_mask'])

engine_input = [input_ids, token_type_ids, attention_mask]
result = copy.deepcopy(self.engine_model.inference(engine_input))
ort_last_hidden_state = torch.tensor(result['last_hidden_state:0']).reshape(
input_ids.shape[0], input_ids.shape[1], self.hidden_size)
ort_embeddings = self.pooling(ort_last_hidden_state, inputs['attention_mask'])
if self.normalize_embeddings:
ort_embeddings = torch.nn.functional.normalize(ort_embeddings, dim=-1)
ort_embeddings = cast(torch.Tensor, ort_embeddings)
ort_all_embeddings.append(ort_embeddings.cpu().numpy())
return np.concatenate(ort_all_embeddings, axis=0)
elif self.backend == 'Pytorch':
self.pytorch_model.eval()
all_embeddings = []
for start_index in tqdm(range(0, len(sentences), self.batch_size),
desc="Batches",
disable=len(sentences) < 256):
sentences_batch = sentences[start_index:start_index + self.batch_size]
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512,
).to(self.device)
ort_inputs = self.tokenizer(sentences_batch,
padding=True,
truncation=True,
max_length=512,
return_tensors="np")
last_hidden_state = self.pytorch_model(**inputs, return_dict=True).last_hidden_state
embeddings = self.pooling(last_hidden_state, inputs['attention_mask'])
if self.normalize_embeddings:
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
embeddings = cast(torch.Tensor, embeddings)
all_embeddings.append(embeddings.cpu().numpy())
return np.concatenate(all_embeddings, axis=0)

elif self.backend == 'Onnxruntime':
ort_all_embeddings = []
for start_index in tqdm(range(0, len(sentences), self.batch_size),
desc="Batches",
disable=len(sentences) < 256):
sentences_batch = sentences[start_index:start_index + self.batch_size]
inputs = self.tokenizer(
sentences_batch,
padding=True,
truncation=True,
return_tensors='pt',
max_length=512,
).to(self.device)

ort_inputs = self.tokenizer(sentences_batch,
padding=True,
truncation=True,
max_length=512,
return_tensors="np")

ort_last_hidden_state = torch.tensor(self.ort_model(**ort_inputs).last_hidden_state)
ort_embeddings = self.pooling(ort_last_hidden_state, inputs['attention_mask'])
if self.normalize_embeddings:
ort_embeddings = torch.nn.functional.normalize(ort_embeddings, dim=-1)
ort_embeddings = cast(torch.Tensor, ort_embeddings)
ort_all_embeddings.append(ort_embeddings.cpu().numpy())
return np.concatenate(ort_all_embeddings, axis=0)

def pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor = None):
if self.pooling_method == 'cls':
return last_hidden_state[:, 0]
elif self.pooling_method == 'mean':
s = torch.sum(last_hidden_state * attention_mask.unsqueeze(-1).float(), dim=1)
d = attention_mask.sum(dim=1, keepdim=True).float()
return s / d
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,46 @@
from tqdm import tqdm
from datasets import load_metric
from executor_dataloader import DataLoader
from intel_extension_for_transformers.llm.runtime.deprecated.compile import compile, autocast
from intel_extension_for_transformers.llm.runtime.deprecated.compile.graph import Graph
import sys
import os
import logging

common_dir = os.path.join(sys.path[0], "../../../../neural_engine_utils/")
sys.path.append(common_dir)
from common import (log, DummyDataLoader, compute_performance, Neural_Engine_base)
from common import (log, DummyDataLoader, compute_performance)


class Neural_Engine(Neural_Engine_base):
# set log file
def set_log_file(log, log_file):
file_handler = logging.FileHandler(log_file, 'w')
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s', "%Y-%m-%d %H:%M:%S")
file_handler.setFormatter(formatter)
log.addHandler(file_handler)

bge_pattern_config = {
'pattern_switch': {
'MultiHeadAttention': False,
}
}

class Neural_Engine_bge():

def __init__(self, model_path, log_file, cast_type="native"):
set_log_file(log, log_file)
with autocast(cast_type):
self.graph = compile(model_path, bge_pattern_config)
self.log_file = log_file

def accuracy(self, batch_size, seq_len, dataset_name, task_name, data_dir, tokenizer_dir):
pass

def performance(self, batch_size, seq_len, iteration, warm_up):
pass


class Neural_Engine(Neural_Engine_bge):

def accuracy(self, batch_size, seq_len, dataset_name, task_name, data_dir, tokenizer_dir):
# load dataset
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
neural-compressor
transformers==4.34.1
accelerate
datasets >= 1.8.0
sentencepiece != 0.1.92
protobuf
torch==2.1.0
onnx>=1.12
onnx==1.13.1
onnxruntime==1.13.1

mteb==1.1.1
beir
transformers==4.34.1

0 comments on commit 0c4c5ed

Please sign in to comment.