Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Add support for nested interpolations in Override grammar #1594

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 76 additions & 24 deletions hydra/core/override_parser/overrides_visitor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import sys
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

from antlr4 import ParserRuleContext, TerminalNode, Token
from antlr4.error.ErrorListener import ErrorListener
Expand Down Expand Up @@ -75,17 +75,24 @@ def visitPrimitive(
def visitListContainer(
self, ctx: OverrideParser.ListContainerContext
) -> List[ParsedElementType]:
ret: List[ParsedElementType] = []

return (
[]
if ctx.getChildCount() == 2
else list(self.visitSequence(ctx.getChild(1)))
)

def visitSequence(
self, ctx: OverrideParser.SequenceContext
) -> Iterable[ParsedElementType]:
idx = 0
while True:
element = ctx.element(idx)
if element is None:
break
else:
idx = idx + 1
ret.append(self.visitElement(element))
return ret
yield self.visitElement(element)

def visitDictContainer(
self, ctx: OverrideParser.DictContainerContext
Expand Down Expand Up @@ -117,6 +124,8 @@ def visitElement(self, ctx: OverrideParser.ElementContext) -> ParsedElementType:
return self.visitFunction(ctx.function()) # type: ignore
elif ctx.primitive():
return self.visitPrimitive(ctx.primitive())
elif ctx.quotedValue():
return self.visitQuotedValue(ctx.quotedValue())
elif ctx.listContainer():
return self.visitListContainer(ctx.listContainer())
elif ctx.dictContainer():
Expand Down Expand Up @@ -240,6 +249,44 @@ def visitFunction(self, ctx: OverrideParser.FunctionContext) -> Any:
f"{type(e).__name__} while evaluating '{ctx.getText()}': {e}"
) from e

def visitQuotedValue(self, ctx: OverrideParser.QuotedValueContext) -> QuotedString:
children = list(ctx.getChildren())
assert len(children) >= 2

# Single or double quote?
first_quote = children[0].getText()
if first_quote == "'":
quote = Quote.single
else:
assert first_quote == '"'
quote = Quote.double

tokens = []
is_interpolation = False
for child in children[1:-1]: # iterate on child nodes between quotes
if isinstance(child, TerminalNode):
s = child.symbol
if s.type == OverrideLexer.ESC_QUOTE:
# Always un-escape quotes.
tokens.append(s.text[1])
continue
if s.type == OverrideLexer.ESC_INTER:
# OmegaConf processes escaped interpolations as interpolations.
is_interpolation = True
else:
assert isinstance(child, OverrideParser.InterpolationContext)
is_interpolation = True
tokens.append(child.getText())

ret = "".join(tokens)

# If it is an interpolation, then OmegaConf will take care of un-escaping
# the `\\`. But if it is not, then we need to do it here.
if not is_interpolation:
ret = ret.replace("\\\\", "\\")
Copy link
Collaborator Author

@odelalleau odelalleau Apr 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to discuss this point. It looks (and is) a bit clumsy but it's the best I have right now.

Consider the two following override values (here we really have 4 backslashes ending each quoted string, I'm assuming these are the actual characters after any Python un-escaping):

  1. ["abc\\\\", "def\\\\"]
  2. ["${abc}\\\\", "${def}\\\\"]

In both situations we would like the \\\\ to be turned into \\ (i.e., \\\\ == two consecutive escaped backslashes).

In case (1), if the content of each quoted string is kept unchanged (ex: abc\\\\) then we'll end up with extra unwanted backslashes at the end of the string (since the intent here is to obtain abc\\).

In case (2), if Hydra un-escapes the backslahes, then it will feed the string ${abc}\\ to OmegaConf, which will resolve it to <value_of_abc>\ (it will un-escape \\ into \), so we would end up with only one \ at the end instead of two.

As a result, we need to know whether or not OmegaConf will process the string with its grammar, in order to decide whether or not to un-escape the \\. I guess we could decide to ignore this and let the user figure it out, but this wouldn't seem intuitive.

@omry how do you feel about this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is a side effect of the fact that in OmegaConf, setting a value node to abc\\ will use this string unchanged, while ${abc}\\ will be resolved to <value_of_abc>\.
One potential approach to resolve this issue would be the one from omry/omegaconf#621 where \\ is only un-escaped when required.


return QuotedString(text=ret, quote=quote)

def _createPrimitive(
self, ctx: ParserRuleContext
) -> Optional[Union[QuotedString, int, bool, float, str]]:
Expand All @@ -261,33 +308,38 @@ def _createPrimitive(
# Concatenate, while un-escaping as needed.
tokens = []
for i, n in enumerate(ctx.getChildren()):
if n.symbol.type == OverrideLexer.WS and (
if isinstance(n, OverrideParser.InterpolationContext):
tokens.append(n.getText())
elif n.symbol.type == OverrideLexer.WS and (
i < first_idx or i >= last_idx
):
# Skip leading / trailing whitespaces.
continue
tokens.append(
n.symbol.text[1::2] # un-escape by skipping every other char
if n.symbol.type == OverrideLexer.ESC
else n.symbol.text
)
else:
tokens.append(
n.symbol.text[1::2] # un-escape by skipping every other char
if n.symbol.type == OverrideLexer.ESC
else n.symbol.text
)
ret = "".join(tokens)
else:
node = ctx.getChild(first_idx)
if node.symbol.type == OverrideLexer.QUOTED_VALUE:
text = node.getText()
qc = text[0]
text = text[1:-1]
if qc == "'":
quote = Quote.single
text = text.replace("\\'", "'")
elif qc == '"':
quote = Quote.double
text = text.replace('\\"', '"')
else:
assert False
return QuotedString(text=text, quote=quote)
elif node.symbol.type in (OverrideLexer.ID, OverrideLexer.INTERPOLATION):
# if node.symbol.type == OverrideLexer.QUOTED_VALUE:
# text = node.getText()
# qc = text[0]
# text = text[1:-1]
# if qc == "'":
# quote = Quote.single
# text = text.replace("\\'", "'")
# elif qc == '"':
# quote = Quote.double
# text = text.replace('\\"', '"')
# else:
# assert False
# return QuotedString(text=text, quote=quote)
if isinstance(node, OverrideParser.InterpolationContext):
ret = node.getText()
elif node.symbol.type == OverrideLexer.ID:
ret = node.symbol.text
elif node.symbol.type == OverrideLexer.INT:
ret = int(node.symbol.text)
Expand Down
73 changes: 66 additions & 7 deletions hydra/grammar/OverrideLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@ DOT_PATH: (ID | INT_UNSIGNED) ('.' (ID | INT_UNSIGNED))+;

mode VALUE_MODE;

INTER_OPEN: '${' WS? -> pushMode(INTERPOLATION_MODE);
BRACE_OPEN: '{' WS? -> pushMode(VALUE_MODE); // must keep track of braces to detect end of interpolation
BRACE_CLOSE: WS? '}' -> popMode;
QUOTE_OPEN_SINGLE: '\'' -> pushMode(QUOTED_SINGLE_MODE);
QUOTE_OPEN_DOUBLE: '"' -> pushMode(QUOTED_DOUBLE_MODE);

POPEN: WS? '(' WS?; // whitespaces before to allow `func (x)`
COMMA: WS? ',' WS?;
PCLOSE: WS? ')';
BRACKET_OPEN: '[' WS?;
BRACKET_CLOSE: WS? ']';
BRACE_OPEN: '{' WS?;
BRACE_CLOSE: WS? '}';
VALUE_COLON: WS? ':' WS? -> type(COLON);
VALUE_EQUAL: WS? '=' WS? -> type(EQUAL);

Expand All @@ -58,16 +62,71 @@ BOOL:

NULL: [Nn][Uu][Ll][Ll];

UNQUOTED_CHAR: [/\-\\+.$%*@]; // other characters allowed in unquoted strings
UNQUOTED_CHAR: [/\-\\+.$%*@?]; // other characters allowed in unquoted strings
ID: (CHAR|'_') (CHAR|DIGIT|'_')*;
// Note: when adding more characters to the ESC rule below, also add them to
// the `_ESC` string in `_internal/grammar/utils.py`.
ESC: (ESC_BACKSLASH | '\\(' | '\\)' | '\\[' | '\\]' | '\\{' | '\\}' |
'\\:' | '\\=' | '\\,' | '\\ ' | '\\\t')+;
WS: [ \t]+;

QUOTED_VALUE:
'\'' ('\\\''|.)*? '\'' // Single quotes, can contain escaped single quote : /'
| '"' ('\\"'|.)*? '"' ; // Double quotes, can contain escaped double quote : /"

INTERPOLATION: '${' ~('}')+ '}';
////////////////////////
// INTERPOLATION_MODE //
////////////////////////

mode INTERPOLATION_MODE;

NESTED_INTER_OPEN: INTER_OPEN WS? -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
INTER_COLON: WS? ':' WS? -> type(COLON), mode(VALUE_MODE);
INTER_CLOSE: WS? '}' -> popMode;

DOT: '.';
INTER_BRACKET_OPEN: '[' -> type(BRACKET_OPEN);
INTER_BRACKET_CLOSE: ']' -> type(BRACKET_CLOSE);
INTER_ID: ID -> type(ID);

// Interpolation key, may contain any non special character.
// Note that we can allow '$' because the parser does not support interpolations that
// are only part of a key name, i.e., "${foo${bar}}" is not allowed. As a result, it
// is ok to "consume" all '$' characters within the `INTER_KEY` token.
INTER_KEY: ~[\\{}()[\]:. \t'"]+;


////////////////////////
// QUOTED_SINGLE_MODE //
////////////////////////

mode QUOTED_SINGLE_MODE;

// This mode is very similar to `DEFAULT_MODE` except for the handling of quotes.

QSINGLE_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
MATCHING_QUOTE_CLOSE: '\'' -> popMode;

ESC_QUOTE: '\\\'';
QSINGLE_ESC_BACKSLASH: ESC_BACKSLASH -> type(ESC);
ESC_INTER: '\\${';


SPECIAL_CHAR: [\\$];
ANY_STR: ~['\\$]+;


////////////////////////
// QUOTED_DOUBLE_MODE //
////////////////////////

mode QUOTED_DOUBLE_MODE;

// Same as `QUOTED_SINGLE_MODE` but for double quotes.

QDOUBLE_INTER_OPEN: INTER_OPEN -> type(INTER_OPEN), pushMode(INTERPOLATION_MODE);
QDOUBLE_CLOSE: '"' -> type(MATCHING_QUOTE_CLOSE), popMode;

QDOUBLE_ESC_QUOTE: '\\"' -> type(ESC_QUOTE);
QDOUBLE_ESC_BACKSLASH: ESC_BACKSLASH -> type(ESC);
QDOUBLE_ESC_INTER: ESC_INTER -> type(ESC_INTER);

QDOUBLE_SPECIAL_CHAR: SPECIAL_CHAR -> type(SPECIAL_CHAR);
QDOUBLE_STR: ~["\\$]+ -> type(ANY_STR);
77 changes: 51 additions & 26 deletions hydra/grammar/OverrideParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ override: (
| TILDE key (EQUAL value?)? // ~key | ~key=value
| PLUS PLUS? key EQUAL value? // +key= | +key=value | ++key=value
) EOF;

// Key:
key : packageOrGroup (AT package)?; // key | group@pkg

Expand All @@ -24,8 +23,16 @@ package: ( | ID | DOT_PATH); // db, hydra.launcher, or the e

value: element | simpleChoiceSweep;


// Composite text expression (may contain interpolations).



// Elements.

element:
primitive
| quotedValue
| listContainer
| dictContainer
| function
Expand All @@ -42,38 +49,56 @@ function: ID POPEN (argName? element (COMMA argName? element )* )? PCLOSE;

// Data structures.

listContainer: BRACKET_OPEN // [], [1,2,3], [a,b,[1,2]]
(element(COMMA element)*)?
BRACKET_CLOSE;

listContainer: BRACKET_OPEN sequence? BRACKET_CLOSE; // [], [1,2,3], [a,b,[1,2]]
dictContainer: BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE; // {}, {a:10,b:20}
dictKeyValuePair: dictKey COLON element;
sequence: (element (COMMA element?)*) | (COMMA element?)+;


// Interpolations.

interpolation: interpolationNode | interpolationResolver;

interpolationNode:
INTER_OPEN
DOT* // relative interpolation?
(configKey | BRACKET_OPEN configKey BRACKET_CLOSE) // foo, [foo]
(DOT configKey | BRACKET_OPEN configKey BRACKET_CLOSE)* // .foo, [foo], .foo[bar], [foo].bar[baz]
INTER_CLOSE;
interpolationResolver: INTER_OPEN resolverName COLON sequence? BRACE_CLOSE;
configKey: interpolation | ID | INTER_KEY;
resolverName: (interpolation | ID) (DOT (interpolation | ID))* ; // oc.env, myfunc, ns.${x}, ns1.ns2.f


// Primitive types.

// Ex: "hello world", 'hello ${world}'
quotedValue:
(QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE)
(interpolation | ESC | ESC_INTER | ESC_QUOTE | SPECIAL_CHAR | ANY_STR)*
MATCHING_QUOTE_CLOSE;

primitive:
QUOTED_VALUE // 'hello world', "hello world"
| ( ID // foo_10
| NULL // null, NULL
| INT // 0, 10, -20, 1_000_000
| FLOAT // 3.14, -20.0, 1e-1, -10e3
| BOOL // true, TrUe, false, False
| INTERPOLATION // ${foo.bar}, ${oc.env:USER,me}
| UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @
| COLON // :
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
( ID // foo_10
| NULL // null, NULL
| INT // 0, 10, -20, 1_000_000
| FLOAT // 3.14, -20.0, 1e-1, -10e3
| BOOL // true, TrUe, false, False
| UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @
| COLON // :
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
| interpolation
)+;

// Same as `primitive` except that `COLON` and `INTERPOLATION` are not allowed.
// Same as `primitive` except that `COLON` and interpolations are not allowed.
dictKey:
QUOTED_VALUE // 'hello world', "hello world"
| ( ID // foo_10
| NULL // null, NULL
| INT // 0, 10, -20, 1_000_000
| FLOAT // 3.14, -20.0, 1e-1, -10e3
| BOOL // true, TrUe, false, False
| UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
( ID // foo_10
| NULL // null, NULL
| INT // 0, 10, -20, 1_000_000
| FLOAT // 3.14, -20.0, 1e-1, -10e3
| BOOL // true, TrUe, false, False
| UNQUOTED_CHAR // /, -, \, +, ., $, %, *, @
| ESC // \\, \(, \), \[, \], \{, \}, \:, \=, \ , \\t, \,
| WS // whitespaces
)+;
Loading