In [None]:
from xlang.xl_ast import (
    VariableType,
    GlobalScope,
    VariableTypeEnum,
    PrimitiveType,
    VariableDeclaration,
    VariableDefinition,
    BaseExpression,
    FunctionCall,
    VariableAccess,
    Constant,
    ConstantType,
    OperatorExpression,
    VariableAssign,
)
from typing import Optional
from xlang.validation_pass import typeify
from xlang.parser import Parser

In [None]:
class ScopeStack:
    def def_variable(self, name: str, variable_type: VariableType):
        # check if not already  existing
        # add to dict
        pass

    def get_variable_type(self, name: str) -> Optional[VariableType]:
        pass

    def push_scope(self):
        pass

    def pop_scope(self):
        pass


def typeify(base_type: VariableType, global_scope: GlobalScope):
    if base_type.variable_type == VariableTypeEnum.ARRAY:
        assert base_type.array_type.variable_type == VariableTypeEnum.UNKNOWN
        base_type.array_type = get_type_from_string(
            global_scope, base_type.array_type.type_name
        )
    elif base_type.variable_type == VariableTypeEnum.UNKNOWN:
        base_type = get_type_from_string(global_scope, base_type.type_name)
    else:
        raise Exception("Unhandled type in struct validation pass")
    return base_type


def primitive(primitive_type: PrimitiveType) -> VariableType:
    return VariableType(VariableTypeEnum.PRIMITIVE, primitive_type=primitive_type)


def get_type_from_string(global_scope: GlobalScope, type_name: str) -> VariableType:
    # int is just an alias for i64
    if type_name == "int":
        type_name = "i64"
    if type_name == "i64":
        return primitive(PrimitiveType.I64)
    elif type_name == "i32":
        return primitive(PrimitiveType.I64)
    elif type_name == "i16":
        return primitive(PrimitiveType.I64)
    elif type_name == "i8":
        return primitive(PrimitiveType.I64)
    elif type_name == "i8":
        return primitive(PrimitiveType.I64)
    elif type_name == "u64":
        return primitive(PrimitiveType.I64)
    elif type_name == "u32":
        return primitive(PrimitiveType.I64)
    elif type_name == "u16":
        return primitive(PrimitiveType.I64)
    elif type_name == "u8":
        return primitive(PrimitiveType.I64)
    elif type_name == "float":
        return primitive(PrimitiveType.FLOAT)
    elif type_name == "string":
        return primitive(PrimitiveType.STRING)
    elif type_name in global_scope.structs:
        return VariableType(VariableTypeEnum.STRUCT, type_name=type_name)
    else:
        raise Exception(f"Unknown type: {type_name}")

In [None]:
def is_type_compatible(variable_type_a: VariableType, variable_type_b: VariableType) -> bool:
    if variable_type_a.variable_type != variable_type_b.variable_type:
        return False
    if variable_type_a.variable_type == VariableTypeEnum.ARRAY:
        return is_type_compatible(variable_type_a.array_type, variable_type_b.array_type)
    elif variable_type_a.variable_type == VariableTypeEnum.STRUCT:
        return variable_type_a.type_name == variable_type_b.type_name
    elif variable_type_a.variable_type == VariableTypeEnum.PRIMITIVE:
        # todo i32 is also compatible to i64, i32 is compatible with u64 etc.
        return variable_type_a.primitive_type == variable_type_b.primitive_type
    else:
        raise Exception("Unhandled type in is_type_compatible")

class Typeifier:
    def __init__(self, global_scope):
        self.global_scope = global_scope
        self.scope_stack = ScopeStack()

    def statements(self, statements):
        for statement in statements:
            self.statement(statement)

    def statement(self, statement):
        if isinstance(statement, VariableDeclaration):
            statement.variable_type = typeify(statement.variable_type, self.global_scope)
            self.scope_stack.def_variable(statement.name, statement.variable_type)
        elif isinstance(statement, VariableDefinition):
            statement.variable_type = typeify(statement.variable_type, self.global_scope)
            self.scope_stack.def_variable(statement.name, statement.variable_type)
            value_type = self.expression(statement.value)
            if not is_type_compatible(statement.variable_type, value_type):
                raise Exception("Incompatible value type")
        elif isinstance(statement, VariableAssign):
            value_type = self.expression(statement.value)
            variable_type = self.scope_stack.get_variable_type(statement.name)
            if not variable_type:
                raise Exception(f"Unknown variable {statement.name}")
            if not is_type_compatible(statement.variable_type, variable_type):
                raise Exception("Incompatible value type")
        else:
            raise Exception("Unhandled statement")
    
    def expression(self, expression: BaseExpression):
        if isinstance(expression, FunctionCall):
            # check if function actually exists
            if not expression.function_name in self.global_scope.functions:
                raise Exception(f"Unknown function called: {expression.function_name}")
            function = self.global_scope.functions[expression.function_name]
            
            # check correct count of params given in call
            if len(expression.params) != len(function.function_params):
                raise Exception(f"function {expression.function_name} takes {len(function.function_params)} params,"
                                f"{len(expression.params)} given")
            
            # evaluate parameters and check if type matches
            for (param, param_type) in zip(expression.params, function.function_params):
                expression_type = self.expression(param)
                if not is_type_compatible(expression_type, param_type):
                    raise Exception("Invalid function parameter type")
            expression.type = function.return_type

        elif isinstance(expression, VariableAccess):
            if expression.variable_access is not None:
                raise NotImplementedError("Recursive variable access not implemented")
            if expression.array_access is not None:
                raise NotImplementedError("Array access not implemented")
            
            variable_type = self.scope_stack.get_variable_type(expression.variable_name)
            if not variable_type:
                raise Exception(f"Unknown variable {expression.variable_name}")
            expression.type = variable_type

        elif isinstance(expression, Constant):
            if expression.constant_type == ConstantType.STRING:
                expression.type = VariableType(VariableTypeEnum.PRIMITIVE, primitive_type=PrimitiveType.STRING)
            elif expression.constant_type == ConstantType.FLOAT:
                expression.type = VariableType(VariableTypeEnum.PRIMITIVE, primitive_type=PrimitiveType.FLOAT)
            elif expression.constant_type == ConstantType.INTEGER:
                # TODO: use finer graded primitive type
                expression.type = VariableType(VariableTypeEnum.PRIMITIVE, primitive_type=PrimitiveType.I64)
            else:
                raise Exception("Internal compiler error: Unknown constant type")

        elif isinstance(expression, OperatorExpression):
            operand1_type = self.expression(operand1)
            operand2_type = self.expression(operand2)
            if not is_type_compatible(operand1_type, operand2_type):
                raise Exception("Incompatible type in operator expressions")
            expression.type = operand1_type
        else:
            raise Exception("Unknown expression")
        return expression.type


def validation_pass(global_scope: GlobalScope):
    for struct in global_scope.structs.values():
        for member in struct.members:
            member.param_type = typeify(member.param_type, global_scope)

    for function in global_scope.functions.values():
        if function.return_type:
            function.return_type = typeify(function.return_type, global_scope)
        for parameter in function.function_params:
            parameter.param_type = typeify(parameter.param_type, global_scope)

    for function in global_scope.functions.values():
        typeifier = Typeifier(global_scope)
        typeifier.statements(function.statements)

parser = Parser()
ast: GlobalScope = parser.parse(
    """
    main() {
        a: int = 5;
    }
    """
)
validation_pass(ast)
ast.functions['main']

In [None]:
"""
    struct MyStruct {
        a: int,
    }

    intfunc(): int {
        return 5;
    }

    param_func(a: int, b: u16, c: [int], d: MyStruct): [MyStruct] {}
"""