Skip to content

Commit

Permalink
As @minrk recommends, warning if one or more symbol are not found
Browse files Browse the repository at this point in the history
  • Loading branch information
mgaitan committed Oct 7, 2013
1 parent 348e90f commit 7391345
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 14 deletions.
32 changes: 24 additions & 8 deletions IPython/core/magics/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from IPython.utils.contexts import preserve_keys
from IPython.utils.path import get_py_filename, unquote_filename
from IPython.utils.warn import warn
from IPython.utils.text import get_text_list

#-----------------------------------------------------------------------------
# Magic implementation classes
Expand All @@ -48,6 +49,7 @@ class MacroToEdit(ValueError): pass
(?P<end>\d+)?)?
$""", re.VERBOSE)


def extract_code_ranges(ranges_str):
"""Turn a string of range for %%load into 2-tuples of (start, stop)
ready to use as a slice of the content splitted by lines.
Expand Down Expand Up @@ -80,18 +82,21 @@ def extract_code_ranges(ranges_str):
@skip_doctest
def extract_symbols(code, symbols):
"""
Return a list of code fragments for each symbol parsed from code
For example, suppose code is a string::
Return a tuple (blocks, not_found)
where ``blocks`` is a list of code fragments
for each symbol parsed from code, and ``not_found`` are
symbols not found in the code.
For example::
a = 10
>>> code = '''a = 10
def b(): return 42
class A: pass
class A: pass'''
>>> extract_symbols(code, 'A,b')
["class A: pass", "def b(): return 42"]
(["class A: pass", "def b(): return 42"], [])
"""
try:
py_code = ast.parse(code)
Expand All @@ -115,12 +120,15 @@ class A: pass

# fill a list with chunks of codes for each symbol
blocks = []
not_found = []
for symbol in symbols.split(','):
if symbol in symbols_lines:
start, end = symbols_lines[symbol]
blocks.append('\n'.join(code[start:end]) + '\n')
else:
not_found.append(symbol)

return blocks
return blocks, not_found


class InteractivelyDefined(Exception):
Expand Down Expand Up @@ -289,7 +297,15 @@ def load(self, arg_s):
contents = self.shell.find_user_code(args)

if 's' in opts:
contents = '\n'.join(extract_symbols(contents, opts['s']))
blocks, not_found = extract_symbols(contents, opts['s'])
if len(not_found) == 1:
warn('The symbol `%s` was not found' % not_found[0])
elif len(not_found) > 1:
warn('The symbols %s were not found' % get_text_list(not_found,
wrap_item_with='`')
)

contents = '\n'.join(blocks)

if 'r' in opts:
ranges = opts['r'].replace(',', ' ')
Expand Down
12 changes: 6 additions & 6 deletions IPython/core/tests/test_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@ def test_extract_code_ranges():
def test_extract_symbols():
source = """import foo\na = 10\ndef b():\n return 42\n\n\nclass A: pass\n\n\n"""
symbols_args = ["a", "b", "A", "A,b", "A,a", "z"]
expected = [[],
["def b():\n return 42\n"],
["class A: pass\n"],
["class A: pass\n", "def b():\n return 42\n"],
["class A: pass\n"],
[]]
expected = [([], ['a']),
(["def b():\n return 42\n"], []),
(["class A: pass\n"], []),
(["class A: pass\n", "def b():\n return 42\n"], []),
(["class A: pass\n"], ['a']),
([], ['z'])]
for symbols, exp in zip(symbols_args, expected):
nt.assert_equal(code.extract_symbols(source, symbols), exp)

Expand Down

0 comments on commit 7391345

Please sign in to comment.