In [44]:
from typing import Dict, List
import ast
from ast import ClassDef, ImportFrom, FunctionDef, Return
from glob import glob
import os
from functools import lru_cache

In [54]:
def get_python_code_from_import(
    original_file_path: str,
    import_from: Dict
):
    element = import_from["element"]
    source_path = os.path.join(
        *original_file_path.split(os.sep)[:-element.level],
        f"{element.module}.py"
    )
    
    if not os.path.exists(source_path):
        source_path = os.path.join(
            *original_file_path.split(os.sep)[:-element.level],
            f"{element.module}/__init__.py"
        )
    
    expected_class_name = import_from["name"]
    
    with open(source_path, "r") as f:
        python_code = f.read()
    
    parsed = ast.parse(python_code)
    klasses = get_classes(parsed)
    imports = get_imports(parsed)
    
    desired_klass = None
    
    # We search for the class here
    for klass in klasses:
        if klass.name == expected_class_name:
            desired_klass = klass
            
    # We need to go look for it in the imports
    for import_name, import_from in imports.items():
        # If this is a parent class
        if import_name == expected_class_name:
            # Get the imported path and parsed code.
            return get_python_code_from_import(
                source_path,
                import_from
            )
        
    return (
        original_file_path,
        desired_klass,
        klasses,
        imports
    )
    

def get_class_parent_names(
    original_file_path: str,
    klass: ClassDef,
    klasses: List[ClassDef],
    imports: List[Dict],
    expected_parent: str
) -> List[str]:
    """Return list of parent classes names."""
    # First we search for the parent class in the same file.
    parent_names = [
        base.id
        for base in klass.bases
    ]
    
    if expected_parent in parent_names:
        return parent_names
    
    # If the class does not have parents
    if len(parent_names) == 0:
        return []
    
    # Otherwise we search for the parents in this same directory
    for candidate_parent in klasses:
        # If this is a parent class
        if candidate_parent.name in parent_names:
            # We find the parents of this class            
            parent_names.extend(
                get_class_parent_names(
                    original_file_path,
                    candidate_parent,
                    klasses
                )
            )
            
            if expected_parent in parent_names:
                return parent_names
            
    # Then we check if any of the imports are parents
    for import_name, import_from in imports.items():
        # If this is a parent class
        if import_name in parent_names:
            # Get the imported path and parsed code.
            parent_names.extend(get_class_parent_names(
                *get_python_code_from_import(
                    original_file_path,
                    import_from
                ),
                expected_parent
            ))
            
            if expected_parent in parent_names:
                return parent_names
            
    return parent_names

def get_imports(parsed) -> Dict[str, ImportFrom]:
    """Returns local imports identified in parsed Python code."""
    return {
        getattr(alias, "as_name", alias.name): {
            "element": element,
            "name": alias.name
        }
        for element in parsed.body
        if isinstance(element, ImportFrom)
        for alias in element.names
        if element.level > 0
    }

def get_classes(parsed) -> List[ClassDef]:
    """Returns classes identified in parsed Python code."""
    return [
        element
        for element in parsed.body
        if isinstance(element, ClassDef)
    ]

def find_method_name(klass: ClassDef) -> str:
    """Returns name extracted from class."""
    for function in klass.body:
        if not isinstance(function, FunctionDef) and function.name == "model_name":
            continue
        for function_line in function.body:
            if isinstance(function_line, Return):
                return function_line.value.s
    raise ValueError(
        "Unable to find the method `model_name` in the "
        f"model class {klass.name}."
    )

def build_init(
    path_pattern: str,
    expected_parent: str
):
    paths = glob(path_pattern)
    for path in paths:
        with open(path, "r") as f:
            python_code = f.read()
        parsed = ast.parse(python_code)
        klasses = get_classes(parsed)
        imports = get_imports(parsed)
        for klass in klasses:
            # If this class is an abstract.
            if len(klass.decorator_list) > 0 and klass.decorator_list[0].id == "abstract_class":
                continue
            # If this class has the expected parent.
            if expected_parent in get_class_parent_names(
                path,
                klass,
                klasses,
                imports,
                expected_parent
            ):
                print(klass.name, find_method_name(klass))
        

In [55]:
build_init("embiggen/embedders/ensmallen_embedders/*.py", "AbstractModel")

SkipGramEnsmallen SkipGram
TransEEnsmallen TransE
WeightedSPINE WeightedSPINE
SPINE SPINE
CBOWEnsmallen CBOW


In [14]:
[
    element
    for element in parsed.body
    if isinstance(element, ClassDef)
]

[<_ast.ClassDef at 0x7f8950d9b850>]

In [None]:
[
    element
    for element in parsed.body
    if isinstance(element, ClassDef)
]

In [24]:
[
    element
    for element in parsed.body
    if isinstance(element, ImportFrom)
]

[<_ast.ImportFrom at 0x7f8950d9b750>, <_ast.ImportFrom at 0x7f8950d9b7d0>]

In [20]:
parsed.body[-1].bases[0].id

'Node2Vec'

In [92]:
parsed.body[1].names[0].name

'Optional'

In [27]:
imports = [
    element
    for element in parsed.body
    if isinstance(element, ImportFrom)
]

In [38]:
imports[1].names[0].name

'Node2Vec'