In [159]:
from typing import Dict, Union, Literal, Tuple, Generator, List, Any
from pydantic import BaseModel
import tokenize
from io import StringIO
import inspect
import pandas as pd
from crimson.data_class import DataClassList
import inspect
import ast
from typing import Dict, Tuple, Any

class TokenInfo(BaseModel):
    type: str
    value: Any
    start_line: int
    start_column: int
    end_line: int
    end_column: int

class TokenInfos(DataClassList):
    data: List[TokenInfo]

class _Tokenize:
    @classmethod
    def _get_tokens(cls, function_block: str):
        tokens = tokenize.generate_tokens(StringIO(function_block).readline)
        return tokens

    @classmethod
    def get_token_info(cls, function_block: str) -> TokenInfos:
        tokens = cls._get_tokens(function_block)
        token_infos = []
        for toknum, tokval, start, end, _ in tokens:
            token_infos.append(
                TokenInfo(
                    type=tokenize.tok_name[toknum],
                    value=tokval,
                    start_line=start[0] - 1,
                    start_column=start[1],
                    end_line=end[0] - 1,
                    end_column=end[1]
                    )
            )
        return TokenInfos(data=token_infos)

In [160]:
class _GetArgs:
    @classmethod
    def _get_tokens(cls, token_info: List[TokenInfo]):

        pass

In [161]:
from example import example_func

In [162]:
code = inspect.getsource(example_func)

In [163]:
code

'def example_func(arg1: int = 1, arg2: Dict[str, A] = {"hi": A()}) -> int:\n    more_line = 1\n    return more_line\n'

In [164]:
token_infos = _Tokenize.get_token_info(code)

In [165]:
list_token_info = token_infos()

In [166]:
def reconstruct_code(list_token_info:List[TokenInfo]) -> str:
    lines = {}

    for token_info in list_token_info:
        token_value = token_info.value
        start_line = token_info.start_line
        start_col = token_info.start_column

        if start_line not in lines:
            lines[start_line] = [" " * start_col]
        else:
            lines[start_line].append(" " * (start_col - len("".join(lines[start_line]))))

        lines[start_line].append(token_value)
        
    # Join the lines together
    code_lines = []
    for line_num in sorted(lines.keys()):
        code_lines.append("".join(lines[line_num]))
        
    return "".join(code_lines)

In [167]:
print(reconstruct_code(list_token_info[5: 10]))

                       int = 1, arg2


In [168]:
def get_arg_info(arg_block):
    arg_info = {
        'name': arg_block[0]
    }

In [169]:
token_infos.get_dataframe()

Unnamed: 0,type,value,start_line,start_column,end_line,end_column
0,NAME,def,0,0,0,3
1,NAME,example_func,0,4,0,16
2,OP,(,0,16,0,17
3,NAME,arg1,0,17,0,21
4,OP,:,0,21,0,22
5,NAME,int,0,23,0,26
6,OP,=,0,27,0,28
7,NUMBER,1,0,29,0,30
8,OP,",",0,30,0,31
9,NAME,arg2,0,32,0,36


In [170]:
def get_arg_block_indexes(token_infos):
    list_token_info = token_infos()

    for i, token_info in enumerate(list_token_info):
        if all([token_info.type == 'OP', token_info.value == '(']) :
            start_index = i + 1
            break
        
    for i, token_info in enumerate(list_token_info):
        if all([token_info.type == 'OP', token_info.value == ')']) :
            end_index = i
            break
    
    return start_index, end_index

In [171]:
def split_arg_splitter_indexes(token_infos):
    start_index, end_index = get_arg_block_indexes(token_infos)
    
    list_token_info = token_infos()

    splitter = []
    for i in range(start_index, end_index):
        token_info = list_token_info[i]
        if all([token_info.type == 'OP', token_info.value == ',']) :
            splitter.append(i)
    
    return splitter

In [172]:
get_arg_block_indexes(token_infos)

(3, 23)

In [173]:
print(reconstruct_code(list_token_info[3: 26]))

                 arg1: int = 1, arg2: Dict[str, A] = {"hi": A()})


In [174]:
split_arg_splitter_indexes(token_infos)

[8, 14]

In [184]:
from typing import Optional

class A:
    pass  # Placeholder class for demonstration

def example_func(arg1: int = 1, arg2: Dict[str, A] = {"hi": A()}) -> Tuple[int, A, Optional[Dict[str, int]]]:
    more_line = 1
    another_out = A()
    complex_out = None
    return more_line, another_out, complex_out

In [185]:
out = example_func()

In [186]:
class OutProps(BaseModel):
    data:Tuple[int, A, Optional[Dict[str, int]]]

    class Config:
        arbitrary_types_allowed = True

In [187]:
OutProps(data=out)

OutProps(data=(1, <__main__.A object at 0x7feb17be5f30>, None))

In [188]:
class _GetArgInfoHolder:
    @classmethod
    def extract_return_info(cls, source: str) -> Dict[str, str]:
        tree = ast.parse(source)
        func_def = tree.body[0]

        return_type = None
        if func_def.returns:
            return_type = ast.unparse(func_def.returns)
            if return_type.startswith("Tuple"):
                # Extract return variable names and their types
                return_expr = func_def.body[-1]
                if isinstance(return_expr, ast.Return):
                    if isinstance(return_expr.value, ast.Tuple):
                        elements = return_expr.value.elts
                        if isinstance(func_def.returns, ast.Subscript):
                            types = func_def.returns.slice
                            if isinstance(types, ast.Tuple):
                                return_elements = types.elts
                            else:
                                return_elements = [types]
                        else:
                            return_elements = []

                        return_info = []
                        for name, typ in zip(elements, return_elements):
                            if isinstance(name, ast.Name):
                                var_name = name.id
                            else:
                                var_name = ast.unparse(name)
                            
                            var_type = ast.unparse(typ)
                            return_info.append({
                                "name": var_name,
                                "type_hint": var_type
                            })
                        return return_info

        return None
    
    @classmethod
    def extract_arg_info(cls, source: str)->Dict[str, str]:
        tree = ast.parse(source)
        func_def = tree.body[0]
        
        arg_info_list = []
        total_args = len(func_def.args.args)
        total_defaults = len(func_def.args.defaults)
        
        for i, arg in enumerate(func_def.args.args):
            name = arg.arg
            if arg.annotation:
                type_hint = ast.unparse(arg.annotation)
            else:
                type_hint = None
            
            default_value = None
            if i >= total_args - total_defaults:
                default_index = i - (total_args - total_defaults)
                default_value = func_def.args.defaults[default_index]
                default_value = ast.unparse(default_value)
            
            arg_info_list.append({
                "name": name,
                "type_hint": type_hint,
                "default": default_value
            })
        
        return arg_info_list
    
    @classmethod
    def extract_return_type(cls, source: str)->str:
        tree = ast.parse(source)
        func_def = tree.body[0]

        if func_def.returns:
            return_type = ast.unparse(func_def.returns)
        else:
            return_type = None

        return return_type

In [189]:
source = inspect.getsource(example_func)

In [190]:
_GetArgInfoHolder.extract_arg_info(source)

[{'name': 'arg1', 'type_hint': 'int', 'default': '1'},
 {'name': 'arg2', 'type_hint': 'Dict[str, A]', 'default': "{'hi': A()}"}]

In [191]:
_GetArgInfoHolder.extract_return_info(source)

[{'name': 'more_line', 'type_hint': 'int'},
 {'name': 'another_out', 'type_hint': 'A'},
 {'name': 'complex_out', 'type_hint': 'Optional[Dict[str, int]]'}]

In [192]:
_GetArgInfoHolder.extract_return_type(source)

'Tuple[int, A, Optional[Dict[str, int]]]'