Skip to content
This repository has been archived by the owner on Jun 24, 2020. It is now read-only.

Commit

Permalink
chore: recommit gazette refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Philippe Cote-Boucher committed Jun 18, 2019
1 parent 8121844 commit c15da0b
Showing 1 changed file with 42 additions and 63 deletions.
105 changes: 42 additions & 63 deletions rasa_addons/nlu/components/gazette.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,6 @@

from .fuzzy_matcher import process


def _find_matches(query, gazette, mode="ratio", limit=5):
output = {}
for key, val in gazette.items():
output[key] = process.extract(query, val, limit=limit, scorer=mode)
return output


def _find_entity_config(entity, config):
for rep in config.get("entities", []):
if entity["entity"] == rep["name"]:
return rep

return None


class Gazette(Component):
name = "Gazette"

Expand All @@ -36,28 +20,33 @@ class Gazette(Component):
"entities": [],
}

def __init__(self, component_config=None, gazette=None):
# type: (RasaNLUModelConfig, Dict) -> None
def __init__(self,
component_config: Text = None,
gazette: Optional[Dict] = None) -> None:

super(Gazette, self).__init__(component_config)
self.gazette = gazette if gazette else {}
if gazette: self._load_config()
self.limit = self.component_config.get("max_num_suggestions")
self.entities = self.component_config.get("entities", [])

def process(self, message, **kwargs):
# type: (Message, **Any) -> None
def process(self, message: Message, **kwargs: Any) -> None:

self._load_config()
entities = message.get("entities", [])
limit = self.component_config.get("max_num_suggestions")

new_entities = []

for entity in entities:
config = _find_entity_config(entity, self.component_config)
config = self._find_entity(entity, self.entities)
if config is None or not isinstance(entity["value"], str):
new_entities.append(entity)
continue

matches = process.extract(entity["value"], self.gazette.get(entity["entity"], []), limit=limit,
scorer=config["mode"])
matches = process.extract(
entity["value"],
self.gazette.get(entity["entity"], []),
limit=self.limit,
scorer=config["mode"]
)
primary, score = matches[0] if len(matches) else (None, None)

if primary is not None and score > config["min_score"]:
Expand All @@ -67,61 +56,51 @@ def process(self, message, **kwargs):

message.set("entities", new_entities)

def train(self, training_data, config, **kwargs):
# type: (TrainingData, RasaNLUModelConfig, **Any) -> None

self._load_gazette_list(training_data.gazette)

def persist(self, file_name: Text, model_dir: Text) -> Optional[Dict[Text, Any]]:

gazette = self.gazette if self.gazette else {}

from rasa_nlu.utils import write_json_to_file

file_name = os.path.join(model_dir, "{}.json".format(file_name))
write_json_to_file(file_name, gazette,
separators=(',', ': '))

return {"component": file_name}
def train(
self, training_data: TrainingData, cfg: RasaNLUModelConfig, **kwargs: Any
) -> None:
self.gazette = self._load_gazette_list(training_data.gazette)

@classmethod
def load(cls,
meta: Dict[Text, Any],
model_dir: Optional[Text] = None,
model_metadata: Optional[Metadata] = None,
component_meta: Dict[Text, Any],
model_dir: Text = None,
model_metadata: Metadata = None,
cached_component: Optional['Gazette'] = None,
**kwargs: Any
) -> 'Gazette':
from rasa.nlu.utils import read_json_file

file_name = meta.get("file")
if not file_name:
gazette = None
return cls(meta, gazette)

gazette_file = os.path.join(model_dir, file_name)
if os.path.isfile(gazette_file):
gazette = rasa.utils.io.read_json_file(gazette_file)
td = read_json_file(os.path.join(model_dir, "training_data.json"))
if "gazette" in td["rasa_nlu_data"]:
gazette = cls._load_gazette_list(td["rasa_nlu_data"]["gazette"])
else:
gazette = None
warnings.warn(
"Failed to load gazette file from '{}'".format(gazette_file)
)
return cls(meta, gazette)
warnings.warn("Could not find Gazette in persisted training data file.")

def _load_gazette_list(self, gazette):
# type: (Dict) -> None
return Gazette(component_meta, gazette)

@staticmethod
def _load_gazette_list(gazette: Optional[Dict]) -> None:
gazette_dict = {}
for item in gazette:
name = item["value"]
table = item["gazette"]
self.gazette[name] = table
gazette_dict[name] = table
return gazette_dict

@staticmethod
def _find_entity(entity, entities):
for rep in entities:
if entity["entity"] == rep["name"]:
return rep
return None

def _load_config(self):
entities = []
for rep in self.component_config.get("entities", []):
assert "name" in rep, "Must provide the entity name for the gazette entity configuration: {}".format(rep)
assert rep["name"] in self.gazette, "Could not find entity name {0} in gazette {1}".format(rep["name"],
self.gazette)
assert rep["name"] in self.gazette, "Could not find entity name {0} in gazette {1}".format(rep["name"], self.gazette)

supported_properties = ["mode", "min_score"]
defaults = ["ratio", 80]
Expand All @@ -136,4 +115,4 @@ def _load_config(self):

entities.append(new_element)

self.component_config["entities"] = entities
self.component_config["entities"] = entities

1 comment on commit c15da0b

@pheel
Copy link
Contributor

@pheel pheel commented on c15da0b Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cf. 40313f3

Please sign in to comment.