Skip to content

Commit

Permalink
Merge pull request #60 from DoubleAix/master
Browse files Browse the repository at this point in the history
update jieba_tokenizer.py
  • Loading branch information
crownpku committed Aug 10, 2018
2 parents b95c34c + df9161a commit 130f167
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions rasa_nlu/tokenizers/jieba_tokenizer.py
Expand Up @@ -16,6 +16,11 @@

import os
import glob
import shutil

DEFAULT_DICT_FILE_NAME = "jieba_default_dict"
USER_DICTS_FOLDER_NAME = "jieba_user_dicts/"
USER_DICT_FILE_NAME = USER_DICTS_FOLDER_NAME + "user_dict.txt"

class JiebaTokenizer(Tokenizer, Component):

Expand Down Expand Up @@ -46,7 +51,7 @@ def create(cls, cfg):
component_conf = cfg.for_component(cls.name, cls.defaults)
tokenizer = cls.init_jieba(tokenizer, component_conf)

return JiebaTokenizer(component_conf, tokenizer)
return cls(component_conf, tokenizer)

@classmethod
def load(cls,
Expand All @@ -60,9 +65,16 @@ def load(cls,
import jieba as tokenizer

component_meta = model_metadata.for_component(cls.name)

if component_meta.get("default_dict"):
path_default_dict = os.path.join(model_dir, component_meta.get("default_dict"))
component_meta["default_dict"] = path_default_dict
if component_meta.get("user_dicts"):
path_user_dicts = os.path.join(model_dir, component_meta.get("user_dicts"))
component_meta["user_dicts"] = path_user_dicts
tokenizer = cls.init_jieba(tokenizer, component_meta)

return JiebaTokenizer(component_meta, tokenizer)
return cls(component_meta, tokenizer)

@classmethod
def required_packages(cls):
Expand Down Expand Up @@ -140,8 +152,26 @@ def set_user_dicts(tokenizer, path_user_dicts):

def persist(self, model_dir):
# type: (Text) -> Dict[Text, Any]
return_dict = {}

return {
"user_dicts": self.component_config.get("user_dicts"),
"default_dict": self.component_config.get("default_dict")
}
if self.component_config.get("default_dict"):
des_path_default_dict = os.path.join(model_dir, DEFAULT_DICT_FILE_NAME)
if os.path.isfile(self.component_config.get("default_dict")):
shutil.copy2(self.component_config.get("default_dict"), des_path_default_dict)
return_dict.update({"default_dict": DEFAULT_DICT_FILE_NAME})

if self.component_config.get("user_dicts"):
des_path_user_dicts = os.path.join(model_dir, USER_DICTS_FOLDER_NAME)
os.mkdir(des_path_user_dicts)
if os.path.isdir(self.component_config.get("user_dicts")):
parse_pattern = "{}/*"
path_user_dicts = glob.glob(parse_pattern.format(self.component_config.get("user_dicts")))
for path_user_dict in path_user_dicts:
shutil.copy2(path_user_dict, des_path_user_dicts)
return_dict.update({"user_dicts": USER_DICTS_FOLDER_NAME})
else:
des_path_user_dict = os.path.join(model_dir, USER_DICT_FILE_NAME)
shutil.copy2(self.component_config.get("user_dicts"), des_path_user_dict)
return_dict.update({"user_dicts": USER_DICT_FILE_NAME})

return return_dict

0 comments on commit 130f167

Please sign in to comment.