Skip to content

Commit

Permalink
[NeuralChat] Fix ner nightly ut bug (#822)
Browse files Browse the repository at this point in the history
  • Loading branch information
letonghan committed Nov 30, 2023
1 parent 775e6fa commit 9e5a6b3
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 144 deletions.
3 changes: 0 additions & 3 deletions intel_extension_for_transformers/neural_chat/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ def build_chatbot(config: PipelineConfig=None):
elif plugin_name == "ner":
from .pipeline.plugins.ner.ner import NamedEntityRecognition
plugins[plugin_name]['class'] = NamedEntityRecognition
elif plugin_name == "ner_int":
from .pipeline.plugins.ner.ner_int import NamedEntityRecognitionINT
plugins[plugin_name]['class'] = NamedEntityRecognitionINT
elif plugin_name == "face_animation": # pragma: no cover
from .pipeline.plugins.video.face_animation.sadtalker import SadTalker
plugins[plugin_name]['class'] = SadTalker
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from threading import Thread
import re
import time
import torch
import spacy
from transformers import (
TextIteratorStreamer,
)
from .utils.utils import (
enforce_stop_tokens,
get_current_time
)
from .utils.process_text import process_time, process_entities
from intel_extension_for_transformers.neural_chat.prompts import PromptTemplate


class NamedEntityRecognition():
Expand Down

This file was deleted.

1 change: 0 additions & 1 deletion intel_extension_for_transformers/neural_chat/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def reset_plugins(self):
"cache": {"enable": False, "class": None, "args": {}, "instance": None},
"safety_checker": {"enable": False, "class": None, "args": {}, "instance": None},
"ner": {"enable": False, "class": None, "args": {}, "instance": None},
"ner_int": {"enable": False, "class": None, "args": {}, "instance": None},
"face_animation": {"enable": False, "class": None, "args": {}, "instance": None}
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def setUp(self):

def tearDown(self) -> None:
for filename in os.getcwd():
import re
if re.match(r'ne_.*_fp32.bin', filename) or re.match(r'ne_.*_q.bin', filename):
file_path = os.path.join(os.getcwd(), filename)
try:
Expand All @@ -37,55 +38,11 @@ def tearDown(self) -> None:
print(f"Error deleting file {filename}: {str(e)}")
return super().tearDown()

def test_fp32(self):
def test_ner(self):
os.system('python -m spacy download en_core_web_lg')
ner_obj = NamedEntityRecognition(model_path="/tf_dataset2/models/nlp_toolkit/mpt-7b")
ner_obj = NamedEntityRecognition()
query = "Show me photos taken in Shanghai."
result = ner_obj.inference(query=query)
_result = {
'period': [],
'time': [],
'location': ['Shanghai'],
'name': [],
'organization': []
}
self.assertEqual(result, _result)

def test_bf16(self):
os.system('python -m spacy download en_core_web_lg')
ner_obj = NamedEntityRecognition(model_path="/tf_dataset2/models/nlp_toolkit/mpt-7b", bf16=True)
query = "Show me photos taken in Shanghai."
result = ner_obj.inference(query=query)
_result = {
'period': [],
'time': [],
'location': ['Shanghai'],
'name': [],
'organization': []
}
self.assertEqual(result, _result)

def test_int8(self):
os.system('python -m spacy download en_core_web_lg')
ner_obj = NamedEntityRecognitionINT(model_path="/tf_dataset2/models/nlp_toolkit/mpt-7b")
query = "Show me photos taken in Shanghai."
result = ner_obj.inference(query=query, threads=8)
_result = {
'period': [],
'time': [],
'location': ['Shanghai'],
'name': [],
'organization': []
}
self.assertEqual(result, _result)

def test_int4(self):
os.system('python -m spacy download en_core_web_lg')
ner_obj = NamedEntityRecognitionINT(model_path="/tf_dataset2/models/nlp_toolkit/mpt-7b",
compute_dtype="int8",
weight_dtype="int4")
query = "Show me photos taken in Shanghai."
result = ner_obj.inference(query=query, threads=8)
result = ner_obj.ner_inference(query)
_result = {
'period': [],
'time': [],
Expand Down

0 comments on commit 9e5a6b3

Please sign in to comment.