In [1]:
#export
from __future__ import annotations
import sys, ast
from fastcore.docments import docments

def _verify_version():
    "Returns if Python version is < 3.9"
    return sys.version_info.major <= 3 and sys.version_info.minor < 9

def unparse(o:str):
    "Unparses `o` with the correct unparser"
    if o is None: return ''
    if _verify_version():
        import astunparse
        return astunparse.unparse(o).rstrip()
    else:
        return ast.unparse(o)
    
def parse(o:str): 
    "Shortcut for `ast.parse`"
    return ast.parse(o)

In [2]:
source = '''@delegates()
def addition(
    a:(int, float), # The first number to add
    # The second number to add
    b:int = 2,
) -> (int,float): # The sum of a and b
    "Adds two numbers together"
    return a+b'''

In [3]:
#export
def get_annotations(
    parsed_function:ast.FunctionDef # Parsed function
):
    "Extracts type annotations from a single function"
    arg_annos = []
    for anno in parsed_function.args.args:
        arg_annos.append(unparse(anno.annotation))
        anno.annotation = None
    ret_anno = unparse(parsed_function.returns)
    parsed_function.returns = None
    return arg_annos, ret_anno

In [4]:
arg_text = '\n\nParameters\n----------\n'
arg_format = """<arg_name> : <arg_type>
  <arg_documentation>
"""
return_text = "\nReturns\n-------\n"
return_format = """<return_type>
  <return_documentation>
"""
template = {
    "args": {
        "text": arg_text,
        "format": arg_format
    },
    "return": {
        "text": return_text,
        "format": return_format
    }
}

In [181]:
#export
def reformat_function(
    source:str, # Source code
    template:dict, # Template to format code with
):
    "Refactors `source` function"
    docs = docments(source)
    parsed_source = parse(source).body[0]
    annos = get_annotations(parsed_source)
    if isinstance(parsed_source.body[0], ast.Expr):
        docstring = unparse(parsed_source.body[0]).replace("'", '')
        start_lineno = parsed_source.body[0].end_lineno
    else:
        docstring = ""
        start_lineno = parsed_source.body[0].end_lineno - 1
    body = '\n'.join(source.split('\n')[start_lineno:])
    orig = list(docs.keys())
    if "self" in orig: orig.remove("self")
    if "cls" in orig: orig.remove("cls")
    if "return" in orig: orig.remove("return")
    if len(orig) != 0:
        docstring += template["args"]["text"]
        for i, name in enumerate(docs.keys()):
            if name == "self": continue
            elif name != "return":
                doc = docs[name] or ''
                anno = annos[0][i]
                new_str = (template["args"]["format"]
                    .replace("<arg_name>", name)
                    .replace("<arg_documentation>", doc))
                if anno != "":
                    new_str = new_str.replace("<arg_type>", anno)
                else:
                    new_str = new_str.replace(" : <arg_type>", anno)
                docstring += new_str
            else:
                doc = docs[name]
                anno = annos[1]
                docstring += template["return"]["text"]
                docstring += (template["return"]["format"]
                    .replace("<return_type>", anno)
                    .replace("<return_documentation>", doc))
    offset = parsed_source.body[0].col_offset
    docstring = '\n'.join([f'{" "*offset}{line}' for line in docstring.split("\n")])
    docstring = parse(f'"""{docstring.lstrip()}"""')
    if not isinstance(parsed_source.body[0], ast.Expr):
        parsed_source.body.insert(0, docstring)
    else:
        parsed_source.body[0] = docstring
    parsed_source.body = parsed_source.body[:1]
    return f'{unparse(parsed_source)}\n{body}'

In [182]:
source = """@delegates()
def addition(
    a:(int, float), # The first number to add
    # The second number to add
    b:int = 2,
) -> (int,float): # The sum of a and b
    "Adds two numbers together"
    def _inner(): return a+b
    # This is a comment!
    return _inner()"""

In [183]:
s = reformat_function(
    source,
    template
)

In [184]:
print(s)

@delegates()
def addition(a, b=2):
    """Adds two numbers together
    
    Parameters
    ----------
    a : (int, float)
      The first number to add
    b : int
      The second number to add
    
    Returns
    -------
    (int, float)
      The sum of a and b
    """
    def _inner(): return a+b
    # This is a comment!
    return _inner()


In [185]:
source = '''class Arithmetic:
    "A class that can perform basic arithmetic on ops"
    _o = 2
    # Here's a comment
    _b = 5
    _c = 3
    
    class A:
        "My docstring"
        def __init__(
          self, 
          o:int # An integer
        ):
            self.o = o
    
    def __init__(
        self,
        a:int, # The first number to use
        b:(int, float), # The second number to use
    ):
        self.a = a
        self.b = b
        
    @delegates()
    def add(
        self
    ) -> (int,float): # Sum of a and b
        "Adds self.a and self.b"
        return self.a + self.b'''

In [186]:
def reformat_class(
    source:str,
    template
):
    parsed_class = parse(source)
    new_class = ''
    for item in parsed_class.body[0].body[1:]:
        orig_source = '\n'.join(source.split('\n')[item.lineno-1:item.end_lineno])
        if orig_source[0].isspace():
            orig_source = '\n'.join([o[item.col_offset:] for o in orig_source.split('\n')])
        if isinstance(item, ast.FunctionDef):
            refactored = reformat_function(
                orig_source,
                template
            ).split('\n')
            refactored = '\n'.join([f'{" "*item.col_offset}{o}' for o in refactored])
            new_class += f'\n\n{refactored}'
        elif isinstance(item, ast.ClassDef):
            # De-indent inner class
            refactored = reformat_class(orig_source, template).split('\n')
            if '"' in refactored[1] and refactored[2] == '':
                refactored.pop(2)
            refactored = '\n'.join([f'{" "*item.col_offset}{o}' for o in refactored])
            new_class += f'\n\n{refactored}'
        else:
            new_class += f'\n{" "*item.col_offset}{orig_source}'
    parsed_class.body[0].body = parsed_class.body[0].body[:1]
    return f'{unparse(parsed_class)}{new_class}'

In [187]:
parsed_source = reformat_class(source, template)

In [188]:
print(parsed_source)

class Arithmetic:
    """A class that can perform basic arithmetic on ops"""
    _o = 2
    _b = 5
    _c = 3

    class A:
        """My docstring"""
        def __init__(self, o):
            """Parameters
            ----------
            o : int
              An integer
            """
            self.o = o

    def __init__(self, a, b):
        """Parameters
        ----------
        a : int
          The first number to use
        b : (int, float)
          The second number to use
        """
        self.a = a
        self.b = b

    def add(self):
        """Adds self.a and self.b"""
        return self.a + self.b
