Skip to content

Commit

Permalink
Merge pull request #18 from gregoil/add_bubble_on_all_levels
Browse files Browse the repository at this point in the history
Added bubble to all levels
  • Loading branch information
osherdp committed Apr 11, 2019
2 parents f890e1e + ec4ecda commit 6b0219a
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 81 deletions.
221 changes: 144 additions & 77 deletions ipdbugger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@
"""
from __future__ import print_function
from __future__ import absolute_import
# pylint: disable=no-member
# pylint: disable=no-member,not-callable
# pylint: disable=protected-access,bare-except
# pylint: disable=missing-docstring,too-many-locals,too-many-branches
import re
import ast
import sys
import types
import inspect
import functools
import traceback
from itertools import chain

import colorama
from termcolor import colored
Expand All @@ -37,6 +35,9 @@
colorama.init()


IS_PYTHON_3 = sys.version_info > (3, 0)


class IPDBugger(TerminalPdb):
"""Debugger class, adds functionality to the normal pdb."""

Expand Down Expand Up @@ -107,65 +108,130 @@ def start_debugging():
class ErrorsCatchTransformer(ast.NodeTransformer):
"""Surround each statement with a try/except block to catch errors."""

def __init__(self, ignore_exceptions=(), catch_exception=None):
if sys.version_info > (3, 0): # pragma: no cover
start_debug_cmd = ast.Expr(
value=ast.Call(
ast.Name("start_debugging", ast.Load()),
[],
[],
)
)

else: # pragma: no cover
start_debug_cmd = ast.Expr(
value=ast.Call(ast.Name("start_debugging", ast.Load()),
[], [], None, None))

catch_exception_node = None
def __init__(self, ignore_exceptions=(), catch_exception=None, depth=0):
self.depth = depth
self.catch_exception = None
self.ignore_exceptions = None

if ignore_exceptions is not None:
self.ignore_exceptions = [exception_class.__name__
for exception_class in ignore_exceptions]

if catch_exception is not None:
catch_exception_node = ast.Name(catch_exception.__name__,
ast.Load())
self.catch_exception = catch_exception.__name__

@property
def ast_try_except(self):
return ast.Try if IS_PYTHON_3 else ast.TryExcept

def wrap_with_try(self, node):
"""Wrap an ast node in a 'try' node to enter debug on exception."""
handlers = []

if self.ignore_exceptions is None:
handlers.append(ast.ExceptHandler(type=None,
name=None,
body=[ast.Raise()]))

else:
ignores_nodes = [ast.Name(exception_class, ast.Load())
for exception_class in self.ignore_exceptions]

handlers.append(ast.ExceptHandler(type=ast.Tuple(ignores_nodes,
ast.Load()),
name=None,
body=[ast.Raise()]))

if self.catch_exception not in self.ignore_exceptions:
call_extra_parameters = [] if IS_PYTHON_3 else [None, None]
start_debug_cmd = ast.Expr(
value=ast.Call(ast.Name("start_debugging", ast.Load()),
[], [], *call_extra_parameters))

self.exception_handlers = [ast.ExceptHandler(type=catch_exception_node,
name=None,
body=[start_debug_cmd])]
catch_exception_type = None
if self.catch_exception is not None:
catch_exception_type = ast.Name(self.catch_exception,
ast.Load())

for exception_class in ignore_exceptions:
ignore_exception_node = ast.Name(exception_class.__name__,
ast.Load())
handlers.append(ast.ExceptHandler(type=catch_exception_type,
name=None,
body=[start_debug_cmd]))

self.exception_handlers.insert(
0,
ast.ExceptHandler(type=ignore_exception_node,
name=None,
body=[ast.Raise()]))
try_except_extra_params = {"finalbody": []} if IS_PYTHON_3 else {}

new_node = self.ast_try_except(orelse=[], body=[node],
handlers=handlers,
**try_except_extra_params)

return ast.copy_location(new_node, node)

def try_except_handler(self, node):
"""Handler for try except statement to ignore excepted exceptions."""
# List all excepted handlers
excepted = [ast.ExceptHandler(type=handler.type,
name=None,
body=[ast.Raise()])
for handler in node.handlers]

new_exception_handlers = []
for except_handler in chain(excepted, self.exception_handlers):
new_exception_handlers.append(except_handler)

# Default 'except:' must be last
if except_handler.type is None:
# List all excepted exception's names
excepted_types = []
for handler in node.handlers:
if handler.type is None:
excepted_types = None
break

if isinstance(handler.type, ast.Tuple):
excepted_types.extends([exception_type.id for exception_type
in handler.type.elts])

else:
excepted_types.append(handler.type.id)

new_exception_list = self.ignore_exceptions

if self.ignore_exceptions is not None:
if excepted_types is None:
new_exception_list = None
else:
new_exception_list = list(set(excepted_types +
self.ignore_exceptions))

# Set the new ignore list, and save the old one
old_exception_handlers, self.exception_handlers = \
self.exception_handlers, new_exception_handlers
old_exception_handlers, self.ignore_exceptions = \
self.ignore_exceptions, new_exception_list

# Run recursively on all sub nodes with the new ignore list
node.body = [self.visit(node_item) for node_item in node.body]

# Revert changes from ignore list
self.exception_handlers = old_exception_handlers
self.ignore_exceptions = old_exception_handlers

# pylint: disable=invalid-name
def visit_Call(self, node):
"""Propagate 'debug' wrapper into inner function calls if needed.
Args:
node (ast.AST): node statement to surround.
"""
if self.depth == 0:
return node

if self.ignore_exceptions is None:
ignore_exceptions = ast.Name("None", ast.Load())

else:
exception_names = [ast.Name(exception, ast.Load())
for exception in self.ignore_exceptions]
ignore_exceptions = ast.List(exception_names, ast.Load())

catch_exception_type = self.catch_exception if self.catch_exception \
else "None"

catch_exception = ast.Name(catch_exception_type, ast.Load())
depth = ast.Num(self.depth - 1 if self.depth > 0 else -1)

debug_node_name = ast.Name("debug", ast.Load())
call_extra_parameters = [] if IS_PYTHON_3 else [None, None]
node.func = ast.Call(debug_node_name,
[node.func, ignore_exceptions,
catch_exception, depth],
[], *call_extra_parameters)

return node

def generic_visit(self, node):
"""Surround node statement with a try/except block to catch errors.
Expand All @@ -179,25 +245,13 @@ def generic_visit(self, node):
if (isinstance(node, ast.stmt) and
not isinstance(node, ast.FunctionDef)):

is_python_3 = sys.version_info > (3, 0)
ast_try_except = ast.Try if is_python_3 else ast.TryExcept
try_except_extra_params = {"finalbody": []} if is_python_3 else {}

new_node = ast_try_except(
orelse=[],
body=[node],
handlers=self.exception_handlers,
**try_except_extra_params)
new_node = self.wrap_with_try(node)

# handling try except statement
if isinstance(node, ast_try_except):
if isinstance(node, self.ast_try_except):
self.try_except_handler(node)
ast.copy_location(new_node, node)
return new_node

# Set new node location as old node
ast.copy_location(new_node, node)

# Run recursively on all sub nodes
super(ErrorsCatchTransformer, self).generic_visit(node)

Expand All @@ -207,7 +261,26 @@ def generic_visit(self, node):
return super(ErrorsCatchTransformer, self).generic_visit(node)


def debug(victim, ignore_exceptions=(), catch_exception=None):
def get_last_lineno(node):
"""Recursively find the last line number of the ast node."""
max_lineno = 0

if hasattr(node, "lineno"):
max_lineno = node.lineno

for _, field in ast.iter_fields(node):
if isinstance(field, list):
for value in field:
if isinstance(value, ast.AST):
max_lineno = max(max_lineno, get_last_lineno(value))

elif isinstance(field, ast.AST):
max_lineno = max(max_lineno, get_last_lineno(field))

return max_lineno


def debug(victim, ignore_exceptions=(), catch_exception=None, depth=0):
"""A decorator function to catch exceptions and enter debug mode.
Args:
Expand All @@ -216,6 +289,7 @@ def debug(victim, ignore_exceptions=(), catch_exception=None):
ignore_exceptions (list): list of classes of exceptions not to catch.
catch_exception (type): class of exception to catch and debug.
default is None, meaning catch all exceptions.
depth (number): how many levels of inner function calls to propagate.
Returns:
object. wrapped class or function.
Expand All @@ -230,7 +304,8 @@ def debug(victim, ignore_exceptions=(), catch_exception=None):

_transformer = ErrorsCatchTransformer(
ignore_exceptions=ignore_exceptions,
catch_exception=catch_exception)
catch_exception=catch_exception,
depth=depth)

try:
# Try to get the source code of the wrapped object.
Expand All @@ -241,11 +316,7 @@ def debug(victim, ignore_exceptions=(), catch_exception=None):
except IOError:
# Worst-case scenario we can only catch errors at a granularity
# of the whole function
@functools.wraps(victim)
def wrapper(*args, **kw):
return victim(*args, **kw)

return wrapper
return victim

else:
# If we have access to the source, we can silence errors on a
Expand All @@ -257,7 +328,8 @@ def wrapper(*args, **kw):
tree = _transformer.visit(old_code_tree)

import_debug_cmd = ast.ImportFrom(
__name__, [ast.alias("start_debugging", None)], 0)
__name__, [ast.alias("start_debugging", None),
ast.alias("debug", None)], 0)

# Add import to the debugger as first command
tree.body[0].body.insert(0, import_debug_cmd)
Expand All @@ -280,17 +352,12 @@ def wrapper(*args, **kw):
# Delete the debugger decorator of the function
del tree.body[0].decorator_list[:]

# Index of the function (first original command in it)
first_command_index = 1 + len(ignore_exceptions)
if catch_exception is not None:
first_command_index += 1

# Add pass at the end (to enable debugging the last command)
pass_cmd = ast.Pass()
func_body = tree.body[0].body
pass_cmd.lineno = func_body[-1].lineno + 1 # Next of the last line
pass_cmd.col_offset = func_body[first_command_index].col_offset
func_body.insert(len(func_body), pass_cmd)
pass_cmd.lineno = get_last_lineno(func_body[-1]) + 1
pass_cmd.col_offset = func_body[-1].col_offset
func_body.append(pass_cmd)

# Fix missing line numbers and column offsets before compiling
for node in ast.walk(tree):
Expand Down
Loading

0 comments on commit 6b0219a

Please sign in to comment.