Skip to content
This repository has been archived by the owner on Apr 28, 2021. It is now read-only.

Commit

Permalink
Merge pull request #26 from Amirali-Shirkh/feature/url-rewriting-for-…
Browse files Browse the repository at this point in the history
…images

feat: support for image url text replacements in BotfrontTemplatedNaturalLanguageGenerator and GraphQLNaturalLanguageGenerator
  • Loading branch information
pheel committed Jul 22, 2020
2 parents e5dac90 + ef3792c commit d16e75f
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
6 changes: 5 additions & 1 deletion rasa_addons/core/nlg/bftemplate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import logging
from collections import defaultdict

from rasa_addons.core.nlg.nlg_helper import rewrite_url
from rasa.core.trackers import DialogueStateTracker
from typing import Text, Any, Dict, Optional, List

Expand All @@ -14,6 +14,8 @@
class BotfrontTemplatedNaturalLanguageGenerator(NaturalLanguageGenerator):
def __init__(self, **kwargs) -> None:
domain = kwargs.get("domain")
templated_endpoint = kwargs.get("endpoint_config")
self.url_substitution_pattern = templated_endpoint.kwargs.get('url_substitutions') or []
self.templates = domain.templates if domain else []

def _templates_for_utter_action(self, utter_action, output_channel, **kwargs):
Expand Down Expand Up @@ -85,11 +87,13 @@ async def generate(
fallback_language=fallback_language,
)
if "language" in message: del message["language"]
rewrite_url(message, self.url_substitution_pattern)
metadata = message.pop("metadata", {}) or {}
for key in metadata: message[key] = metadata[key]

return message


def generate_from_slots(
self,
template_name: Text,
Expand Down
10 changes: 10 additions & 0 deletions rasa_addons/core/nlg/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from enum import Enum

class NlgEnum(Enum):
"""A class to represent constants for enum values for nlg classes."""

IMAGE = "image"
IMAGE_URL = "image_url"
ELEMENTS = "elements"
PATTERN = "pattern"
REPLACEMENT = "replacement"
5 changes: 3 additions & 2 deletions rasa_addons/core/nlg/graphql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Text, Any, Dict, Optional, List

from rasa_addons.core.nlg.nlg_helper import rewrite_url
from rasa.core.constants import DEFAULT_REQUEST_TIMEOUT
from rasa.core.nlg.generator import NaturalLanguageGenerator
from rasa.core.trackers import DialogueStateTracker, EventVerbosity
Expand Down Expand Up @@ -100,13 +100,13 @@ def nlg_request_format(
"channel": {"name": output_channel},
}


class GraphQLNaturalLanguageGenerator(NaturalLanguageGenerator):
"""Like Rasa's CallbackNLG, but queries Botfront's GraphQL endpoint"""

def __init__(self, **kwargs) -> None:
endpoint_config = kwargs.get("endpoint_config")
self.nlg_endpoint = endpoint_config
self.url_substitution_pattern = endpoint_config.kwargs.get('url_substitutions') or []

async def generate(
self,
Expand Down Expand Up @@ -152,6 +152,7 @@ async def generate(
", ".join([e.get("message") for e in response.get("errors")])
)
response = response.get("data", {}).get("getResponse", {})
rewrite_url(response, self.url_substitution_pattern)
if "customText" in response:
response["text"] = response.pop("customText")
if "customImage" in response:
Expand Down
23 changes: 23 additions & 0 deletions rasa_addons/core/nlg/nlg_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import re
from rasa_addons.core.nlg.constants import NlgEnum

def rewrite_url(message: dict, url_substitution_pattern: list):
"""Rewrite image url with the pattern found in endpoint."""

if url_substitution_pattern:
if NlgEnum.IMAGE.value in message.keys():
substitute(message, NlgEnum.IMAGE.value, url_substitution_pattern)
elif NlgEnum.ELEMENTS.value in message.keys():
for element in message[NlgEnum.ELEMENTS.value]:
substitute(element, NlgEnum.IMAGE_URL.value, url_substitution_pattern)

def substitute(message: dict, key: str, url_substitution_pattern: list):
"""Substitute rewritten url."""

url = message[key]
for item in url_substitution_pattern:
substitute = re.sub(item.get(NlgEnum.PATTERN.value), item.get(NlgEnum.REPLACEMENT.value), message[key])
if substitute != url:
message[key] = substitute
return
return

0 comments on commit d16e75f

Please sign in to comment.