Skip to content

Commit

Permalink
Fix escaping in quoted values
Browse files Browse the repository at this point in the history
  • Loading branch information
odelalleau committed May 3, 2021
1 parent 4f89e38 commit cc7a6db
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 27 deletions.
51 changes: 38 additions & 13 deletions hydra/core/override_parser/overrides_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,42 @@ def visitFunction(self, ctx: OverrideParser.FunctionContext) -> Any:
f"{type(e).__name__} while evaluating '{ctx.getText()}': {e}"
) from e

def visitQuotedValue(self, ctx: OverrideParser.QuotedValueContext) -> QuotedString:
children = list(ctx.getChildren())
assert len(children) >= 2

# Identity quote type.
first_quote = children[0].getText()
if first_quote == "'":
quote = Quote.single
else:
assert first_quote == '"'
quote = Quote.double

# Inspect string content.
tokens = []
is_interpolation = False
for child in children[1:-1]:
assert isinstance(child, TerminalNode)
symbol = child.symbol
text = symbol.text
if symbol.type == OverrideLexer.ESC_QUOTE:
# Always un-escape quotes.
text = text[1]
elif symbol.type == OverrideLexer.INTERPOLATION:
is_interpolation = True
tokens.append(text)

# Contactenate string fragments.
ret = "".join(tokens)

# If it is an interpolation, then OmegaConf will take care of un-escaping
# the `\\`. But if it is not, then we need to do it here.
if not is_interpolation:
ret = ret.replace("\\\\", "\\")

return QuotedString(text=ret, quote=quote, esc_backslash=not is_interpolation)

def _createPrimitive(
self, ctx: ParserRuleContext
) -> Optional[Union[QuotedString, int, bool, float, str]]:
Expand Down Expand Up @@ -274,19 +310,8 @@ def _createPrimitive(
ret = "".join(tokens)
else:
node = ctx.getChild(first_idx)
if node.symbol.type == OverrideLexer.QUOTED_VALUE:
text = node.getText()
qc = text[0]
text = text[1:-1]
if qc == "'":
quote = Quote.single
text = text.replace("\\'", "'")
elif qc == '"':
quote = Quote.double
text = text.replace('\\"', '"')
else:
assert False
return QuotedString(text=text, quote=quote)
if isinstance(node, OverrideParser.QuotedValueContext):
return self.visitQuotedValue(node)
elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION):
ret = node.symbol.text
elif node.symbol.type == OverrideLexer.INT:
Expand Down
9 changes: 6 additions & 3 deletions hydra/core/override_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,19 @@ class Quote(Enum):
@dataclass(frozen=True)
class QuotedString:
text: str

quote: Quote
esc_backslash: bool = True

def with_quotes(self) -> str:
text = self.text
if self.esc_backslash:
text = text.replace("\\", "\\\\")
if self.quote == Quote.single:
q = "'"
text = self.text.replace("'", "\\'")
text = text.replace("'", "\\'")
elif self.quote == Quote.double:
q = '"'
text = self.text.replace('"', '\\"')
text = text.replace('"', '\\"')
else:
assert False
return f"{q}{text}{q}"
Expand Down
41 changes: 37 additions & 4 deletions hydra/grammar/OverrideLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ DOT_PATH: (ID | INT_UNSIGNED) ('.' (ID | INT_UNSIGNED))+;

mode VALUE_MODE;

QUOTE_OPEN_SINGLE: '\'' -> pushMode(QUOTED_SINGLE_MODE);
QUOTE_OPEN_DOUBLE: '"' -> pushMode(QUOTED_DOUBLE_MODE);

POPEN: WS? '(' WS?; // whitespaces before to allow `func (x)`
COMMA: WS? ',' WS?;
PCLOSE: WS? ')';
Expand Down Expand Up @@ -66,8 +69,38 @@ ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' |
'\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+;
WS: [ \t]+;

QUOTED_VALUE:
'\'' ('\\\''|.)*? '\'' // Single quotes, can contain escaped single quote : /'
| '"' ('\\"'|.)*? '"' ; // Double quotes, can contain escaped double quote : /"

INTERPOLATION: '${' ~('}')+ '}';


////////////////////////
// QUOTED_SINGLE_MODE //
////////////////////////

mode QUOTED_SINGLE_MODE;

MATCHING_QUOTE_CLOSE: '\'' -> popMode;

ESC_QUOTE: '\\\'';
QSINGLE_ESC_BACKSLASH: ESC_BACKSLASH -> type(ESC);

QSINGLE_INTERPOLATION: INTERPOLATION -> type(INTERPOLATION);
SPECIAL_CHAR: [\\$];
ANY_STR: ~['\\$]+;
////////////////////////
// QUOTED_DOUBLE_MODE //
////////////////////////
mode QUOTED_DOUBLE_MODE;
// Same as `QUOTED_SINGLE_MODE` but for double quotes.
QDOUBLE_CLOSE: '"' -> type(MATCHING_QUOTE_CLOSE), popMode;
QDOUBLE_ESC_QUOTE: '\\"' -> type(ESC_QUOTE);
QDOUBLE_ESC_BACKSLASH: ESC_BACKSLASH -> type(ESC);
QDOUBLE_INTERPOLATION: INTERPOLATION -> type(INTERPOLATION);
QDOUBLE_SPECIAL_CHAR: SPECIAL_CHAR -> type(SPECIAL_CHAR);
QDOUBLE_STR: ~["\\$]+ -> type(ANY_STR);
8 changes: 7 additions & 1 deletion hydra/grammar/OverrideParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,14 @@ dictKeyValuePair: dictKey COLON element;

// Primitive types.

// Ex: "hello world", 'hello ${world}'
quotedValue:
(QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE)
(INTERPOLATION | ESC | ESC_QUOTE | SPECIAL_CHAR | ANY_STR)*
MATCHING_QUOTE_CLOSE;

primitive:
QUOTED_VALUE // 'hello world', "hello world"
quotedValue // 'hello world', "hello world"
| ( ID // foo_10
| NULL // null, NULL
| INT // 0, 10, -20, 1_000_000
Expand Down
133 changes: 127 additions & 6 deletions tests/test_overrides_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,22 @@ def test_value(value: str, expected: Any) -> None:
param("[[a]]", [["a"]], id="list:nested_list"),
param("[[[a]]]", [[["a"]]], id="list:double_nested_list"),
param("[1,[a]]", [1, ["a"]], id="list:simple_and_list_elements"),
param(
r"['a\\', 'b\\']",
[
QuotedString(text="a\\", quote=Quote.single),
QuotedString(text="b\\", quote=Quote.single),
],
id="list:str_trailing_backslash_single",
),
param(
r'["a\\", "b\\"]',
[
QuotedString(text="a\\", quote=Quote.double),
QuotedString(text="b\\", quote=Quote.double),
],
id="list:str_trailing_backslash_double",
),
],
)
def test_list_container(value: str, expected: Any) -> None:
Expand Down Expand Up @@ -301,6 +317,22 @@ def test_shuffle_sequence(value: str, expected: Any) -> None:
},
id="dict_mixed_keys",
),
param(
r"{a: 'a\\', b: 'b\\'}",
{
"a": QuotedString(text="a\\", quote=Quote.single),
"b": QuotedString(text="b\\", quote=Quote.single),
},
id="dict_str_trailing_backslash_single",
),
param(
r'{a: "a\\", b: "b\\"}',
{
"a": QuotedString(text="a\\", quote=Quote.double),
"b": QuotedString(text="b\\", quote=Quote.double),
},
id="dict_str_trailing_backslash_double",
),
],
)
def test_dict_container(value: str, expected: Any) -> None:
Expand Down Expand Up @@ -432,21 +464,23 @@ def test_interval_sweep(value: str, expected: Any) -> None:
param(
"override",
"key=[1,2,3]'",
raises(HydraException, match=re.escape("token recognition error at: '''")),
raises(
HydraException, match=re.escape("extraneous input ''' expecting <EOF>")
),
id="error:left_overs",
),
param(
"dictContainer",
"{'0a': 0, \"1b\": 1}",
raises(HydraException, match=re.escape("mismatched input ''0a''")),
raises(HydraException, match=re.escape("mismatched input '''")),
id="error:dict_quoted_key_dictContainer",
),
param(
"override",
"key={' abc ': 0}",
raises(
HydraException,
match=re.escape("no viable alternative at input '{' abc ''"),
match=re.escape("no viable alternative at input '{''"),
),
id="error:dict_quoted_key_override_single",
),
Expand All @@ -455,7 +489,7 @@ def test_interval_sweep(value: str, expected: Any) -> None:
'key={" abc ": 0}',
raises(
HydraException,
match=re.escape("""no viable alternative at input '{" abc "'"""),
match=re.escape("""no viable alternative at input '{"'"""),
),
id="error:dict_quoted_key_override_double",
),
Expand Down Expand Up @@ -567,15 +601,41 @@ def test_key(value: str, expected: Any) -> None:
param("false", False, id="primitive:bool"),
# quoted string
param(
"'foo \\'bar'",
r"'foo \'bar'",
QuotedString(text="foo 'bar", quote=Quote.single),
id="value:escape_single_quote",
),
param(
'"foo \\"bar"',
r'"foo \"bar"',
QuotedString(text='foo "bar', quote=Quote.double),
id="value:escape_double_quote",
),
param(
r"'foo \\\'bar'",
QuotedString(text=r"foo \'bar", quote=Quote.single),
id="value:escape_single_quote_x3",
),
param(
r'"foo \\\"bar"',
QuotedString(text=r"foo \"bar", quote=Quote.double),
id="value:escape_double_quote_x3",
),
param(
r"'foo\\bar'",
QuotedString(text=r"foo\bar", quote=Quote.single),
id="value:escape_backslash",
),
param(
r"'foo\\\\bar'",
QuotedString(text=r"foo\\bar", quote=Quote.single),
id="value:escape_backslash_x4",
),
param(
r"'foo bar\\'",
# Note: raw strings do not allow trailing \, adding a space and stripping it.
QuotedString(text=r" foo bar\ ".strip(), quote=Quote.single),
id="value:escape_backslash_trailing",
),
param(
"'\t []{},=+~'",
QuotedString(text="\t []{},=+~", quote=Quote.single),
Expand Down Expand Up @@ -649,6 +709,41 @@ def test_key(value: str, expected: Any) -> None:
QuotedString(text="false", quote=Quote.single),
id="value:bool:quoted",
),
param(
"'a ${b}'",
QuotedString(text="a ${b}", quote=Quote.single, esc_backslash=False),
id="value:interpolation:quoted",
),
param(
r"'a \${b}'",
QuotedString(text=r"a \${b}", quote=Quote.single, esc_backslash=False),
id="value:esc_interpolation:quoted",
),
param(
r"'a \\${b}'",
QuotedString(text=r"a \\${b}", quote=Quote.single, esc_backslash=False),
id="value:backslash_and_interpolation:quoted",
),
param(
r"'a \'${b}\''",
QuotedString(text=r"a '${b}'", quote=Quote.single, esc_backslash=False),
id="value:quotes_and_interpolation:quoted",
),
param(
r"'a \'\${b}\''",
QuotedString(text=r"a '\${b}'", quote=Quote.single, esc_backslash=False),
id="value:quotes_and_esc_interpolation:quoted",
),
param(
r"'a \'\\${b}\''",
QuotedString(text=r"a '\\${b}'", quote=Quote.single, esc_backslash=False),
id="value:quotes_backslash_and_interpolation:quoted",
),
param(
r"'a \\\'${b}\\\''",
QuotedString(text=r"a \\'${b}\\'", quote=Quote.single, esc_backslash=False),
id="value:backaslash_quotes_and_interpolation:quoted",
),
# interpolations:
param("${a}", "${a}", id="primitive:interpolation"),
param("${a.b.c}", "${a.b.c}", id="primitive:interpolation"),
Expand All @@ -665,6 +760,32 @@ def test_primitive(value: str, expected: Any) -> None:
assert eq(ret, expected)


@mark.parametrize(
("value", "expected", "with_quotes"),
[
param(
r"'foo\bar'",
QuotedString(text=r"foo\bar", quote=Quote.single),
r"'foo\\bar'",
id="value:one_backslash_single",
),
param(
r'"foo\bar"',
QuotedString(text=r"foo\bar", quote=Quote.double),
r'"foo\\bar"',
id="value:one_backslash_double",
),
],
)
def test_with_quotes_one_backslash(value: str, expected: Any, with_quotes: str) -> None:
# This test's objective is to test the case where a quoted string contains a single
# (i.e., non-escaped) backslash. This case can't be included in `test_primitive()`
# because the backslash is escaped by `with_quotes()` => value != ret.with_quotes()
ret = parse_rule(value, "primitive")
assert eq(ret, expected)
assert ret.with_quotes() == with_quotes


@mark.parametrize(
"prefix,override_type",
[
Expand Down

0 comments on commit cc7a6db

Please sign in to comment.