In [25]:
import javalang
from javalang.tree import ClassDeclaration, Declaration
from typing import List, Union

In [26]:
JAVA_FILE_PATH = "../InnerClass.java"
with open(JAVA_FILE_PATH, "r") as java_file:
    tree = javalang.parse.parse(java_file.read())

In [27]:
class JavaClass:
    def __init__(
        self,
        name: str,
        position: int,
        end_position: int,
        parent_class: "JavaClass" = None,
        inner_classes: List["JavaClass"] = None,
    ) -> None:
        self.name = name
        self.parent_class = parent_class
        self.position = position
        self._end_position = end_position
        self.inner_classes = []
        if inner_classes and isinstance(inner_classes, list):
            for c in inner_classes:
                c.parent_class = self
            self.inner_classes = inner_classes
    def __str__(self) -> str:
        return f"{self.name} @ {self.position} <> {self.end_position}"
    @property
    def end_position(self):
        if self._end_position == -1:
            return self.parent_class.end_position
        return self._end_position

In [28]:
def get_loc(java_file_path):
    with open(java_file_path, "r") as java_file:
        # loc = len(java_file.readlines()) + 1
        buffer = java_file.read()
    return len(buffer.split('\n'))
def get_position(declaration: Union[Declaration, List[Declaration]]):
    if isinstance(declaration, list):
        return get_position(declaration[0])
    return declaration.position
def parse_class_declaration(class_declaration: ClassDeclaration, next_position=-1) -> JavaClass:
    classes: List[JavaClass] = []
    if class_declaration.body:
        classes.extend(parse_class_children(class_declaration.body, parent_next_position=next_position))
    return JavaClass(class_declaration.name, class_declaration.position, next_position, inner_classes=classes)
def parse_class_children(children: List[Declaration], parent_next_position) -> List[JavaClass]:
    buffer = []
    for i, declaration in enumerate(children):
        if isinstance(declaration, ClassDeclaration):
            try:
                next_declaration = children[i+1]
                next_position = get_position(next_declaration)
                buffer.append(parse_class_declaration(declaration, next_position))
            except IndexError:
                buffer.append(parse_class_declaration(declaration, parent_next_position))
        else:
            continue
    return buffer
def find_classes(tree):
    classes = []
    for i, declaration in enumerate(tree.types):
        try:
            next_statement_position = get_position(tree.types[i + 1])
        except IndexError:
            # reached end of file, last position is LOC
            next_statement_position = get_loc(JAVA_FILE_PATH)
        if isinstance(declaration, ClassDeclaration):
            classes.append(parse_class_declaration(declaration, next_position=next_statement_position))
    return classes


In [29]:
get_loc(JAVA_FILE_PATH)

23

In [30]:
def walk_classes(java_class: JavaClass):
    if isinstance(java_class, list):
        for class_ in java_class:
            walk_classes(class_)
        return
    print(java_class)
    if java_class.inner_classes:
        for class_ in java_class.inner_classes:
            walk_classes(class_)

In [34]:
classes = find_classes(tree)

In [32]:
walk_classes(classes)

OuterClass @ Position(line=3, column=8) <> Position(line=21, column=9)
InnerClass @ Position(line=10, column=5) <> Position(line=21, column=9)
InnerMostClass @ Position(line=13, column=9) <> Position(line=17, column=16)
AnotherInnerClass @ Position(line=17, column=16) <> Position(line=21, column=9)
AnotherClass @ Position(line=21, column=9) <> 23
