In [1]:
# from debugger import epic_debugger, epic_debugger_decorator

In [2]:
# debug_always: will run the normal_debug_fn even if there is no exception
# enabled: will completely enable/disable the debugger, useful if you have some kind of global toggle
# do_pdb: activate python debugger after printing error and trace

# exception_fn: function to run when an exception is caught
# normal_debug_fn: function to run when no exception is caught

# for these two functions, any additional kwargs you provide will be fed to

In [3]:
import contextlib
import functools
import pdb
import traceback


# if you want to use as a context manager over certain sections of a function
@contextlib.contextmanager
def epic_debugger(debug_always=False, enabled=True, do_pdb=True, exception_fn=None, normal_debug_fn=None, **debug_kwargs):
    """
    :param debug_always: whether to run a portion of code regardless of whether an exception is raised
    :param enabled: whether to run the debugger at all, so can be easily enabled from an args/config without needing to change code
    :param do_pdb: activate pythons debugger if an exception is raised
    :param exception_fn: function to run if an exception is raised
    :param normal_debug_fn: function to run if no exception is raised
    """
    # default behavior
    try:
        yield

    # if there is an error
    except Exception as e:
        if enabled:
            traceback.print_exc()
            if exception_fn is not None:
                print("*"*10 + " BEGIN EXCEPTION_FN " + "*"*10)
                exception_fn(**debug_kwargs)
                print("*"*10 + " END EXCEPTION_FN " + "*"*10)
            if do_pdb:
                pdb.set_trace()
        raise e

    # if debug_always is enabled
    finally:
        if debug_always and enabled:
            if normal_debug_fn is not None:
                print("*"*10 + " BEGIN DEBUG_FN " + "*"*10)
                normal_debug_fn(**debug_kwargs)
                print("*"*10 + " END DEBUG_FN " + "*"*10)



# if you want to use as a decorator over an entire function
def epic_debugger_decorator(enabled=True, debug_always=False, do_pdb=True, exception_fn=None, normal_debug_fn=None, **debug_kwargs):
    """
    :param enabled: whether to run a portion of code regardless of whether an exception is raised
    :param debug_always:  whether to run the debugger at all, so can be easily enabled from an args/config without needing to change code
    :param do_pdb: activate pythons debugger if an exception is raised
    :param exception_fn: function to run if an exception is raised
    :param normal_debug_fn: function to run if no exception is raised
    """
    def debug_decorator(func):
        # a wrapper to go around your function
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            try:
                # run the actual function
                result = func(*args, **kwargs)
                # if debug_always is enabled run this portion
                if debug_always and enabled:
                    if normal_debug_fn is not None:
                        print("*"*10 + " BEGIN DEBUG_FN " + "*"*10)
                        normal_debug_fn(**debug_kwargs)
                        print("*"*10 + " END DEBUG_FN " + "*"*10)

                # return the output of the original function
                return result

            # if there is an error
            except Exception as e:
                if enabled:
                    traceback.print_exc()
                    if exception_fn is not None:
                        print("*"*10 + " BEGIN EXCEPTION_FN " + "*"*10)
                        exception_fn(**debug_kwargs)
                        print("*"*10 + " END EXCEPTION_FN " + "*"*10)
                    if do_pdb:
                        pdb.set_trace()
                # Re-raise the exception after handling
                raise e

        return wrapper
    return debug_decorator

In [45]:
# simple example
def broken_fn(x):
    return x / 0

with epic_debugger(debug_always=False, enabled=True, do_pdb=True, exception_fn=None, normal_debug_fn=None):
    broken_fn(1)



In [8]:
# using the decorator
@epic_debugger_decorator(debug_always=False, enabled=True, do_pdb=True, exception_fn=None, normal_debug_fn=None)
def broken_fn(x):
    return x / 0

broken_fn(1)

In [51]:
def another_broken_fn(a_dict, a_list):
    return a_dict + a_list

def print_named_vars(names=None):
    for name, obj in globals().items():
        if names is not None:
            if name not in names:
                continue
        print(name, obj)

In [52]:
fruits = {"apple": 1, "banana": 2}
numbers = [1, 2, 3]
with epic_debugger(debug_always=True, do_pdb=False, exception_fn=print_named_vars, names=["fruits", "numbers"]):
    another_broken_fn(fruits, numbers)

********** BEGIN EXCEPTION_FN **********
fruits {'apple': 1, 'banana': 2}
numbers [1, 2, 3]
********** END EXCEPTION_FN **********


Traceback (most recent call last):
  File "/tmp/ipykernel_4069491/1114188134.py", line 19, in epic_debugger
    yield
  File "/tmp/ipykernel_4069491/307613978.py", line 4, in <module>
    another_broken_fn(fruits, numbers)
  File "/tmp/ipykernel_4069491/4278361117.py", line 2, in another_broken_fn
    return a_dict + a_list
TypeError: unsupported operand type(s) for +: 'dict' and 'list'


TypeError: unsupported operand type(s) for +: 'dict' and 'list'

In [4]:
import torch

# an example of printing out torch tensor details when we fail to add two tensors of different shapes

def print_tensor_details(tensor, name=None):
    if name is None:
        name = ""
    print(f"{name} Device: {tensor.device}")
    print(f"{name} Type: {tensor.dtype}")
    print(f"{name} Shape: {tensor.shape}")
    print(f"{name} dtype: {tensor.dtype}")
    print(f"{name} Is Nan: {torch.isnan(tensor).any()}")
    print(f"{name} Is Inf: {torch.isinf(tensor).any()}")
    print(f"{name} Min: {torch.min(tensor)}")
    print(f"{name} Max: {torch.max(tensor)}")
    if tensor.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16]:
        print(f"{name} Mean: {torch.mean(tensor)}")
        print(f"{name} Std: {torch.std(tensor)}")
    if hasattr(tensor, "grad") and tensor.grad is not None:
        print_tensor_details(tensor.grad, name=name + " Grad" if name else "Grad")


def print_vars(names=None):
    for name, var in globals().items():

        # only perform debugging for specific variables if we provide that
        if names is not None:
            if name not in names:
                continue

        # if a tensor, print details
        if isinstance(var, torch.Tensor):
            print_tensor_details(var, name=name)
            print("---"*5)

        # if dictionary, tuple, or list, print a map of its hiearchy
        elif isinstance(var, (dict, tuple, list)):
            analyze_hierarchy(var, name)
            print("---"*5)


def recursive(obj, string_to_print, depth, keyname=None):
    string = f"{depth * '- '}"
    if keyname is not None:
        string += f"{keyname}"
    string += f" {type(obj)}"

    if isinstance(obj, (list, tuple)):
        string_to_print += [f"{string}, length:{len(obj)}"]
        for item in obj:
            string_to_print = recursive(item, string_to_print, depth + 1)

    elif isinstance(obj, dict):
        string_to_print += [f"{string}, length:{len(obj)}"]
        for key, value in obj.items():
            string_to_print = recursive(value, string_to_print, depth + 1, keyname=key)

    else:
        string_to_print += [f"{string}"]

    return string_to_print


def analyze_hierarchy(obj, name):
    """
    decomposes a dictionary, list, or tuple into its components and prints attributes
    """
    string_to_print = []
    depth = 0
    string_to_print = recursive(obj, string_to_print, depth, keyname=name)
    string_to_print = "\n".join(string_to_print)
    print(string_to_print)

In [140]:
thing = {"a": 1, "b": 2, "c": [1, 2, 3], "d": {"e": 1, "f": 2, "g": [(4,5,6), 2, 3]}}
tensor1 = torch.arange(4)
tensor2 = torch.arange(5)

def broken_fn(dictionary, a, b):
    print(a + b)

In [141]:
with epic_debugger(debug_always=True, do_pdb=False, exception_fn=print_vars, names=["thing", "tensor1", "tensor2"]):
    broken_fn(thing, tensor1, tensor2)

********** BEGIN EXCEPTION_FN **********
thing <class 'dict'>, length:4
- a <class 'int'>
- b <class 'int'>
- c <class 'list'>, length:3
- -  <class 'int'>
- -  <class 'int'>
- -  <class 'int'>
- d <class 'dict'>, length:3
- - e <class 'int'>
- - f <class 'int'>
- - g <class 'list'>, length:3
- - -  <class 'tuple'>, length:3
- - - -  <class 'int'>
- - - -  <class 'int'>
- - - -  <class 'int'>
- - -  <class 'int'>
- - -  <class 'int'>
---------------
tensor1 Device: cpu
tensor1 Type: torch.int64
tensor1 Shape: torch.Size([4])
tensor1 dtype: torch.int64
tensor1 Is Nan: False
tensor1 Is Inf: False
tensor1 Min: 0
tensor1 Max: 3
---------------
tensor2 Device: cpu
tensor2 Type: torch.int64
tensor2 Shape: torch.Size([5])
tensor2 dtype: torch.int64
tensor2 Is Nan: False
tensor2 Is Inf: False
tensor2 Min: 0
tensor2 Max: 4
---------------
********** END EXCEPTION_FN **********


Traceback (most recent call last):
  File "/tmp/ipykernel_4069491/1114188134.py", line 19, in epic_debugger
    yield
  File "/tmp/ipykernel_4069491/4073057097.py", line 2, in <module>
    broken_fn(thing, tensor1, tensor2)
  File "/tmp/ipykernel_4069491/3653530737.py", line 6, in broken_fn
    print(a + b)
RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 0


RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 0

In [5]:
# for a function that actually works, but may hide silent errors if tensor values are not checked
tensor1 = torch.arange(4) / 0
tensor2 = torch.arange(4)

def add_tensors(a,b):
    return a + b
    
# runs no problem
out = add_tensors(tensor1, tensor2)

In [6]:
# for this case, because no exception happens, we'll want to use normal_debug_fn and always_debug=True
def throw_if_nan():
    for name, var in globals().items():
        # if a tensor, print details
        if isinstance(var, torch.Tensor):
            if torch.isnan(var).any():
                print(f"NaN found in {name}")

In [7]:
@epic_debugger(debug_always=True, do_pdb=False, normal_debug_fn=throw_if_nan)
def add_tensors(a,b):
    return a + b

out = add_tensors(tensor1, tensor2)

********** BEGIN DEBUG_FN **********
NaN found in tensor1
NaN found in out
********** END DEBUG_FN **********
