From 2845d78ecb572b654c4b9150c08a1fc3721d6a5a Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Wed, 13 Sep 2023 22:02:59 +0900 Subject: [PATCH 01/35] wip --- crates/hir/src/lib.rs | 4 +- crates/hir_ty/src/checker.rs | 63 +- crates/hir_ty/src/inference.rs | 398 +------ crates/hir_ty/src/inference/environment.rs | 265 +++++ crates/hir_ty/src/inference/type_scheme.rs | 54 + crates/hir_ty/src/inference/type_unifier.rs | 89 ++ crates/hir_ty/src/inference/types.rs | 89 ++ crates/hir_ty/src/lib.rs | 1100 +++++++++++-------- 8 files changed, 1248 insertions(+), 814 deletions(-) create mode 100644 crates/hir_ty/src/inference/environment.rs create mode 100644 crates/hir_ty/src/inference/type_scheme.rs create mode 100644 crates/hir_ty/src/inference/type_unifier.rs create mode 100644 crates/hir_ty/src/inference/types.rs diff --git a/crates/hir/src/lib.rs b/crates/hir/src/lib.rs index 8b0d7648..a8a5f067 100644 --- a/crates/hir/src/lib.rs +++ b/crates/hir/src/lib.rs @@ -40,7 +40,7 @@ pub use db::{HirMasterDatabase, Jar}; pub use input::{FixtureDatabase, NailFile, SourceDatabase, SourceDatabaseTrait}; pub use item::{Function, Item, Module, ModuleKind, Param, Type, UseItem}; use name_resolver::resolve_symbols; -pub use name_resolver::{ResolutionMap, ResolutionStatus}; +pub use name_resolver::{ModuleScopeOrigin, ResolutionMap, ResolutionStatus}; pub use testing::TestingDatabase; /// ビルド対象全体を表します。 @@ -90,6 +90,8 @@ impl Pod { } /// ファイルの登録順の昇順でHIR構築結果を返します。 + /// + /// ルートファイルは含まれません。 pub fn get_hir_files_order_registration_asc(&self) -> Vec<(NailFile, &HirFile)> { let mut lower_results = vec![]; for file in &self.registration_order { diff --git a/crates/hir_ty/src/checker.rs b/crates/hir_ty/src/checker.rs index 0924314b..ac7a72b0 100644 --- a/crates/hir_ty/src/checker.rs +++ b/crates/hir_ty/src/checker.rs @@ -1,8 +1,6 @@ use std::collections::HashMap; -use la_arena::Idx; - -use crate::inference::{InferenceBodyResult, InferenceResult, ResolvedType, Signature}; +use crate::inference::{InferenceBodyResult, InferenceResult, Monotype, Signature}; pub fn check_type_pods( db: &dyn hir::HirMasterDatabase, @@ -27,7 +25,7 @@ pub fn check_type_pods( #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum TypeCheckError { /// 型を解決できない - UnresolvedType { + UnMonotype { /// 対象の式 expr: hir::ExprId, }, @@ -36,40 +34,40 @@ pub enum TypeCheckError { /// 期待される型の式 expected_expr: hir::ExprId, /// 期待される型 - expected_ty: ResolvedType, + expected_ty: Monotype, /// 実際の式 found_expr: hir::ExprId, /// 実際の型 - found_ty: ResolvedType, + found_ty: Monotype, }, /// Ifの条件式の型が一致しない MismatchedTypeIfCondition { /// 期待される型 - expected_ty: ResolvedType, + expected_ty: Monotype, /// 実際の式 found_expr: hir::ExprId, /// 実際の型 - found_ty: ResolvedType, + found_ty: Monotype, }, /// Ifのthenブランチとelseブランチの型が一致しない MismatchedTypeElseBranch { /// 期待される型 - expected_ty: ResolvedType, + expected_ty: Monotype, /// 実際の式 found_expr: hir::ExprId, /// 実際の型 - found_ty: ResolvedType, + found_ty: Monotype, }, /// 関数呼び出しの引数の数が一致しない MismaatchedSignature { /// 期待される型 - expected_ty: ResolvedType, + expected_ty: Monotype, /// 呼び出そうとしている関数のシグネチャ signature: Signature, /// 実際の式 found_expr: hir::ExprId, /// 実際の型 - found_ty: ResolvedType, + found_ty: Monotype, }, /// 関数の戻り値の型と実際の戻り値の型が異なる /// @@ -78,11 +76,11 @@ pub enum TypeCheckError { /// - 関数ボディの最後の式の型 MismatchedReturnType { /// 期待される型 - expected_ty: ResolvedType, + expected_ty: Monotype, /// 実際の式 found_expr: Option, /// 実際の型 - found_ty: ResolvedType, + found_ty: Monotype, }, } @@ -123,8 +121,8 @@ impl<'a> FunctionTypeChecker<'a> { } } - fn signature_by_function(&self, function: hir::Function) -> Idx { - self.infer_result.signature_by_function[&function] + fn signature_by_function(&self, function: hir::Function) -> &Signature { + &self.infer_result.signature_by_function[&function] } fn check(mut self) -> Vec { @@ -144,11 +142,10 @@ impl<'a> FunctionTypeChecker<'a> { self.check_expr(tail); let signature = self.signature_by_function(self.function); - let signature = &self.infer_result.signatures[signature]; - let tail_ty = self.current_inference().type_by_expr[&tail]; + let tail_ty = self.current_inference().type_by_expr[&tail].clone(); if tail_ty != signature.return_type { self.errors.push(TypeCheckError::MismatchedReturnType { - expected_ty: signature.return_type, + expected_ty: signature.return_type.clone(), found_expr: Some(tail), found_ty: tail_ty, }); @@ -178,7 +175,7 @@ impl<'a> FunctionTypeChecker<'a> { let lhs_ty = self.type_by_expr(*lhs); let rhs_ty = self.type_by_expr(*rhs); match (lhs_ty, rhs_ty) { - (ResolvedType::Unknown, ResolvedType::Unknown) => (), + (Monotype::Unknown, Monotype::Unknown) => (), (lhs_ty, rhs_ty) => { if lhs_ty != rhs_ty { self.errors.push(TypeCheckError::MismatchedTypes { @@ -214,7 +211,7 @@ impl<'a> FunctionTypeChecker<'a> { } hir::ResolutionStatus::Resolved { path: _, item } => match item { hir::Item::Function(function) => { - self.signature_by_function(function) + self.signature_by_function(function).clone() } hir::Item::Module(_) | hir::Item::UseItem(_) => unimplemented!(), }, @@ -222,13 +219,12 @@ impl<'a> FunctionTypeChecker<'a> { } }; - let signature = &self.infer_result.signatures[signature]; for (i, param_ty) in signature.params.iter().enumerate() { let arg = args[i]; let arg_ty = self.type_by_expr(arg); - if *param_ty != arg_ty { + if param_ty != &arg_ty { self.errors.push(TypeCheckError::MismaatchedSignature { - expected_ty: *param_ty, + expected_ty: param_ty.clone(), signature: signature.clone(), found_expr: arg, found_ty: arg_ty, @@ -242,7 +238,7 @@ impl<'a> FunctionTypeChecker<'a> { else_branch, } => { let condition_ty = self.type_by_expr(*condition); - let expected_condition_ty = ResolvedType::Bool; + let expected_condition_ty = Monotype::Bool; if condition_ty != expected_condition_ty { self.errors.push(TypeCheckError::MismatchedTypeIfCondition { expected_ty: expected_condition_ty, @@ -263,7 +259,7 @@ impl<'a> FunctionTypeChecker<'a> { }); } } else { - let else_branch_ty = ResolvedType::Unit; + let else_branch_ty = Monotype::Unit; if then_branch_ty != else_branch_ty { self.errors.push(TypeCheckError::MismatchedTypeElseBranch { expected_ty: else_branch_ty, @@ -277,14 +273,13 @@ impl<'a> FunctionTypeChecker<'a> { let return_value_ty = if let Some(value) = value { self.type_by_expr(*value) } else { - ResolvedType::Unit + Monotype::Unit }; let signature = self.signature_by_function(self.function); - let signature = &self.infer_result.signatures[signature]; if return_value_ty != signature.return_type { self.errors.push(TypeCheckError::MismatchedReturnType { - expected_ty: signature.return_type, + expected_ty: signature.return_type.clone(), found_expr: *value, found_ty: return_value_ty, }); @@ -294,10 +289,10 @@ impl<'a> FunctionTypeChecker<'a> { } } - fn type_by_expr(&mut self, expr: hir::ExprId) -> ResolvedType { - let ty = self.current_inference().type_by_expr[&expr]; - if ty == ResolvedType::Unknown { - self.errors.push(TypeCheckError::UnresolvedType { expr }); + fn type_by_expr(&mut self, expr: hir::ExprId) -> Monotype { + let ty = self.current_inference().type_by_expr[&expr].clone(); + if ty == Monotype::Unknown { + self.errors.push(TypeCheckError::UnMonotype { expr }); } ty @@ -306,7 +301,7 @@ impl<'a> FunctionTypeChecker<'a> { /// 現在の関数の推論結果を取得する fn current_inference(&self) -> &InferenceBodyResult { self.infer_result - .inference_by_body + .inference_body_result_by_function .get(&self.function) .unwrap() } diff --git a/crates/hir_ty/src/inference.rs b/crates/hir_ty/src/inference.rs index 6121d083..68862a07 100644 --- a/crates/hir_ty/src/inference.rs +++ b/crates/hir_ty/src/inference.rs @@ -1,370 +1,66 @@ +mod environment; +mod type_scheme; +mod type_unifier; +mod types; + use std::collections::HashMap; -use la_arena::{Arena, Idx}; +use environment::{Environment, InferBody}; +pub use environment::{InferenceBodyResult, InferenceError, InferenceResult, Signature}; +pub use type_scheme::TypeScheme; +pub use types::Monotype; pub fn infer_pods(db: &dyn hir::HirMasterDatabase, pods: &hir::Pods) -> InferenceResult { - // TODO: 全てのPodを走査する - let pod = &pods.root_pod; - - // 依存関係を気にしなくていいようにシグネチャを先に解決しておく - let mut signatures = Arena::::new(); - let mut signature_by_function = HashMap::>::new(); - - let functions = pod.all_functions(db); - for (hir_file, function) in functions.clone() { - let mut params = vec![]; - for param in function.params(db) { - let ty = infer_ty(¶m.data(hir_file.db(db)).ty); - params.push(ty); - } - - let signature = Signature { - params, - return_type: infer_ty(&function.return_type(db)), - }; - let signature_idx = signatures.alloc(signature); - signature_by_function.insert(function, signature_idx); + let mut signature_by_function = HashMap::::new(); + for (hir_file, function) in pods.root_pod.all_functions(db) { + let signature = lower_signature(db, hir_file, function); + signature_by_function.insert(function, signature); } - // 各関数内を型推論する - let mut inference_by_body = HashMap::::new(); - for (hir_file, function) in functions { - let type_inferencer = TypeInferencer::new(db, hir_file, &pods.resolution_map, function); - let inference_result = type_inferencer.infer(); - inference_by_body.insert(function, inference_result); + let mut body_result_by_function = HashMap::::new(); + for (hir_file, function) in pods.root_pod.all_functions(db) { + let env = Environment::new(); + let infer_body = InferBody::new(db, hir_file, function, &signature_by_function, env); + let infer_body_result = infer_body.infer_body(); + + body_result_by_function.insert(function, infer_body_result); } InferenceResult { - inference_by_body, - signatures, signature_by_function, + inference_body_result_by_function: body_result_by_function, } } -fn infer_ty(ty: &hir::Type) -> ResolvedType { - match ty { - hir::Type::Unknown => ResolvedType::Unknown, - hir::Type::Integer => ResolvedType::Integer, - hir::Type::String => ResolvedType::String, - hir::Type::Char => ResolvedType::Char, - hir::Type::Boolean => ResolvedType::Bool, - hir::Type::Unit => ResolvedType::Unit, - } -} - -/// 型推論の結果 -/// -/// 関数の引数と戻り値は必ず確定しているため、それより広い範囲で型推論を行う必要はありません。 -/// そのため、関数単位に結果を持ちます。 -#[derive(Debug)] -pub struct InferenceBodyResult { - /// 関数内の式に対応する型 - pub type_by_expr: HashMap, - /// 型推論中に発生したエラー - pub errors: Vec, -} - -/// 型推論の結果 -#[derive(Debug)] -pub struct InferenceResult { - /// 関数に対応する型推論結果 - pub inference_by_body: HashMap, - /// 関数シグネチャ一覧 - pub signatures: Arena, - /// 関数に対応するシグネチャ - pub signature_by_function: HashMap>, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct InferenceContext { - type_by_expr: HashMap, - type_by_param: HashMap, - signatures: Arena, - signature_by_function: HashMap>, - errors: Vec, -} - -impl InferenceContext { - fn new() -> Self { - Self { - type_by_expr: HashMap::new(), - type_by_param: HashMap::new(), - signatures: Arena::new(), - signature_by_function: HashMap::new(), - errors: Vec::new(), - } - } -} - -/// 型推論中に発生したエラー -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum InferenceError {} - -/// 解決後の型 -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum ResolvedType { - /// 未解決の型 - Unknown, - /// 数値型 - Integer, - /// 文字列型 - String, - /// 文字型 - Char, - /// 真偽値型 - Bool, - /// 単一値型 - Unit, - /// 値を取り得ないことを表す型 - /// - /// 例えば、必ず`panic`を起こす関数の型は`Never`です。 - Never, - /// 関数型 - #[allow(dead_code)] - Function(Idx), -} - -/// 関数シグネチャ -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Signature { - /// パラメータの型一覧 - /// - /// パラメータの順番に対応しています。 - pub params: Vec, - /// 戻り値の型 - pub return_type: ResolvedType, -} - -/// 型推論器 -struct TypeInferencer<'a> { - db: &'a dyn hir::HirMasterDatabase, +fn lower_signature( + db: &dyn hir::HirMasterDatabase, hir_file: hir::HirFile, - resolution_map: &'a hir::ResolutionMap, - ctx: InferenceContext, function: hir::Function, -} -impl<'a> TypeInferencer<'a> { - fn new( - db: &'a dyn hir::HirMasterDatabase, - hir_file: hir::HirFile, - symbol_table: &'a hir::ResolutionMap, - function: hir::Function, - ) -> Self { - Self { - db, - hir_file, - resolution_map: symbol_table, - ctx: InferenceContext::new(), - function, - } - } - - fn infer(mut self) -> InferenceBodyResult { - let body_ast_id = self.function.ast(self.db).body().unwrap(); - let body = self - .hir_file - .db(self.db) - .function_body_by_ast_block(body_ast_id) - .unwrap(); - match body { - hir::Expr::Block(block) => self.infer_block(block), - _ => unreachable!(), - }; - - InferenceBodyResult { - type_by_expr: self.ctx.type_by_expr, - errors: self.ctx.errors, - } - } - - fn infer_stmts(&mut self, stmts: &[hir::Stmt]) { - for stmt in stmts { - match stmt { - hir::Stmt::ExprStmt { expr, .. } => { - let ty = self.infer_expr_id(*expr); - self.ctx.type_by_expr.insert(*expr, ty); - } - hir::Stmt::VariableDef { value, .. } => { - let ty = self.infer_expr_id(*value); - self.ctx.type_by_expr.insert(*value, ty); - } - hir::Stmt::Item { .. } => (), - } - } - } - - fn infer_expr(&mut self, expr: &hir::Expr) -> ResolvedType { - match expr { - hir::Expr::Symbol(symbol) => match symbol { - hir::Symbol::Local { expr, .. } => self.infer_expr_id(*expr), - hir::Symbol::Param { param, .. } => { - infer_ty(¶m.data(self.hir_file.db(self.db)).ty) - } - // TODO: supports function, name resolution - hir::Symbol::Missing { path: _ } => ResolvedType::Unknown, - }, - hir::Expr::Literal(literal) => match literal { - hir::Literal::Integer(_) => ResolvedType::Integer, - hir::Literal::String(_) => ResolvedType::String, - hir::Literal::Char(_) => ResolvedType::Char, - hir::Literal::Bool(_) => ResolvedType::Bool, - }, - hir::Expr::Binary { op, lhs, rhs } => { - // TODO: supports string equal - let lhs_ty = self.infer_expr_id(*lhs); - let rhs_ty = self.infer_expr_id(*rhs); - - match op { - ast::BinaryOp::Add(_) - | ast::BinaryOp::Sub(_) - | ast::BinaryOp::Mul(_) - | ast::BinaryOp::Div(_) => { - if rhs_ty == lhs_ty && (matches!(rhs_ty, ResolvedType::Integer)) { - return rhs_ty; - } - } - ast::BinaryOp::Equal(_) => { - if rhs_ty == lhs_ty - && (matches!(rhs_ty, ResolvedType::Integer | ResolvedType::Bool)) - { - return ResolvedType::Bool; - } - } - ast::BinaryOp::GreaterThan(_) | ast::BinaryOp::LessThan(_) => { - if rhs_ty == lhs_ty && (matches!(rhs_ty, ResolvedType::Integer)) { - return ResolvedType::Bool; - } - } - } - - match (lhs_ty, rhs_ty) { - (ty, ResolvedType::Unknown) => { - self.ctx.type_by_expr.insert(*rhs, ty); - ty - } - (ResolvedType::Unknown, ty) => { - self.ctx.type_by_expr.insert(*lhs, ty); - ty - } - (_, _) => ResolvedType::Unknown, - } - } - hir::Expr::Unary { op, expr } => { - let expr_ty = self.infer_expr_id(*expr); - match op { - ast::UnaryOp::Neg(_) => { - if expr_ty == ResolvedType::Integer { - return ResolvedType::Integer; - } - } - ast::UnaryOp::Not(_) => { - if expr_ty == ResolvedType::Bool { - return ResolvedType::Bool; - } - } - } - - ResolvedType::Unknown - } - hir::Expr::Call { callee, args } => match callee { - hir::Symbol::Missing { path } => { - let Some(resolution_status) = self.resolution_map.item_by_symbol(path) else { return ResolvedType::Unknown; }; - - let resolved_item = match resolution_status { - hir::ResolutionStatus::Resolved { path: _, item } => item, - hir::ResolutionStatus::Unresolved | hir::ResolutionStatus::Error => { - return ResolvedType::Unknown; - } - }; - - match resolved_item { - hir::Item::Function(function) => { - for (i, arg) in args.iter().enumerate() { - let param = function.params(self.db)[i]; - - let arg_ty = self.infer_expr_id(*arg); - let param_ty = infer_ty(¶m.data(self.hir_file.db(self.db)).ty); - - if arg_ty == param_ty { - continue; - } - - match (arg_ty, param_ty) { - (ResolvedType::Unknown, ResolvedType::Unknown) => (), - (ResolvedType::Unknown, ty) => { - self.ctx.type_by_expr.insert(*arg, ty); - } - (_, _) => (), - } - } - - infer_ty(&function.return_type(self.db)) - } - hir::Item::Module(_) => unimplemented!(), - hir::Item::UseItem(_) => unimplemented!(), - } - } - hir::Symbol::Local { .. } | hir::Symbol::Param { .. } => unimplemented!(), - }, - hir::Expr::Block(block) => self.infer_block(block), - hir::Expr::If { - condition, - then_branch, - else_branch, - } => { - self.infer_expr_id(*condition); - let then_branch_ty = self.infer_expr_id(*then_branch); - let else_branch_ty = if let Some(else_branch) = else_branch { - self.infer_expr_id(*else_branch) - } else { - ResolvedType::Unit - }; - - match (then_branch_ty, else_branch_ty) { - (ResolvedType::Unknown, ResolvedType::Unknown) => ResolvedType::Unknown, - (ResolvedType::Unknown, ty) => { - self.ctx.type_by_expr.insert(*then_branch, ty); - ty - } - (ty, ResolvedType::Unknown) => { - self.ctx.type_by_expr.insert(else_branch.unwrap(), ty); - ty - } - (ty_a, ty_b) if ty_a == ty_b => ty_a, - (_, _) => ResolvedType::Unknown, - } - } - hir::Expr::Return { value } => { - if let Some(value) = value { - self.infer_expr_id(*value); - } - ResolvedType::Never - } - hir::Expr::Missing => ResolvedType::Unknown, - } - } - - fn infer_block(&mut self, block: &hir::Block) -> ResolvedType { - self.infer_stmts(&block.stmts); - if let Some(tail) = block.tail { - self.infer_expr_id(tail) - } else { - ResolvedType::Unit - } - } - - fn infer_expr_id(&mut self, expr_id: hir::ExprId) -> ResolvedType { - if let Some(ty) = self.lookup_type(expr_id) { - return ty; - } - - let ty = self.infer_expr(expr_id.lookup(self.hir_file.db(self.db))); - self.ctx.type_by_expr.insert(expr_id, ty); - - ty +) -> Signature { + let params = function + .params(db) + .iter() + .map(|param| { + let param_data = param.data(hir_file.db(db)); + lower_type(¶m_data.ty) + }) + .collect::>(); + + let return_type = lower_type(&function.return_type(db)); + + Signature { + params, + return_type, } +} - fn lookup_type(&self, expr_id: hir::ExprId) -> Option { - self.ctx.type_by_expr.get(&expr_id).copied() +fn lower_type(ty: &hir::Type) -> Monotype { + match ty { + hir::Type::Integer => Monotype::Integer, + hir::Type::String => Monotype::String, + hir::Type::Char => Monotype::Char, + hir::Type::Boolean => Monotype::Bool, + hir::Type::Unit => Monotype::Unit, + hir::Type::Unknown => Monotype::Unknown, } } diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs new file mode 100644 index 00000000..98691b2b --- /dev/null +++ b/crates/hir_ty/src/inference/environment.rs @@ -0,0 +1,265 @@ +use std::{ + collections::{HashMap, HashSet}, + ops::Sub, +}; + +use super::{type_scheme::TypeScheme, type_unifier::TypeUnifier, types::Monotype}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Signature { + pub params: Vec, + pub return_type: Monotype, +} + +pub(crate) struct InferBody<'a> { + db: &'a dyn hir::HirMasterDatabase, + hir_file: hir::HirFile, + function: hir::Function, + signature: &'a Signature, + + unifier: TypeUnifier, + cxt: Context, + + env_stack: Vec, + signature_by_function: &'a HashMap, + type_by_expr: HashMap, +} +impl<'a> InferBody<'a> { + pub(crate) fn new( + db: &'a dyn hir::HirMasterDatabase, + hir_file: hir::HirFile, + function: hir::Function, + signature_by_function: &'a HashMap, + env: Environment, + ) -> Self { + InferBody { + db, + hir_file, + function, + signature: signature_by_function.get(&function).unwrap(), + + unifier: TypeUnifier::new(), + cxt: Context::default(), + env_stack: vec![env], + signature_by_function, + type_by_expr: HashMap::new(), + } + } + + pub(crate) fn infer_body(mut self) -> InferenceBodyResult { + let hir::Expr::Block(body) = self.hir_file.function_body_by_function(self.db, self.function).unwrap() else { panic!("Should be Block.") }; + for stmt in &body.stmts { + self.infer_stmt(stmt); + } + + let ty = if let Some(tail) = &body.tail { + self.infer_expr(*tail) + } else { + Monotype::Unit + }; + dbg!(&self.signature.return_type); + self.unifier.unify(&ty, &self.signature.return_type); + + InferenceBodyResult { + type_by_expr: self.type_by_expr, + errors: self.unifier.errors, + } + } + + fn infer_type(&mut self, ty: &hir::Type) -> Monotype { + match ty { + hir::Type::Integer => Monotype::Integer, + hir::Type::String => Monotype::String, + hir::Type::Char => Monotype::Char, + hir::Type::Boolean => Monotype::Bool, + hir::Type::Unit => Monotype::Unit, + hir::Type::Unknown => Monotype::Unknown, + } + } + + fn infer_stmt(&mut self, stmt: &hir::Stmt) { + match stmt { + hir::Stmt::VariableDef { name, value } => { + let ty = self.infer_expr(*value); + let ty_scheme = TypeScheme::new(ty); + self.mut_current_scope().bindings.insert(*name, ty_scheme); + } + hir::Stmt::ExprStmt { + expr, + has_semicolon: _, + } => { + self.infer_expr(*expr); + } + hir::Stmt::Item { .. } => (), + } + } + + fn infer_expr(&mut self, expr_id: hir::ExprId) -> Monotype { + let expr = expr_id.lookup(self.hir_file.db(self.db)); + let ty = match expr { + hir::Expr::Literal(literal) => match literal { + hir::Literal::Integer(_) => Monotype::Integer, + hir::Literal::String(_) => Monotype::String, + hir::Literal::Char(_) => Monotype::Char, + hir::Literal::Bool(_) => Monotype::Bool, + }, + hir::Expr::Missing => Monotype::Unknown, + hir::Expr::Symbol(symbol) => match symbol { + hir::Symbol::Param { name: _, param } => { + let param = param.data(self.hir_file.db(self.db)); + self.infer_type(¶m.ty) + } + hir::Symbol::Local { name, expr: _ } => { + let ty_scheme = self.current_scope().bindings.get(name).cloned(); + if let Some(ty_scheme) = ty_scheme { + ty_scheme.instantiate(&mut self.cxt) + } else { + panic!("Unbound variable {symbol:?}"); + } + } + hir::Symbol::Missing { path } => unimplemented!(), + }, + hir::Expr::Call { callee, args } => todo!(), + hir::Expr::Binary { op, lhs, rhs } => match op { + ast::BinaryOp::Add(_) + | ast::BinaryOp::Sub(_) + | ast::BinaryOp::Mul(_) + | ast::BinaryOp::Div(_) => { + let lhs_ty = self.infer_expr(*lhs); + let rhs_ty = self.infer_expr(*rhs); + self.unifier.unify(&Monotype::Integer, &lhs_ty); + self.unifier.unify(&Monotype::Integer, &rhs_ty); + + Monotype::Integer + } + ast::BinaryOp::Equal(_) + | ast::BinaryOp::GreaterThan(_) + | ast::BinaryOp::LessThan(_) => { + let lhs_ty = self.infer_expr(*lhs); + let rhs_ty = self.infer_expr(*rhs); + self.unifier.unify(&lhs_ty, &rhs_ty); + + Monotype::Bool + } + }, + hir::Expr::Unary { op, expr } => match op { + ast::UnaryOp::Neg(_) => { + let expr_ty = self.infer_expr(*expr); + self.unifier.unify(&Monotype::Integer, &expr_ty); + + Monotype::Integer + } + ast::UnaryOp::Not(_) => { + let expr_ty = self.infer_expr(*expr); + self.unifier.unify(&Monotype::Bool, &expr_ty); + + Monotype::Bool + } + }, + hir::Expr::Block(block) => { + self.entry_scope(); + + for stmt in &block.stmts { + self.infer_stmt(stmt); + } + + let ty = if let Some(tail) = &block.tail { + self.infer_expr(*tail) + } else { + Monotype::Unit + }; + + self.exit_scope(); + + ty + } + hir::Expr::If { + condition, + then_branch, + else_branch, + } => todo!(), + hir::Expr::Return { value } => todo!(), + }; + + self.type_by_expr.insert(expr_id, ty.clone()); + + ty + } + + fn entry_scope(&mut self) { + let env = self.env_stack.last().unwrap().with(); + self.env_stack.push(env); + } + + fn exit_scope(&mut self) { + self.env_stack.pop(); + } + + fn mut_current_scope(&mut self) -> &mut Environment { + self.env_stack.last_mut().unwrap() + } + + fn current_scope(&self) -> &Environment { + self.env_stack.last().unwrap() + } +} + +#[derive(Debug)] +pub struct InferenceResult { + pub signature_by_function: HashMap, + pub inference_body_result_by_function: HashMap, +} + +#[derive(Debug)] +pub struct InferenceBodyResult { + pub type_by_expr: HashMap, + pub errors: Vec, +} +#[derive(Debug)] +pub enum InferenceError { + TypeMismatch { + expected: Monotype, + actual: Monotype, + }, +} + +#[derive(Default)] +pub struct Environment { + bindings: HashMap, +} + +#[derive(Default)] +pub struct Context { + pub gen_counter: u32, +} + +impl Environment { + pub fn new() -> Self { + Environment { + bindings: HashMap::new(), + } + } + + fn free_variables(&self) -> HashSet { + let mut union = HashSet::::new(); + for type_scheme in self.bindings.values() { + union.extend(type_scheme.free_variables()); + } + + union + } + + fn with(&self) -> Environment { + let mut copy = HashMap::::new(); + copy.extend(self.bindings.clone()); + + Environment { bindings: copy } + } + + fn generalize(&self, ty: &Monotype) -> TypeScheme { + TypeScheme { + variables: ty.free_variables().sub(&self.free_variables()), + ty: ty.clone(), + } + } +} diff --git a/crates/hir_ty/src/inference/type_scheme.rs b/crates/hir_ty/src/inference/type_scheme.rs new file mode 100644 index 00000000..65c9ae6f --- /dev/null +++ b/crates/hir_ty/src/inference/type_scheme.rs @@ -0,0 +1,54 @@ +use std::{ + collections::{HashMap, HashSet}, + iter::FromIterator, +}; + +use super::{environment::Context, types::Monotype}; + +#[derive(Clone)] +pub struct TypeScheme { + pub variables: HashSet, + pub ty: Monotype, +} + +impl TypeScheme { + pub fn new(ty: Monotype) -> TypeScheme { + TypeScheme { + variables: HashSet::new(), + ty, + } + } + + pub fn free_variables(&self) -> HashSet { + self.ty + .free_variables() + .into_iter() + .filter(|var| !self.variables.contains(var)) + .collect() + } + + /// 具体的な型を生成する + pub fn instantiate(&self, cxt: &mut Context) -> Monotype { + let new_vars = self + .variables + .iter() + .map(|v| (*v, Monotype::gen_variable(cxt))); + + let replacement = TypeSubstitution { + replacements: HashMap::from_iter(new_vars), + }; + + self.ty.apply(&replacement) + } +} + +#[derive(Default)] +pub struct TypeSubstitution { + pub replacements: HashMap, +} + +impl TypeSubstitution { + pub fn lookup(&self, id: u32) -> Option { + self.replacements.get(&id).cloned() + } +} diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs new file mode 100644 index 00000000..e2f2672d --- /dev/null +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -0,0 +1,89 @@ +use std::collections::HashMap; + +use super::{environment::InferenceError, types::Monotype}; + +#[derive(Default, Debug)] +pub struct TypeUnifier { + pub(crate) nodes: HashMap, + pub(crate) errors: Vec, +} + +impl TypeUnifier { + pub fn new() -> Self { + Default::default() + } + + pub fn find(&mut self, ty: &Monotype) -> Monotype { + let node = self.nodes.get(ty); + if let Some(node) = node { + node.topmost_parent().value + } else { + self.nodes.insert(ty.clone(), Node::new(ty.clone())); + ty.clone() + } + } + + pub fn unify(&mut self, a: &Monotype, b: &Monotype) { + let a_rep = self.find(a); + let b_rep = self.find(b); + + if a_rep == b_rep { + return; + } + + match (&a_rep, &b_rep) { + ( + Monotype::Function { + from: a_from, + to: a_to, + }, + Monotype::Function { + from: b_from, + to: b_to, + }, + ) => { + self.unify(a_from, b_from); + self.unify(a_to, b_to); + } + (Monotype::Variable(_), b_rep) => self.unify_var(&a_rep, b_rep), + (a_rep, Monotype::Variable(_)) => self.unify_var(&b_rep, a_rep), + (_, _) => { + self.errors.push(InferenceError::TypeMismatch { + expected: a_rep, + actual: b_rep, + }); + } + } + } + + fn unify_var(&mut self, type_var: &Monotype, term: &Monotype) { + assert!(matches!(type_var, Monotype::Variable(_))); + + let value = Some(Box::new(self.nodes.get(term).unwrap().clone())); + let node = self.nodes.get_mut(type_var); + node.unwrap().parent = value; + } +} + +#[derive(Debug, Clone)] +pub struct Node { + value: Monotype, + parent: Option>, +} + +impl Node { + fn new(ty: Monotype) -> Self { + Node { + value: ty, + parent: None, + } + } + + fn topmost_parent(&self) -> Self { + if let Some(node) = &self.parent { + node.topmost_parent() + } else { + self.clone() + } + } +} diff --git a/crates/hir_ty/src/inference/types.rs b/crates/hir_ty/src/inference/types.rs new file mode 100644 index 00000000..9929fea4 --- /dev/null +++ b/crates/hir_ty/src/inference/types.rs @@ -0,0 +1,89 @@ +use std::{collections::HashSet, fmt}; + +use super::{environment::Context, type_scheme::TypeSubstitution}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum Monotype { + Integer, + Bool, + Unit, + Char, + String, + Variable(u32), + Function { + from: Box, + to: Box, + }, + Never, + Unknown, +} + +impl fmt::Display for Monotype { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self { + Monotype::Integer => write!(f, "integer"), + Monotype::Bool => write!(f, "bool"), + Monotype::Char => write!(f, "char"), + Monotype::String => write!(f, "string"), + Monotype::Unit => write!(f, "()"), + Monotype::Never => write!(f, "!"), + Monotype::Unknown => write!(f, "unknown"), + Monotype::Variable(id) => write!(f, "{}", id), + Monotype::Function { from, to } => { + if let Monotype::Function { from, .. } = from.as_ref() { + write!(f, "({}) -> {}", from.to_string(), to.to_string()) + } else { + write!(f, "{} -> {}", from.to_string(), to.to_string()) + } + } + } + } +} + +impl Monotype { + pub fn gen_variable(cxt: &mut Context) -> Self { + let monotype = Self::Variable(cxt.gen_counter); + cxt.gen_counter += 1; + monotype + } + + pub fn free_variables(&self) -> HashSet { + match self { + Monotype::Variable(id) => { + let mut set = HashSet::new(); + set.insert(*id); + + set + } + Monotype::Function { from, to } => from + .free_variables() + .union(&to.free_variables()) + .cloned() + .collect(), + _ => Default::default(), + } + } + + pub fn apply(&self, subst: &TypeSubstitution) -> Monotype { + match self { + Monotype::Integer + | Monotype::Bool + | Monotype::Unit + | Monotype::Char + | Monotype::String + | Monotype::Never + | Monotype::Unknown => self.clone(), + Monotype::Variable(id) => { + if let Some(ty) = subst.lookup(*id) { + ty + } else { + self.clone() + } + } + Monotype::Function { from, to } => Monotype::Function { + from: Box::new(from.apply(subst)), + to: Box::new(to.apply(subst)), + }, + } + } +} diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index cf94ad4f..291f4f38 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -19,9 +19,8 @@ mod checker; mod inference; pub use checker::{TypeCheckError, TypeCheckResult}; -pub use inference::{ - InferenceBodyResult, InferenceError, InferenceResult, ResolvedType, Signature, -}; +use inference::Signature; +pub use inference::{InferenceBodyResult, InferenceResult}; /// HIRを元にTypedHIRを構築します。 pub fn lower_pods(db: &dyn hir::HirMasterDatabase, pods: &hir::Pods) -> TyLowerResult { @@ -45,26 +44,24 @@ pub struct TyLowerResult { impl TyLowerResult { /// 指定した関数の型を取得します。 pub fn signature_by_function(&self, function_id: hir::Function) -> &Signature { - let signature_idx = self.inference_result.signature_by_function[&function_id]; - &self.inference_result.signatures[signature_idx] + &self.inference_result.signature_by_function[&function_id] } } #[cfg(test)] mod tests { use expect_test::{expect, Expect}; - use hir::{Name, Path, Symbol, TestingDatabase}; - use super::*; + use crate::{inference::infer_pods, InferenceResult}; fn check_pod_start_with_root_file(fixture: &str, expect: Expect) { - let db = TestingDatabase::default(); + let db = hir::TestingDatabase::default(); let mut source_db = hir::FixtureDatabase::new(&db, fixture); let pods = hir::parse_pods(&db, "/main.nail", &mut source_db); - let ty_result = lower_pods(&db, &pods); + let inference_result = infer_pods(&db, &pods); - expect.assert_eq(&debug_file(&db, &ty_result, &pods.root_pod.root_hir_file)); + expect.assert_eq(&TestingDebug::new(&db, &pods, &inference_result).debug()); } fn check_in_root_file(fixture: &str, expect: Expect) { @@ -74,260 +71,446 @@ mod tests { check_pod_start_with_root_file(&fixture, expect); } - fn debug_file( - db: &dyn hir::HirMasterDatabase, - ty_lower_result: &TyLowerResult, - hir_file: &hir::HirFile, - ) -> String { - let mut msg = "".to_string(); - - let TyLowerResult { - type_check_result, - inference_result, - } = ty_lower_result; - - for (_, signature) in inference_result.signatures.iter() { - let params = signature - .params + fn indent(nesting: usize) -> String { + " ".repeat(nesting) + } + + struct TestingDebug<'a> { + db: &'a dyn hir::HirMasterDatabase, + pods: &'a hir::Pods, + inference_result: &'a InferenceResult, + } + impl<'a> TestingDebug<'a> { + fn new( + db: &'a dyn hir::HirMasterDatabase, + pods: &'a hir::Pods, + inference_result: &'a InferenceResult, + ) -> Self { + TestingDebug { + db, + pods, + inference_result, + } + } + + fn debug(&self) -> String { + let mut msg = "".to_string(); + + msg.push_str(&self.debug_hir_file(self.pods.root_pod.root_hir_file)); + + for (_nail_file, hir_file) in self.pods.root_pod.get_hir_files_order_registration_asc() + { + msg.push_str(&self.debug_hir_file(*hir_file)); + msg.push('\n'); + } + + msg.push_str("---\n"); + for (_hir_file, function) in self.pods.root_pod.all_functions(self.db) { + let inference_body_result = self + .inference_result + .inference_body_result_by_function + .get(&function) + .unwrap(); + + for error in &inference_body_result.errors { + match error { + crate::inference::InferenceError::TypeMismatch { expected, actual } => { + msg.push_str(&format!( + "error: expected {expected}, actual: {actual}\n" + )); + } + } + } + } + + msg + } + + fn debug_hir_file(&self, hir_file: hir::HirFile) -> String { + let mut msg = format!( + "//- {}\n", + hir_file.file(self.db).file_path(self.db).to_str().unwrap() + ); + + for item in hir_file.top_level_items(self.db) { + msg.push_str(&self.debug_item(hir_file, *item, 0)); + } + + msg + } + + fn debug_function( + &self, + hir_file: hir::HirFile, + function: hir::Function, + nesting: usize, + ) -> String { + let body_expr = hir_file + .db(self.db) + .function_body_by_ast_block(function.ast(self.db).body().unwrap()) + .unwrap(); + + let name = function.name(self.db).text(self.db); + let params = function + .params(self.db) .iter() - .map(debug_type) - .collect::>() + .map(|param| { + let name = if let Some(name) = param.data(hir_file.db(self.db)).name { + name.text(self.db) + } else { + "" + }; + let ty = match param.data(hir_file.db(self.db)).ty { + hir::Type::Integer => "int", + hir::Type::String => "string", + hir::Type::Char => "char", + hir::Type::Boolean => "bool", + hir::Type::Unit => "()", + hir::Type::Unknown => "", + }; + format!("{name}: {ty}") + }) + .collect::>() .join(", "); - msg.push_str(&format!( - "fn({params}) -> {}\n", - debug_type(&signature.return_type) - )); + let return_type = match &function.return_type(self.db) { + hir::Type::Integer => "int", + hir::Type::String => "string", + hir::Type::Char => "char", + hir::Type::Boolean => "bool", + hir::Type::Unit => "()", + hir::Type::Unknown => "", + }; + + let scope_origin = hir::ModuleScopeOrigin::Function { origin: function }; + + let hir::Expr::Block(block) = body_expr else { panic!("Should be Block") }; + + let mut body = "{\n".to_string(); + for stmt in &block.stmts { + body.push_str(&self.debug_stmt( + hir_file, + function, + scope_origin, + stmt, + nesting + 1, + )); + } + if let Some(tail) = block.tail { + let indent = indent(nesting + 1); + body.push_str(&format!( + "{indent}expr:{}\n", + self.debug_expr(hir_file, function, scope_origin, tail, nesting + 1) + )); + body.push_str(&format!( + "{indent}{}\n", + self.debug_type_line(hir_file, function, scope_origin, tail) + )); + } + body.push_str(&format!("{}}}", indent(nesting))); + + let is_entry_point = hir_file.entry_point(self.db) == Some(function); + format!( + "{}fn {}{name}({params}) -> {return_type} {body}\n", + indent(nesting), + if is_entry_point { "entry:" } else { "" } + ) } - msg.push_str("---\n"); + fn debug_module( + &self, + hir_file: hir::HirFile, + module: hir::Module, + nesting: usize, + ) -> String { + let _scope_origin = hir::ModuleScopeOrigin::Module { origin: module }; + + let curr_indent = indent(nesting); + + let module_name = module.name(self.db).text(self.db); + + match module.kind(self.db) { + hir::ModuleKind::Inline { items } => { + let mut module_str = "".to_string(); + module_str.push_str(&format!("{curr_indent}mod {module_name} {{\n")); + for (i, item) in items.iter().enumerate() { + module_str.push_str(&self.debug_item(hir_file, *item, nesting + 1)); + if i == items.len() - 1 { + continue; + } + + module_str.push('\n'); + } + module_str.push_str(&format!("{curr_indent}}}\n")); + module_str + } + hir::ModuleKind::Outline => { + format!("{curr_indent}mod {module_name};\n") + } + } + } - for function in hir_file.functions(db) { - let inference_body_result = inference_result.inference_by_body.get(function).unwrap(); + fn debug_use_item(&self, use_item: hir::UseItem) -> String { + let path_name = self.debug_path(use_item.path(self.db)); + let item_name = use_item.name(self.db).text(self.db); - let mut indexes = inference_body_result - .type_by_expr - .keys() - .collect::>(); - indexes.sort(); - for expr_id in indexes { - let expr = debug_hir_expr(db, expr_id, hir_file); - msg.push_str(&format!( - "`{}`: {}\n", + format!("{path_name}::{item_name}") + } + + fn debug_item(&self, hir_file: hir::HirFile, item: hir::Item, nesting: usize) -> String { + match item { + hir::Item::Function(function) => self.debug_function(hir_file, function, nesting), + hir::Item::Module(module) => self.debug_module(hir_file, module, nesting), + hir::Item::UseItem(use_item) => self.debug_use_item(use_item), + } + } + + fn debug_stmt( + &self, + hir_file: hir::HirFile, + function: hir::Function, + scope_origin: hir::ModuleScopeOrigin, + stmt: &hir::Stmt, + nesting: usize, + ) -> String { + match stmt { + hir::Stmt::VariableDef { name, value } => { + let indent = indent(nesting); + let name = name.text(self.db); + let expr_str = + self.debug_expr(hir_file, function, scope_origin, *value, nesting); + let mut stmt_str = format!("{indent}let {name} = {expr_str}\n"); + + let type_line = self.debug_type_line(hir_file, function, scope_origin, *value); + stmt_str.push_str(&format!("{indent}{type_line}\n")); + + stmt_str + } + hir::Stmt::ExprStmt { expr, - debug_type(&inference_body_result.type_by_expr[expr_id]) - )); + has_semicolon, + } => { + let indent = indent(nesting); + let expr_str = + self.debug_expr(hir_file, function, scope_origin, *expr, nesting); + let type_line = self.debug_type_line(hir_file, function, scope_origin, *expr); + let maybe_semicolon = if *has_semicolon { ";" } else { "" }; + format!("{indent}{expr_str}{maybe_semicolon}\n{indent}{type_line}\n") + } + hir::Stmt::Item { item } => self.debug_item(hir_file, *item, nesting), } } - msg.push_str("---\n"); + fn debug_type_line( + &self, + hir_file: hir::HirFile, + function: hir::Function, + scope_origin: hir::ModuleScopeOrigin, + expr_id: hir::ExprId, + ) -> String { + let ty = self + .inference_result + .inference_body_result_by_function + .get(&function) + .unwrap() + .type_by_expr + .get(&expr_id) + .unwrap(); + let expr_str = match expr_id.lookup(hir_file.db(self.db)) { + hir::Expr::Symbol(_) + | hir::Expr::Binary { .. } + | hir::Expr::Literal(_) + | hir::Expr::Unary { .. } + | hir::Expr::Call { .. } + | hir::Expr::If { .. } + | hir::Expr::Return { .. } + | hir::Expr::Missing => { + self.debug_expr(hir_file, function, scope_origin, expr_id, 0) + } + hir::Expr::Block(_) => "{ .. }".to_string(), + }; - for function in hir_file.functions(db) { - let type_check_errors = type_check_result.errors_by_function.get(function).unwrap(); + format!("// {expr_str}: {ty}") + } - for error in type_check_errors { - match error { - TypeCheckError::UnresolvedType { expr } => { - msg.push_str(&format!( - "error: `{}` is unresolved type.\n", - debug_hir_expr(db, expr, hir_file), + fn debug_expr( + &self, + hir_file: hir::HirFile, + function: hir::Function, + scope_origin: hir::ModuleScopeOrigin, + expr_id: hir::ExprId, + nesting: usize, + ) -> String { + match expr_id.lookup(hir_file.db(self.db)) { + hir::Expr::Symbol(symbol) => { + self.debug_symbol(hir_file, function, scope_origin, symbol, nesting) + } + hir::Expr::Literal(literal) => match literal { + hir::Literal::Bool(b) => b.to_string(), + hir::Literal::Char(c) => format!("'{c}'"), + hir::Literal::String(s) => format!("\"{s}\""), + hir::Literal::Integer(i) => i.to_string(), + }, + hir::Expr::Binary { op, lhs, rhs } => { + let op = match op { + ast::BinaryOp::Add(_) => "+", + ast::BinaryOp::Sub(_) => "-", + ast::BinaryOp::Mul(_) => "*", + ast::BinaryOp::Div(_) => "/", + ast::BinaryOp::Equal(_) => "==", + ast::BinaryOp::GreaterThan(_) => ">", + ast::BinaryOp::LessThan(_) => "<", + }; + let lhs_str = self.debug_expr(hir_file, function, scope_origin, *lhs, nesting); + let rhs_str = self.debug_expr(hir_file, function, scope_origin, *rhs, nesting); + format!("{lhs_str} {op} {rhs_str}") + } + hir::Expr::Unary { op, expr } => { + let op = match op { + ast::UnaryOp::Neg(_) => "-", + ast::UnaryOp::Not(_) => "!", + }; + let expr_str = + self.debug_expr(hir_file, function, scope_origin, *expr, nesting); + format!("{op}{expr_str}") + } + hir::Expr::Call { callee, args } => { + let callee = + self.debug_symbol(hir_file, function, scope_origin, callee, nesting); + let args = args + .iter() + .map(|arg| self.debug_expr(hir_file, function, scope_origin, *arg, nesting)) + .collect::>() + .join(", "); + + format!("{callee}({args})") + } + hir::Expr::Block(block) => { + let scope_origin = hir::ModuleScopeOrigin::Block { origin: expr_id }; + + let mut msg = "{\n".to_string(); + for stmt in &block.stmts { + msg.push_str(&self.debug_stmt( + hir_file, + function, + scope_origin, + stmt, + nesting + 1, )); } - TypeCheckError::MismatchedTypes { - expected_expr, - expected_ty, - found_expr, - found_ty, - } => { + if let Some(tail) = block.tail { + let indent = indent(nesting + 1); msg.push_str(&format!( - "error: expected {}, found {} by `{}` and `{}`\n", - debug_type(expected_ty), - debug_type(found_ty), - debug_hir_expr(db, expected_expr, hir_file), - debug_hir_expr(db, found_expr, hir_file) + "{indent}expr:{}\n", + self.debug_expr(hir_file, function, scope_origin, tail, nesting + 1) )); - } - TypeCheckError::MismaatchedSignature { - expected_ty, - found_expr, - found_ty, - .. - } => msg.push_str(&format!( - "error: expected {}, found {} by `{}`\n", - debug_type(expected_ty), - debug_type(found_ty), - debug_hir_expr(db, found_expr, hir_file) - )), - TypeCheckError::MismatchedTypeIfCondition { - expected_ty, - found_expr, - found_ty, - } => { msg.push_str(&format!( - "error: expected {}, found {} by `{}`\n", - debug_type(expected_ty), - debug_type(found_ty), - debug_hir_expr(db, found_expr, hir_file) + "{indent}{}\n", + self.debug_type_line(hir_file, function, scope_origin, tail) )); } - TypeCheckError::MismatchedTypeElseBranch { - expected_ty, - found_expr, - found_ty, - } => { - msg.push_str(&format!( - "error: expected {}, found {} by `{}`\n", - debug_type(expected_ty), - debug_type(found_ty), - debug_hir_expr(db, found_expr, hir_file) + msg.push_str(&format!("{}}}", indent(nesting))); + + msg + } + hir::Expr::If { + condition, + then_branch, + else_branch, + } => { + let mut msg = "if ".to_string(); + msg.push_str(&self.debug_expr( + hir_file, + function, + scope_origin, + *condition, + nesting, + )); + msg.push(' '); + msg.push_str(&self.debug_expr( + hir_file, + function, + scope_origin, + *then_branch, + nesting, + )); + + if let Some(else_branch) = else_branch { + msg.push_str(" else "); + msg.push_str(&self.debug_expr( + hir_file, + function, + scope_origin, + *else_branch, + nesting, )); } - TypeCheckError::MismatchedReturnType { - expected_ty, - found_expr, - found_ty, - } => { + + msg + } + hir::Expr::Return { value } => { + let mut msg = "return".to_string(); + if let Some(value) = value { msg.push_str(&format!( - "error: expected {}, found {}", - debug_type(expected_ty), - debug_type(found_ty) + " {}", + &self.debug_expr(hir_file, function, scope_origin, *value, nesting,) )); - if let Some(found_expr) = found_expr { - msg.push_str(&format!( - " by `{}`", - debug_hir_expr(db, found_expr, hir_file) - )); - } - msg.push('\n'); } + + msg } + hir::Expr::Missing => "".to_string(), } } - msg - } - - fn debug_hir_expr( - db: &dyn hir::HirMasterDatabase, - expr_id: &hir::ExprId, - hir_file: &hir::HirFile, - ) -> String { - let expr = expr_id.lookup(hir_file.db(db)); - match expr { - hir::Expr::Symbol(symbol) => match symbol { - hir::Symbol::Param { name, .. } => debug_name(db, *name), - hir::Symbol::Local { name, .. } => debug_name(db, *name), - hir::Symbol::Missing { path, .. } => debug_path(db, &path.path(db)), - }, - hir::Expr::Missing => "".to_string(), - hir::Expr::Unary { op, expr } => { - let op = match op { - ast::UnaryOp::Neg(_) => "-".to_string(), - ast::UnaryOp::Not(_) => "!".to_string(), - }; - let expr = debug_hir_expr(db, expr, hir_file); - format!("{op}{expr}") - } - hir::Expr::Binary { op, lhs, rhs } => { - let op = match op { - ast::BinaryOp::Add(_) => "+", - ast::BinaryOp::Sub(_) => "-", - ast::BinaryOp::Mul(_) => "*", - ast::BinaryOp::Div(_) => "/", - ast::BinaryOp::Equal(_) => "==", - ast::BinaryOp::GreaterThan(_) => ">", - ast::BinaryOp::LessThan(_) => "<", - } - .to_string(); - let lhs = debug_hir_expr(db, lhs, hir_file); - let rhs = debug_hir_expr(db, rhs, hir_file); - - format!("{lhs} {op} {rhs}") - } - hir::Expr::Block(block) => { - if let Some(tail) = block.tail { - format!("{{ .., {} }}", debug_hir_expr(db, &tail, hir_file)) - } else { - "{{ .. }}".to_string() + fn debug_symbol( + &self, + _hir_file: hir::HirFile, + _function: hir::Function, + _scope_origin: hir::ModuleScopeOrigin, + symbol: &hir::Symbol, + _nesting: usize, + ) -> String { + match &symbol { + hir::Symbol::Local { name, expr: _ } => name.text(self.db).to_string(), + hir::Symbol::Param { name, .. } => { + let name = name.text(self.db); + format!("param:{name}") } - } - hir::Expr::Call { callee, args } => { - let name = debug_symbol(db, callee); - let args = args - .iter() - .map(|id| debug_hir_expr(db, id, hir_file)) - .collect::>() - .join(", "); - - format!("{name}({args})") - } - hir::Expr::Literal(literal) => match literal { - hir::Literal::Bool(b) => b.to_string(), - hir::Literal::Char(c) => format!("'{c}'"), - hir::Literal::Integer(i) => i.to_string(), - hir::Literal::String(s) => format!("\"{s}\""), - }, - hir::Expr::If { - condition, - then_branch, - else_branch, - } => { - let mut if_expr = format!( - "if {} {}", - debug_hir_expr(db, condition, hir_file), - debug_hir_expr(db, then_branch, hir_file) - ); - if let Some(else_branch) = else_branch { - if_expr.push_str(&format!( - " else {}", - debug_hir_expr(db, else_branch, hir_file) - )); - } - - if_expr - } - hir::Expr::Return { value } => { - let mut msg = "return".to_string(); - if let Some(value) = value { - msg.push_str(&format!(" {}", debug_hir_expr(db, value, hir_file))); + hir::Symbol::Missing { path } => { + let resolving_status = self.pods.resolution_map.item_by_symbol(path).unwrap(); + self.debug_resolution_status(resolving_status) } - - msg } } - } - fn debug_symbol(db: &dyn hir::HirMasterDatabase, symbol: &Symbol) -> String { - match symbol { - hir::Symbol::Param { name, .. } => debug_name(db, *name), - hir::Symbol::Local { name, .. } => debug_name(db, *name), - hir::Symbol::Missing { path, .. } => debug_path(db, &path.path(db)), + fn debug_resolution_status(&self, resolution_status: hir::ResolutionStatus) -> String { + match resolution_status { + hir::ResolutionStatus::Unresolved => "".to_string(), + hir::ResolutionStatus::Error => "".to_string(), + hir::ResolutionStatus::Resolved { path, item } => { + let path = self.debug_path(&path); + match item { + hir::Item::Function(_) => { + format!("fn:{path}") + } + hir::Item::Module(_) => { + format!("mod:{path}") + } + hir::Item::UseItem(_) => { + unreachable!() + } + } + } + } } - } - fn debug_name(db: &dyn hir::HirMasterDatabase, name: Name) -> String { - name.text(db).to_string() - } - - fn debug_path(db: &dyn hir::HirMasterDatabase, path: &Path) -> String { - path.segments(db) - .iter() - .map(|segment| segment.text(db).to_string()) - .collect::>() - .join("::") - } - - fn debug_type(ty: &ResolvedType) -> String { - match ty { - ResolvedType::Unknown => "unknown", - ResolvedType::Integer => "int", - ResolvedType::String => "string", - ResolvedType::Char => "char", - ResolvedType::Bool => "bool", - ResolvedType::Unit => "()", - ResolvedType::Never => "!", - ResolvedType::Function(_) => "fn", + fn debug_path(&self, path: &hir::Path) -> String { + path.segments(self.db) + .iter() + .map(|segment| segment.text(self.db).to_string()) + .collect::>() + .join("::") } - .to_string() } #[test] @@ -392,9 +575,11 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int + //- /main.nail + fn entry:main() -> () { + 10; + // 10: integer + } --- "#]], ); @@ -409,9 +594,11 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `"aaa"`: string + //- /main.nail + fn entry:main() -> () { + "aaa"; + // "aaa": string + } --- "#]], ); @@ -426,9 +613,11 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `'a'`: char + //- /main.nail + fn entry:main() -> () { + 'a'; + // 'a': char + } --- "#]], ); @@ -443,9 +632,11 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `true`: bool + //- /main.nail + fn entry:main() -> () { + true; + // true: bool + } --- "#]], ); @@ -457,9 +648,11 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `false`: bool + //- /main.nail + fn entry:main() -> () { + false; + // false: bool + } --- "#]], ); @@ -474,9 +667,11 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `true`: bool + //- /main.nail + fn entry:main() -> () { + let a = true + // true: bool + } --- "#]], ) @@ -494,12 +689,17 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `true`: bool - `10`: int - `"aa"`: string - `'a'`: char + //- /main.nail + fn entry:main() -> () { + let a = true + // true: bool + let b = 10 + // 10: integer + let c = "aa" + // "aa": string + let d = 'a' + // 'a': char + } --- "#]], ) @@ -535,82 +735,72 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int - `20`: int - `10 + 20`: int - `"aaa"`: string - `"bbb"`: string - `"aaa" + "bbb"`: unknown - `10`: int - `"aaa"`: string - `10 + "aaa"`: unknown - `'a'`: char - `'a'`: char - `'a' + 'a'`: unknown - `10`: int - `'a'`: char - `10 + 'a'`: unknown - `10`: int - `'a'`: char - `10 < 'a'`: unknown - `10`: int - `'a'`: char - `10 > 'a'`: unknown - `true`: bool - `true`: bool - `true + true`: unknown - `true`: bool - `true`: bool - `true - true`: unknown - `true`: bool - `true`: bool - `true * true`: unknown - `true`: bool - `true`: bool - `true / true`: unknown - `true`: bool - `true`: bool - `true == true`: bool - `true`: bool - `false`: bool - `true < false`: unknown - `true`: bool - `false`: bool - `true > false`: unknown - `10`: int - `true`: bool - `10 + true`: unknown - `10`: int - `"aaa"`: string - `10`: int - `10 + "aaa"`: int - `10 + 10 + "aaa"`: int - `10`: int - `20`: int - `10 - 20`: int - `10`: int - `20`: int - `10 * 20`: int - `10`: int - `20`: int - `10 / 20`: int - `10`: int - `20`: int - `10 == 20`: bool - `10`: int - `20`: int - `10 < 20`: bool - `10`: int - `20`: int - `10 > 20`: bool + //- /main.nail + fn entry:main() -> () { + 10 + 20 + // 10 + 20: integer + "aaa" + "bbb" + // "aaa" + "bbb": integer + 10 + "aaa" + // 10 + "aaa": integer + 'a' + 'a' + // 'a' + 'a': integer + 10 + 'a' + // 10 + 'a': integer + 10 < 'a' + // 10 < 'a': bool + 10 > 'a' + // 10 > 'a': bool + true + true + // true + true: integer + true - true + // true - true: integer + true * true + // true * true: integer + true / true + // true / true: integer + true == true + // true == true: bool + true < false + // true < false: bool + true > false + // true > false: bool + 10 + true + // 10 + true: integer + 10 + 10 + "aaa" + // 10 + 10 + "aaa": integer + 10 - 20 + // 10 - 20: integer + 10 * 20 + // 10 * 20: integer + 10 / 20 + // 10 / 20: integer + 10 == 20 + // 10 == 20: bool + 10 < 20 + // 10 < 20: bool + 10 > 20; + // 10 > 20: bool + } --- - error: expected int, found string by `10` and `"aaa"` - error: expected int, found char by `10` and `'a'` - error: expected int, found char by `10` and `'a'` - error: expected int, found char by `10` and `'a'` - error: expected int, found bool by `10` and `true` + error: expected integer, actual: string + error: expected integer, actual: string + error: expected integer, actual: string + error: expected integer, actual: char + error: expected integer, actual: char + error: expected integer, actual: char + error: expected integer, actual: char + error: expected integer, actual: char + error: expected integer, actual: bool + error: expected integer, actual: bool + error: expected integer, actual: bool + error: expected integer, actual: bool + error: expected integer, actual: bool + error: expected integer, actual: bool + error: expected integer, actual: bool + error: expected integer, actual: bool + error: expected integer, actual: bool + error: expected integer, actual: string "#]], ); @@ -621,18 +811,14 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int - `"aaa"`: string - `10`: int - `"aaa"`: string - `10 + "aaa"`: unknown - `10 + "aaa"`: unknown - `10 + "aaa" + 10 + "aaa"`: unknown + //- /main.nail + fn entry:main() -> () { + 10 + "aaa" + 10 + "aaa"; + // 10 + "aaa" + 10 + "aaa": integer + } --- - error: `10 + "aaa"` is unresolved type. - error: `10 + "aaa"` is unresolved type. + error: expected integer, actual: string + error: expected integer, actual: string "#]], ); } @@ -654,25 +840,32 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int - `-10`: int - `"aaa"`: string - `-"aaa"`: unknown - `'a'`: char - `-'a'`: unknown - `true`: bool - `-true`: unknown - `10`: int - `!10`: unknown - `"aaa"`: string - `!"aaa"`: unknown - `'a'`: char - `!'a'`: unknown - `true`: bool - `!true`: bool + //- /main.nail + fn entry:main() -> () { + let a = -10 + // -10: integer + let b = -"aaa" + // -"aaa": integer + let c = -'a' + // -'a': integer + let d = -true + // -true: integer + let e = !10 + // !10: bool + let f = !"aaa" + // !"aaa": bool + let g = !'a' + // !'a': bool + let h = !true + // !true: bool + } --- + error: expected integer, actual: string + error: expected integer, actual: char + error: expected integer, actual: bool + error: expected bool, actual: integer + error: expected bool, actual: string + error: expected bool, actual: char "#]], ) } @@ -688,17 +881,15 @@ mod tests { } "#, expect![[r#" - fn() -> bool - --- - `true`: bool - `!true`: bool - `false`: bool - `!false`: bool - `a`: bool - `b`: bool - `!a`: bool - `!b`: bool - `!a == !b`: bool + //- /main.nail + fn entry:main() -> bool { + let a = !true + // !true: bool + let b = !false + // !false: bool + expr:!a == !b + // !a == !b: bool + } --- "#]], ); @@ -714,11 +905,13 @@ mod tests { } "#, expect![[r#" - fn() -> int - --- - `10`: int - `-10`: int - `a`: int + //- /main.nail + fn entry:main() -> int { + let a = -10 + // -10: integer + expr:a + // a: integer + } --- "#]], ) @@ -735,10 +928,14 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int - `{ .., 10 }`: int + //- /main.nail + fn entry:main() -> () { + { + expr:10 + // 10: integer + }; + // { .. }: integer + } --- "#]], ); @@ -755,12 +952,19 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int - `"aaa"`: string - `{ .., "aaa" }`: string - `{ .., { .., "aaa" } }`: string + //- /main.nail + fn entry:main() -> () { + { + expr:{ + 10 + // 10: integer + expr:"aaa" + // "aaa": string + } + // { .. }: string + }; + // { .. }: string + } --- "#]], ); @@ -777,15 +981,20 @@ mod tests { } "#, expect![[r#" - fn() -> int - --- - `10`: int - `20`: int - `a`: int - `c`: int - `a + c`: int - `{ .., a + c }`: int - `b`: int + //- /main.nail + fn entry:main() -> int { + let a = 10 + // 10: integer + let b = { + let c = 20 + // 20: integer + expr:a + c + // a + c: integer + } + // { .. }: integer + expr:b + // b: integer + } --- "#]], ); @@ -800,11 +1009,13 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int + //- /main.nail + fn aaa() -> () { + expr:10 + // 10: integer + } --- - error: expected (), found int by `10` + error: expected integer, actual: () "#]], ); @@ -815,9 +1026,11 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int + //- /main.nail + fn aaa() -> () { + 10; + // 10: integer + } --- "#]], ); @@ -834,12 +1047,19 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int - `{ .., 10 }`: int - `20`: int - `{{ .. }}`: () + //- /main.nail + fn aaa() -> () { + { + expr:10 + // 10: integer + }; + // { .. }: integer + { + 20; + // 20: integer + }; + // { .. }: () + } --- "#]], ); @@ -858,13 +1078,19 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int - `{ .., 10 }`: int - `{ .., { .., 10 } }`: int + //- /main.nail + fn aaa() -> () { + expr:{ + expr:{ + expr:10 + // 10: integer + } + // { .. }: integer + } + // { .. }: integer + } --- - error: expected (), found int by `{ .., { .., 10 } }` + error: expected integer, actual: () "#]], ); @@ -879,11 +1105,17 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int - `{{ .. }}`: () - `{ .., {{ .. }} }`: () + //- /main.nail + fn aaa() -> () { + expr:{ + expr:{ + 10; + // 10: integer + } + // { .. }: () + } + // { .. }: () + } --- "#]], ); @@ -899,11 +1131,17 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int - `{ .., 10 }`: int - `{{ .. }}`: () + //- /main.nail + fn aaa() -> () { + expr:{ + { + expr:10 + // 10: integer + }; + // { .. }: integer + } + // { .. }: () + } --- "#]], ); @@ -919,10 +1157,13 @@ mod tests { } "#, expect![[r#" - fn() -> int - --- - `10`: int - `a`: int + //- /main.nail + fn aaa() -> int { + let a = 10 + // 10: integer + expr:a + // a: integer + } --- "#]], ); @@ -938,10 +1179,13 @@ mod tests { } "#, expect![[r#" - fn(int, string) -> () - --- - `x`: int - `y`: string + //- /main.nail + fn aaa(x: int, y: string) -> () { + let a = param:x + // param:x: integer + let b = param:y + // param:y: string + } --- "#]], ); From 2d8155b301888f225c1b2844a5b46ae8c56446c9 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Thu, 14 Sep 2023 12:50:10 +0900 Subject: [PATCH 02/35] wip --- crates/hir_ty/src/inference.rs | 2 +- crates/hir_ty/src/inference/environment.rs | 132 ++++- crates/hir_ty/src/inference/type_unifier.rs | 23 +- crates/hir_ty/src/inference/types.rs | 51 +- crates/hir_ty/src/lib.rs | 539 +++++++++++--------- 5 files changed, 468 insertions(+), 279 deletions(-) diff --git a/crates/hir_ty/src/inference.rs b/crates/hir_ty/src/inference.rs index 68862a07..bbb7b485 100644 --- a/crates/hir_ty/src/inference.rs +++ b/crates/hir_ty/src/inference.rs @@ -20,7 +20,7 @@ pub fn infer_pods(db: &dyn hir::HirMasterDatabase, pods: &hir::Pods) -> Inferenc let mut body_result_by_function = HashMap::::new(); for (hir_file, function) in pods.root_pod.all_functions(db) { let env = Environment::new(); - let infer_body = InferBody::new(db, hir_file, function, &signature_by_function, env); + let infer_body = InferBody::new(db, pods, hir_file, function, &signature_by_function, env); let infer_body_result = infer_body.infer_body(); body_result_by_function.insert(function, infer_body_result); diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 98691b2b..491b0807 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -13,6 +13,7 @@ pub struct Signature { pub(crate) struct InferBody<'a> { db: &'a dyn hir::HirMasterDatabase, + pods: &'a hir::Pods, hir_file: hir::HirFile, function: hir::Function, signature: &'a Signature, @@ -27,6 +28,7 @@ pub(crate) struct InferBody<'a> { impl<'a> InferBody<'a> { pub(crate) fn new( db: &'a dyn hir::HirMasterDatabase, + pods: &'a hir::Pods, hir_file: hir::HirFile, function: hir::Function, signature_by_function: &'a HashMap, @@ -34,6 +36,7 @@ impl<'a> InferBody<'a> { ) -> Self { InferBody { db, + pods, hir_file, function, signature: signature_by_function.get(&function).unwrap(), @@ -94,6 +97,55 @@ impl<'a> InferBody<'a> { } } + fn infer_symbol(&mut self, symbol: &hir::Symbol) -> Monotype { + match symbol { + hir::Symbol::Param { name: _, param } => { + let param = param.data(self.hir_file.db(self.db)); + self.infer_type(¶m.ty) + } + hir::Symbol::Local { name, expr: _ } => { + let ty_scheme = self.current_scope().bindings.get(name).cloned(); + if let Some(ty_scheme) = ty_scheme { + ty_scheme.instantiate(&mut self.cxt) + } else { + panic!("Unbound variable {symbol:?}"); + } + } + hir::Symbol::Missing { path } => { + let item = self.pods.resolution_map.item_by_symbol(path).unwrap(); + match item { + hir::ResolutionStatus::Unresolved | hir::ResolutionStatus::Error => { + // 解決できないエラーを追加 + Monotype::Unknown + } + hir::ResolutionStatus::Resolved { path, item } => { + match item { + hir::Item::Function(function) => { + let signature = self.signature_by_function.get(&function); + if let Some(signature) = signature { + Monotype::Function { + args: signature.params.clone(), + return_type: Box::new(signature.return_type.clone()), + } + } else { + unreachable!("Function signature should be resolved.") + } + } + hir::Item::Module(_) => { + // モジュールを型推論使用としているエラーを追加 + Monotype::Unknown + } + hir::Item::UseItem(_) => { + // 使用宣言を型推論使用としているエラーを追加 + Monotype::Unknown + } + } + } + } + } + } + } + fn infer_expr(&mut self, expr_id: hir::ExprId) -> Monotype { let expr = expr_id.lookup(self.hir_file.db(self.db)); let ty = match expr { @@ -104,22 +156,43 @@ impl<'a> InferBody<'a> { hir::Literal::Bool(_) => Monotype::Bool, }, hir::Expr::Missing => Monotype::Unknown, - hir::Expr::Symbol(symbol) => match symbol { - hir::Symbol::Param { name: _, param } => { - let param = param.data(self.hir_file.db(self.db)); - self.infer_type(¶m.ty) - } - hir::Symbol::Local { name, expr: _ } => { - let ty_scheme = self.current_scope().bindings.get(name).cloned(); - if let Some(ty_scheme) = ty_scheme { - ty_scheme.instantiate(&mut self.cxt) - } else { - panic!("Unbound variable {symbol:?}"); + hir::Expr::Symbol(symbol) => self.infer_symbol(symbol), + hir::Expr::Call { + callee, + args: call_args, + } => { + let ty = self.infer_symbol(callee); + match ty { + Monotype::Integer + | Monotype::Bool + | Monotype::Unit + | Monotype::Char + | Monotype::String + | Monotype::Never + | Monotype::Unknown + | Monotype::Variable(_) => { + // TODO: 関数ではないものを呼び出そうとしているエラーを追加 + Monotype::Unknown + } + Monotype::Function { args, return_type } => { + let call_args_ty = call_args + .iter() + .map(|arg| self.infer_expr(*arg)) + .collect::>(); + + if call_args_ty.len() != args.len() { + // TODO: 引数の数が異なるエラーを追加 + Monotype::Unknown + } else { + for (call_arg, arg) in call_args_ty.iter().zip(args.iter()) { + self.unifier.unify(call_arg, arg); + } + + *return_type + } } } - hir::Symbol::Missing { path } => unimplemented!(), - }, - hir::Expr::Call { callee, args } => todo!(), + } hir::Expr::Binary { op, lhs, rhs } => match op { ast::BinaryOp::Add(_) | ast::BinaryOp::Sub(_) @@ -166,6 +239,7 @@ impl<'a> InferBody<'a> { let ty = if let Some(tail) = &block.tail { self.infer_expr(*tail) } else { + // 最後の式がない場合は Unit として扱う Monotype::Unit }; @@ -177,8 +251,34 @@ impl<'a> InferBody<'a> { condition, then_branch, else_branch, - } => todo!(), - hir::Expr::Return { value } => todo!(), + } => { + let condition_ty = self.infer_expr(*condition); + self.unifier.unify(&Monotype::Bool, &condition_ty); + + let then_ty = self.infer_expr(*then_branch); + if let Some(else_branch) = else_branch { + let else_ty = self.infer_expr(*else_branch); + self.unifier.unify(&then_ty, &else_ty); + } else { + // elseブランチがない場合は Unit として扱う + self.unifier.unify(&then_ty, &Monotype::Unit); + } + + then_ty + } + hir::Expr::Return { value } => { + if let Some(return_value) = value { + let ty = self.infer_expr(*return_value); + self.unifier.unify(&ty, &self.signature.return_type); + } else { + // 何も指定しない場合は Unit を返すものとして扱う + self.unifier + .unify(&Monotype::Unit, &self.signature.return_type); + } + + // return自体の戻り値は Never として扱う + Monotype::Never + } }; self.type_by_expr.insert(expr_id, ty.clone()); diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index e2f2672d..cbc114d2 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -34,16 +34,27 @@ impl TypeUnifier { match (&a_rep, &b_rep) { ( Monotype::Function { - from: a_from, - to: a_to, + args: a_args, + return_type: a_return_type, }, Monotype::Function { - from: b_from, - to: b_to, + args: b_args, + return_type: b_return_type, }, ) => { - self.unify(a_from, b_from); - self.unify(a_to, b_to); + if a_args.len() != b_args.len() { + self.errors.push(InferenceError::TypeMismatch { + expected: a_rep, + actual: b_rep, + }); + return; + } + + for (a_arg, b_arg) in a_args.iter().zip(b_args.iter()) { + self.unify(a_arg, b_arg); + } + + self.unify(a_return_type, b_return_type); } (Monotype::Variable(_), b_rep) => self.unify_var(&a_rep, b_rep), (a_rep, Monotype::Variable(_)) => self.unify_var(&b_rep, a_rep), diff --git a/crates/hir_ty/src/inference/types.rs b/crates/hir_ty/src/inference/types.rs index 9929fea4..a7bc6cba 100644 --- a/crates/hir_ty/src/inference/types.rs +++ b/crates/hir_ty/src/inference/types.rs @@ -11,8 +11,8 @@ pub enum Monotype { String, Variable(u32), Function { - from: Box, - to: Box, + args: Vec, + return_type: Box, }, Never, Unknown, @@ -21,7 +21,7 @@ pub enum Monotype { impl fmt::Display for Monotype { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { - Monotype::Integer => write!(f, "integer"), + Monotype::Integer => write!(f, "int"), Monotype::Bool => write!(f, "bool"), Monotype::Char => write!(f, "char"), Monotype::String => write!(f, "string"), @@ -29,12 +29,20 @@ impl fmt::Display for Monotype { Monotype::Never => write!(f, "!"), Monotype::Unknown => write!(f, "unknown"), Monotype::Variable(id) => write!(f, "{}", id), - Monotype::Function { from, to } => { - if let Monotype::Function { from, .. } = from.as_ref() { - write!(f, "({}) -> {}", from.to_string(), to.to_string()) - } else { - write!(f, "{} -> {}", from.to_string(), to.to_string()) - } + Monotype::Function { + args: from, + return_type: to, + } => { + write!( + f, + "({}) -> {}", + from.iter() + .map(|ty| ty.to_string()) + .collect::>() + .join(", ") + .to_string(), + to.to_string() + ) } } } @@ -55,11 +63,17 @@ impl Monotype { set } - Monotype::Function { from, to } => from - .free_variables() - .union(&to.free_variables()) - .cloned() - .collect(), + Monotype::Function { + args: from, + return_type: to, + } => { + let mut set = HashSet::new(); + for arg in from.iter() { + set.extend(arg.free_variables()); + } + set.extend(to.free_variables()); + set + } _ => Default::default(), } } @@ -80,9 +94,12 @@ impl Monotype { self.clone() } } - Monotype::Function { from, to } => Monotype::Function { - from: Box::new(from.apply(subst)), - to: Box::new(to.apply(subst)), + Monotype::Function { + args: from, + return_type: to, + } => Monotype::Function { + args: from.iter().map(|arg| arg.apply(subst)).collect::>(), + return_type: Box::new(to.apply(subst)), }, } } diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 291f4f38..de99eb6f 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -321,11 +321,23 @@ mod tests { | hir::Expr::Literal(_) | hir::Expr::Unary { .. } | hir::Expr::Call { .. } - | hir::Expr::If { .. } | hir::Expr::Return { .. } | hir::Expr::Missing => { self.debug_expr(hir_file, function, scope_origin, expr_id, 0) } + hir::Expr::If { + condition, + then_branch: _, + else_branch, + } => { + let condition_str = + self.debug_expr(hir_file, function, scope_origin, *condition, 0); + let mut if_str = format!("if {condition_str} {{ .. }}"); + if else_branch.is_some() { + if_str.push_str(" else { .. }"); + } + if_str + } hir::Expr::Block(_) => "{ .. }".to_string(), }; @@ -533,34 +545,27 @@ mod tests { } "#, expect![[r#" - fn(int) -> int - fn() -> int - --- - `x`: int - `0`: int - `x == 0`: bool - `0`: int - `{ .., 0 }`: int - `x`: int - `1`: int - `x == 1`: bool - `1`: int - `{ .., 1 }`: int - `x`: int - `1`: int - `x - 1`: int - `x`: int - `2`: int - `x - 2`: int - `fibonacci(x - 1)`: int - `fibonacci(x - 2)`: int - `fibonacci(x - 1) + fibonacci(x - 2)`: int - `{ .., fibonacci(x - 1) + fibonacci(x - 2) }`: int - `if x == 1 { .., 1 } else { .., fibonacci(x - 1) + fibonacci(x - 2) }`: int - `{ .., if x == 1 { .., 1 } else { .., fibonacci(x - 1) + fibonacci(x - 2) } }`: int - `if x == 0 { .., 0 } else { .., if x == 1 { .., 1 } else { .., fibonacci(x - 1) + fibonacci(x - 2) } }`: int - `15`: int - `fibonacci(15)`: int + //- /main.nail + fn fibonacci(x: int) -> int { + expr:if param:x == 0 { + expr:0 + // 0: int + } else { + expr:if param:x == 1 { + expr:1 + // 1: int + } else { + expr:fn:fibonacci(param:x - 1) + fn:fibonacci(param:x - 2) + // fn:fibonacci(param:x - 1) + fn:fibonacci(param:x - 2): int + } + // if param:x == 1 { .. } else { .. }: int + } + // if param:x == 0 { .. } else { .. }: int + } + fn entry:main() -> int { + expr:fn:fibonacci(15) + // fn:fibonacci(15): int + } --- "#]], ); @@ -578,7 +583,7 @@ mod tests { //- /main.nail fn entry:main() -> () { 10; - // 10: integer + // 10: int } --- "#]], @@ -694,7 +699,7 @@ mod tests { let a = true // true: bool let b = 10 - // 10: integer + // 10: int let c = "aa" // "aa": string let d = 'a' @@ -738,27 +743,27 @@ mod tests { //- /main.nail fn entry:main() -> () { 10 + 20 - // 10 + 20: integer + // 10 + 20: int "aaa" + "bbb" - // "aaa" + "bbb": integer + // "aaa" + "bbb": int 10 + "aaa" - // 10 + "aaa": integer + // 10 + "aaa": int 'a' + 'a' - // 'a' + 'a': integer + // 'a' + 'a': int 10 + 'a' - // 10 + 'a': integer + // 10 + 'a': int 10 < 'a' // 10 < 'a': bool 10 > 'a' // 10 > 'a': bool true + true - // true + true: integer + // true + true: int true - true - // true - true: integer + // true - true: int true * true - // true * true: integer + // true * true: int true / true - // true / true: integer + // true / true: int true == true // true == true: bool true < false @@ -766,15 +771,15 @@ mod tests { true > false // true > false: bool 10 + true - // 10 + true: integer + // 10 + true: int 10 + 10 + "aaa" - // 10 + 10 + "aaa": integer + // 10 + 10 + "aaa": int 10 - 20 - // 10 - 20: integer + // 10 - 20: int 10 * 20 - // 10 * 20: integer + // 10 * 20: int 10 / 20 - // 10 / 20: integer + // 10 / 20: int 10 == 20 // 10 == 20: bool 10 < 20 @@ -783,24 +788,24 @@ mod tests { // 10 > 20: bool } --- - error: expected integer, actual: string - error: expected integer, actual: string - error: expected integer, actual: string - error: expected integer, actual: char - error: expected integer, actual: char - error: expected integer, actual: char - error: expected integer, actual: char - error: expected integer, actual: char - error: expected integer, actual: bool - error: expected integer, actual: bool - error: expected integer, actual: bool - error: expected integer, actual: bool - error: expected integer, actual: bool - error: expected integer, actual: bool - error: expected integer, actual: bool - error: expected integer, actual: bool - error: expected integer, actual: bool - error: expected integer, actual: string + error: expected int, actual: string + error: expected int, actual: string + error: expected int, actual: string + error: expected int, actual: char + error: expected int, actual: char + error: expected int, actual: char + error: expected int, actual: char + error: expected int, actual: char + error: expected int, actual: bool + error: expected int, actual: bool + error: expected int, actual: bool + error: expected int, actual: bool + error: expected int, actual: bool + error: expected int, actual: bool + error: expected int, actual: bool + error: expected int, actual: bool + error: expected int, actual: bool + error: expected int, actual: string "#]], ); @@ -814,11 +819,11 @@ mod tests { //- /main.nail fn entry:main() -> () { 10 + "aaa" + 10 + "aaa"; - // 10 + "aaa" + 10 + "aaa": integer + // 10 + "aaa" + 10 + "aaa": int } --- - error: expected integer, actual: string - error: expected integer, actual: string + error: expected int, actual: string + error: expected int, actual: string "#]], ); } @@ -843,13 +848,13 @@ mod tests { //- /main.nail fn entry:main() -> () { let a = -10 - // -10: integer + // -10: int let b = -"aaa" - // -"aaa": integer + // -"aaa": int let c = -'a' - // -'a': integer + // -'a': int let d = -true - // -true: integer + // -true: int let e = !10 // !10: bool let f = !"aaa" @@ -860,10 +865,10 @@ mod tests { // !true: bool } --- - error: expected integer, actual: string - error: expected integer, actual: char - error: expected integer, actual: bool - error: expected bool, actual: integer + error: expected int, actual: string + error: expected int, actual: char + error: expected int, actual: bool + error: expected bool, actual: int error: expected bool, actual: string error: expected bool, actual: char "#]], @@ -908,9 +913,9 @@ mod tests { //- /main.nail fn entry:main() -> int { let a = -10 - // -10: integer + // -10: int expr:a - // a: integer + // a: int } --- "#]], @@ -932,9 +937,9 @@ mod tests { fn entry:main() -> () { { expr:10 - // 10: integer + // 10: int }; - // { .. }: integer + // { .. }: int } --- "#]], @@ -957,7 +962,7 @@ mod tests { { expr:{ 10 - // 10: integer + // 10: int expr:"aaa" // "aaa": string } @@ -984,16 +989,16 @@ mod tests { //- /main.nail fn entry:main() -> int { let a = 10 - // 10: integer + // 10: int let b = { let c = 20 - // 20: integer + // 20: int expr:a + c - // a + c: integer + // a + c: int } - // { .. }: integer + // { .. }: int expr:b - // b: integer + // b: int } --- "#]], @@ -1012,10 +1017,10 @@ mod tests { //- /main.nail fn aaa() -> () { expr:10 - // 10: integer + // 10: int } --- - error: expected integer, actual: () + error: expected int, actual: () "#]], ); @@ -1029,7 +1034,7 @@ mod tests { //- /main.nail fn aaa() -> () { 10; - // 10: integer + // 10: int } --- "#]], @@ -1051,12 +1056,12 @@ mod tests { fn aaa() -> () { { expr:10 - // 10: integer + // 10: int }; - // { .. }: integer + // { .. }: int { 20; - // 20: integer + // 20: int }; // { .. }: () } @@ -1083,14 +1088,14 @@ mod tests { expr:{ expr:{ expr:10 - // 10: integer + // 10: int } - // { .. }: integer + // { .. }: int } - // { .. }: integer + // { .. }: int } --- - error: expected integer, actual: () + error: expected int, actual: () "#]], ); @@ -1110,7 +1115,7 @@ mod tests { expr:{ expr:{ 10; - // 10: integer + // 10: int } // { .. }: () } @@ -1136,9 +1141,9 @@ mod tests { expr:{ { expr:10 - // 10: integer + // 10: int }; - // { .. }: integer + // { .. }: int } // { .. }: () } @@ -1160,9 +1165,9 @@ mod tests { //- /main.nail fn aaa() -> int { let a = 10 - // 10: integer + // 10: int expr:a - // a: integer + // a: int } --- "#]], @@ -1182,7 +1187,7 @@ mod tests { //- /main.nail fn aaa(x: int, y: string) -> () { let a = param:x - // param:x: integer + // param:x: int let b = param:y // param:y: string } @@ -1204,20 +1209,19 @@ mod tests { } "#, expect![[r#" - fn() -> () - fn(bool, string) -> int - --- - `true`: bool - `"aaa"`: string - `aaa(true, "aaa")`: int - `res`: int - `30`: int - `res + 30`: int - `10`: int - `20`: int - `10 + 20`: int + //- /main.nail + fn entry:main() -> () { + fn aaa(x: bool, y: string) -> int { + expr:10 + 20 + // 10 + 20: int + } + let res = fn:aaa(true, "aaa") + // fn:aaa(true, "aaa"): int + expr:res + 30 + // res + 30: int + } --- - error: expected (), found int by `res + 30` + error: expected int, actual: () "#]], ); } @@ -1234,18 +1238,18 @@ mod tests { } "#, expect![[r#" - fn() -> () - fn(bool, string) -> int - --- - `"aaa"`: string - `true`: bool - `aaa("aaa", true)`: int - `10`: int - `20`: int - `10 + 20`: int + //- /main.nail + fn entry:main() -> () { + fn aaa(x: bool, y: string) -> int { + expr:10 + 20 + // 10 + 20: int + } + fn:aaa("aaa", true); + // fn:aaa("aaa", true): int + } --- - error: expected bool, found string by `"aaa"` - error: expected string, found bool by `true` + error: expected string, actual: bool + error: expected bool, actual: string "#]], ); } @@ -1263,14 +1267,17 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `true`: bool - `10`: int - `{ .., 10 }`: int - `20`: int - `{ .., 20 }`: int - `if true { .., 10 } else { .., 20 }`: int + //- /main.nail + fn entry:main() -> () { + if true { + expr:10 + // 10: int + } else { + expr:20 + // 20: int + }; + // if true { .. } else { .. }: int + } --- "#]], ); @@ -1287,15 +1294,16 @@ mod tests { } "#, expect![[r#" - fn() -> int - --- - `true`: bool - `10`: int - `{ .., 10 }`: int - `if true { .., 10 }`: unknown + //- /main.nail + fn entry:main() -> int { + expr:if true { + expr:10 + // 10: int + } + // if true { .. }: int + } --- - error: expected (), found int by `{ .., 10 }` - error: expected int, found unknown by `if true { .., 10 }` + error: expected int, actual: () "#]], ); } @@ -1311,12 +1319,13 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `true`: bool - `{{ .. }}`: () - `{{ .. }}`: () - `if true {{ .. }} else {{ .. }}`: () + //- /main.nail + fn entry:main() -> () { + expr:if true { + } else { + } + // if true { .. } else { .. }: () + } --- "#]], ); @@ -1329,11 +1338,12 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `true`: bool - `{{ .. }}`: () - `if true {{ .. }}`: () + //- /main.nail + fn entry:main() -> () { + expr:if true { + } + // if true { .. }: () + } --- "#]], ); @@ -1352,17 +1362,20 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `true`: bool - `10`: int - `{ .., 10 }`: int - `"aaa"`: string - `{ .., "aaa" }`: string - `if true { .., 10 } else { .., "aaa" }`: unknown + //- /main.nail + fn entry:main() -> () { + expr:if true { + expr:10 + // 10: int + } else { + expr:"aaa" + // "aaa": string + } + // if true { .. } else { .. }: int + } --- - error: expected int, found string by `{ .., 10 }` and `{ .., "aaa" }` - error: expected (), found unknown by `if true { .., 10 } else { .., "aaa" }` + error: expected int, actual: string + error: expected int, actual: () "#]], ); } @@ -1380,17 +1393,20 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `10`: int - `"aaa"`: string - `{ .., "aaa" }`: string - `"aaa"`: string - `{ .., "aaa" }`: string - `if 10 { .., "aaa" } else { .., "aaa" }`: string + //- /main.nail + fn entry:main() -> () { + expr:if 10 { + expr:"aaa" + // "aaa": string + } else { + expr:"aaa" + // "aaa": string + } + // if 10 { .. } else { .. }: string + } --- - error: expected bool, found int by `10` - error: expected (), found string by `if 10 { .., "aaa" } else { .., "aaa" }` + error: expected bool, actual: int + error: expected string, actual: () "#]], ); } @@ -1411,18 +1427,21 @@ mod tests { } "#, expect![[r#" - fn() -> int - --- - `true`: bool - `10`: int - `return 10`: ! - `{{ .. }}`: () - `true`: bool - `{ .., true }`: bool - `if true {{ .. }} else { .., true }`: unknown - `20`: int + //- /main.nail + fn entry:main() -> int { + let value = if true { + return 10; + // return 10: ! + } else { + expr:true + // true: bool + } + // if true { .. } else { .. }: () + expr:20 + // 20: int + } --- - error: expected (), found bool by `{{ .. }}` and `{ .., true }` + error: expected (), actual: bool "#]], ); @@ -1440,18 +1459,21 @@ mod tests { } "#, expect![[r#" - fn() -> int - --- - `true`: bool - `true`: bool - `{ .., true }`: bool - `10`: int - `return 10`: ! - `{{ .. }}`: () - `if true { .., true } else {{ .. }}`: unknown - `20`: int + //- /main.nail + fn entry:main() -> int { + let value = if true { + expr:true + // true: bool + } else { + return 10; + // return 10: ! + } + // if true { .. } else { .. }: bool + expr:20 + // 20: int + } --- - error: expected bool, found () by `{ .., true }` and `{{ .. }}` + error: expected bool, actual: () "#]], ); @@ -1469,17 +1491,19 @@ mod tests { } "#, expect![[r#" - fn() -> int - --- - `true`: bool - `10`: int - `return 10`: ! - `{{ .. }}`: () - `20`: int - `return 20`: ! - `{{ .. }}`: () - `if true {{ .. }} else {{ .. }}`: () - `30`: int + //- /main.nail + fn entry:main() -> int { + let value = if true { + return 10; + // return 10: ! + } else { + return 20; + // return 20: ! + } + // if true { .. } else { .. }: () + expr:30 + // 30: int + } --- "#]], ); @@ -1494,11 +1518,13 @@ mod tests { } "#, expect![[r#" - fn() -> () - --- - `return`: ! + //- /main.nail + fn entry:main() -> () { + expr:return + // return: ! + } --- - error: expected (), found ! by `return` + error: expected !, actual: () "#]], ); @@ -1509,12 +1535,13 @@ mod tests { } "#, expect![[r#" - fn() -> int - --- - `10`: int - `return 10`: ! + //- /main.nail + fn entry:main() -> int { + expr:return 10 + // return 10: ! + } --- - error: expected int, found ! by `return 10` + error: expected !, actual: int "#]], ); } @@ -1528,12 +1555,14 @@ mod tests { } "#, expect![[r#" - fn() -> int - --- - `return`: ! + //- /main.nail + fn entry:main() -> int { + expr:return + // return: ! + } --- - error: expected int, found () - error: expected int, found ! by `return` + error: expected (), actual: int + error: expected !, actual: int "#]], ); @@ -1544,13 +1573,14 @@ mod tests { } "#, expect![[r#" - fn() -> int - --- - `"aaa"`: string - `return "aaa"`: ! + //- /main.nail + fn entry:main() -> int { + expr:return "aaa" + // return "aaa": ! + } --- - error: expected int, found string by `"aaa"` - error: expected int, found ! by `return "aaa"` + error: expected string, actual: int + error: expected !, actual: int "#]], ); } @@ -1581,15 +1611,30 @@ mod tests { } "#, expect![[r#" - fn() -> () - fn() -> bool - fn() -> string - fn() -> int - --- - `return`: ! - `true`: bool - `"aaa"`: string - `30`: int + //- /main.nail + fn entry:main() -> () { + return; + // return: ! + } + mod module_aaa { + mod module_bbb { + fn function_aaa() -> bool { + mod module_ccc { + fn function_bbb() -> string { + expr:"aaa" + // "aaa": string + } + } + expr:true + // true: bool + } + } + + fn function_ccc() -> int { + expr:30 + // 30: int + } + } --- "#]], ); @@ -1619,13 +1664,29 @@ mod tests { } "#, expect![[r#" - fn() -> unknown - fn() -> unknown - fn() -> unknown - --- - `mod_aaa::fn_aaa()`: unknown - `mod_aaa::mod_bbb::fn_bbb()`: unknown + //- /main.nail + mod mod_aaa; + fn entry:main() -> { + fn:mod_aaa::fn_aaa(); + // fn:mod_aaa::fn_aaa(): unknown + expr:fn:mod_aaa::mod_bbb::fn_bbb() + // fn:mod_aaa::mod_bbb::fn_bbb(): unknown + } + //- /mod_aaa.nail + mod mod_bbb; + fn fn_aaa() -> { + expr:fn:mod_bbb::fn_bbb() + // fn:mod_bbb::fn_bbb(): unknown + } + + //- /mod_aaa/mod_bbb.nail + fn fn_bbb() -> { + expr:10 + // 10: int + } + --- + error: expected int, actual: unknown "#]], ); } From fd23129cfb0a4b2ee6c1690e28c36893064da639 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Thu, 14 Sep 2023 12:56:08 +0900 Subject: [PATCH 03/35] wip --- crates/hir_ty/src/lib.rs | 431 ++++++++++++--------------------------- 1 file changed, 135 insertions(+), 296 deletions(-) diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index de99eb6f..d58116e6 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -198,13 +198,10 @@ mod tests { if let Some(tail) = block.tail { let indent = indent(nesting + 1); body.push_str(&format!( - "{indent}expr:{}\n", + "{indent}expr:{}", self.debug_expr(hir_file, function, scope_origin, tail, nesting + 1) )); - body.push_str(&format!( - "{indent}{}\n", - self.debug_type_line(hir_file, function, scope_origin, tail) - )); + body.push_str(&format!(" {}\n", self.debug_type_line(function, tail))); } body.push_str(&format!("{}}}", indent(nesting))); @@ -278,10 +275,10 @@ mod tests { let name = name.text(self.db); let expr_str = self.debug_expr(hir_file, function, scope_origin, *value, nesting); - let mut stmt_str = format!("{indent}let {name} = {expr_str}\n"); + let mut stmt_str = format!("{indent}let {name} = {expr_str};"); - let type_line = self.debug_type_line(hir_file, function, scope_origin, *value); - stmt_str.push_str(&format!("{indent}{type_line}\n")); + let type_line = self.debug_type_line(function, *value); + stmt_str.push_str(&format!(" {type_line}\n")); stmt_str } @@ -292,21 +289,15 @@ mod tests { let indent = indent(nesting); let expr_str = self.debug_expr(hir_file, function, scope_origin, *expr, nesting); - let type_line = self.debug_type_line(hir_file, function, scope_origin, *expr); + let type_line = self.debug_type_line(function, *expr); let maybe_semicolon = if *has_semicolon { ";" } else { "" }; - format!("{indent}{expr_str}{maybe_semicolon}\n{indent}{type_line}\n") + format!("{indent}{expr_str}{maybe_semicolon} {type_line}\n") } hir::Stmt::Item { item } => self.debug_item(hir_file, *item, nesting), } } - fn debug_type_line( - &self, - hir_file: hir::HirFile, - function: hir::Function, - scope_origin: hir::ModuleScopeOrigin, - expr_id: hir::ExprId, - ) -> String { + fn debug_type_line(&self, function: hir::Function, expr_id: hir::ExprId) -> String { let ty = self .inference_result .inference_body_result_by_function @@ -315,33 +306,8 @@ mod tests { .type_by_expr .get(&expr_id) .unwrap(); - let expr_str = match expr_id.lookup(hir_file.db(self.db)) { - hir::Expr::Symbol(_) - | hir::Expr::Binary { .. } - | hir::Expr::Literal(_) - | hir::Expr::Unary { .. } - | hir::Expr::Call { .. } - | hir::Expr::Return { .. } - | hir::Expr::Missing => { - self.debug_expr(hir_file, function, scope_origin, expr_id, 0) - } - hir::Expr::If { - condition, - then_branch: _, - else_branch, - } => { - let condition_str = - self.debug_expr(hir_file, function, scope_origin, *condition, 0); - let mut if_str = format!("if {condition_str} {{ .. }}"); - if else_branch.is_some() { - if_str.push_str(" else { .. }"); - } - if_str - } - hir::Expr::Block(_) => "{ .. }".to_string(), - }; - format!("// {expr_str}: {ty}") + format!("//: {ty}") } fn debug_expr( @@ -412,13 +378,10 @@ mod tests { if let Some(tail) = block.tail { let indent = indent(nesting + 1); msg.push_str(&format!( - "{indent}expr:{}\n", + "{indent}expr:{}", self.debug_expr(hir_file, function, scope_origin, tail, nesting + 1) )); - msg.push_str(&format!( - "{indent}{}\n", - self.debug_type_line(hir_file, function, scope_origin, tail) - )); + msg.push_str(&format!(" {}\n", self.debug_type_line(function, tail))); } msg.push_str(&format!("{}}}", indent(nesting))); @@ -548,23 +511,17 @@ mod tests { //- /main.nail fn fibonacci(x: int) -> int { expr:if param:x == 0 { - expr:0 - // 0: int + expr:0 //: int } else { expr:if param:x == 1 { - expr:1 - // 1: int + expr:1 //: int } else { - expr:fn:fibonacci(param:x - 1) + fn:fibonacci(param:x - 2) - // fn:fibonacci(param:x - 1) + fn:fibonacci(param:x - 2): int - } - // if param:x == 1 { .. } else { .. }: int - } - // if param:x == 0 { .. } else { .. }: int + expr:fn:fibonacci(param:x - 1) + fn:fibonacci(param:x - 2) //: int + } //: int + } //: int } fn entry:main() -> int { - expr:fn:fibonacci(15) - // fn:fibonacci(15): int + expr:fn:fibonacci(15) //: int } --- "#]], @@ -582,8 +539,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - 10; - // 10: int + 10; //: int } --- "#]], @@ -601,8 +557,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - "aaa"; - // "aaa": string + "aaa"; //: string } --- "#]], @@ -620,8 +575,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - 'a'; - // 'a': char + 'a'; //: char } --- "#]], @@ -639,8 +593,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - true; - // true: bool + true; //: bool } --- "#]], @@ -655,8 +608,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - false; - // false: bool + false; //: bool } --- "#]], @@ -674,8 +626,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - let a = true - // true: bool + let a = true; //: bool } --- "#]], @@ -696,14 +647,10 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - let a = true - // true: bool - let b = 10 - // 10: int - let c = "aa" - // "aa": string - let d = 'a' - // 'a': char + let a = true; //: bool + let b = 10; //: int + let c = "aa"; //: string + let d = 'a'; //: char } --- "#]], @@ -742,50 +689,28 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - 10 + 20 - // 10 + 20: int - "aaa" + "bbb" - // "aaa" + "bbb": int - 10 + "aaa" - // 10 + "aaa": int - 'a' + 'a' - // 'a' + 'a': int - 10 + 'a' - // 10 + 'a': int - 10 < 'a' - // 10 < 'a': bool - 10 > 'a' - // 10 > 'a': bool - true + true - // true + true: int - true - true - // true - true: int - true * true - // true * true: int - true / true - // true / true: int - true == true - // true == true: bool - true < false - // true < false: bool - true > false - // true > false: bool - 10 + true - // 10 + true: int - 10 + 10 + "aaa" - // 10 + 10 + "aaa": int - 10 - 20 - // 10 - 20: int - 10 * 20 - // 10 * 20: int - 10 / 20 - // 10 / 20: int - 10 == 20 - // 10 == 20: bool - 10 < 20 - // 10 < 20: bool - 10 > 20; - // 10 > 20: bool + 10 + 20 //: int + "aaa" + "bbb" //: int + 10 + "aaa" //: int + 'a' + 'a' //: int + 10 + 'a' //: int + 10 < 'a' //: bool + 10 > 'a' //: bool + true + true //: int + true - true //: int + true * true //: int + true / true //: int + true == true //: bool + true < false //: bool + true > false //: bool + 10 + true //: int + 10 + 10 + "aaa" //: int + 10 - 20 //: int + 10 * 20 //: int + 10 / 20 //: int + 10 == 20 //: bool + 10 < 20 //: bool + 10 > 20; //: bool } --- error: expected int, actual: string @@ -818,8 +743,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - 10 + "aaa" + 10 + "aaa"; - // 10 + "aaa" + 10 + "aaa": int + 10 + "aaa" + 10 + "aaa"; //: int } --- error: expected int, actual: string @@ -847,22 +771,14 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - let a = -10 - // -10: int - let b = -"aaa" - // -"aaa": int - let c = -'a' - // -'a': int - let d = -true - // -true: int - let e = !10 - // !10: bool - let f = !"aaa" - // !"aaa": bool - let g = !'a' - // !'a': bool - let h = !true - // !true: bool + let a = -10; //: int + let b = -"aaa"; //: int + let c = -'a'; //: int + let d = -true; //: int + let e = !10; //: bool + let f = !"aaa"; //: bool + let g = !'a'; //: bool + let h = !true; //: bool } --- error: expected int, actual: string @@ -888,12 +804,9 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> bool { - let a = !true - // !true: bool - let b = !false - // !false: bool - expr:!a == !b - // !a == !b: bool + let a = !true; //: bool + let b = !false; //: bool + expr:!a == !b //: bool } --- "#]], @@ -912,10 +825,8 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> int { - let a = -10 - // -10: int - expr:a - // a: int + let a = -10; //: int + expr:a //: int } --- "#]], @@ -936,10 +847,8 @@ mod tests { //- /main.nail fn entry:main() -> () { { - expr:10 - // 10: int - }; - // { .. }: int + expr:10 //: int + }; //: int } --- "#]], @@ -961,14 +870,10 @@ mod tests { fn entry:main() -> () { { expr:{ - 10 - // 10: int - expr:"aaa" - // "aaa": string - } - // { .. }: string - }; - // { .. }: string + 10 //: int + expr:"aaa" //: string + } //: string + }; //: string } --- "#]], @@ -988,17 +893,12 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> int { - let a = 10 - // 10: int + let a = 10; //: int let b = { - let c = 20 - // 20: int - expr:a + c - // a + c: int - } - // { .. }: int - expr:b - // b: int + let c = 20; //: int + expr:a + c //: int + }; //: int + expr:b //: int } --- "#]], @@ -1016,8 +916,7 @@ mod tests { expect![[r#" //- /main.nail fn aaa() -> () { - expr:10 - // 10: int + expr:10 //: int } --- error: expected int, actual: () @@ -1033,8 +932,7 @@ mod tests { expect![[r#" //- /main.nail fn aaa() -> () { - 10; - // 10: int + 10; //: int } --- "#]], @@ -1055,15 +953,11 @@ mod tests { //- /main.nail fn aaa() -> () { { - expr:10 - // 10: int - }; - // { .. }: int + expr:10 //: int + }; //: int { - 20; - // 20: int - }; - // { .. }: () + 20; //: int + }; //: () } --- "#]], @@ -1087,12 +981,9 @@ mod tests { fn aaa() -> () { expr:{ expr:{ - expr:10 - // 10: int - } - // { .. }: int - } - // { .. }: int + expr:10 //: int + } //: int + } //: int } --- error: expected int, actual: () @@ -1114,12 +1005,9 @@ mod tests { fn aaa() -> () { expr:{ expr:{ - 10; - // 10: int - } - // { .. }: () - } - // { .. }: () + 10; //: int + } //: () + } //: () } --- "#]], @@ -1140,12 +1028,9 @@ mod tests { fn aaa() -> () { expr:{ { - expr:10 - // 10: int - }; - // { .. }: int - } - // { .. }: () + expr:10 //: int + }; //: int + } //: () } --- "#]], @@ -1164,10 +1049,8 @@ mod tests { expect![[r#" //- /main.nail fn aaa() -> int { - let a = 10 - // 10: int - expr:a - // a: int + let a = 10; //: int + expr:a //: int } --- "#]], @@ -1186,10 +1069,8 @@ mod tests { expect![[r#" //- /main.nail fn aaa(x: int, y: string) -> () { - let a = param:x - // param:x: int - let b = param:y - // param:y: string + let a = param:x; //: int + let b = param:y; //: string } --- "#]], @@ -1212,13 +1093,10 @@ mod tests { //- /main.nail fn entry:main() -> () { fn aaa(x: bool, y: string) -> int { - expr:10 + 20 - // 10 + 20: int + expr:10 + 20 //: int } - let res = fn:aaa(true, "aaa") - // fn:aaa(true, "aaa"): int - expr:res + 30 - // res + 30: int + let res = fn:aaa(true, "aaa"); //: int + expr:res + 30 //: int } --- error: expected int, actual: () @@ -1241,11 +1119,9 @@ mod tests { //- /main.nail fn entry:main() -> () { fn aaa(x: bool, y: string) -> int { - expr:10 + 20 - // 10 + 20: int + expr:10 + 20 //: int } - fn:aaa("aaa", true); - // fn:aaa("aaa", true): int + fn:aaa("aaa", true); //: int } --- error: expected string, actual: bool @@ -1270,13 +1146,10 @@ mod tests { //- /main.nail fn entry:main() -> () { if true { - expr:10 - // 10: int + expr:10 //: int } else { - expr:20 - // 20: int - }; - // if true { .. } else { .. }: int + expr:20 //: int + }; //: int } --- "#]], @@ -1297,10 +1170,8 @@ mod tests { //- /main.nail fn entry:main() -> int { expr:if true { - expr:10 - // 10: int - } - // if true { .. }: int + expr:10 //: int + } //: int } --- error: expected int, actual: () @@ -1323,8 +1194,7 @@ mod tests { fn entry:main() -> () { expr:if true { } else { - } - // if true { .. } else { .. }: () + } //: () } --- "#]], @@ -1341,8 +1211,7 @@ mod tests { //- /main.nail fn entry:main() -> () { expr:if true { - } - // if true { .. }: () + } //: () } --- "#]], @@ -1365,13 +1234,10 @@ mod tests { //- /main.nail fn entry:main() -> () { expr:if true { - expr:10 - // 10: int + expr:10 //: int } else { - expr:"aaa" - // "aaa": string - } - // if true { .. } else { .. }: int + expr:"aaa" //: string + } //: int } --- error: expected int, actual: string @@ -1396,13 +1262,10 @@ mod tests { //- /main.nail fn entry:main() -> () { expr:if 10 { - expr:"aaa" - // "aaa": string + expr:"aaa" //: string } else { - expr:"aaa" - // "aaa": string - } - // if 10 { .. } else { .. }: string + expr:"aaa" //: string + } //: string } --- error: expected bool, actual: int @@ -1430,15 +1293,11 @@ mod tests { //- /main.nail fn entry:main() -> int { let value = if true { - return 10; - // return 10: ! + return 10; //: ! } else { - expr:true - // true: bool - } - // if true { .. } else { .. }: () - expr:20 - // 20: int + expr:true //: bool + }; //: () + expr:20 //: int } --- error: expected (), actual: bool @@ -1462,15 +1321,11 @@ mod tests { //- /main.nail fn entry:main() -> int { let value = if true { - expr:true - // true: bool + expr:true //: bool } else { - return 10; - // return 10: ! - } - // if true { .. } else { .. }: bool - expr:20 - // 20: int + return 10; //: ! + }; //: bool + expr:20 //: int } --- error: expected bool, actual: () @@ -1494,15 +1349,11 @@ mod tests { //- /main.nail fn entry:main() -> int { let value = if true { - return 10; - // return 10: ! + return 10; //: ! } else { - return 20; - // return 20: ! - } - // if true { .. } else { .. }: () - expr:30 - // 30: int + return 20; //: ! + }; //: () + expr:30 //: int } --- "#]], @@ -1520,8 +1371,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - expr:return - // return: ! + expr:return //: ! } --- error: expected !, actual: () @@ -1537,8 +1387,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> int { - expr:return 10 - // return 10: ! + expr:return 10 //: ! } --- error: expected !, actual: int @@ -1557,8 +1406,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> int { - expr:return - // return: ! + expr:return //: ! } --- error: expected (), actual: int @@ -1575,8 +1423,7 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> int { - expr:return "aaa" - // return "aaa": ! + expr:return "aaa" //: ! } --- error: expected string, actual: int @@ -1613,26 +1460,22 @@ mod tests { expect![[r#" //- /main.nail fn entry:main() -> () { - return; - // return: ! + return; //: ! } mod module_aaa { mod module_bbb { fn function_aaa() -> bool { mod module_ccc { fn function_bbb() -> string { - expr:"aaa" - // "aaa": string + expr:"aaa" //: string } } - expr:true - // true: bool + expr:true //: bool } } fn function_ccc() -> int { - expr:30 - // 30: int + expr:30 //: int } } --- @@ -1667,22 +1510,18 @@ mod tests { //- /main.nail mod mod_aaa; fn entry:main() -> { - fn:mod_aaa::fn_aaa(); - // fn:mod_aaa::fn_aaa(): unknown - expr:fn:mod_aaa::mod_bbb::fn_bbb() - // fn:mod_aaa::mod_bbb::fn_bbb(): unknown + fn:mod_aaa::fn_aaa(); //: unknown + expr:fn:mod_aaa::mod_bbb::fn_bbb() //: unknown } //- /mod_aaa.nail mod mod_bbb; fn fn_aaa() -> { - expr:fn:mod_bbb::fn_bbb() - // fn:mod_bbb::fn_bbb(): unknown + expr:fn:mod_bbb::fn_bbb() //: unknown } //- /mod_aaa/mod_bbb.nail fn fn_bbb() -> { - expr:10 - // 10: int + expr:10 //: int } --- From 94b4c92ed6fd3283c538e1cfeb793ecdf22bb011 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Thu, 14 Sep 2023 12:56:36 +0900 Subject: [PATCH 04/35] fix --- crates/hir_ty/src/lib.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index d58116e6..6d53cc31 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -1490,42 +1490,41 @@ mod tests { //- /main.nail mod mod_aaa; - fn main() -> integer { + fn main() -> int { mod_aaa::fn_aaa(); mod_aaa::mod_bbb::fn_bbb() } //- /mod_aaa.nail mod mod_bbb; - fn fn_aaa() -> integer { + fn fn_aaa() -> int { mod_bbb::fn_bbb() } //- /mod_aaa/mod_bbb.nail - fn fn_bbb() -> integer { + fn fn_bbb() -> int { 10 } "#, expect![[r#" //- /main.nail mod mod_aaa; - fn entry:main() -> { - fn:mod_aaa::fn_aaa(); //: unknown - expr:fn:mod_aaa::mod_bbb::fn_bbb() //: unknown + fn entry:main() -> int { + fn:mod_aaa::fn_aaa(); //: int + expr:fn:mod_aaa::mod_bbb::fn_bbb() //: int } //- /mod_aaa.nail mod mod_bbb; - fn fn_aaa() -> { - expr:fn:mod_bbb::fn_bbb() //: unknown + fn fn_aaa() -> int { + expr:fn:mod_bbb::fn_bbb() //: int } //- /mod_aaa/mod_bbb.nail - fn fn_bbb() -> { + fn fn_bbb() -> int { expr:10 //: int } --- - error: expected int, actual: unknown "#]], ); } From 8b22acceeea387adc43759342da262cc015142df Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Thu, 14 Sep 2023 19:06:18 +0900 Subject: [PATCH 05/35] wip --- crates/hir_ty/src/lib.rs | 42 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 6d53cc31..0c6ec132 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -97,6 +97,7 @@ mod tests { let mut msg = "".to_string(); msg.push_str(&self.debug_hir_file(self.pods.root_pod.root_hir_file)); + msg.push('\n'); for (_nail_file, hir_file) in self.pods.root_pod.get_hir_files_order_registration_asc() { @@ -523,6 +524,7 @@ mod tests { fn entry:main() -> int { expr:fn:fibonacci(15) //: int } + --- "#]], ); @@ -541,6 +543,7 @@ mod tests { fn entry:main() -> () { 10; //: int } + --- "#]], ); @@ -559,6 +562,7 @@ mod tests { fn entry:main() -> () { "aaa"; //: string } + --- "#]], ); @@ -577,6 +581,7 @@ mod tests { fn entry:main() -> () { 'a'; //: char } + --- "#]], ); @@ -595,6 +600,7 @@ mod tests { fn entry:main() -> () { true; //: bool } + --- "#]], ); @@ -610,6 +616,7 @@ mod tests { fn entry:main() -> () { false; //: bool } + --- "#]], ); @@ -628,6 +635,7 @@ mod tests { fn entry:main() -> () { let a = true; //: bool } + --- "#]], ) @@ -652,6 +660,7 @@ mod tests { let c = "aa"; //: string let d = 'a'; //: char } + --- "#]], ) @@ -712,6 +721,7 @@ mod tests { 10 < 20 //: bool 10 > 20; //: bool } + --- error: expected int, actual: string error: expected int, actual: string @@ -745,6 +755,7 @@ mod tests { fn entry:main() -> () { 10 + "aaa" + 10 + "aaa"; //: int } + --- error: expected int, actual: string error: expected int, actual: string @@ -780,6 +791,7 @@ mod tests { let g = !'a'; //: bool let h = !true; //: bool } + --- error: expected int, actual: string error: expected int, actual: char @@ -808,6 +820,7 @@ mod tests { let b = !false; //: bool expr:!a == !b //: bool } + --- "#]], ); @@ -828,6 +841,7 @@ mod tests { let a = -10; //: int expr:a //: int } + --- "#]], ) @@ -850,6 +864,7 @@ mod tests { expr:10 //: int }; //: int } + --- "#]], ); @@ -875,6 +890,7 @@ mod tests { } //: string }; //: string } + --- "#]], ); @@ -900,6 +916,7 @@ mod tests { }; //: int expr:b //: int } + --- "#]], ); @@ -918,6 +935,7 @@ mod tests { fn aaa() -> () { expr:10 //: int } + --- error: expected int, actual: () "#]], @@ -934,6 +952,7 @@ mod tests { fn aaa() -> () { 10; //: int } + --- "#]], ); @@ -959,6 +978,7 @@ mod tests { 20; //: int }; //: () } + --- "#]], ); @@ -985,6 +1005,7 @@ mod tests { } //: int } //: int } + --- error: expected int, actual: () "#]], @@ -1009,6 +1030,7 @@ mod tests { } //: () } //: () } + --- "#]], ); @@ -1032,6 +1054,7 @@ mod tests { }; //: int } //: () } + --- "#]], ); @@ -1052,6 +1075,7 @@ mod tests { let a = 10; //: int expr:a //: int } + --- "#]], ); @@ -1072,6 +1096,7 @@ mod tests { let a = param:x; //: int let b = param:y; //: string } + --- "#]], ); @@ -1098,6 +1123,7 @@ mod tests { let res = fn:aaa(true, "aaa"); //: int expr:res + 30 //: int } + --- error: expected int, actual: () "#]], @@ -1123,6 +1149,7 @@ mod tests { } fn:aaa("aaa", true); //: int } + --- error: expected string, actual: bool error: expected bool, actual: string @@ -1151,6 +1178,7 @@ mod tests { expr:20 //: int }; //: int } + --- "#]], ); @@ -1173,6 +1201,7 @@ mod tests { expr:10 //: int } //: int } + --- error: expected int, actual: () "#]], @@ -1196,6 +1225,7 @@ mod tests { } else { } //: () } + --- "#]], ); @@ -1213,6 +1243,7 @@ mod tests { expr:if true { } //: () } + --- "#]], ); @@ -1239,6 +1270,7 @@ mod tests { expr:"aaa" //: string } //: int } + --- error: expected int, actual: string error: expected int, actual: () @@ -1267,6 +1299,7 @@ mod tests { expr:"aaa" //: string } //: string } + --- error: expected bool, actual: int error: expected string, actual: () @@ -1299,6 +1332,7 @@ mod tests { }; //: () expr:20 //: int } + --- error: expected (), actual: bool "#]], @@ -1327,6 +1361,7 @@ mod tests { }; //: bool expr:20 //: int } + --- error: expected bool, actual: () "#]], @@ -1355,6 +1390,7 @@ mod tests { }; //: () expr:30 //: int } + --- "#]], ); @@ -1373,6 +1409,7 @@ mod tests { fn entry:main() -> () { expr:return //: ! } + --- error: expected !, actual: () "#]], @@ -1389,6 +1426,7 @@ mod tests { fn entry:main() -> int { expr:return 10 //: ! } + --- error: expected !, actual: int "#]], @@ -1408,6 +1446,7 @@ mod tests { fn entry:main() -> int { expr:return //: ! } + --- error: expected (), actual: int error: expected !, actual: int @@ -1425,6 +1464,7 @@ mod tests { fn entry:main() -> int { expr:return "aaa" //: ! } + --- error: expected string, actual: int error: expected !, actual: int @@ -1478,6 +1518,7 @@ mod tests { expr:30 //: int } } + --- "#]], ); @@ -1513,6 +1554,7 @@ mod tests { fn:mod_aaa::fn_aaa(); //: int expr:fn:mod_aaa::mod_bbb::fn_bbb() //: int } + //- /mod_aaa.nail mod mod_bbb; fn fn_aaa() -> int { From 6ca03f3ca09651ffb9b761279d1c60d992289699 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Thu, 14 Sep 2023 19:07:18 +0900 Subject: [PATCH 06/35] wip --- crates/hir_ty/src/lib.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 0c6ec132..c688bdf0 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -1,4 +1,4 @@ -//! HIRに型付けを行います。 +//! HIRに型付けを行います。型推論はHindley-Milner型推論ベースで行います。 //! Typed HIRと呼びます。 //! //! 以下のステップで行います。 @@ -9,7 +9,6 @@ //! AST -----> HIR -------------------------------> MIR -----> LLVM IR //! \-----> TypedHIR(このcrate) ---/ //! -//! 現時点の型推論は簡易なもので、Hindley-Milner型推論ベースに変更する予定です。 #![feature(trait_upcasting)] // #[salsa::tracked]で生成される関数にドキュメントコメントが作成されないため警告が出てしまうため許可します。 From eddb247b9848c22ec95b298c106d709bd2f560ac Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Thu, 14 Sep 2023 19:16:41 +0900 Subject: [PATCH 07/35] wip --- crates/hir_ty/src/inference/type_unifier.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index cbc114d2..d352f8a3 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -42,6 +42,8 @@ impl TypeUnifier { return_type: b_return_type, }, ) => { + self.unify(a_return_type, b_return_type); + if a_args.len() != b_args.len() { self.errors.push(InferenceError::TypeMismatch { expected: a_rep, @@ -53,8 +55,6 @@ impl TypeUnifier { for (a_arg, b_arg) in a_args.iter().zip(b_args.iter()) { self.unify(a_arg, b_arg); } - - self.unify(a_return_type, b_return_type); } (Monotype::Variable(_), b_rep) => self.unify_var(&a_rep, b_rep), (a_rep, Monotype::Variable(_)) => self.unify_var(&b_rep, a_rep), From 36d6e5bbc06fcce5e2b53c745cfd4280664f5160 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Fri, 15 Sep 2023 21:48:19 +0900 Subject: [PATCH 08/35] wip --- crates/hir_ty/src/inference.rs | 4 +- crates/hir_ty/src/inference/environment.rs | 152 +++++++-- crates/hir_ty/src/inference/error.rs | 106 ++++++ crates/hir_ty/src/inference/type_unifier.rs | 196 +++++++++-- crates/hir_ty/src/inference/types.rs | 45 ++- crates/hir_ty/src/lib.rs | 353 ++++++++++++++++---- 6 files changed, 705 insertions(+), 151 deletions(-) create mode 100644 crates/hir_ty/src/inference/error.rs diff --git a/crates/hir_ty/src/inference.rs b/crates/hir_ty/src/inference.rs index bbb7b485..772a05a8 100644 --- a/crates/hir_ty/src/inference.rs +++ b/crates/hir_ty/src/inference.rs @@ -1,4 +1,5 @@ mod environment; +mod error; mod type_scheme; mod type_unifier; mod types; @@ -6,7 +7,8 @@ mod types; use std::collections::HashMap; use environment::{Environment, InferBody}; -pub use environment::{InferenceBodyResult, InferenceError, InferenceResult, Signature}; +pub use environment::{InferenceBodyResult, InferenceResult, Signature}; +pub use error::InferenceError; pub use type_scheme::TypeScheme; pub use types::Monotype; diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 491b0807..0f2fb12b 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -3,7 +3,12 @@ use std::{ ops::Sub, }; -use super::{type_scheme::TypeScheme, type_unifier::TypeUnifier, types::Monotype}; +use super::{ + error::InferenceError, + type_scheme::TypeScheme, + type_unifier::{TypeUnifier, UnifyPurpose}, + types::Monotype, +}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Signature { @@ -55,13 +60,27 @@ impl<'a> InferBody<'a> { self.infer_stmt(stmt); } - let ty = if let Some(tail) = &body.tail { - self.infer_expr(*tail) + if let Some(tail) = &body.tail { + let ty = self.infer_expr(*tail); + self.unifier.unify( + &ty, + &self.signature.return_type, + &UnifyPurpose::SelfReturnType { + expected_signature: self.signature.clone(), + found_expr: Some(*tail), + }, + ); } else { - Monotype::Unit + let ty = Monotype::Unit; + self.unifier.unify( + &ty, + &self.signature.return_type, + &UnifyPurpose::SelfReturnType { + expected_signature: self.signature.clone(), + found_expr: None, + }, + ); }; - dbg!(&self.signature.return_type); - self.unifier.unify(&ty, &self.signature.return_type); InferenceBodyResult { type_by_expr: self.type_by_expr, @@ -123,10 +142,7 @@ impl<'a> InferBody<'a> { hir::Item::Function(function) => { let signature = self.signature_by_function.get(&function); if let Some(signature) = signature { - Monotype::Function { - args: signature.params.clone(), - return_type: Box::new(signature.return_type.clone()), - } + Monotype::Function(signature.clone().into()) } else { unreachable!("Function signature should be resolved.") } @@ -174,21 +190,30 @@ impl<'a> InferBody<'a> { // TODO: 関数ではないものを呼び出そうとしているエラーを追加 Monotype::Unknown } - Monotype::Function { args, return_type } => { + Monotype::Function(signature) => { let call_args_ty = call_args .iter() .map(|arg| self.infer_expr(*arg)) .collect::>(); - if call_args_ty.len() != args.len() { + if call_args_ty.len() != signature.params.len() { // TODO: 引数の数が異なるエラーを追加 Monotype::Unknown } else { - for (call_arg, arg) in call_args_ty.iter().zip(args.iter()) { - self.unifier.unify(call_arg, arg); + for ((call_arg, call_arg_ty), arg) in + call_args.iter().zip(call_args_ty).zip(&signature.params) + { + self.unifier.unify( + &call_arg_ty, + &arg, + &UnifyPurpose::CallArg { + found_arg: *call_arg, + expected_signature: signature.as_ref().clone(), + }, + ); } - *return_type + signature.return_type.clone() } } } @@ -200,8 +225,22 @@ impl<'a> InferBody<'a> { | ast::BinaryOp::Div(_) => { let lhs_ty = self.infer_expr(*lhs); let rhs_ty = self.infer_expr(*rhs); - self.unifier.unify(&Monotype::Integer, &lhs_ty); - self.unifier.unify(&Monotype::Integer, &rhs_ty); + self.unifier.unify( + &Monotype::Integer, + &lhs_ty, + &UnifyPurpose::BinaryInteger { + found_expr: *lhs, + op: op.clone(), + }, + ); + self.unifier.unify( + &Monotype::Integer, + &rhs_ty, + &UnifyPurpose::BinaryInteger { + found_expr: *rhs, + op: op.clone(), + }, + ); Monotype::Integer } @@ -210,7 +249,15 @@ impl<'a> InferBody<'a> { | ast::BinaryOp::LessThan(_) => { let lhs_ty = self.infer_expr(*lhs); let rhs_ty = self.infer_expr(*rhs); - self.unifier.unify(&lhs_ty, &rhs_ty); + self.unifier.unify( + &lhs_ty, + &rhs_ty, + &UnifyPurpose::BinaryCompare { + expected_expr: *lhs, + found_expr: *rhs, + op: op.clone(), + }, + ); Monotype::Bool } @@ -218,13 +265,27 @@ impl<'a> InferBody<'a> { hir::Expr::Unary { op, expr } => match op { ast::UnaryOp::Neg(_) => { let expr_ty = self.infer_expr(*expr); - self.unifier.unify(&Monotype::Integer, &expr_ty); + self.unifier.unify( + &Monotype::Integer, + &expr_ty, + &UnifyPurpose::Unary { + found_expr: *expr, + op: op.clone(), + }, + ); Monotype::Integer } ast::UnaryOp::Not(_) => { let expr_ty = self.infer_expr(*expr); - self.unifier.unify(&Monotype::Bool, &expr_ty); + self.unifier.unify( + &Monotype::Bool, + &expr_ty, + &UnifyPurpose::Unary { + found_expr: *expr, + op: op.clone(), + }, + ); Monotype::Bool } @@ -253,15 +314,34 @@ impl<'a> InferBody<'a> { else_branch, } => { let condition_ty = self.infer_expr(*condition); - self.unifier.unify(&Monotype::Bool, &condition_ty); + self.unifier.unify( + &Monotype::Bool, + &condition_ty, + &UnifyPurpose::IfCondition { + found_expr: *condition, + }, + ); let then_ty = self.infer_expr(*then_branch); if let Some(else_branch) = else_branch { let else_ty = self.infer_expr(*else_branch); - self.unifier.unify(&then_ty, &else_ty); + self.unifier.unify( + &then_ty, + &else_ty, + &UnifyPurpose::IfThenElseBranch { + expected_expr: *then_branch, + found_expr: *else_branch, + }, + ); } else { // elseブランチがない場合は Unit として扱う - self.unifier.unify(&then_ty, &Monotype::Unit); + self.unifier.unify( + &then_ty, + &Monotype::Unit, + &UnifyPurpose::IfThenOnlyBranch { + found_expr: *then_branch, + }, + ); } then_ty @@ -269,11 +349,24 @@ impl<'a> InferBody<'a> { hir::Expr::Return { value } => { if let Some(return_value) = value { let ty = self.infer_expr(*return_value); - self.unifier.unify(&ty, &self.signature.return_type); + self.unifier.unify( + &ty, + &self.signature.return_type, + &UnifyPurpose::ReturnValue { + expected_signature: self.signature.clone(), + found_expr: Some(*return_value), + }, + ); } else { // 何も指定しない場合は Unit を返すものとして扱う - self.unifier - .unify(&Monotype::Unit, &self.signature.return_type); + self.unifier.unify( + &Monotype::Unit, + &self.signature.return_type, + &UnifyPurpose::ReturnValue { + expected_signature: self.signature.clone(), + found_expr: None, + }, + ); } // return自体の戻り値は Never として扱う @@ -315,13 +408,6 @@ pub struct InferenceBodyResult { pub type_by_expr: HashMap, pub errors: Vec, } -#[derive(Debug)] -pub enum InferenceError { - TypeMismatch { - expected: Monotype, - actual: Monotype, - }, -} #[derive(Default)] pub struct Environment { diff --git a/crates/hir_ty/src/inference/error.rs b/crates/hir_ty/src/inference/error.rs new file mode 100644 index 00000000..7f0f157e --- /dev/null +++ b/crates/hir_ty/src/inference/error.rs @@ -0,0 +1,106 @@ +use super::{Monotype, Signature}; + +/// 型チェックのエラー +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum InferenceError { + /// 型を解決できない + Unresolved { + /// 対象の式 + expr: hir::ExprId, + }, + /// 一致するべき型が一致しない + MismatchedTypes { + /// 期待される型 + expected_ty: Monotype, + /// 実際の型 + found_ty: Monotype, + + /// 期待される式 + expected_expr: hir::ExprId, + /// 実際の式 + found_expr: hir::ExprId, + }, + /// Ifの条件式の型が一致しない + MismatchedTypeIfCondition { + /// 期待される型 + expected_ty: Monotype, + /// 実際の式 + found_expr: hir::ExprId, + /// 実際の型 + found_ty: Monotype, + }, + /// Ifのthenブランチとelseブランチの型が一致しない + MismatchedTypeElseBranch { + /// ifブランチの型 + then_branch_ty: Monotype, + /// ifブランチの式 + then_branch: hir::ExprId, + /// elseブランチの型 + else_branch_ty: Monotype, + /// elseブランチの式 + else_branch: hir::ExprId, + }, + /// Ifのthenブランチのみの型が一致しない + MismatchedTypeOnlyIfBranch { + /// thenブランチの型 + then_branch_ty: Monotype, + /// thenブランチの式 + then_branch: hir::ExprId, + }, + /// 関数呼び出しの引数の数が一致しない + MismaatchedSignature { + /// 期待される型 + expected_ty: Monotype, + /// 呼び出そうとしている関数のシグネチャ + signature: Signature, + /// 実際の式 + found_expr: hir::ExprId, + /// 実際の型 + found_ty: Monotype, + }, + MismatchedBinaryInteger { + /// 期待される型 + expected_ty: Monotype, + /// 実際の式 + found_expr: hir::ExprId, + /// 実際の型 + found_ty: Monotype, + /// 演算子 + op: ast::BinaryOp, + }, + MismatchedBinaryCompare { + /// 期待される型 + expected_ty: Monotype, + /// 期待される型を持つ式 + expected_expr: hir::ExprId, + /// 実際の式 + found_expr: hir::ExprId, + /// 実際の型 + found_ty: Monotype, + /// 演算子 + op: ast::BinaryOp, + }, + MismatchedUnary { + /// 期待される型 + expected_ty: Monotype, + /// 実際の式 + found_expr: hir::ExprId, + /// 実際の型 + found_ty: Monotype, + /// 演算子 + op: ast::UnaryOp, + }, + /// 関数の戻り値の型と実際の戻り値の型が異なる + /// + /// 以下のいずれかが関数の戻り値の型と一致しない場合に発生する + /// - `return`に指定した式の型 + /// - 関数ボディの最後の式の型 + MismatchedTypeReturnValue { + /// 期待される型 + expected_signature: Signature, + /// 実際の型 + found_ty: Monotype, + /// 実際の式 + found_expr: Option, + }, +} diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index d352f8a3..396f589c 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -1,11 +1,158 @@ use std::collections::HashMap; -use super::{environment::InferenceError, types::Monotype}; +use super::{error::InferenceError, types::Monotype, Signature}; #[derive(Default, Debug)] -pub struct TypeUnifier { +pub(crate) struct TypeUnifier { pub(crate) nodes: HashMap, - pub(crate) errors: Vec, + pub(crate) errors: Vec, +} + +/// 型推論の目的 +pub(crate) enum UnifyPurpose { + Expr { + /// 期待する型を持つ式 + expected_expr: hir::ExprId, + /// 実際に得られた型を持つ式 + found_expr: hir::ExprId, + }, + CallArg { + /// 関数呼び出し対象のシグネチャ + expected_signature: Signature, + /// 実際に得られた型を持つ式 + found_arg: hir::ExprId, + }, + SelfReturnType { + expected_signature: Signature, + found_expr: Option, + }, + BinaryInteger { + /// 実際に得られた型を持つ式 + found_expr: hir::ExprId, + /// 演算子 + op: ast::BinaryOp, + }, + BinaryCompare { + /// 期待する型を持つ式 + expected_expr: hir::ExprId, + /// 実際に得られた型を持つ式 + found_expr: hir::ExprId, + /// 演算子 + op: ast::BinaryOp, + }, + Unary { + /// 実際に得られた型を持つ式 + found_expr: hir::ExprId, + /// 演算子 + op: ast::UnaryOp, + }, + IfCondition { + /// 実際に得られた型を持つ式 + found_expr: hir::ExprId, + }, + IfThenElseBranch { + /// 期待する型を持つ式 + expected_expr: hir::ExprId, + /// 実際に得られた型を持つ式 + found_expr: hir::ExprId, + }, + IfThenOnlyBranch { + /// 実際に得られた型を持つ式 + found_expr: hir::ExprId, + }, + ReturnValue { + /// 期待する戻り値の型を持つ関数シグネチャ + expected_signature: Signature, + /// 実際に得られた型を持つ式 + found_expr: Option, + }, +} + +/// 型の不一致を表すエラーを生成します。 +fn build_unify_error_from_unify_purpose( + expected_ty: Monotype, + found_ty: Monotype, + purpose: &UnifyPurpose, +) -> InferenceError { + match purpose { + UnifyPurpose::Expr { + expected_expr, + found_expr, + } => InferenceError::MismatchedTypes { + expected_ty, + found_ty, + expected_expr: *expected_expr, + found_expr: *found_expr, + }, + UnifyPurpose::CallArg { + found_arg, + expected_signature, + } => InferenceError::MismaatchedSignature { + expected_ty, + found_ty, + signature: expected_signature.clone(), + found_expr: *found_arg, + }, + UnifyPurpose::SelfReturnType { + expected_signature, + found_expr, + } => InferenceError::MismatchedTypeReturnValue { + expected_signature: expected_signature.clone(), + found_ty, + found_expr: *found_expr, + }, + UnifyPurpose::BinaryInteger { found_expr, op } => InferenceError::MismatchedBinaryInteger { + expected_ty, + found_ty, + found_expr: *found_expr, + op: op.clone(), + }, + UnifyPurpose::BinaryCompare { + expected_expr, + found_expr, + op, + } => InferenceError::MismatchedBinaryCompare { + expected_ty, + found_ty, + expected_expr: *expected_expr, + found_expr: *found_expr, + op: op.clone(), + }, + UnifyPurpose::Unary { found_expr, op } => InferenceError::MismatchedUnary { + expected_ty, + found_ty, + found_expr: *found_expr, + op: op.clone(), + }, + UnifyPurpose::IfCondition { found_expr } => InferenceError::MismatchedTypeIfCondition { + expected_ty, + found_ty, + found_expr: *found_expr, + }, + UnifyPurpose::IfThenElseBranch { + expected_expr, + found_expr, + } => InferenceError::MismatchedTypeElseBranch { + then_branch_ty: expected_ty, + then_branch: *expected_expr, + else_branch_ty: found_ty, + else_branch: *found_expr, + }, + UnifyPurpose::IfThenOnlyBranch { found_expr } => { + InferenceError::MismatchedTypeOnlyIfBranch { + then_branch_ty: found_ty, + then_branch: *found_expr, + } + } + UnifyPurpose::ReturnValue { + expected_signature, + found_expr, + } => InferenceError::MismatchedTypeReturnValue { + found_ty, + found_expr: *found_expr, + expected_signature: expected_signature.clone(), + }, + } } impl TypeUnifier { @@ -23,7 +170,7 @@ impl TypeUnifier { } } - pub fn unify(&mut self, a: &Monotype, b: &Monotype) { + pub fn unify(&mut self, a: &Monotype, b: &Monotype, purpose: &UnifyPurpose) { let a_rep = self.find(a); let b_rep = self.find(b); @@ -32,37 +179,24 @@ impl TypeUnifier { } match (&a_rep, &b_rep) { - ( - Monotype::Function { - args: a_args, - return_type: a_return_type, - }, - Monotype::Function { - args: b_args, - return_type: b_return_type, - }, - ) => { - self.unify(a_return_type, b_return_type); - - if a_args.len() != b_args.len() { - self.errors.push(InferenceError::TypeMismatch { - expected: a_rep, - actual: b_rep, - }); - return; - } - - for (a_arg, b_arg) in a_args.iter().zip(b_args.iter()) { - self.unify(a_arg, b_arg); - } + (Monotype::Function(_a_signature), Monotype::Function(_b_signature)) => { + unreachable!(); + // self.unify(&a_signature.return_type, &b_signature.return_type, purpose); + + // if a_signature.params.len() != b_signature.params.len() { + // // エラーを追加する + // return; + // } + + // for (a_arg, b_arg) in a_signature.params.iter().zip(b_signature.params.iter()) { + // self.unify(a_arg, b_arg, purpose); + // } } (Monotype::Variable(_), b_rep) => self.unify_var(&a_rep, b_rep), (a_rep, Monotype::Variable(_)) => self.unify_var(&b_rep, a_rep), (_, _) => { - self.errors.push(InferenceError::TypeMismatch { - expected: a_rep, - actual: b_rep, - }); + self.errors + .push(build_unify_error_from_unify_purpose(a_rep, b_rep, purpose)); } } } diff --git a/crates/hir_ty/src/inference/types.rs b/crates/hir_ty/src/inference/types.rs index a7bc6cba..636a38b2 100644 --- a/crates/hir_ty/src/inference/types.rs +++ b/crates/hir_ty/src/inference/types.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, fmt}; -use super::{environment::Context, type_scheme::TypeSubstitution}; +use super::{environment::Context, type_scheme::TypeSubstitution, Signature}; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Monotype { @@ -10,10 +10,7 @@ pub enum Monotype { Char, String, Variable(u32), - Function { - args: Vec, - return_type: Box, - }, + Function(Box), Never, Unknown, } @@ -29,19 +26,18 @@ impl fmt::Display for Monotype { Monotype::Never => write!(f, "!"), Monotype::Unknown => write!(f, "unknown"), Monotype::Variable(id) => write!(f, "{}", id), - Monotype::Function { - args: from, - return_type: to, - } => { + Monotype::Function(signature) => { write!( f, "({}) -> {}", - from.iter() + signature + .params + .iter() .map(|ty| ty.to_string()) .collect::>() .join(", ") .to_string(), - to.to_string() + signature.return_type.to_string() ) } } @@ -63,15 +59,12 @@ impl Monotype { set } - Monotype::Function { - args: from, - return_type: to, - } => { + Monotype::Function(signature) => { let mut set = HashSet::new(); - for arg in from.iter() { + for arg in signature.params.iter() { set.extend(arg.free_variables()); } - set.extend(to.free_variables()); + set.extend(signature.return_type.free_variables()); set } _ => Default::default(), @@ -94,13 +87,17 @@ impl Monotype { self.clone() } } - Monotype::Function { - args: from, - return_type: to, - } => Monotype::Function { - args: from.iter().map(|arg| arg.apply(subst)).collect::>(), - return_type: Box::new(to.apply(subst)), - }, + Monotype::Function(signagure) => Monotype::Function( + Signature { + params: signagure + .params + .iter() + .map(|arg| arg.apply(subst)) + .collect::>(), + return_type: signagure.return_type.apply(subst), + } + .into(), + ), } } } diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index c688bdf0..f0275e4f 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -51,7 +51,10 @@ impl TyLowerResult { mod tests { use expect_test::{expect, Expect}; - use crate::{inference::infer_pods, InferenceResult}; + use crate::{ + inference::{infer_pods, InferenceError, Signature}, + InferenceResult, + }; fn check_pod_start_with_root_file(fixture: &str, expect: Expect) { let db = hir::TestingDatabase::default(); @@ -105,7 +108,7 @@ mod tests { } msg.push_str("---\n"); - for (_hir_file, function) in self.pods.root_pod.all_functions(self.db) { + for (hir_file, function) in self.pods.root_pod.all_functions(self.db) { let inference_body_result = self .inference_result .inference_body_result_by_function @@ -114,18 +117,163 @@ mod tests { for error in &inference_body_result.errors { match error { - crate::inference::InferenceError::TypeMismatch { expected, actual } => { - msg.push_str(&format!( - "error: expected {expected}, actual: {actual}\n" + InferenceError::Unresolved { expr } => todo!(), + InferenceError::MismatchedTypes { + expected_ty, + found_ty, + expected_expr, + found_expr, + } => { + msg.push_str( + &format!( + "error MismatchedTypes: expected_ty: {}, found_ty: {}, expected_expr: {}, found_expr: {}", + expected_ty, + found_ty, + self.debug_simplify_expr(hir_file, *expected_expr), + self.debug_simplify_expr(hir_file, *found_expr), )); } + InferenceError::MismatchedTypeIfCondition { + expected_ty, + found_expr, + found_ty, + } => { + msg.push_str( + &format!( + "error MismatchedTypeIfCondition: expected_ty: {}, found_expr: {}, found_ty: {}", + expected_ty, + self.debug_simplify_expr(hir_file, *found_expr), + found_ty, + )); + } + InferenceError::MismatchedTypeElseBranch { + then_branch_ty, + then_branch, + else_branch_ty, + else_branch, + } => { + msg.push_str( + &format!( + "error MismatchedTypeElseBranch: then_branch_ty: {}, then_branch: {}, else_branch_ty: {}, else_branch: {}", + then_branch_ty, + self.debug_simplify_expr(hir_file, *then_branch), + else_branch_ty, + self.debug_simplify_expr(hir_file, *else_branch), + )); + } + InferenceError::MismatchedTypeOnlyIfBranch { + then_branch_ty, + then_branch, + } => { + msg.push_str( + &format!( + "error MismatchedTypeOnlyIfBranch: then_branch_ty: {}, then_branch: {}", + then_branch_ty, + self.debug_simplify_expr(hir_file, *then_branch), + )); + } + InferenceError::MismaatchedSignature { + expected_ty, + signature, + found_expr, + found_ty, + } => { + msg.push_str( + &format!( + "error MismaatchedSignature: expected_ty: {}, signature: {}, found_expr: {}, found_ty: {}", + expected_ty, + self.debug_signature(signature), + self.debug_simplify_expr(hir_file, *found_expr), + found_ty, + )); + } + InferenceError::MismatchedBinaryInteger { + expected_ty, + found_expr, + found_ty, + op, + } => { + msg.push_str( + &format!( + "error MismatchedBinaryInteger: expected_ty: {}, found_expr: {}, found_ty: {}, op: {}", + expected_ty, + self.debug_simplify_expr(hir_file, *found_expr), + found_ty, + self.debug_binary_op(op), + )); + } + InferenceError::MismatchedBinaryCompare { + expected_ty, + expected_expr, + found_expr, + found_ty, + op, + } => { + msg.push_str( + &format!( + "error MismatchedBinaryCompare: expected_ty: {}, expected_expr: {}, found_expr: {}, found_ty: {}, op: {}", + expected_ty, + self.debug_simplify_expr(hir_file, *expected_expr), + self.debug_simplify_expr(hir_file, *found_expr), + found_ty, + self.debug_binary_op(op), + )); + } + InferenceError::MismatchedUnary { + expected_ty, + found_expr, + found_ty, + op, + } => { + msg.push_str( + &format!( + "error MismatchedUnary: expected_ty: {}, found_expr: {}, found_ty: {}, op: {}", + expected_ty, + self.debug_simplify_expr(hir_file, *found_expr), + found_ty, + self.debug_unary_op(op), + )); + } + InferenceError::MismatchedTypeReturnValue { + expected_signature, + found_ty, + found_expr, + } => { + if let Some(found_expr) = found_expr { + msg.push_str( + &format!( + "error MismatchedReturnType: expected_ty: {}, found_expr: {}, found_ty: {}", + expected_signature.return_type, + self.debug_simplify_expr(hir_file, *found_expr), + found_ty, + )); + } else { + msg.push_str(&format!( + "error MismatchedReturnType: expected_ty: {}, found_ty: {}", + expected_signature.return_type, found_ty, + )); + } + } } + msg.push('\n'); } } msg } + fn debug_signature(&self, signature: &Signature) -> String { + let params = signature + .params + .iter() + .map(|param| param.to_string()) + .collect::>() + .join(", "); + let return_type = signature.return_type.to_string(); + + format!("({params}) -> {return_type}") + } + fn debug_hir_file(&self, hir_file: hir::HirFile) -> String { let mut msg = format!( "//- {}\n", @@ -319,9 +467,7 @@ mod tests { nesting: usize, ) -> String { match expr_id.lookup(hir_file.db(self.db)) { - hir::Expr::Symbol(symbol) => { - self.debug_symbol(hir_file, function, scope_origin, symbol, nesting) - } + hir::Expr::Symbol(symbol) => self.debug_symbol(symbol), hir::Expr::Literal(literal) => match literal { hir::Literal::Bool(b) => b.to_string(), hir::Literal::Char(c) => format!("'{c}'"), @@ -352,8 +498,7 @@ mod tests { format!("{op}{expr_str}") } hir::Expr::Call { callee, args } => { - let callee = - self.debug_symbol(hir_file, function, scope_origin, callee, nesting); + let callee = self.debug_symbol(callee); let args = args .iter() .map(|arg| self.debug_expr(hir_file, function, scope_origin, *arg, nesting)) @@ -437,14 +582,77 @@ mod tests { } } - fn debug_symbol( - &self, - _hir_file: hir::HirFile, - _function: hir::Function, - _scope_origin: hir::ModuleScopeOrigin, - symbol: &hir::Symbol, - _nesting: usize, - ) -> String { + fn debug_simplify_expr(&self, hir_file: hir::HirFile, expr_id: hir::ExprId) -> String { + match expr_id.lookup(hir_file.db(self.db)) { + hir::Expr::Symbol(symbol) => self.debug_symbol(symbol), + hir::Expr::Literal(literal) => match literal { + hir::Literal::Bool(b) => b.to_string(), + hir::Literal::Char(c) => format!("'{c}'"), + hir::Literal::String(s) => format!("\"{s}\""), + hir::Literal::Integer(i) => i.to_string(), + }, + hir::Expr::Binary { op, lhs, rhs } => { + let op = match op { + ast::BinaryOp::Add(_) => "+", + ast::BinaryOp::Sub(_) => "-", + ast::BinaryOp::Mul(_) => "*", + ast::BinaryOp::Div(_) => "/", + ast::BinaryOp::Equal(_) => "==", + ast::BinaryOp::GreaterThan(_) => ">", + ast::BinaryOp::LessThan(_) => "<", + }; + let lhs_str = self.debug_simplify_expr(hir_file, *lhs); + let rhs_str = self.debug_simplify_expr(hir_file, *rhs); + format!("{lhs_str} {op} {rhs_str}") + } + hir::Expr::Unary { op, expr } => { + let op = match op { + ast::UnaryOp::Neg(_) => "-", + ast::UnaryOp::Not(_) => "!", + }; + let expr_str = self.debug_simplify_expr(hir_file, *expr); + format!("{op}{expr_str}") + } + hir::Expr::Call { callee, args } => { + let callee = self.debug_symbol(callee); + let args = args + .iter() + .map(|arg| self.debug_simplify_expr(hir_file, *arg)) + .collect::>() + .join(", "); + + format!("{callee}({args})") + } + hir::Expr::Block(block) => "{{ .. }}".to_string(), + hir::Expr::If { + condition, + then_branch, + else_branch, + } => { + let cond_str = self.debug_simplify_expr(hir_file, *condition); + let mut msg = format!("if {cond_str} {{ .. }}"); + if else_branch.is_some() { + msg.push_str(" else { .. }"); + } + + msg + } + hir::Expr::Return { value } => { + let mut msg = "return".to_string(); + if let Some(value) = value { + msg.push_str(&format!( + " {}", + &self.debug_simplify_expr(hir_file, *value,) + )); + } + + msg + } + hir::Expr::Missing => "".to_string(), + } + } + + fn debug_symbol(&self, symbol: &hir::Symbol) -> String { match &symbol { hir::Symbol::Local { name, expr: _ } => name.text(self.db).to_string(), hir::Symbol::Param { name, .. } => { @@ -458,6 +666,27 @@ mod tests { } } + fn debug_binary_op(&self, op: &ast::BinaryOp) -> String { + match op { + ast::BinaryOp::Add(_) => "+", + ast::BinaryOp::Sub(_) => "-", + ast::BinaryOp::Mul(_) => "*", + ast::BinaryOp::Div(_) => "/", + ast::BinaryOp::Equal(_) => "==", + ast::BinaryOp::GreaterThan(_) => ">", + ast::BinaryOp::LessThan(_) => "<", + } + .to_string() + } + + fn debug_unary_op(&self, op: &ast::UnaryOp) -> String { + match op { + ast::UnaryOp::Neg(_) => "-", + ast::UnaryOp::Not(_) => "!", + } + .to_string() + } + fn debug_resolution_status(&self, resolution_status: hir::ResolutionStatus) -> String { match resolution_status { hir::ResolutionStatus::Unresolved => "".to_string(), @@ -722,24 +951,24 @@ mod tests { } --- - error: expected int, actual: string - error: expected int, actual: string - error: expected int, actual: string - error: expected int, actual: char - error: expected int, actual: char - error: expected int, actual: char - error: expected int, actual: char - error: expected int, actual: char - error: expected int, actual: bool - error: expected int, actual: bool - error: expected int, actual: bool - error: expected int, actual: bool - error: expected int, actual: bool - error: expected int, actual: bool - error: expected int, actual: bool - error: expected int, actual: bool - error: expected int, actual: bool - error: expected int, actual: string + error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + + error MismatchedBinaryInteger: expected_ty: int, found_expr: "bbb", found_ty: string, op: + + error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + + error MismatchedBinaryInteger: expected_ty: int, found_expr: 'a', found_ty: char, op: + + error MismatchedBinaryInteger: expected_ty: int, found_expr: 'a', found_ty: char, op: + + error MismatchedBinaryInteger: expected_ty: int, found_expr: 'a', found_ty: char, op: + + error MismatchedBinaryCompare: expected_ty: int, expected_expr: 10, found_expr: 'a', found_ty: char, op: < + error MismatchedBinaryCompare: expected_ty: int, expected_expr: 10, found_expr: 'a', found_ty: char, op: > + error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: + + error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: + + error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: - + error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: - + error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: * + error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: * + error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: / + error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: / + error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: + + error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + "#]], ); @@ -756,8 +985,8 @@ mod tests { } --- - error: expected int, actual: string - error: expected int, actual: string + error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + + error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + "#]], ); } @@ -792,12 +1021,12 @@ mod tests { } --- - error: expected int, actual: string - error: expected int, actual: char - error: expected int, actual: bool - error: expected bool, actual: int - error: expected bool, actual: string - error: expected bool, actual: char + error MismatchedUnary: expected_ty: int, found_expr: "aaa", found_ty: string, op: - + error MismatchedUnary: expected_ty: int, found_expr: 'a', found_ty: char, op: - + error MismatchedUnary: expected_ty: int, found_expr: true, found_ty: bool, op: - + error MismatchedUnary: expected_ty: bool, found_expr: 10, found_ty: int, op: ! + error MismatchedUnary: expected_ty: bool, found_expr: "aaa", found_ty: string, op: ! + error MismatchedUnary: expected_ty: bool, found_expr: 'a', found_ty: char, op: ! "#]], ) } @@ -936,7 +1165,7 @@ mod tests { } --- - error: expected int, actual: () + error MismatchedReturnType: expected_ty: (), found_expr: 10, found_ty: () "#]], ); @@ -1006,7 +1235,7 @@ mod tests { } --- - error: expected int, actual: () + error MismatchedReturnType: expected_ty: (), found_expr: {{ .. }}, found_ty: () "#]], ); @@ -1124,7 +1353,7 @@ mod tests { } --- - error: expected int, actual: () + error MismatchedReturnType: expected_ty: (), found_expr: res + 30, found_ty: () "#]], ); } @@ -1150,8 +1379,8 @@ mod tests { } --- - error: expected string, actual: bool - error: expected bool, actual: string + error MismaatchedSignature: expected_ty: string, signature: (bool, string) -> int, found_expr: "aaa", found_ty: bool + error MismaatchedSignature: expected_ty: bool, signature: (bool, string) -> int, found_expr: true, found_ty: string "#]], ); } @@ -1202,7 +1431,7 @@ mod tests { } --- - error: expected int, actual: () + error MismatchedTypeOnlyIfBranch: then_branch_ty: (), then_branch: {{ .. }} "#]], ); } @@ -1271,8 +1500,8 @@ mod tests { } --- - error: expected int, actual: string - error: expected int, actual: () + error MismatchedTypeElseBranch: then_branch_ty: int, then_branch: {{ .. }}, else_branch_ty: string, else_branch: {{ .. }} + error MismatchedReturnType: expected_ty: (), found_expr: if true { .. } else { .. }, found_ty: () "#]], ); } @@ -1300,8 +1529,8 @@ mod tests { } --- - error: expected bool, actual: int - error: expected string, actual: () + error MismatchedTypeIfCondition: expected_ty: bool, found_expr: 10, found_ty: int + error MismatchedReturnType: expected_ty: (), found_expr: if 10 { .. } else { .. }, found_ty: () "#]], ); } @@ -1333,7 +1562,7 @@ mod tests { } --- - error: expected (), actual: bool + error MismatchedTypeElseBranch: then_branch_ty: (), then_branch: {{ .. }}, else_branch_ty: bool, else_branch: {{ .. }} "#]], ); @@ -1362,7 +1591,7 @@ mod tests { } --- - error: expected bool, actual: () + error MismatchedTypeElseBranch: then_branch_ty: bool, then_branch: {{ .. }}, else_branch_ty: (), else_branch: {{ .. }} "#]], ); @@ -1410,7 +1639,7 @@ mod tests { } --- - error: expected !, actual: () + error MismatchedReturnType: expected_ty: (), found_expr: return, found_ty: () "#]], ); @@ -1427,7 +1656,7 @@ mod tests { } --- - error: expected !, actual: int + error MismatchedReturnType: expected_ty: int, found_expr: return 10, found_ty: int "#]], ); } @@ -1447,8 +1676,8 @@ mod tests { } --- - error: expected (), actual: int - error: expected !, actual: int + error MismatchedReturnType: expected_ty: int, found_ty: int + error MismatchedReturnType: expected_ty: int, found_expr: return, found_ty: int "#]], ); @@ -1465,8 +1694,8 @@ mod tests { } --- - error: expected string, actual: int - error: expected !, actual: int + error MismatchedReturnType: expected_ty: int, found_expr: "aaa", found_ty: int + error MismatchedReturnType: expected_ty: int, found_expr: return "aaa", found_ty: int "#]], ); } From 478f23c5dda754daa3990a7337668f752fa08c5d Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Fri, 15 Sep 2023 21:59:56 +0900 Subject: [PATCH 09/35] wip --- crates/hir_ty/src/lib.rs | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index f0275e4f..11a245b6 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -623,16 +623,33 @@ mod tests { format!("{callee}({args})") } - hir::Expr::Block(block) => "{{ .. }}".to_string(), + hir::Expr::Block(block) => { + let mut msg = "{ tail:".to_string(); + if let Some(tail) = block.tail { + msg.push_str(&self.debug_simplify_expr(hir_file, tail)); + } else { + msg.push_str("none"); + } + msg.push_str(" }"); + + msg + } hir::Expr::If { condition, then_branch, else_branch, } => { let cond_str = self.debug_simplify_expr(hir_file, *condition); - let mut msg = format!("if {cond_str} {{ .. }}"); - if else_branch.is_some() { - msg.push_str(" else { .. }"); + let mut msg = format!( + "if {cond_str} {}", + self.debug_simplify_expr(hir_file, *then_branch) + ); + + if let Some(else_branch) = else_branch { + msg.push_str(&format!( + " else {}", + self.debug_simplify_expr(hir_file, *else_branch) + )); } msg @@ -1235,7 +1252,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_expr: {{ .. }}, found_ty: () + error MismatchedReturnType: expected_ty: (), found_expr: { tail:{ tail:10 } }, found_ty: () "#]], ); @@ -1431,7 +1448,7 @@ mod tests { } --- - error MismatchedTypeOnlyIfBranch: then_branch_ty: (), then_branch: {{ .. }} + error MismatchedTypeOnlyIfBranch: then_branch_ty: (), then_branch: { tail:10 } "#]], ); } @@ -1500,8 +1517,8 @@ mod tests { } --- - error MismatchedTypeElseBranch: then_branch_ty: int, then_branch: {{ .. }}, else_branch_ty: string, else_branch: {{ .. }} - error MismatchedReturnType: expected_ty: (), found_expr: if true { .. } else { .. }, found_ty: () + error MismatchedTypeElseBranch: then_branch_ty: int, then_branch: { tail:10 }, else_branch_ty: string, else_branch: { tail:"aaa" } + error MismatchedReturnType: expected_ty: (), found_expr: if true { tail:10 } else { tail:"aaa" }, found_ty: () "#]], ); } @@ -1530,7 +1547,7 @@ mod tests { --- error MismatchedTypeIfCondition: expected_ty: bool, found_expr: 10, found_ty: int - error MismatchedReturnType: expected_ty: (), found_expr: if 10 { .. } else { .. }, found_ty: () + error MismatchedReturnType: expected_ty: (), found_expr: if 10 { tail:"aaa" } else { tail:"aaa" }, found_ty: () "#]], ); } @@ -1562,7 +1579,7 @@ mod tests { } --- - error MismatchedTypeElseBranch: then_branch_ty: (), then_branch: {{ .. }}, else_branch_ty: bool, else_branch: {{ .. }} + error MismatchedTypeElseBranch: then_branch_ty: (), then_branch: { tail:none }, else_branch_ty: bool, else_branch: { tail:true } "#]], ); @@ -1591,7 +1608,7 @@ mod tests { } --- - error MismatchedTypeElseBranch: then_branch_ty: bool, then_branch: {{ .. }}, else_branch_ty: (), else_branch: {{ .. }} + error MismatchedTypeElseBranch: then_branch_ty: bool, then_branch: { tail:true }, else_branch_ty: (), else_branch: { tail:none } "#]], ); From 2629252f5cf1ea508cb789985869d8bd7718e369 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Fri, 15 Sep 2023 22:01:53 +0900 Subject: [PATCH 10/35] wip --- crates/hir_ty/src/lib.rs | 69 +++++++++------------------------------- 1 file changed, 15 insertions(+), 54 deletions(-) diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 11a245b6..1d605a32 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -117,7 +117,7 @@ mod tests { for error in &inference_body_result.errors { match error { - InferenceError::Unresolved { expr } => todo!(), + InferenceError::Unresolved { expr: _ } => todo!(), InferenceError::MismatchedTypes { expected_ty, found_ty, @@ -329,25 +329,17 @@ mod tests { hir::Type::Unknown => "", }; - let scope_origin = hir::ModuleScopeOrigin::Function { origin: function }; - let hir::Expr::Block(block) = body_expr else { panic!("Should be Block") }; let mut body = "{\n".to_string(); for stmt in &block.stmts { - body.push_str(&self.debug_stmt( - hir_file, - function, - scope_origin, - stmt, - nesting + 1, - )); + body.push_str(&self.debug_stmt(hir_file, function, stmt, nesting + 1)); } if let Some(tail) = block.tail { let indent = indent(nesting + 1); body.push_str(&format!( "{indent}expr:{}", - self.debug_expr(hir_file, function, scope_origin, tail, nesting + 1) + self.debug_expr(hir_file, function, tail, nesting + 1) )); body.push_str(&format!(" {}\n", self.debug_type_line(function, tail))); } @@ -413,7 +405,6 @@ mod tests { &self, hir_file: hir::HirFile, function: hir::Function, - scope_origin: hir::ModuleScopeOrigin, stmt: &hir::Stmt, nesting: usize, ) -> String { @@ -421,8 +412,7 @@ mod tests { hir::Stmt::VariableDef { name, value } => { let indent = indent(nesting); let name = name.text(self.db); - let expr_str = - self.debug_expr(hir_file, function, scope_origin, *value, nesting); + let expr_str = self.debug_expr(hir_file, function, *value, nesting); let mut stmt_str = format!("{indent}let {name} = {expr_str};"); let type_line = self.debug_type_line(function, *value); @@ -435,8 +425,7 @@ mod tests { has_semicolon, } => { let indent = indent(nesting); - let expr_str = - self.debug_expr(hir_file, function, scope_origin, *expr, nesting); + let expr_str = self.debug_expr(hir_file, function, *expr, nesting); let type_line = self.debug_type_line(function, *expr); let maybe_semicolon = if *has_semicolon { ";" } else { "" }; format!("{indent}{expr_str}{maybe_semicolon} {type_line}\n") @@ -462,7 +451,6 @@ mod tests { &self, hir_file: hir::HirFile, function: hir::Function, - scope_origin: hir::ModuleScopeOrigin, expr_id: hir::ExprId, nesting: usize, ) -> String { @@ -484,8 +472,8 @@ mod tests { ast::BinaryOp::GreaterThan(_) => ">", ast::BinaryOp::LessThan(_) => "<", }; - let lhs_str = self.debug_expr(hir_file, function, scope_origin, *lhs, nesting); - let rhs_str = self.debug_expr(hir_file, function, scope_origin, *rhs, nesting); + let lhs_str = self.debug_expr(hir_file, function, *lhs, nesting); + let rhs_str = self.debug_expr(hir_file, function, *rhs, nesting); format!("{lhs_str} {op} {rhs_str}") } hir::Expr::Unary { op, expr } => { @@ -493,38 +481,29 @@ mod tests { ast::UnaryOp::Neg(_) => "-", ast::UnaryOp::Not(_) => "!", }; - let expr_str = - self.debug_expr(hir_file, function, scope_origin, *expr, nesting); + let expr_str = self.debug_expr(hir_file, function, *expr, nesting); format!("{op}{expr_str}") } hir::Expr::Call { callee, args } => { let callee = self.debug_symbol(callee); let args = args .iter() - .map(|arg| self.debug_expr(hir_file, function, scope_origin, *arg, nesting)) + .map(|arg| self.debug_expr(hir_file, function, *arg, nesting)) .collect::>() .join(", "); format!("{callee}({args})") } hir::Expr::Block(block) => { - let scope_origin = hir::ModuleScopeOrigin::Block { origin: expr_id }; - let mut msg = "{\n".to_string(); for stmt in &block.stmts { - msg.push_str(&self.debug_stmt( - hir_file, - function, - scope_origin, - stmt, - nesting + 1, - )); + msg.push_str(&self.debug_stmt(hir_file, function, stmt, nesting + 1)); } if let Some(tail) = block.tail { let indent = indent(nesting + 1); msg.push_str(&format!( "{indent}expr:{}", - self.debug_expr(hir_file, function, scope_origin, tail, nesting + 1) + self.debug_expr(hir_file, function, tail, nesting + 1) )); msg.push_str(&format!(" {}\n", self.debug_type_line(function, tail))); } @@ -538,31 +517,13 @@ mod tests { else_branch, } => { let mut msg = "if ".to_string(); - msg.push_str(&self.debug_expr( - hir_file, - function, - scope_origin, - *condition, - nesting, - )); + msg.push_str(&self.debug_expr(hir_file, function, *condition, nesting)); msg.push(' '); - msg.push_str(&self.debug_expr( - hir_file, - function, - scope_origin, - *then_branch, - nesting, - )); + msg.push_str(&self.debug_expr(hir_file, function, *then_branch, nesting)); if let Some(else_branch) = else_branch { msg.push_str(" else "); - msg.push_str(&self.debug_expr( - hir_file, - function, - scope_origin, - *else_branch, - nesting, - )); + msg.push_str(&self.debug_expr(hir_file, function, *else_branch, nesting)); } msg @@ -572,7 +533,7 @@ mod tests { if let Some(value) = value { msg.push_str(&format!( " {}", - &self.debug_expr(hir_file, function, scope_origin, *value, nesting,) + &self.debug_expr(hir_file, function, *value, nesting,) )); } From 083b1734061be284856bf47e3830c7305761dda5 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Fri, 15 Sep 2023 22:05:31 +0900 Subject: [PATCH 11/35] wip --- crates/hir_ty/src/inference/environment.rs | 6 ++++-- crates/hir_ty/src/inference/type_scheme.rs | 1 + crates/hir_ty/src/inference/type_unifier.rs | 15 --------------- crates/hir_ty/src/inference/types.rs | 3 +-- 4 files changed, 6 insertions(+), 19 deletions(-) diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 0f2fb12b..cab03387 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -137,7 +137,7 @@ impl<'a> InferBody<'a> { // 解決できないエラーを追加 Monotype::Unknown } - hir::ResolutionStatus::Resolved { path, item } => { + hir::ResolutionStatus::Resolved { path: _, item } => { match item { hir::Item::Function(function) => { let signature = self.signature_by_function.get(&function); @@ -205,7 +205,7 @@ impl<'a> InferBody<'a> { { self.unifier.unify( &call_arg_ty, - &arg, + arg, &UnifyPurpose::CallArg { found_arg: *call_arg, expected_signature: signature.as_ref().clone(), @@ -426,6 +426,7 @@ impl Environment { } } + #[allow(dead_code)] fn free_variables(&self) -> HashSet { let mut union = HashSet::::new(); for type_scheme in self.bindings.values() { @@ -442,6 +443,7 @@ impl Environment { Environment { bindings: copy } } + #[allow(dead_code)] fn generalize(&self, ty: &Monotype) -> TypeScheme { TypeScheme { variables: ty.free_variables().sub(&self.free_variables()), diff --git a/crates/hir_ty/src/inference/type_scheme.rs b/crates/hir_ty/src/inference/type_scheme.rs index 65c9ae6f..2e0d7880 100644 --- a/crates/hir_ty/src/inference/type_scheme.rs +++ b/crates/hir_ty/src/inference/type_scheme.rs @@ -19,6 +19,7 @@ impl TypeScheme { } } + #[allow(dead_code)] pub fn free_variables(&self) -> HashSet { self.ty .free_variables() diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index 396f589c..ded6f280 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -10,12 +10,6 @@ pub(crate) struct TypeUnifier { /// 型推論の目的 pub(crate) enum UnifyPurpose { - Expr { - /// 期待する型を持つ式 - expected_expr: hir::ExprId, - /// 実際に得られた型を持つ式 - found_expr: hir::ExprId, - }, CallArg { /// 関数呼び出し対象のシグネチャ expected_signature: Signature, @@ -75,15 +69,6 @@ fn build_unify_error_from_unify_purpose( purpose: &UnifyPurpose, ) -> InferenceError { match purpose { - UnifyPurpose::Expr { - expected_expr, - found_expr, - } => InferenceError::MismatchedTypes { - expected_ty, - found_ty, - expected_expr: *expected_expr, - found_expr: *found_expr, - }, UnifyPurpose::CallArg { found_arg, expected_signature, diff --git a/crates/hir_ty/src/inference/types.rs b/crates/hir_ty/src/inference/types.rs index 636a38b2..fc368242 100644 --- a/crates/hir_ty/src/inference/types.rs +++ b/crates/hir_ty/src/inference/types.rs @@ -35,8 +35,7 @@ impl fmt::Display for Monotype { .iter() .map(|ty| ty.to_string()) .collect::>() - .join(", ") - .to_string(), + .join(", "), signature.return_type.to_string() ) } From eda5c279835e2cd28b3d495b5d288b1b4b0132a0 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Fri, 15 Sep 2023 23:35:32 +0900 Subject: [PATCH 12/35] wip --- crates/hir_ty/src/checker.rs | 195 +++++------------------------------ crates/hir_ty/src/lib.rs | 63 ++++++++--- 2 files changed, 75 insertions(+), 183 deletions(-) diff --git a/crates/hir_ty/src/checker.rs b/crates/hir_ty/src/checker.rs index ac7a72b0..6f238592 100644 --- a/crates/hir_ty/src/checker.rs +++ b/crates/hir_ty/src/checker.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::inference::{InferenceBodyResult, InferenceResult, Monotype, Signature}; +use crate::inference::{InferenceBodyResult, InferenceResult, Monotype}; pub fn check_type_pods( db: &dyn hir::HirMasterDatabase, @@ -12,8 +12,7 @@ pub fn check_type_pods( let mut errors_by_function = HashMap::new(); for (hir_file, function) in pod.all_functions(db) { - let type_checker = - FunctionTypeChecker::new(db, &pods.resolution_map, infer_result, hir_file, function); + let type_checker = FunctionTypeChecker::new(db, infer_result, hir_file, function); let type_errors = type_checker.check(); errors_by_function.insert(function, type_errors); } @@ -25,63 +24,10 @@ pub fn check_type_pods( #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum TypeCheckError { /// 型を解決できない - UnMonotype { + Unresolved { /// 対象の式 expr: hir::ExprId, }, - /// 一致するべき型が一致しない - MismatchedTypes { - /// 期待される型の式 - expected_expr: hir::ExprId, - /// 期待される型 - expected_ty: Monotype, - /// 実際の式 - found_expr: hir::ExprId, - /// 実際の型 - found_ty: Monotype, - }, - /// Ifの条件式の型が一致しない - MismatchedTypeIfCondition { - /// 期待される型 - expected_ty: Monotype, - /// 実際の式 - found_expr: hir::ExprId, - /// 実際の型 - found_ty: Monotype, - }, - /// Ifのthenブランチとelseブランチの型が一致しない - MismatchedTypeElseBranch { - /// 期待される型 - expected_ty: Monotype, - /// 実際の式 - found_expr: hir::ExprId, - /// 実際の型 - found_ty: Monotype, - }, - /// 関数呼び出しの引数の数が一致しない - MismaatchedSignature { - /// 期待される型 - expected_ty: Monotype, - /// 呼び出そうとしている関数のシグネチャ - signature: Signature, - /// 実際の式 - found_expr: hir::ExprId, - /// 実際の型 - found_ty: Monotype, - }, - /// 関数の戻り値の型と実際の戻り値の型が異なる - /// - /// 以下のいずれかが関数の戻り値の型と一致しない場合に発生する - /// - `return`に指定した式の型 - /// - 関数ボディの最後の式の型 - MismatchedReturnType { - /// 期待される型 - expected_ty: Monotype, - /// 実際の式 - found_expr: Option, - /// 実際の型 - found_ty: Monotype, - }, } /// 型チェックの結果 @@ -93,7 +39,6 @@ pub struct TypeCheckResult { struct FunctionTypeChecker<'a> { db: &'a dyn hir::HirMasterDatabase, - resolution_map: &'a hir::ResolutionMap, infer_result: &'a InferenceResult, hir_file: hir::HirFile, @@ -105,7 +50,6 @@ struct FunctionTypeChecker<'a> { impl<'a> FunctionTypeChecker<'a> { fn new( db: &'a dyn hir::HirMasterDatabase, - resolution_map: &'a hir::ResolutionMap, infer_result: &'a InferenceResult, hir_file: hir::HirFile, @@ -113,7 +57,6 @@ impl<'a> FunctionTypeChecker<'a> { ) -> Self { Self { db, - resolution_map, infer_result, hir_file, function, @@ -121,10 +64,6 @@ impl<'a> FunctionTypeChecker<'a> { } } - fn signature_by_function(&self, function: hir::Function) -> &Signature { - &self.infer_result.signature_by_function[&function] - } - fn check(mut self) -> Vec { let block_ast_id = self.function.ast(self.db).body().unwrap(); let body = self @@ -140,17 +79,6 @@ impl<'a> FunctionTypeChecker<'a> { } if let Some(tail) = block.tail { self.check_expr(tail); - - let signature = self.signature_by_function(self.function); - let tail_ty = self.current_inference().type_by_expr[&tail].clone(); - if tail_ty != signature.return_type { - self.errors.push(TypeCheckError::MismatchedReturnType { - expected_ty: signature.return_type.clone(), - found_expr: Some(tail), - found_ty: tail_ty, - }); - } - } else { } } _ => unreachable!(), @@ -168,28 +96,20 @@ impl<'a> FunctionTypeChecker<'a> { } fn check_expr(&mut self, expr: hir::ExprId) { + let ty = self.current_inference().type_by_expr[&expr].clone(); + if ty == Monotype::Unknown { + self.errors.push(TypeCheckError::Unresolved { expr }); + } + let expr = expr.lookup(self.hir_file.db(self.db)); match expr { hir::Expr::Symbol(_) => (), hir::Expr::Binary { lhs, rhs, .. } => { - let lhs_ty = self.type_by_expr(*lhs); - let rhs_ty = self.type_by_expr(*rhs); - match (lhs_ty, rhs_ty) { - (Monotype::Unknown, Monotype::Unknown) => (), - (lhs_ty, rhs_ty) => { - if lhs_ty != rhs_ty { - self.errors.push(TypeCheckError::MismatchedTypes { - expected_expr: *lhs, - expected_ty: lhs_ty, - found_expr: *rhs, - found_ty: rhs_ty, - }); - } - } - } + self.check_expr(*lhs); + self.check_expr(*rhs); } hir::Expr::Unary { expr, .. } => { - self.type_by_expr(*expr); + self.check_expr(*expr); } hir::Expr::Literal(_) => (), hir::Expr::Block(block) => { @@ -201,101 +121,36 @@ impl<'a> FunctionTypeChecker<'a> { } } hir::Expr::Call { callee, args } => { - let signature = match callee { - hir::Symbol::Local { .. } | hir::Symbol::Param { .. } => return, - hir::Symbol::Missing { path } => { - let Some(item) = self.resolution_map.item_by_symbol(path) else { return; }; - match item { - hir::ResolutionStatus::Unresolved | hir::ResolutionStatus::Error => { - return - } - hir::ResolutionStatus::Resolved { path: _, item } => match item { - hir::Item::Function(function) => { - self.signature_by_function(function).clone() - } - hir::Item::Module(_) | hir::Item::UseItem(_) => unimplemented!(), - }, - } - } - }; + for arg in args { + self.check_expr(*arg); + } - for (i, param_ty) in signature.params.iter().enumerate() { - let arg = args[i]; - let arg_ty = self.type_by_expr(arg); - if param_ty != &arg_ty { - self.errors.push(TypeCheckError::MismaatchedSignature { - expected_ty: param_ty.clone(), - signature: signature.clone(), - found_expr: arg, - found_ty: arg_ty, - }); + match callee { + hir::Symbol::Local { expr, .. } => { + self.check_expr(*expr); } - } + hir::Symbol::Param { .. } | hir::Symbol::Missing { .. } => (), + }; } hir::Expr::If { condition, then_branch, else_branch, } => { - let condition_ty = self.type_by_expr(*condition); - let expected_condition_ty = Monotype::Bool; - if condition_ty != expected_condition_ty { - self.errors.push(TypeCheckError::MismatchedTypeIfCondition { - expected_ty: expected_condition_ty, - found_expr: *condition, - found_ty: condition_ty, - }); - } + self.check_expr(*condition); - let then_branch_ty = self.type_by_expr(*then_branch); + self.check_expr(*then_branch); if let Some(else_branch) = else_branch { - let else_branch_ty = self.type_by_expr(*else_branch); - if then_branch_ty != else_branch_ty { - self.errors.push(TypeCheckError::MismatchedTypes { - expected_expr: *then_branch, - expected_ty: then_branch_ty, - found_expr: *else_branch, - found_ty: else_branch_ty, - }); - } - } else { - let else_branch_ty = Monotype::Unit; - if then_branch_ty != else_branch_ty { - self.errors.push(TypeCheckError::MismatchedTypeElseBranch { - expected_ty: else_branch_ty, - found_expr: *then_branch, - found_ty: then_branch_ty, - }); - } + self.check_expr(*else_branch); } } hir::Expr::Return { value } => { - let return_value_ty = if let Some(value) = value { - self.type_by_expr(*value) - } else { - Monotype::Unit - }; - - let signature = self.signature_by_function(self.function); - if return_value_ty != signature.return_type { - self.errors.push(TypeCheckError::MismatchedReturnType { - expected_ty: signature.return_type.clone(), - found_expr: *value, - found_ty: return_value_ty, - }); + if let Some(value) = value { + self.check_expr(*value); } } hir::Expr::Missing => (), - } - } - - fn type_by_expr(&mut self, expr: hir::ExprId) -> Monotype { - let ty = self.current_inference().type_by_expr[&expr].clone(); - if ty == Monotype::Unknown { - self.errors.push(TypeCheckError::UnMonotype { expr }); - } - - ty + }; } /// 現在の関数の推論結果を取得する diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 1d605a32..317ad25f 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -36,15 +36,33 @@ pub fn lower_pods(db: &dyn hir::HirMasterDatabase, pods: &hir::Pods) -> TyLowerR #[derive(Debug)] pub struct TyLowerResult { /// 型推論の結果 - pub inference_result: InferenceResult, + inference_result: InferenceResult, /// 型チェックの結果 - pub type_check_result: TypeCheckResult, + type_check_result: TypeCheckResult, } impl TyLowerResult { /// 指定した関数の型を取得します。 pub fn signature_by_function(&self, function_id: hir::Function) -> &Signature { &self.inference_result.signature_by_function[&function_id] } + + /// 指定した関数の型推論結果を取得します。 + pub fn inference_body_by_function( + &self, + function: hir::Function, + ) -> Option<&InferenceBodyResult> { + self.inference_result + .inference_body_result_by_function + .get(&function) + } + + /// 指定した関数の型チェック結果を取得します。 + pub fn type_check_errors_by_function( + &self, + function: hir::Function, + ) -> Option<&Vec> { + self.type_check_result.errors_by_function.get(&function) + } } #[cfg(test)] @@ -52,8 +70,8 @@ mod tests { use expect_test::{expect, Expect}; use crate::{ - inference::{infer_pods, InferenceError, Signature}, - InferenceResult, + inference::{InferenceError, Signature}, + lower_pods, TyLowerResult, TypeCheckError, }; fn check_pod_start_with_root_file(fixture: &str, expect: Expect) { @@ -61,9 +79,9 @@ mod tests { let mut source_db = hir::FixtureDatabase::new(&db, fixture); let pods = hir::parse_pods(&db, "/main.nail", &mut source_db); - let inference_result = infer_pods(&db, &pods); + let ty_lower_result = lower_pods(&db, &pods); - expect.assert_eq(&TestingDebug::new(&db, &pods, &inference_result).debug()); + expect.assert_eq(&TestingDebug::new(&db, &pods, &ty_lower_result).debug()); } fn check_in_root_file(fixture: &str, expect: Expect) { @@ -80,18 +98,18 @@ mod tests { struct TestingDebug<'a> { db: &'a dyn hir::HirMasterDatabase, pods: &'a hir::Pods, - inference_result: &'a InferenceResult, + ty_lower_result: &'a TyLowerResult, } impl<'a> TestingDebug<'a> { fn new( db: &'a dyn hir::HirMasterDatabase, pods: &'a hir::Pods, - inference_result: &'a InferenceResult, + ty_lower_result: &'a TyLowerResult, ) -> Self { - TestingDebug { + Self { db, pods, - inference_result, + ty_lower_result, } } @@ -110,6 +128,7 @@ mod tests { msg.push_str("---\n"); for (hir_file, function) in self.pods.root_pod.all_functions(self.db) { let inference_body_result = self + .ty_lower_result .inference_result .inference_body_result_by_function .get(&function) @@ -259,6 +278,25 @@ mod tests { } } + msg.push_str("---\n"); + for (hir_file, function) in self.pods.root_pod.all_functions(self.db) { + let type_check_errors = self + .ty_lower_result + .type_check_errors_by_function(function) + .unwrap(); + for error in type_check_errors { + match error { + TypeCheckError::Unresolved { expr } => { + msg.push_str(&format!( + "error Type is unknown: expr: {}", + self.debug_simplify_expr(hir_file, *expr), + )); + } + } + } + msg.push('\n'); + } + msg } @@ -436,9 +474,8 @@ mod tests { fn debug_type_line(&self, function: hir::Function, expr_id: hir::ExprId) -> String { let ty = self - .inference_result - .inference_body_result_by_function - .get(&function) + .ty_lower_result + .inference_body_by_function(function) .unwrap() .type_by_expr .get(&expr_id) From 3f82d5ccbfa8ae48ab3f42e889f8d7025e47ba85 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Fri, 15 Sep 2023 23:36:45 +0900 Subject: [PATCH 13/35] wip --- crates/hir_ty/src/inference/error.rs | 5 ----- crates/hir_ty/src/lib.rs | 1 - 2 files changed, 6 deletions(-) diff --git a/crates/hir_ty/src/inference/error.rs b/crates/hir_ty/src/inference/error.rs index 7f0f157e..0a431bde 100644 --- a/crates/hir_ty/src/inference/error.rs +++ b/crates/hir_ty/src/inference/error.rs @@ -3,11 +3,6 @@ use super::{Monotype, Signature}; /// 型チェックのエラー #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum InferenceError { - /// 型を解決できない - Unresolved { - /// 対象の式 - expr: hir::ExprId, - }, /// 一致するべき型が一致しない MismatchedTypes { /// 期待される型 diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 317ad25f..9528a706 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -136,7 +136,6 @@ mod tests { for error in &inference_body_result.errors { match error { - InferenceError::Unresolved { expr: _ } => todo!(), InferenceError::MismatchedTypes { expected_ty, found_ty, From 4175312a8d0086617610243c0c65dc24bf75979d Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 00:57:00 +0900 Subject: [PATCH 14/35] wip --- crates/base_db/src/lib.rs | 2 +- crates/hir_ty/src/db.rs | 15 +++ crates/hir_ty/src/inference.rs | 14 +-- crates/hir_ty/src/inference/environment.rs | 52 +++++---- crates/hir_ty/src/inference/type_scheme.rs | 7 +- crates/hir_ty/src/inference/type_unifier.rs | 6 +- crates/hir_ty/src/inference/types.rs | 54 ++------- crates/hir_ty/src/lib.rs | 122 +++++++++++++++----- crates/hir_ty/src/testing.rs | 45 ++++++++ 9 files changed, 203 insertions(+), 114 deletions(-) create mode 100644 crates/hir_ty/src/db.rs create mode 100644 crates/hir_ty/src/testing.rs diff --git a/crates/base_db/src/lib.rs b/crates/base_db/src/lib.rs index 97134516..86ecd845 100644 --- a/crates/base_db/src/lib.rs +++ b/crates/base_db/src/lib.rs @@ -1,5 +1,5 @@ #[derive(Default)] -#[salsa::db(hir::Jar)] +#[salsa::db(hir::Jar, hir_ty::Jar)] pub struct SalsaDatabase { storage: salsa::Storage, } diff --git a/crates/hir_ty/src/db.rs b/crates/hir_ty/src/db.rs new file mode 100644 index 00000000..587e536f --- /dev/null +++ b/crates/hir_ty/src/db.rs @@ -0,0 +1,15 @@ +use crate::inference; + +/// HIR-tyの全体のデータベースです。 +/// +/// ここに`salsa`データを定義します。 +#[salsa::jar(db = HirTyMasterDatabase)] +pub struct Jar(crate::Signature, inference::lower_signature); + +/// [Jar]用のDBトレイトです。 +pub trait HirTyMasterDatabase: salsa::DbWithJar + hir::HirMasterDatabase {} + +impl HirTyMasterDatabase for DB where + DB: ?Sized + salsa::DbWithJar + salsa::DbWithJar +{ +} diff --git a/crates/hir_ty/src/inference.rs b/crates/hir_ty/src/inference.rs index 772a05a8..ee70e53e 100644 --- a/crates/hir_ty/src/inference.rs +++ b/crates/hir_ty/src/inference.rs @@ -12,7 +12,9 @@ pub use error::InferenceError; pub use type_scheme::TypeScheme; pub use types::Monotype; -pub fn infer_pods(db: &dyn hir::HirMasterDatabase, pods: &hir::Pods) -> InferenceResult { +use crate::HirTyMasterDatabase; + +pub fn infer_pods(db: &dyn HirTyMasterDatabase, pods: &hir::Pods) -> InferenceResult { let mut signature_by_function = HashMap::::new(); for (hir_file, function) in pods.root_pod.all_functions(db) { let signature = lower_signature(db, hir_file, function); @@ -34,8 +36,9 @@ pub fn infer_pods(db: &dyn hir::HirMasterDatabase, pods: &hir::Pods) -> Inferenc } } -fn lower_signature( - db: &dyn hir::HirMasterDatabase, +#[salsa::tracked] +pub(crate) fn lower_signature( + db: &dyn HirTyMasterDatabase, hir_file: hir::HirFile, function: hir::Function, ) -> Signature { @@ -50,10 +53,7 @@ fn lower_signature( let return_type = lower_type(&function.return_type(db)); - Signature { - params, - return_type, - } + Signature::new(db, params, return_type) } fn lower_type(ty: &hir::Type) -> Monotype { diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index cab03387..a379b071 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -9,19 +9,21 @@ use super::{ type_unifier::{TypeUnifier, UnifyPurpose}, types::Monotype, }; +use crate::HirTyMasterDatabase; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[salsa::tracked] pub struct Signature { + #[return_ref] pub params: Vec, pub return_type: Monotype, } pub(crate) struct InferBody<'a> { - db: &'a dyn hir::HirMasterDatabase, + db: &'a dyn HirTyMasterDatabase, pods: &'a hir::Pods, hir_file: hir::HirFile, function: hir::Function, - signature: &'a Signature, + signature: Signature, unifier: TypeUnifier, cxt: Context, @@ -32,7 +34,7 @@ pub(crate) struct InferBody<'a> { } impl<'a> InferBody<'a> { pub(crate) fn new( - db: &'a dyn hir::HirMasterDatabase, + db: &'a dyn HirTyMasterDatabase, pods: &'a hir::Pods, hir_file: hir::HirFile, function: hir::Function, @@ -44,7 +46,7 @@ impl<'a> InferBody<'a> { pods, hir_file, function, - signature: signature_by_function.get(&function).unwrap(), + signature: *signature_by_function.get(&function).unwrap(), unifier: TypeUnifier::new(), cxt: Context::default(), @@ -64,9 +66,9 @@ impl<'a> InferBody<'a> { let ty = self.infer_expr(*tail); self.unifier.unify( &ty, - &self.signature.return_type, + &self.signature.return_type(self.db), &UnifyPurpose::SelfReturnType { - expected_signature: self.signature.clone(), + expected_signature: self.signature, found_expr: Some(*tail), }, ); @@ -74,9 +76,9 @@ impl<'a> InferBody<'a> { let ty = Monotype::Unit; self.unifier.unify( &ty, - &self.signature.return_type, + &self.signature.return_type(self.db), &UnifyPurpose::SelfReturnType { - expected_signature: self.signature.clone(), + expected_signature: self.signature, found_expr: None, }, ); @@ -125,7 +127,7 @@ impl<'a> InferBody<'a> { hir::Symbol::Local { name, expr: _ } => { let ty_scheme = self.current_scope().bindings.get(name).cloned(); if let Some(ty_scheme) = ty_scheme { - ty_scheme.instantiate(&mut self.cxt) + ty_scheme.instantiate(self.db, &mut self.cxt) } else { panic!("Unbound variable {symbol:?}"); } @@ -142,7 +144,7 @@ impl<'a> InferBody<'a> { hir::Item::Function(function) => { let signature = self.signature_by_function.get(&function); if let Some(signature) = signature { - Monotype::Function(signature.clone().into()) + Monotype::Function(*signature) } else { unreachable!("Function signature should be resolved.") } @@ -196,24 +198,26 @@ impl<'a> InferBody<'a> { .map(|arg| self.infer_expr(*arg)) .collect::>(); - if call_args_ty.len() != signature.params.len() { + if call_args_ty.len() != signature.params(self.db).len() { // TODO: 引数の数が異なるエラーを追加 Monotype::Unknown } else { - for ((call_arg, call_arg_ty), arg) in - call_args.iter().zip(call_args_ty).zip(&signature.params) + for ((call_arg, call_arg_ty), arg) in call_args + .iter() + .zip(call_args_ty) + .zip(signature.params(self.db)) { self.unifier.unify( &call_arg_ty, arg, &UnifyPurpose::CallArg { found_arg: *call_arg, - expected_signature: signature.as_ref().clone(), + expected_signature: signature, }, ); } - signature.return_type.clone() + signature.return_type(self.db) } } } @@ -351,9 +355,9 @@ impl<'a> InferBody<'a> { let ty = self.infer_expr(*return_value); self.unifier.unify( &ty, - &self.signature.return_type, + &self.signature.return_type(self.db), &UnifyPurpose::ReturnValue { - expected_signature: self.signature.clone(), + expected_signature: self.signature, found_expr: Some(*return_value), }, ); @@ -361,9 +365,9 @@ impl<'a> InferBody<'a> { // 何も指定しない場合は Unit を返すものとして扱う self.unifier.unify( &Monotype::Unit, - &self.signature.return_type, + &self.signature.return_type(self.db), &UnifyPurpose::ReturnValue { - expected_signature: self.signature.clone(), + expected_signature: self.signature, found_expr: None, }, ); @@ -427,10 +431,10 @@ impl Environment { } #[allow(dead_code)] - fn free_variables(&self) -> HashSet { + fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { let mut union = HashSet::::new(); for type_scheme in self.bindings.values() { - union.extend(type_scheme.free_variables()); + union.extend(type_scheme.free_variables(db)); } union @@ -444,9 +448,9 @@ impl Environment { } #[allow(dead_code)] - fn generalize(&self, ty: &Monotype) -> TypeScheme { + fn generalize(&self, ty: &Monotype, db: &dyn HirTyMasterDatabase) -> TypeScheme { TypeScheme { - variables: ty.free_variables().sub(&self.free_variables()), + variables: ty.free_variables(db).sub(&self.free_variables(db)), ty: ty.clone(), } } diff --git a/crates/hir_ty/src/inference/type_scheme.rs b/crates/hir_ty/src/inference/type_scheme.rs index 2e0d7880..2f0fb855 100644 --- a/crates/hir_ty/src/inference/type_scheme.rs +++ b/crates/hir_ty/src/inference/type_scheme.rs @@ -4,6 +4,7 @@ use std::{ }; use super::{environment::Context, types::Monotype}; +use crate::HirTyMasterDatabase; #[derive(Clone)] pub struct TypeScheme { @@ -20,16 +21,16 @@ impl TypeScheme { } #[allow(dead_code)] - pub fn free_variables(&self) -> HashSet { + pub fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { self.ty - .free_variables() + .free_variables(db) .into_iter() .filter(|var| !self.variables.contains(var)) .collect() } /// 具体的な型を生成する - pub fn instantiate(&self, cxt: &mut Context) -> Monotype { + pub fn instantiate(&self, db: &dyn HirTyMasterDatabase, cxt: &mut Context) -> Monotype { let new_vars = self .variables .iter() diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index ded6f280..997418f0 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -75,14 +75,14 @@ fn build_unify_error_from_unify_purpose( } => InferenceError::MismaatchedSignature { expected_ty, found_ty, - signature: expected_signature.clone(), + signature: *expected_signature, found_expr: *found_arg, }, UnifyPurpose::SelfReturnType { expected_signature, found_expr, } => InferenceError::MismatchedTypeReturnValue { - expected_signature: expected_signature.clone(), + expected_signature: *expected_signature, found_ty, found_expr: *found_expr, }, @@ -135,7 +135,7 @@ fn build_unify_error_from_unify_purpose( } => InferenceError::MismatchedTypeReturnValue { found_ty, found_expr: *found_expr, - expected_signature: expected_signature.clone(), + expected_signature: *expected_signature, }, } } diff --git a/crates/hir_ty/src/inference/types.rs b/crates/hir_ty/src/inference/types.rs index fc368242..aa92d7e3 100644 --- a/crates/hir_ty/src/inference/types.rs +++ b/crates/hir_ty/src/inference/types.rs @@ -1,6 +1,7 @@ -use std::{collections::HashSet, fmt}; +use std::collections::HashSet; use super::{environment::Context, type_scheme::TypeSubstitution, Signature}; +use crate::HirTyMasterDatabase; #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Monotype { @@ -10,39 +11,11 @@ pub enum Monotype { Char, String, Variable(u32), - Function(Box), + Function(Signature), Never, Unknown, } -impl fmt::Display for Monotype { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { - match self { - Monotype::Integer => write!(f, "int"), - Monotype::Bool => write!(f, "bool"), - Monotype::Char => write!(f, "char"), - Monotype::String => write!(f, "string"), - Monotype::Unit => write!(f, "()"), - Monotype::Never => write!(f, "!"), - Monotype::Unknown => write!(f, "unknown"), - Monotype::Variable(id) => write!(f, "{}", id), - Monotype::Function(signature) => { - write!( - f, - "({}) -> {}", - signature - .params - .iter() - .map(|ty| ty.to_string()) - .collect::>() - .join(", "), - signature.return_type.to_string() - ) - } - } - } -} - impl Monotype { pub fn gen_variable(cxt: &mut Context) -> Self { let monotype = Self::Variable(cxt.gen_counter); @@ -50,7 +23,7 @@ impl Monotype { monotype } - pub fn free_variables(&self) -> HashSet { + pub fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { match self { Monotype::Variable(id) => { let mut set = HashSet::new(); @@ -60,10 +33,10 @@ impl Monotype { } Monotype::Function(signature) => { let mut set = HashSet::new(); - for arg in signature.params.iter() { - set.extend(arg.free_variables()); + for arg in signature.params(db).iter() { + set.extend(arg.free_variables(db)); } - set.extend(signature.return_type.free_variables()); + set.extend(signature.return_type(db).free_variables(db)); set } _ => Default::default(), @@ -86,17 +59,8 @@ impl Monotype { self.clone() } } - Monotype::Function(signagure) => Monotype::Function( - Signature { - params: signagure - .params - .iter() - .map(|arg| arg.apply(subst)) - .collect::>(), - return_type: signagure.return_type.apply(subst), - } - .into(), - ), + // 関数シグネチャに自由変数を持たないので何もしない + Monotype::Function(_signagure) => self.clone(), } } } diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 9528a706..8c1d19f9 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -15,14 +15,16 @@ // #![warn(missing_docs)] mod checker; +mod db; mod inference; +mod testing; pub use checker::{TypeCheckError, TypeCheckResult}; -use inference::Signature; -pub use inference::{InferenceBodyResult, InferenceResult}; +pub use db::{HirTyMasterDatabase, Jar}; +pub use inference::{InferenceBodyResult, InferenceResult, Signature}; /// HIRを元にTypedHIRを構築します。 -pub fn lower_pods(db: &dyn hir::HirMasterDatabase, pods: &hir::Pods) -> TyLowerResult { +pub fn lower_pods(db: &dyn HirTyMasterDatabase, pods: &hir::Pods) -> TyLowerResult { let inference_result = inference::infer_pods(db, pods); let type_check_result = checker::check_type_pods(db, pods, &inference_result); @@ -70,12 +72,14 @@ mod tests { use expect_test::{expect, Expect}; use crate::{ - inference::{InferenceError, Signature}, - lower_pods, TyLowerResult, TypeCheckError, + inference::{InferenceError, Monotype, Signature}, + lower_pods, + testing::TestingDatabase, + HirTyMasterDatabase, TyLowerResult, TypeCheckError, }; fn check_pod_start_with_root_file(fixture: &str, expect: Expect) { - let db = hir::TestingDatabase::default(); + let db = TestingDatabase::default(); let mut source_db = hir::FixtureDatabase::new(&db, fixture); let pods = hir::parse_pods(&db, "/main.nail", &mut source_db); @@ -96,13 +100,13 @@ mod tests { } struct TestingDebug<'a> { - db: &'a dyn hir::HirMasterDatabase, + db: &'a dyn HirTyMasterDatabase, pods: &'a hir::Pods, ty_lower_result: &'a TyLowerResult, } impl<'a> TestingDebug<'a> { fn new( - db: &'a dyn hir::HirMasterDatabase, + db: &'a dyn HirTyMasterDatabase, pods: &'a hir::Pods, ty_lower_result: &'a TyLowerResult, ) -> Self { @@ -145,8 +149,8 @@ mod tests { msg.push_str( &format!( "error MismatchedTypes: expected_ty: {}, found_ty: {}, expected_expr: {}, found_expr: {}", - expected_ty, - found_ty, + self.debug_monotype(expected_ty), + self.debug_monotype(found_ty), self.debug_simplify_expr(hir_file, *expected_expr), self.debug_simplify_expr(hir_file, *found_expr), )); @@ -159,9 +163,9 @@ mod tests { msg.push_str( &format!( "error MismatchedTypeIfCondition: expected_ty: {}, found_expr: {}, found_ty: {}", - expected_ty, + self.debug_monotype(expected_ty), self.debug_simplify_expr(hir_file, *found_expr), - found_ty, + self.debug_monotype(found_ty), )); } InferenceError::MismatchedTypeElseBranch { @@ -173,9 +177,9 @@ mod tests { msg.push_str( &format!( "error MismatchedTypeElseBranch: then_branch_ty: {}, then_branch: {}, else_branch_ty: {}, else_branch: {}", - then_branch_ty, + self.debug_monotype(then_branch_ty), self.debug_simplify_expr(hir_file, *then_branch), - else_branch_ty, + self.debug_monotype(else_branch_ty), self.debug_simplify_expr(hir_file, *else_branch), )); } @@ -186,7 +190,7 @@ mod tests { msg.push_str( &format!( "error MismatchedTypeOnlyIfBranch: then_branch_ty: {}, then_branch: {}", - then_branch_ty, + self.debug_monotype(then_branch_ty), self.debug_simplify_expr(hir_file, *then_branch), )); } @@ -199,10 +203,10 @@ mod tests { msg.push_str( &format!( "error MismaatchedSignature: expected_ty: {}, signature: {}, found_expr: {}, found_ty: {}", - expected_ty, + self.debug_monotype(expected_ty), self.debug_signature(signature), self.debug_simplify_expr(hir_file, *found_expr), - found_ty, + self.debug_monotype(found_ty), )); } InferenceError::MismatchedBinaryInteger { @@ -214,9 +218,9 @@ mod tests { msg.push_str( &format!( "error MismatchedBinaryInteger: expected_ty: {}, found_expr: {}, found_ty: {}, op: {}", - expected_ty, + self.debug_monotype(expected_ty), self.debug_simplify_expr(hir_file, *found_expr), - found_ty, + self.debug_monotype(found_ty), self.debug_binary_op(op), )); } @@ -230,10 +234,10 @@ mod tests { msg.push_str( &format!( "error MismatchedBinaryCompare: expected_ty: {}, expected_expr: {}, found_expr: {}, found_ty: {}, op: {}", - expected_ty, + self.debug_monotype(expected_ty), self.debug_simplify_expr(hir_file, *expected_expr), self.debug_simplify_expr(hir_file, *found_expr), - found_ty, + self.debug_monotype(found_ty), self.debug_binary_op(op), )); } @@ -246,9 +250,9 @@ mod tests { msg.push_str( &format!( "error MismatchedUnary: expected_ty: {}, found_expr: {}, found_ty: {}, op: {}", - expected_ty, + self.debug_monotype(expected_ty), self.debug_simplify_expr(hir_file, *found_expr), - found_ty, + self.debug_monotype(found_ty), self.debug_unary_op(op), )); } @@ -261,14 +265,15 @@ mod tests { msg.push_str( &format!( "error MismatchedReturnType: expected_ty: {}, found_expr: {}, found_ty: {}", - expected_signature.return_type, + self.debug_monotype(&expected_signature.return_type(self.db)), self.debug_simplify_expr(hir_file, *found_expr), - found_ty, + self.debug_monotype(found_ty), )); } else { msg.push_str(&format!( "error MismatchedReturnType: expected_ty: {}, found_ty: {}", - expected_signature.return_type, found_ty, + self.debug_monotype(&expected_signature.return_type(self.db)), + self.debug_monotype(found_ty), )); } } @@ -292,8 +297,8 @@ mod tests { )); } } + msg.push('\n'); } - msg.push('\n'); } msg @@ -301,12 +306,12 @@ mod tests { fn debug_signature(&self, signature: &Signature) -> String { let params = signature - .params + .params(self.db) .iter() - .map(|param| param.to_string()) + .map(|param| self.debug_monotype(param)) .collect::>() .join(", "); - let return_type = signature.return_type.to_string(); + let return_type = self.debug_monotype(&signature.return_type(self.db)); format!("({params}) -> {return_type}") } @@ -480,7 +485,21 @@ mod tests { .get(&expr_id) .unwrap(); - format!("//: {ty}") + format!("//: {}", self.debug_monotype(ty)) + } + + fn debug_monotype(&self, monotype: &Monotype) -> String { + match monotype { + Monotype::Integer => "int".to_string(), + Monotype::Bool => "bool".to_string(), + Monotype::Unit => "()".to_string(), + Monotype::Char => "char".to_string(), + Monotype::String => "string".to_string(), + Monotype::Variable(id) => format!("${}", id), + Monotype::Function(signature) => self.debug_signature(signature), + Monotype::Never => "!".to_string(), + Monotype::Unknown => "".to_string(), + } } fn debug_expr( @@ -768,6 +787,7 @@ mod tests { } --- + --- "#]], ); } @@ -787,6 +807,7 @@ mod tests { } --- + --- "#]], ); } @@ -806,6 +827,7 @@ mod tests { } --- + --- "#]], ); } @@ -825,6 +847,7 @@ mod tests { } --- + --- "#]], ); } @@ -844,6 +867,7 @@ mod tests { } --- + --- "#]], ); @@ -860,6 +884,7 @@ mod tests { } --- + --- "#]], ); } @@ -879,6 +904,7 @@ mod tests { } --- + --- "#]], ) } @@ -904,6 +930,7 @@ mod tests { } --- + --- "#]], ) } @@ -983,6 +1010,7 @@ mod tests { error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: / error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: + error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + + --- "#]], ); @@ -1001,6 +1029,7 @@ mod tests { --- error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + + --- "#]], ); } @@ -1041,6 +1070,7 @@ mod tests { error MismatchedUnary: expected_ty: bool, found_expr: 10, found_ty: int, op: ! error MismatchedUnary: expected_ty: bool, found_expr: "aaa", found_ty: string, op: ! error MismatchedUnary: expected_ty: bool, found_expr: 'a', found_ty: char, op: ! + --- "#]], ) } @@ -1064,6 +1094,7 @@ mod tests { } --- + --- "#]], ); } @@ -1085,6 +1116,7 @@ mod tests { } --- + --- "#]], ) } @@ -1108,6 +1140,7 @@ mod tests { } --- + --- "#]], ); @@ -1134,6 +1167,7 @@ mod tests { } --- + --- "#]], ); @@ -1160,6 +1194,7 @@ mod tests { } --- + --- "#]], ); } @@ -1180,6 +1215,7 @@ mod tests { --- error MismatchedReturnType: expected_ty: (), found_expr: 10, found_ty: () + --- "#]], ); @@ -1196,6 +1232,7 @@ mod tests { } --- + --- "#]], ); @@ -1222,6 +1259,7 @@ mod tests { } --- + --- "#]], ); } @@ -1250,6 +1288,7 @@ mod tests { --- error MismatchedReturnType: expected_ty: (), found_expr: { tail:{ tail:10 } }, found_ty: () + --- "#]], ); @@ -1274,6 +1313,7 @@ mod tests { } --- + --- "#]], ); @@ -1298,6 +1338,7 @@ mod tests { } --- + --- "#]], ); } @@ -1319,6 +1360,7 @@ mod tests { } --- + --- "#]], ); } @@ -1340,6 +1382,7 @@ mod tests { } --- + --- "#]], ); } @@ -1368,6 +1411,7 @@ mod tests { --- error MismatchedReturnType: expected_ty: (), found_expr: res + 30, found_ty: () + --- "#]], ); } @@ -1395,6 +1439,7 @@ mod tests { --- error MismaatchedSignature: expected_ty: string, signature: (bool, string) -> int, found_expr: "aaa", found_ty: bool error MismaatchedSignature: expected_ty: bool, signature: (bool, string) -> int, found_expr: true, found_ty: string + --- "#]], ); } @@ -1422,6 +1467,7 @@ mod tests { } --- + --- "#]], ); } @@ -1446,6 +1492,7 @@ mod tests { --- error MismatchedTypeOnlyIfBranch: then_branch_ty: (), then_branch: { tail:10 } + --- "#]], ); } @@ -1469,6 +1516,7 @@ mod tests { } --- + --- "#]], ); @@ -1487,6 +1535,7 @@ mod tests { } --- + --- "#]], ); } @@ -1516,6 +1565,7 @@ mod tests { --- error MismatchedTypeElseBranch: then_branch_ty: int, then_branch: { tail:10 }, else_branch_ty: string, else_branch: { tail:"aaa" } error MismatchedReturnType: expected_ty: (), found_expr: if true { tail:10 } else { tail:"aaa" }, found_ty: () + --- "#]], ); } @@ -1545,6 +1595,7 @@ mod tests { --- error MismatchedTypeIfCondition: expected_ty: bool, found_expr: 10, found_ty: int error MismatchedReturnType: expected_ty: (), found_expr: if 10 { tail:"aaa" } else { tail:"aaa" }, found_ty: () + --- "#]], ); } @@ -1577,6 +1628,7 @@ mod tests { --- error MismatchedTypeElseBranch: then_branch_ty: (), then_branch: { tail:none }, else_branch_ty: bool, else_branch: { tail:true } + --- "#]], ); @@ -1606,6 +1658,7 @@ mod tests { --- error MismatchedTypeElseBranch: then_branch_ty: bool, then_branch: { tail:true }, else_branch_ty: (), else_branch: { tail:none } + --- "#]], ); @@ -1634,6 +1687,7 @@ mod tests { } --- + --- "#]], ); } @@ -1654,6 +1708,7 @@ mod tests { --- error MismatchedReturnType: expected_ty: (), found_expr: return, found_ty: () + --- "#]], ); @@ -1671,6 +1726,7 @@ mod tests { --- error MismatchedReturnType: expected_ty: int, found_expr: return 10, found_ty: int + --- "#]], ); } @@ -1692,6 +1748,7 @@ mod tests { --- error MismatchedReturnType: expected_ty: int, found_ty: int error MismatchedReturnType: expected_ty: int, found_expr: return, found_ty: int + --- "#]], ); @@ -1710,6 +1767,7 @@ mod tests { --- error MismatchedReturnType: expected_ty: int, found_expr: "aaa", found_ty: int error MismatchedReturnType: expected_ty: int, found_expr: return "aaa", found_ty: int + --- "#]], ); } @@ -1762,6 +1820,7 @@ mod tests { } --- + --- "#]], ); } @@ -1809,6 +1868,7 @@ mod tests { } --- + --- "#]], ); } diff --git a/crates/hir_ty/src/testing.rs b/crates/hir_ty/src/testing.rs new file mode 100644 index 00000000..4170752c --- /dev/null +++ b/crates/hir_ty/src/testing.rs @@ -0,0 +1,45 @@ +use std::sync::{Arc, Mutex}; + +use salsa::DebugWithDb; + +/// SalsaのDB +#[derive(Default)] +#[salsa::db(hir::Jar, crate::Jar)] +pub struct TestingDatabase { + storage: salsa::Storage, + + /// テスト用のログ + logs: Option>>>, +} + +impl TestingDatabase { + #[cfg(test)] + #[allow(dead_code)] + pub fn enable_logging(self) -> Self { + assert!(self.logs.is_none()); + Self { + storage: self.storage, + logs: Some(Default::default()), + } + } + + #[cfg(test)] + #[allow(dead_code)] + pub fn take_logs(&mut self) -> Vec { + if let Some(logs) = &self.logs { + std::mem::take(&mut *logs.lock().unwrap()) + } else { + panic!("logs not enabled"); + } + } +} + +impl salsa::Database for TestingDatabase { + fn salsa_event(&self, event: salsa::Event) { + if let Some(logs) = &self.logs { + logs.lock() + .unwrap() + .push(format!("Event: {:?}", event.debug(self))); + } + } +} From 4e2e7f56dd5f7cca4d8bea9ce3993e90b711d364 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 10:31:57 +0900 Subject: [PATCH 15/35] wip --- crates/hir_ty/src/lib.rs | 120 +++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 8c1d19f9..399e25a1 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -162,10 +162,10 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedTypeIfCondition: expected_ty: {}, found_expr: {}, found_ty: {}", + "error MismatchedTypeIfCondition: expected_ty: {}, found_ty: {}, found_expr: {}", self.debug_monotype(expected_ty), - self.debug_simplify_expr(hir_file, *found_expr), self.debug_monotype(found_ty), + self.debug_simplify_expr(hir_file, *found_expr), )); } InferenceError::MismatchedTypeElseBranch { @@ -176,10 +176,10 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedTypeElseBranch: then_branch_ty: {}, then_branch: {}, else_branch_ty: {}, else_branch: {}", + "error MismatchedTypeElseBranch: then_branch_ty: {}, else_branch_ty: {}, then_branch: {}, else_branch: {}", self.debug_monotype(then_branch_ty), - self.debug_simplify_expr(hir_file, *then_branch), self.debug_monotype(else_branch_ty), + self.debug_simplify_expr(hir_file, *then_branch), self.debug_simplify_expr(hir_file, *else_branch), )); } @@ -202,11 +202,11 @@ mod tests { } => { msg.push_str( &format!( - "error MismaatchedSignature: expected_ty: {}, signature: {}, found_expr: {}, found_ty: {}", + "error MismaatchedSignature: expected_ty: {}, found_ty: {}, found_expr: {}, signature: {}", self.debug_monotype(expected_ty), - self.debug_signature(signature), - self.debug_simplify_expr(hir_file, *found_expr), self.debug_monotype(found_ty), + self.debug_simplify_expr(hir_file, *found_expr), + self.debug_signature(signature), )); } InferenceError::MismatchedBinaryInteger { @@ -217,11 +217,11 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedBinaryInteger: expected_ty: {}, found_expr: {}, found_ty: {}, op: {}", + "error MismatchedBinaryInteger: op: {}, expected_ty: {}, found_ty: {}, found_expr: {}", + self.debug_binary_op(op), self.debug_monotype(expected_ty), - self.debug_simplify_expr(hir_file, *found_expr), self.debug_monotype(found_ty), - self.debug_binary_op(op), + self.debug_simplify_expr(hir_file, *found_expr), )); } InferenceError::MismatchedBinaryCompare { @@ -233,12 +233,12 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedBinaryCompare: expected_ty: {}, expected_expr: {}, found_expr: {}, found_ty: {}, op: {}", + "error MismatchedBinaryCompare: op: {}, expected_ty: {}, found_ty: {}, expected_expr: {}, found_expr: {}", + self.debug_binary_op(op), self.debug_monotype(expected_ty), + self.debug_monotype(found_ty), self.debug_simplify_expr(hir_file, *expected_expr), self.debug_simplify_expr(hir_file, *found_expr), - self.debug_monotype(found_ty), - self.debug_binary_op(op), )); } InferenceError::MismatchedUnary { @@ -249,11 +249,11 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedUnary: expected_ty: {}, found_expr: {}, found_ty: {}, op: {}", + "error MismatchedUnary: op: {}, expected_ty: {}, found_ty: {}, found_expr: {}", + self.debug_unary_op(op), self.debug_monotype(expected_ty), - self.debug_simplify_expr(hir_file, *found_expr), self.debug_monotype(found_ty), - self.debug_unary_op(op), + self.debug_simplify_expr(hir_file, *found_expr), )); } InferenceError::MismatchedTypeReturnValue { @@ -264,10 +264,10 @@ mod tests { if let Some(found_expr) = found_expr { msg.push_str( &format!( - "error MismatchedReturnType: expected_ty: {}, found_expr: {}, found_ty: {}", + "error MismatchedReturnType: expected_ty: {}, found_ty: {}, found_expr: {}", self.debug_monotype(&expected_signature.return_type(self.db)), - self.debug_simplify_expr(hir_file, *found_expr), self.debug_monotype(found_ty), + self.debug_simplify_expr(hir_file, *found_expr), )); } else { msg.push_str(&format!( @@ -992,24 +992,24 @@ mod tests { } --- - error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + - error MismatchedBinaryInteger: expected_ty: int, found_expr: "bbb", found_ty: string, op: + - error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + - error MismatchedBinaryInteger: expected_ty: int, found_expr: 'a', found_ty: char, op: + - error MismatchedBinaryInteger: expected_ty: int, found_expr: 'a', found_ty: char, op: + - error MismatchedBinaryInteger: expected_ty: int, found_expr: 'a', found_ty: char, op: + - error MismatchedBinaryCompare: expected_ty: int, expected_expr: 10, found_expr: 'a', found_ty: char, op: < - error MismatchedBinaryCompare: expected_ty: int, expected_expr: 10, found_expr: 'a', found_ty: char, op: > - error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: + - error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: + - error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: - - error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: - - error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: * - error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: * - error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: / - error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: / - error MismatchedBinaryInteger: expected_ty: int, found_expr: true, found_ty: bool, op: + - error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "aaa" + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "bbb" + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "aaa" + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: char, found_expr: 'a' + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: char, found_expr: 'a' + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: char, found_expr: 'a' + error MismatchedBinaryCompare: op: <, expected_ty: int, found_ty: char, expected_expr: 10, found_expr: 'a' + error MismatchedBinaryCompare: op: >, expected_ty: int, found_ty: char, expected_expr: 10, found_expr: 'a' + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: bool, found_expr: true + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: bool, found_expr: true + error MismatchedBinaryInteger: op: -, expected_ty: int, found_ty: bool, found_expr: true + error MismatchedBinaryInteger: op: -, expected_ty: int, found_ty: bool, found_expr: true + error MismatchedBinaryInteger: op: *, expected_ty: int, found_ty: bool, found_expr: true + error MismatchedBinaryInteger: op: *, expected_ty: int, found_ty: bool, found_expr: true + error MismatchedBinaryInteger: op: /, expected_ty: int, found_ty: bool, found_expr: true + error MismatchedBinaryInteger: op: /, expected_ty: int, found_ty: bool, found_expr: true + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: bool, found_expr: true + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "aaa" --- "#]], ); @@ -1027,8 +1027,8 @@ mod tests { } --- - error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + - error MismatchedBinaryInteger: expected_ty: int, found_expr: "aaa", found_ty: string, op: + + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "aaa" + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "aaa" --- "#]], ); @@ -1064,12 +1064,12 @@ mod tests { } --- - error MismatchedUnary: expected_ty: int, found_expr: "aaa", found_ty: string, op: - - error MismatchedUnary: expected_ty: int, found_expr: 'a', found_ty: char, op: - - error MismatchedUnary: expected_ty: int, found_expr: true, found_ty: bool, op: - - error MismatchedUnary: expected_ty: bool, found_expr: 10, found_ty: int, op: ! - error MismatchedUnary: expected_ty: bool, found_expr: "aaa", found_ty: string, op: ! - error MismatchedUnary: expected_ty: bool, found_expr: 'a', found_ty: char, op: ! + error MismatchedUnary: op: -, expected_ty: int, found_ty: string, found_expr: "aaa" + error MismatchedUnary: op: -, expected_ty: int, found_ty: char, found_expr: 'a' + error MismatchedUnary: op: -, expected_ty: int, found_ty: bool, found_expr: true + error MismatchedUnary: op: !, expected_ty: bool, found_ty: int, found_expr: 10 + error MismatchedUnary: op: !, expected_ty: bool, found_ty: string, found_expr: "aaa" + error MismatchedUnary: op: !, expected_ty: bool, found_ty: char, found_expr: 'a' --- "#]], ) @@ -1214,7 +1214,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_expr: 10, found_ty: () + error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: 10 --- "#]], ); @@ -1287,7 +1287,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_expr: { tail:{ tail:10 } }, found_ty: () + error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: { tail:{ tail:10 } } --- "#]], ); @@ -1410,7 +1410,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_expr: res + 30, found_ty: () + error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: res + 30 --- "#]], ); @@ -1437,8 +1437,8 @@ mod tests { } --- - error MismaatchedSignature: expected_ty: string, signature: (bool, string) -> int, found_expr: "aaa", found_ty: bool - error MismaatchedSignature: expected_ty: bool, signature: (bool, string) -> int, found_expr: true, found_ty: string + error MismaatchedSignature: expected_ty: string, found_ty: bool, found_expr: "aaa", signature: (bool, string) -> int + error MismaatchedSignature: expected_ty: bool, found_ty: string, found_expr: true, signature: (bool, string) -> int --- "#]], ); @@ -1563,8 +1563,8 @@ mod tests { } --- - error MismatchedTypeElseBranch: then_branch_ty: int, then_branch: { tail:10 }, else_branch_ty: string, else_branch: { tail:"aaa" } - error MismatchedReturnType: expected_ty: (), found_expr: if true { tail:10 } else { tail:"aaa" }, found_ty: () + error MismatchedTypeElseBranch: then_branch_ty: int, else_branch_ty: string, then_branch: { tail:10 }, else_branch: { tail:"aaa" } + error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: if true { tail:10 } else { tail:"aaa" } --- "#]], ); @@ -1593,8 +1593,8 @@ mod tests { } --- - error MismatchedTypeIfCondition: expected_ty: bool, found_expr: 10, found_ty: int - error MismatchedReturnType: expected_ty: (), found_expr: if 10 { tail:"aaa" } else { tail:"aaa" }, found_ty: () + error MismatchedTypeIfCondition: expected_ty: bool, found_ty: int, found_expr: 10 + error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: if 10 { tail:"aaa" } else { tail:"aaa" } --- "#]], ); @@ -1627,7 +1627,7 @@ mod tests { } --- - error MismatchedTypeElseBranch: then_branch_ty: (), then_branch: { tail:none }, else_branch_ty: bool, else_branch: { tail:true } + error MismatchedTypeElseBranch: then_branch_ty: (), else_branch_ty: bool, then_branch: { tail:none }, else_branch: { tail:true } --- "#]], ); @@ -1657,7 +1657,7 @@ mod tests { } --- - error MismatchedTypeElseBranch: then_branch_ty: bool, then_branch: { tail:true }, else_branch_ty: (), else_branch: { tail:none } + error MismatchedTypeElseBranch: then_branch_ty: bool, else_branch_ty: (), then_branch: { tail:true }, else_branch: { tail:none } --- "#]], ); @@ -1707,7 +1707,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_expr: return, found_ty: () + error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: return --- "#]], ); @@ -1725,7 +1725,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: int, found_expr: return 10, found_ty: int + error MismatchedReturnType: expected_ty: int, found_ty: int, found_expr: return 10 --- "#]], ); @@ -1747,7 +1747,7 @@ mod tests { --- error MismatchedReturnType: expected_ty: int, found_ty: int - error MismatchedReturnType: expected_ty: int, found_expr: return, found_ty: int + error MismatchedReturnType: expected_ty: int, found_ty: int, found_expr: return --- "#]], ); @@ -1765,8 +1765,8 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: int, found_expr: "aaa", found_ty: int - error MismatchedReturnType: expected_ty: int, found_expr: return "aaa", found_ty: int + error MismatchedReturnType: expected_ty: int, found_ty: int, found_expr: "aaa" + error MismatchedReturnType: expected_ty: int, found_ty: int, found_expr: return "aaa" --- "#]], ); From ee539a7cdd87a7a1d45e811e121a0518e6edcc31 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 10:32:27 +0900 Subject: [PATCH 16/35] wip --- crates/hir_ty/src/inference/environment.rs | 2 +- crates/hir_ty/src/inference/type_scheme.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index a379b071..66f3c0a5 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -127,7 +127,7 @@ impl<'a> InferBody<'a> { hir::Symbol::Local { name, expr: _ } => { let ty_scheme = self.current_scope().bindings.get(name).cloned(); if let Some(ty_scheme) = ty_scheme { - ty_scheme.instantiate(self.db, &mut self.cxt) + ty_scheme.instantiate(&mut self.cxt) } else { panic!("Unbound variable {symbol:?}"); } diff --git a/crates/hir_ty/src/inference/type_scheme.rs b/crates/hir_ty/src/inference/type_scheme.rs index 2f0fb855..90fc77ea 100644 --- a/crates/hir_ty/src/inference/type_scheme.rs +++ b/crates/hir_ty/src/inference/type_scheme.rs @@ -30,7 +30,7 @@ impl TypeScheme { } /// 具体的な型を生成する - pub fn instantiate(&self, db: &dyn HirTyMasterDatabase, cxt: &mut Context) -> Monotype { + pub fn instantiate(&self, cxt: &mut Context) -> Monotype { let new_vars = self .variables .iter() From ab6b0a9caa85bfe32dc0459bfc972e9540c25539 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 13:20:31 +0900 Subject: [PATCH 17/35] wip --- crates/hir_ty/src/inference/environment.rs | 38 +++---- crates/hir_ty/src/inference/error.rs | 33 +++--- crates/hir_ty/src/inference/type_unifier.rs | 117 ++++++++++---------- crates/hir_ty/src/lib.rs | 64 +++++------ 4 files changed, 125 insertions(+), 127 deletions(-) diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 66f3c0a5..d4e3e681 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -65,21 +65,21 @@ impl<'a> InferBody<'a> { if let Some(tail) = &body.tail { let ty = self.infer_expr(*tail); self.unifier.unify( - &ty, &self.signature.return_type(self.db), - &UnifyPurpose::SelfReturnType { + &ty, + &UnifyPurpose::ReturnValue { expected_signature: self.signature, - found_expr: Some(*tail), + found_return_expr: Some(*tail), }, ); } else { let ty = Monotype::Unit; self.unifier.unify( - &ty, &self.signature.return_type(self.db), - &UnifyPurpose::SelfReturnType { + &ty, + &UnifyPurpose::ReturnValue { expected_signature: self.signature, - found_expr: None, + found_return_expr: None, }, ); }; @@ -202,17 +202,17 @@ impl<'a> InferBody<'a> { // TODO: 引数の数が異なるエラーを追加 Monotype::Unknown } else { - for ((call_arg, call_arg_ty), arg) in call_args + for ((call_arg, call_arg_ty), signature_arg_ty) in call_args .iter() .zip(call_args_ty) .zip(signature.params(self.db)) { self.unifier.unify( + signature_arg_ty, &call_arg_ty, - arg, &UnifyPurpose::CallArg { found_arg: *call_arg, - expected_signature: signature, + callee_signature: signature, }, ); } @@ -257,8 +257,8 @@ impl<'a> InferBody<'a> { &lhs_ty, &rhs_ty, &UnifyPurpose::BinaryCompare { - expected_expr: *lhs, - found_expr: *rhs, + found_compare_from_expr: *lhs, + found_compare_to_expr: *rhs, op: op.clone(), }, ); @@ -322,7 +322,7 @@ impl<'a> InferBody<'a> { &Monotype::Bool, &condition_ty, &UnifyPurpose::IfCondition { - found_expr: *condition, + found_condition_expr: *condition, }, ); @@ -333,17 +333,17 @@ impl<'a> InferBody<'a> { &then_ty, &else_ty, &UnifyPurpose::IfThenElseBranch { - expected_expr: *then_branch, - found_expr: *else_branch, + found_then_branch_expr: *then_branch, + found_else_branch_expr: *else_branch, }, ); } else { // elseブランチがない場合は Unit として扱う self.unifier.unify( - &then_ty, &Monotype::Unit, + &then_ty, &UnifyPurpose::IfThenOnlyBranch { - found_expr: *then_branch, + found_then_branch_expr: *then_branch, }, ); } @@ -354,11 +354,11 @@ impl<'a> InferBody<'a> { if let Some(return_value) = value { let ty = self.infer_expr(*return_value); self.unifier.unify( - &ty, &self.signature.return_type(self.db), + &ty, &UnifyPurpose::ReturnValue { expected_signature: self.signature, - found_expr: Some(*return_value), + found_return_expr: Some(*return_value), }, ); } else { @@ -368,7 +368,7 @@ impl<'a> InferBody<'a> { &self.signature.return_type(self.db), &UnifyPurpose::ReturnValue { expected_signature: self.signature, - found_expr: None, + found_return_expr: None, }, ); } diff --git a/crates/hir_ty/src/inference/error.rs b/crates/hir_ty/src/inference/error.rs index 0a431bde..beb68922 100644 --- a/crates/hir_ty/src/inference/error.rs +++ b/crates/hir_ty/src/inference/error.rs @@ -7,22 +7,21 @@ pub enum InferenceError { MismatchedTypes { /// 期待される型 expected_ty: Monotype, - /// 実際の型 - found_ty: Monotype, - /// 期待される式 expected_expr: hir::ExprId, + /// 実際の型 + found_ty: Monotype, /// 実際の式 found_expr: hir::ExprId, }, /// Ifの条件式の型が一致しない MismatchedTypeIfCondition { /// 期待される型 - expected_ty: Monotype, - /// 実際の式 - found_expr: hir::ExprId, + expected_condition_bool_ty: Monotype, /// 実際の型 - found_ty: Monotype, + found_condition_ty: Monotype, + /// 実際の式 + found_condition_expr: hir::ExprId, }, /// Ifのthenブランチとelseブランチの型が一致しない MismatchedTypeElseBranch { @@ -41,6 +40,8 @@ pub enum InferenceError { then_branch_ty: Monotype, /// thenブランチの式 then_branch: hir::ExprId, + /// elseブランチの型 + else_branch_unit_ty: Monotype, }, /// 関数呼び出しの引数の数が一致しない MismaatchedSignature { @@ -55,7 +56,7 @@ pub enum InferenceError { }, MismatchedBinaryInteger { /// 期待される型 - expected_ty: Monotype, + expected_int_ty: Monotype, /// 実際の式 found_expr: hir::ExprId, /// 実際の型 @@ -65,23 +66,23 @@ pub enum InferenceError { }, MismatchedBinaryCompare { /// 期待される型 - expected_ty: Monotype, + compare_from_ty: Monotype, /// 期待される型を持つ式 - expected_expr: hir::ExprId, - /// 実際の式 - found_expr: hir::ExprId, + compare_from_expr: hir::ExprId, /// 実際の型 - found_ty: Monotype, + compare_to_ty: Monotype, + /// 実際の式 + compare_to_expr: hir::ExprId, /// 演算子 op: ast::BinaryOp, }, MismatchedUnary { /// 期待される型 expected_ty: Monotype, - /// 実際の式 - found_expr: hir::ExprId, /// 実際の型 found_ty: Monotype, + /// 実際の式 + found_expr: hir::ExprId, /// 演算子 op: ast::UnaryOp, }, @@ -96,6 +97,6 @@ pub enum InferenceError { /// 実際の型 found_ty: Monotype, /// 実際の式 - found_expr: Option, + found_return_expr: Option, }, } diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index 997418f0..45d68857 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -9,56 +9,56 @@ pub(crate) struct TypeUnifier { } /// 型推論の目的 +/// +/// エラーを具体的に収集するために使用されます。 pub(crate) enum UnifyPurpose { CallArg { /// 関数呼び出し対象のシグネチャ - expected_signature: Signature, - /// 実際に得られた型を持つ式 + callee_signature: Signature, + /// 引数の式 found_arg: hir::ExprId, }, - SelfReturnType { - expected_signature: Signature, - found_expr: Option, - }, BinaryInteger { - /// 実際に得られた型を持つ式 + /// 数値演算子の対象式 found_expr: hir::ExprId, /// 演算子 op: ast::BinaryOp, }, BinaryCompare { - /// 期待する型を持つ式 - expected_expr: hir::ExprId, - /// 実際に得られた型を持つ式 - found_expr: hir::ExprId, + /// 比較元の式 + found_compare_from_expr: hir::ExprId, + /// 比較先の式 + found_compare_to_expr: hir::ExprId, /// 演算子 op: ast::BinaryOp, }, Unary { - /// 実際に得られた型を持つ式 + /// 単行演算子の対象式 found_expr: hir::ExprId, /// 演算子 op: ast::UnaryOp, }, IfCondition { - /// 実際に得られた型を持つ式 - found_expr: hir::ExprId, + /// If条件の式 + found_condition_expr: hir::ExprId, }, IfThenElseBranch { - /// 期待する型を持つ式 - expected_expr: hir::ExprId, - /// 実際に得られた型を持つ式 - found_expr: hir::ExprId, + /// Thenブランチの式 + found_then_branch_expr: hir::ExprId, + /// Elseブランチの式 + found_else_branch_expr: hir::ExprId, }, IfThenOnlyBranch { - /// 実際に得られた型を持つ式 - found_expr: hir::ExprId, + /// Thenブランチの式 + found_then_branch_expr: hir::ExprId, }, ReturnValue { - /// 期待する戻り値の型を持つ関数シグネチャ + /// 期待する関数のシグネチャ expected_signature: Signature, - /// 実際に得られた型を持つ式 - found_expr: Option, + /// 戻り値の式 + /// + /// Noneの場合は`return;`のように指定なしを表す + found_return_expr: Option, }, } @@ -71,36 +71,28 @@ fn build_unify_error_from_unify_purpose( match purpose { UnifyPurpose::CallArg { found_arg, - expected_signature, + callee_signature: expected_signature, } => InferenceError::MismaatchedSignature { expected_ty, found_ty, signature: *expected_signature, found_expr: *found_arg, }, - UnifyPurpose::SelfReturnType { - expected_signature, - found_expr, - } => InferenceError::MismatchedTypeReturnValue { - expected_signature: *expected_signature, - found_ty, - found_expr: *found_expr, - }, UnifyPurpose::BinaryInteger { found_expr, op } => InferenceError::MismatchedBinaryInteger { - expected_ty, + expected_int_ty: expected_ty, found_ty, found_expr: *found_expr, op: op.clone(), }, UnifyPurpose::BinaryCompare { - expected_expr, - found_expr, + found_compare_from_expr, + found_compare_to_expr, op, } => InferenceError::MismatchedBinaryCompare { - expected_ty, - found_ty, - expected_expr: *expected_expr, - found_expr: *found_expr, + compare_from_ty: expected_ty, + compare_to_ty: found_ty, + compare_from_expr: *found_compare_from_expr, + compare_to_expr: *found_compare_to_expr, op: op.clone(), }, UnifyPurpose::Unary { found_expr, op } => InferenceError::MismatchedUnary { @@ -109,32 +101,35 @@ fn build_unify_error_from_unify_purpose( found_expr: *found_expr, op: op.clone(), }, - UnifyPurpose::IfCondition { found_expr } => InferenceError::MismatchedTypeIfCondition { - expected_ty, - found_ty, - found_expr: *found_expr, + UnifyPurpose::IfCondition { + found_condition_expr, + } => InferenceError::MismatchedTypeIfCondition { + expected_condition_bool_ty: expected_ty, + found_condition_ty: found_ty, + found_condition_expr: *found_condition_expr, }, UnifyPurpose::IfThenElseBranch { - expected_expr, - found_expr, + found_then_branch_expr, + found_else_branch_expr, } => InferenceError::MismatchedTypeElseBranch { then_branch_ty: expected_ty, - then_branch: *expected_expr, + then_branch: *found_then_branch_expr, else_branch_ty: found_ty, - else_branch: *found_expr, + else_branch: *found_else_branch_expr, + }, + UnifyPurpose::IfThenOnlyBranch { + found_then_branch_expr, + } => InferenceError::MismatchedTypeOnlyIfBranch { + then_branch_ty: found_ty, + then_branch: *found_then_branch_expr, + else_branch_unit_ty: expected_ty, }, - UnifyPurpose::IfThenOnlyBranch { found_expr } => { - InferenceError::MismatchedTypeOnlyIfBranch { - then_branch_ty: found_ty, - then_branch: *found_expr, - } - } UnifyPurpose::ReturnValue { expected_signature, - found_expr, + found_return_expr, } => InferenceError::MismatchedTypeReturnValue { found_ty, - found_expr: *found_expr, + found_return_expr: *found_return_expr, expected_signature: *expected_signature, }, } @@ -155,9 +150,9 @@ impl TypeUnifier { } } - pub fn unify(&mut self, a: &Monotype, b: &Monotype, purpose: &UnifyPurpose) { - let a_rep = self.find(a); - let b_rep = self.find(b); + pub fn unify(&mut self, a_expected: &Monotype, b_actual: &Monotype, purpose: &UnifyPurpose) { + let a_rep = self.find(a_expected); + let b_rep = self.find(b_actual); if a_rep == b_rep { return; @@ -177,8 +172,8 @@ impl TypeUnifier { // self.unify(a_arg, b_arg, purpose); // } } - (Monotype::Variable(_), b_rep) => self.unify_var(&a_rep, b_rep), - (a_rep, Monotype::Variable(_)) => self.unify_var(&b_rep, a_rep), + (Monotype::Variable(_), b_rep) => self.unify_var(b_rep, &a_rep), + (a_rep, Monotype::Variable(_)) => self.unify_var(a_rep, &b_rep), (_, _) => { self.errors .push(build_unify_error_from_unify_purpose(a_rep, b_rep, purpose)); @@ -186,7 +181,7 @@ impl TypeUnifier { } } - fn unify_var(&mut self, type_var: &Monotype, term: &Monotype) { + fn unify_var(&mut self, term: &Monotype, type_var: &Monotype) { assert!(matches!(type_var, Monotype::Variable(_))); let value = Some(Box::new(self.nodes.get(term).unwrap().clone())); diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 399e25a1..adca3436 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -156,16 +156,16 @@ mod tests { )); } InferenceError::MismatchedTypeIfCondition { - expected_ty, - found_expr, - found_ty, + expected_condition_bool_ty, + found_condition_expr, + found_condition_ty, } => { msg.push_str( &format!( "error MismatchedTypeIfCondition: expected_ty: {}, found_ty: {}, found_expr: {}", - self.debug_monotype(expected_ty), - self.debug_monotype(found_ty), - self.debug_simplify_expr(hir_file, *found_expr), + self.debug_monotype(expected_condition_bool_ty), + self.debug_monotype(found_condition_ty), + self.debug_simplify_expr(hir_file, *found_condition_expr), )); } InferenceError::MismatchedTypeElseBranch { @@ -186,11 +186,13 @@ mod tests { InferenceError::MismatchedTypeOnlyIfBranch { then_branch_ty, then_branch, + else_branch_unit_ty, } => { msg.push_str( &format!( - "error MismatchedTypeOnlyIfBranch: then_branch_ty: {}, then_branch: {}", + "error MismatchedTypeOnlyIfBranch: then_branch_ty: {}, else_branch_ty: {}, then_branch: {}", self.debug_monotype(then_branch_ty), + self.debug_monotype(else_branch_unit_ty), self.debug_simplify_expr(hir_file, *then_branch), )); } @@ -210,7 +212,7 @@ mod tests { )); } InferenceError::MismatchedBinaryInteger { - expected_ty, + expected_int_ty, found_expr, found_ty, op, @@ -219,26 +221,26 @@ mod tests { &format!( "error MismatchedBinaryInteger: op: {}, expected_ty: {}, found_ty: {}, found_expr: {}", self.debug_binary_op(op), - self.debug_monotype(expected_ty), + self.debug_monotype(expected_int_ty), self.debug_monotype(found_ty), self.debug_simplify_expr(hir_file, *found_expr), )); } InferenceError::MismatchedBinaryCompare { - expected_ty, - expected_expr, - found_expr, - found_ty, + compare_from_ty, + compare_from_expr, + compare_to_expr, + compare_to_ty, op, } => { msg.push_str( &format!( "error MismatchedBinaryCompare: op: {}, expected_ty: {}, found_ty: {}, expected_expr: {}, found_expr: {}", self.debug_binary_op(op), - self.debug_monotype(expected_ty), - self.debug_monotype(found_ty), - self.debug_simplify_expr(hir_file, *expected_expr), - self.debug_simplify_expr(hir_file, *found_expr), + self.debug_monotype(compare_from_ty), + self.debug_monotype(compare_to_ty), + self.debug_simplify_expr(hir_file, *compare_from_expr), + self.debug_simplify_expr(hir_file, *compare_to_expr), )); } InferenceError::MismatchedUnary { @@ -259,7 +261,7 @@ mod tests { InferenceError::MismatchedTypeReturnValue { expected_signature, found_ty, - found_expr, + found_return_expr: found_expr, } => { if let Some(found_expr) = found_expr { msg.push_str( @@ -1214,7 +1216,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: 10 + error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: 10 --- "#]], ); @@ -1287,7 +1289,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: { tail:{ tail:10 } } + error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: { tail:{ tail:10 } } --- "#]], ); @@ -1410,7 +1412,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: res + 30 + error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: res + 30 --- "#]], ); @@ -1437,8 +1439,8 @@ mod tests { } --- - error MismaatchedSignature: expected_ty: string, found_ty: bool, found_expr: "aaa", signature: (bool, string) -> int - error MismaatchedSignature: expected_ty: bool, found_ty: string, found_expr: true, signature: (bool, string) -> int + error MismaatchedSignature: expected_ty: bool, found_ty: string, found_expr: "aaa", signature: (bool, string) -> int + error MismaatchedSignature: expected_ty: string, found_ty: bool, found_expr: true, signature: (bool, string) -> int --- "#]], ); @@ -1491,7 +1493,7 @@ mod tests { } --- - error MismatchedTypeOnlyIfBranch: then_branch_ty: (), then_branch: { tail:10 } + error MismatchedTypeOnlyIfBranch: then_branch_ty: int, else_branch_ty: (), then_branch: { tail:10 } --- "#]], ); @@ -1564,7 +1566,7 @@ mod tests { --- error MismatchedTypeElseBranch: then_branch_ty: int, else_branch_ty: string, then_branch: { tail:10 }, else_branch: { tail:"aaa" } - error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: if true { tail:10 } else { tail:"aaa" } + error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: if true { tail:10 } else { tail:"aaa" } --- "#]], ); @@ -1594,7 +1596,7 @@ mod tests { --- error MismatchedTypeIfCondition: expected_ty: bool, found_ty: int, found_expr: 10 - error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: if 10 { tail:"aaa" } else { tail:"aaa" } + error MismatchedReturnType: expected_ty: (), found_ty: string, found_expr: if 10 { tail:"aaa" } else { tail:"aaa" } --- "#]], ); @@ -1707,7 +1709,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_ty: (), found_expr: return + error MismatchedReturnType: expected_ty: (), found_ty: !, found_expr: return --- "#]], ); @@ -1725,7 +1727,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: int, found_ty: int, found_expr: return 10 + error MismatchedReturnType: expected_ty: int, found_ty: !, found_expr: return 10 --- "#]], ); @@ -1747,7 +1749,7 @@ mod tests { --- error MismatchedReturnType: expected_ty: int, found_ty: int - error MismatchedReturnType: expected_ty: int, found_ty: int, found_expr: return + error MismatchedReturnType: expected_ty: int, found_ty: !, found_expr: return --- "#]], ); @@ -1765,8 +1767,8 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: int, found_ty: int, found_expr: "aaa" - error MismatchedReturnType: expected_ty: int, found_ty: int, found_expr: return "aaa" + error MismatchedReturnType: expected_ty: int, found_ty: string, found_expr: "aaa" + error MismatchedReturnType: expected_ty: int, found_ty: !, found_expr: return "aaa" --- "#]], ); From 0e1c2f48615caae4aef53d0b1317f6eb86c6e374 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 13:49:31 +0900 Subject: [PATCH 18/35] wip --- crates/hir_ty/src/inference/environment.rs | 2 +- crates/hir_ty/src/lib.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index d4e3e681..8b0b6a9c 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -364,8 +364,8 @@ impl<'a> InferBody<'a> { } else { // 何も指定しない場合は Unit を返すものとして扱う self.unifier.unify( - &Monotype::Unit, &self.signature.return_type(self.db), + &Monotype::Unit, &UnifyPurpose::ReturnValue { expected_signature: self.signature, found_return_expr: None, diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index adca3436..fd0b48cd 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -1748,8 +1748,8 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: int, found_ty: int error MismatchedReturnType: expected_ty: int, found_ty: !, found_expr: return + error MismatchedReturnType: expected_ty: int, found_ty: () --- "#]], ); From b3dc777a3c92c11be0afb8e8213ed274743a684c Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 13:49:51 +0900 Subject: [PATCH 19/35] wip --- crates/hir_ty/src/inference/type_unifier.rs | 3 +++ crates/hir_ty/src/lib.rs | 10 +++------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index 45d68857..8b279914 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -174,6 +174,9 @@ impl TypeUnifier { } (Monotype::Variable(_), b_rep) => self.unify_var(b_rep, &a_rep), (a_rep, Monotype::Variable(_)) => self.unify_var(a_rep, &b_rep), + // 実際の型がNever型の場合は、期待する型がなんであれ到達しないので型チェックする必要がない + // 期待する型がNever型の場合は、、値が渡ってはいけないので型チェックするべき。途中でreturnした場合など改善の余地があると思われる。 + (_, Monotype::Never) => (), (_, _) => { self.errors .push(build_unify_error_from_unify_purpose(a_rep, b_rep, purpose)); diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index fd0b48cd..bbfe5fca 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -261,15 +261,15 @@ mod tests { InferenceError::MismatchedTypeReturnValue { expected_signature, found_ty, - found_return_expr: found_expr, + found_return_expr, } => { - if let Some(found_expr) = found_expr { + if let Some(found_return_expr) = found_return_expr { msg.push_str( &format!( "error MismatchedReturnType: expected_ty: {}, found_ty: {}, found_expr: {}", self.debug_monotype(&expected_signature.return_type(self.db)), self.debug_monotype(found_ty), - self.debug_simplify_expr(hir_file, *found_expr), + self.debug_simplify_expr(hir_file, *found_return_expr), )); } else { msg.push_str(&format!( @@ -1709,7 +1709,6 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_ty: !, found_expr: return --- "#]], ); @@ -1727,7 +1726,6 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: int, found_ty: !, found_expr: return 10 --- "#]], ); @@ -1748,7 +1746,6 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: int, found_ty: !, found_expr: return error MismatchedReturnType: expected_ty: int, found_ty: () --- "#]], @@ -1768,7 +1765,6 @@ mod tests { --- error MismatchedReturnType: expected_ty: int, found_ty: string, found_expr: "aaa" - error MismatchedReturnType: expected_ty: int, found_ty: !, found_expr: return "aaa" --- "#]], ); From 4cc3b4e2794a6efd8bf1058602afd3ae79a9b81d Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 13:57:35 +0900 Subject: [PATCH 20/35] wip --- crates/hir_ty/src/inference/environment.rs | 6 ++++-- crates/hir_ty/src/inference/error.rs | 18 ++++++++++-------- crates/hir_ty/src/inference/type_unifier.rs | 12 ++++++++---- crates/hir_ty/src/lib.rs | 16 +++++++++------- 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 8b0b6a9c..9dd8a086 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -202,8 +202,9 @@ impl<'a> InferBody<'a> { // TODO: 引数の数が異なるエラーを追加 Monotype::Unknown } else { - for ((call_arg, call_arg_ty), signature_arg_ty) in call_args + for (((arg_pos, call_arg), call_arg_ty), signature_arg_ty) in call_args .iter() + .enumerate() .zip(call_args_ty) .zip(signature.params(self.db)) { @@ -211,8 +212,9 @@ impl<'a> InferBody<'a> { signature_arg_ty, &call_arg_ty, &UnifyPurpose::CallArg { - found_arg: *call_arg, + found_arg_expr: *call_arg, callee_signature: signature, + arg_pos, }, ); } diff --git a/crates/hir_ty/src/inference/error.rs b/crates/hir_ty/src/inference/error.rs index beb68922..658d1b82 100644 --- a/crates/hir_ty/src/inference/error.rs +++ b/crates/hir_ty/src/inference/error.rs @@ -43,16 +43,18 @@ pub enum InferenceError { /// elseブランチの型 else_branch_unit_ty: Monotype, }, - /// 関数呼び出しの引数の数が一致しない - MismaatchedSignature { - /// 期待される型 + /// 関数呼び出しの型が一致しない + MismaatchedTypeCallArg { + /// 期待される引数の型 expected_ty: Monotype, - /// 呼び出そうとしている関数のシグネチャ - signature: Signature, - /// 実際の式 - found_expr: hir::ExprId, - /// 実際の型 + /// 実際の引数の型 found_ty: Monotype, + /// 期待される引数を持つ関数シグネチャ + expected_signature: Signature, + /// 実際の引数式 + found_expr: hir::ExprId, + /// 実際の引数の位置(0-indexed) + arg_pos: usize, }, MismatchedBinaryInteger { /// 期待される型 diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index 8b279914..efe51e19 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -16,7 +16,9 @@ pub(crate) enum UnifyPurpose { /// 関数呼び出し対象のシグネチャ callee_signature: Signature, /// 引数の式 - found_arg: hir::ExprId, + found_arg_expr: hir::ExprId, + /// 引数の位置 + arg_pos: usize, }, BinaryInteger { /// 数値演算子の対象式 @@ -70,13 +72,15 @@ fn build_unify_error_from_unify_purpose( ) -> InferenceError { match purpose { UnifyPurpose::CallArg { - found_arg, + found_arg_expr: found_arg, callee_signature: expected_signature, - } => InferenceError::MismaatchedSignature { + arg_pos, + } => InferenceError::MismaatchedTypeCallArg { expected_ty, found_ty, - signature: *expected_signature, + expected_signature: *expected_signature, found_expr: *found_arg, + arg_pos: *arg_pos, }, UnifyPurpose::BinaryInteger { found_expr, op } => InferenceError::MismatchedBinaryInteger { expected_int_ty: expected_ty, diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index bbfe5fca..3802600e 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -196,19 +196,21 @@ mod tests { self.debug_simplify_expr(hir_file, *then_branch), )); } - InferenceError::MismaatchedSignature { + InferenceError::MismaatchedTypeCallArg { expected_ty, - signature, - found_expr, found_ty, + expected_signature, + found_expr, + arg_pos, } => { msg.push_str( &format!( - "error MismaatchedSignature: expected_ty: {}, found_ty: {}, found_expr: {}, signature: {}", + "error MismaatchedSignature: expected_ty: {}, found_ty: {}, found_expr: {}, signature: {}, arg_pos: {}", self.debug_monotype(expected_ty), self.debug_monotype(found_ty), self.debug_simplify_expr(hir_file, *found_expr), - self.debug_signature(signature), + self.debug_signature(expected_signature), + arg_pos )); } InferenceError::MismatchedBinaryInteger { @@ -1439,8 +1441,8 @@ mod tests { } --- - error MismaatchedSignature: expected_ty: bool, found_ty: string, found_expr: "aaa", signature: (bool, string) -> int - error MismaatchedSignature: expected_ty: string, found_ty: bool, found_expr: true, signature: (bool, string) -> int + error MismaatchedSignature: expected_ty: bool, found_ty: string, found_expr: "aaa", signature: (bool, string) -> int, arg_pos: 0 + error MismaatchedSignature: expected_ty: string, found_ty: bool, found_expr: true, signature: (bool, string) -> int, arg_pos: 1 --- "#]], ); From b9b39abdbe00b548cba5ed9740aa40c0c8935c4e Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 14:08:56 +0900 Subject: [PATCH 21/35] wip --- crates/hir_ty/src/lib.rs | 7 ++++--- crates/mir/src/body.rs | 25 ++++++++++++------------- crates/mir/src/lib.rs | 31 ++++++++++++++++--------------- 3 files changed, 32 insertions(+), 31 deletions(-) diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 3802600e..052344e6 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -21,7 +21,8 @@ mod testing; pub use checker::{TypeCheckError, TypeCheckResult}; pub use db::{HirTyMasterDatabase, Jar}; -pub use inference::{InferenceBodyResult, InferenceResult, Signature}; +pub use inference::{InferenceBodyResult, InferenceError, InferenceResult, Monotype, Signature}; +pub use testing::TestingDatabase; /// HIRを元にTypedHIRを構築します。 pub fn lower_pods(db: &dyn HirTyMasterDatabase, pods: &hir::Pods) -> TyLowerResult { @@ -44,8 +45,8 @@ pub struct TyLowerResult { } impl TyLowerResult { /// 指定した関数の型を取得します。 - pub fn signature_by_function(&self, function_id: hir::Function) -> &Signature { - &self.inference_result.signature_by_function[&function_id] + pub fn signature_by_function(&self, function_id: hir::Function) -> Signature { + self.inference_result.signature_by_function[&function_id] } /// 指定した関数の型推論結果を取得します。 diff --git a/crates/mir/src/body.rs b/crates/mir/src/body.rs index c9b82ec2..2da9f863 100644 --- a/crates/mir/src/body.rs +++ b/crates/mir/src/body.rs @@ -8,7 +8,7 @@ use crate::{ }; pub(crate) struct FunctionLower<'a> { - db: &'a dyn hir::HirMasterDatabase, + db: &'a dyn hir_ty::HirTyMasterDatabase, resolution_map: &'a hir::ResolutionMap, hir_ty_result: &'a hir_ty::TyLowerResult, function_by_hir_function: &'a HashMap, @@ -31,7 +31,7 @@ pub(crate) struct FunctionLower<'a> { impl<'a> FunctionLower<'a> { pub(crate) fn new( - db: &'a dyn hir::HirMasterDatabase, + db: &'a dyn hir_ty::HirTyMasterDatabase, hir_file: hir::HirFile, resolution_map: &'a hir::ResolutionMap, hir_ty_result: &'a hir_ty::TyLowerResult, @@ -39,9 +39,9 @@ impl<'a> FunctionLower<'a> { function: hir::Function, ) -> Self { let mut locals = Arena::new(); - let signature = &hir_ty_result.signature_by_function(function); + let signature = hir_ty_result.signature_by_function(function); let return_local = locals.alloc(Local { - ty: signature.return_type, + ty: signature.return_type(db), idx: 0, }); @@ -83,21 +83,19 @@ impl<'a> FunctionLower<'a> { fn get_inference_by_function(&self, function: hir::Function) -> &hir_ty::InferenceBodyResult { self.hir_ty_result - .inference_result - .inference_by_body - .get(&function) + .inference_body_by_function(function) .unwrap() } fn alloc_local(&mut self, expr: hir::ExprId) -> Idx { - let ty = self.get_inference_by_function(self.function).type_by_expr[&expr]; - let local_idx = self.alloc_local_by_ty(ty); + let ty = &self.get_inference_by_function(self.function).type_by_expr[&expr]; + let local_idx = self.alloc_local_by_ty(ty.clone()); self.local_by_hir.insert(expr, local_idx); local_idx } - fn alloc_local_by_ty(&mut self, ty: hir_ty::ResolvedType) -> Idx { + fn alloc_local_by_ty(&mut self, ty: hir_ty::Monotype) -> Idx { let local = Local { ty, idx: self.local_idx, @@ -461,7 +459,8 @@ impl<'a> FunctionLower<'a> { let function_id = self.function_by_hir_function[&function]; let signature = self.hir_ty_result.signature_by_function(function); - let called_local = self.alloc_local_by_ty(signature.return_type); + let called_local = + self.alloc_local_by_ty(signature.return_type(self.db)); let dest_place = Place::Local(called_local); let target_bb = self.alloc_standard_bb(); @@ -525,10 +524,10 @@ impl<'a> FunctionLower<'a> { .function .params(self.db) .iter() - .zip(signature.params.iter()) + .zip(signature.params(self.db).iter()) { let param_idx = self.params.alloc(Param { - ty: *param_ty, + ty: param_ty.clone(), idx: self.local_idx, pos: param .data(self.hir_file.db(self.db)) diff --git a/crates/mir/src/lib.rs b/crates/mir/src/lib.rs index d4b111fd..053b69d6 100755 --- a/crates/mir/src/lib.rs +++ b/crates/mir/src/lib.rs @@ -12,18 +12,19 @@ //! 例えば、`if`式は、`then`ブロックと`else`ブロックを持つ基本ブロックに変換されます。 //! この構造は、LLVM IRの基本ブロックと制御フローに対応しており、LLVM IRへの変換が容易になります。 #![warn(missing_docs)] +#![feature(trait_upcasting)] mod body; use std::collections::HashMap; use body::FunctionLower; -use hir_ty::ResolvedType; +use hir_ty::Monotype; use la_arena::{Arena, Idx}; /// HIRとTyped HIRからMIRを構築する pub fn lower_pods( - db: &dyn hir::HirMasterDatabase, + db: &dyn hir_ty::HirTyMasterDatabase, pods: &hir::Pods, hir_ty_result: &hir_ty::TyLowerResult, ) -> LowerResult { @@ -181,7 +182,7 @@ pub struct Signature { #[derive(Debug)] pub struct Param { /// パラメータの型 - pub ty: ResolvedType, + pub ty: Monotype, /// パラメータのローカル変数のインデックス pub idx: u64, /// パラメータの位置 @@ -194,7 +195,7 @@ pub type ParamIdx = Idx; #[derive(Debug)] pub struct Local { /// ローカル変数の型 - pub ty: ResolvedType, + pub ty: Monotype, /// ローカル変数のインデックス pub idx: u64, } @@ -449,12 +450,11 @@ pub enum BasicBlockKind { #[cfg(test)] mod tests { use expect_test::{expect, Expect}; - use hir_ty::ResolvedType; use crate::lower_pods; fn check_pod_start_with_root_file(fixture: &str, expect: Expect) { - let db = hir::TestingDatabase::default(); + let db = hir_ty::TestingDatabase::default(); let mut source_db = hir::FixtureDatabase::new(&db, fixture); let pods = hir::parse_pods(&db, "/main.nail", &mut source_db); @@ -687,16 +687,17 @@ mod tests { .join(", ") } - fn debug_ty(ty: &ResolvedType) -> String { + fn debug_ty(ty: &hir_ty::Monotype) -> String { match ty { - ResolvedType::Unknown => "unknown", - ResolvedType::Integer => "int", - ResolvedType::String => "string", - ResolvedType::Char => "char", - ResolvedType::Bool => "bool", - ResolvedType::Unit => "()", - ResolvedType::Never => "!", - ResolvedType::Function(_) => todo!(), + hir_ty::Monotype::Unit => "()", + hir_ty::Monotype::Integer => "int", + hir_ty::Monotype::Bool => "bool", + hir_ty::Monotype::Char => "char", + hir_ty::Monotype::String => "string", + hir_ty::Monotype::Never => "!", + hir_ty::Monotype::Unknown => "unknown", + hir_ty::Monotype::Variable(_) => unreachable!(), + hir_ty::Monotype::Function(_) => todo!(), } .to_string() } From edd5b95087bb44a5989dcb5ea00d05cd539edc20 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 14:10:50 +0900 Subject: [PATCH 22/35] wip --- crates/codegen_llvm/src/body.rs | 17 +++++++++-------- crates/codegen_llvm/src/lib.rs | 20 ++++++++++---------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/crates/codegen_llvm/src/body.rs b/crates/codegen_llvm/src/body.rs index 1abc1332..dc551962 100644 --- a/crates/codegen_llvm/src/body.rs +++ b/crates/codegen_llvm/src/body.rs @@ -41,14 +41,15 @@ impl<'a, 'ctx> BodyCodegen<'a, 'ctx> { // register alloc locals. for (idx, local) in self.body.locals.iter() { let ty: BasicTypeEnum = match local.ty { - hir_ty::ResolvedType::Unit => self.codegen.unit_type().into(), - hir_ty::ResolvedType::Integer => self.codegen.integer_type().into(), - hir_ty::ResolvedType::String => self.codegen.string_type().into(), - hir_ty::ResolvedType::Bool => self.codegen.bool_type().into(), - hir_ty::ResolvedType::Char => todo!(), - hir_ty::ResolvedType::Never => todo!(), - hir_ty::ResolvedType::Function(_) => todo!(), - hir_ty::ResolvedType::Unknown => unreachable!(), + hir_ty::Monotype::Unit => self.codegen.unit_type().into(), + hir_ty::Monotype::Integer => self.codegen.integer_type().into(), + hir_ty::Monotype::String => self.codegen.string_type().into(), + hir_ty::Monotype::Bool => self.codegen.bool_type().into(), + hir_ty::Monotype::Char => todo!(), + hir_ty::Monotype::Never => todo!(), + hir_ty::Monotype::Function(_) => todo!(), + hir_ty::Monotype::Variable(_) => unreachable!(""), + hir_ty::Monotype::Unknown => unreachable!(), }; let local_ptr = self diff --git a/crates/codegen_llvm/src/lib.rs b/crates/codegen_llvm/src/lib.rs index a413ca16..e7a88c7c 100644 --- a/crates/codegen_llvm/src/lib.rs +++ b/crates/codegen_llvm/src/lib.rs @@ -153,16 +153,16 @@ impl<'a, 'ctx> Codegen<'a, 'ctx> { body.params .iter() .map(|(_, param)| match param.ty { - hir_ty::ResolvedType::Integer => { + hir_ty::Monotype::Integer => { BasicMetadataTypeEnum::IntType(self.context.i64_type()) } - hir_ty::ResolvedType::String => self + hir_ty::Monotype::String => self .context .i8_type() .vec_type(1) .ptr_type(AddressSpace::default()) .into(), - hir_ty::ResolvedType::Bool => self.context.bool_type().into(), + hir_ty::Monotype::Bool => self.context.bool_type().into(), _ => unimplemented!(), }) .collect::>() @@ -197,22 +197,22 @@ impl<'a, 'ctx> Codegen<'a, 'ctx> { fn gen_function_signatures(&mut self, db: &dyn hir::HirMasterDatabase) { for (idx, body) in self.mir_result.ref_bodies() { let params = self.body_to_params(body); - let return_type = body.locals[body.return_local].ty; + let return_type = body.locals[body.return_local].ty.clone(); let fn_ty: FunctionType<'ctx> = match return_type { - hir_ty::ResolvedType::Unit => { + hir_ty::Monotype::Unit => { let ty = self.context.struct_type(&[], false); ty.fn_type(¶ms, false) } - hir_ty::ResolvedType::Integer => { + hir_ty::Monotype::Integer => { let ty = self.context.i64_type(); ty.fn_type(¶ms, false) } - hir_ty::ResolvedType::String => { + hir_ty::Monotype::String => { let ty = self.string_type(); ty.fn_type(¶ms, false) } - hir_ty::ResolvedType::Bool => { + hir_ty::Monotype::Bool => { let ty = self.context.bool_type(); ty.fn_type(¶ms, false) } @@ -256,12 +256,12 @@ mod tests { fn lower( fixture: &str, ) -> ( - hir::TestingDatabase, + hir_ty::TestingDatabase, hir::Pods, hir_ty::TyLowerResult, mir::LowerResult, ) { - let db = hir::TestingDatabase::default(); + let db = hir_ty::TestingDatabase::default(); let mut source_db = FixtureDatabase::new(&db, fixture); let pods = hir::parse_pods(&db, "/main.nail", &mut source_db); From ac92250c15f6292ffe9171dcce294dce4e56937b Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 14:42:41 +0900 Subject: [PATCH 23/35] wip --- crates/hir_ty/src/checker.rs | 4 ++-- crates/hir_ty/src/lib.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/hir_ty/src/checker.rs b/crates/hir_ty/src/checker.rs index 6f238592..294d22fd 100644 --- a/crates/hir_ty/src/checker.rs +++ b/crates/hir_ty/src/checker.rs @@ -24,7 +24,7 @@ pub fn check_type_pods( #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum TypeCheckError { /// 型を解決できない - Unresolved { + UnresolvedType { /// 対象の式 expr: hir::ExprId, }, @@ -98,7 +98,7 @@ impl<'a> FunctionTypeChecker<'a> { fn check_expr(&mut self, expr: hir::ExprId) { let ty = self.current_inference().type_by_expr[&expr].clone(); if ty == Monotype::Unknown { - self.errors.push(TypeCheckError::Unresolved { expr }); + self.errors.push(TypeCheckError::UnresolvedType { expr }); } let expr = expr.lookup(self.hir_file.db(self.db)); diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 052344e6..a851fb1a 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -295,7 +295,7 @@ mod tests { .unwrap(); for error in type_check_errors { match error { - TypeCheckError::Unresolved { expr } => { + TypeCheckError::UnresolvedType { expr } => { msg.push_str(&format!( "error Type is unknown: expr: {}", self.debug_simplify_expr(hir_file, *expr), From 9e6f0713da069783420918ab11334b245ae60287 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 17:51:22 +0900 Subject: [PATCH 24/35] wip --- crates/hir_ty/src/inference.rs | 1 + crates/hir_ty/src/inference/environment.rs | 34 +++++++++++++++++---- crates/hir_ty/src/inference/error.rs | 3 ++ crates/hir_ty/src/inference/type_scheme.rs | 10 ++++++ crates/hir_ty/src/inference/type_unifier.rs | 12 ++++++-- crates/hir_ty/src/inference/types.rs | 7 +++-- 6 files changed, 55 insertions(+), 12 deletions(-) diff --git a/crates/hir_ty/src/inference.rs b/crates/hir_ty/src/inference.rs index ee70e53e..bf780da0 100644 --- a/crates/hir_ty/src/inference.rs +++ b/crates/hir_ty/src/inference.rs @@ -14,6 +14,7 @@ pub use types::Monotype; use crate::HirTyMasterDatabase; +/// Pod全体の型推論を行う pub fn infer_pods(db: &dyn HirTyMasterDatabase, pods: &hir::Pods) -> InferenceResult { let mut signature_by_function = HashMap::::new(); for (hir_file, function) in pods.root_pod.all_functions(db) { diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 9dd8a086..ba230937 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -11,13 +11,17 @@ use super::{ }; use crate::HirTyMasterDatabase; +/// 関数シグネチャ #[salsa::tracked] pub struct Signature { + /// 関数のパラメータ #[return_ref] pub params: Vec, + /// 関数の戻り値 pub return_type: Monotype, } +/// 関数内を型推論します。 pub(crate) struct InferBody<'a> { db: &'a dyn HirTyMasterDatabase, pods: &'a hir::Pods, @@ -28,8 +32,14 @@ pub(crate) struct InferBody<'a> { unifier: TypeUnifier, cxt: Context, + /// 環境のスタック + /// + /// スコープを入れ子にするために使用しています。 + /// スコープに入る時にpushし、スコープから抜ける時はpopします。 env_stack: Vec, signature_by_function: &'a HashMap, + + /// 推論結果の式の型を記録するためのマップ type_by_expr: HashMap, } impl<'a> InferBody<'a> { @@ -106,7 +116,7 @@ impl<'a> InferBody<'a> { hir::Stmt::VariableDef { name, value } => { let ty = self.infer_expr(*value); let ty_scheme = TypeScheme::new(ty); - self.mut_current_scope().bindings.insert(*name, ty_scheme); + self.mut_current_scope().insert(*name, ty_scheme); } hir::Stmt::ExprStmt { expr, @@ -125,7 +135,7 @@ impl<'a> InferBody<'a> { self.infer_type(¶m.ty) } hir::Symbol::Local { name, expr: _ } => { - let ty_scheme = self.current_scope().bindings.get(name).cloned(); + let ty_scheme = self.current_scope().get(name).cloned(); if let Some(ty_scheme) = ty_scheme { ty_scheme.instantiate(&mut self.cxt) } else { @@ -403,18 +413,21 @@ impl<'a> InferBody<'a> { } } +/// Pod全体の型推論結果 #[derive(Debug)] pub struct InferenceResult { pub signature_by_function: HashMap, pub inference_body_result_by_function: HashMap, } +/// 関数内の型推論結果 #[derive(Debug)] pub struct InferenceBodyResult { pub type_by_expr: HashMap, pub errors: Vec, } +/// Hindley-Milner型システムにおける型環境 #[derive(Default)] pub struct Environment { bindings: HashMap, @@ -444,16 +457,25 @@ impl Environment { fn with(&self) -> Environment { let mut copy = HashMap::::new(); + // FIXME: clone かつサイズが不定なので遅いかも。 copy.extend(self.bindings.clone()); Environment { bindings: copy } } + fn get(&self, name: &hir::Name) -> Option<&TypeScheme> { + self.bindings.get(name) + } + + fn insert(&mut self, name: hir::Name, ty_scheme: TypeScheme) { + self.bindings.insert(name, ty_scheme); + } + #[allow(dead_code)] fn generalize(&self, ty: &Monotype, db: &dyn HirTyMasterDatabase) -> TypeScheme { - TypeScheme { - variables: ty.free_variables(db).sub(&self.free_variables(db)), - ty: ty.clone(), - } + TypeScheme::new_with_variables( + ty.clone(), + ty.free_variables(db).sub(&self.free_variables(db)), + ) } } diff --git a/crates/hir_ty/src/inference/error.rs b/crates/hir_ty/src/inference/error.rs index 658d1b82..edf132d9 100644 --- a/crates/hir_ty/src/inference/error.rs +++ b/crates/hir_ty/src/inference/error.rs @@ -56,6 +56,7 @@ pub enum InferenceError { /// 実際の引数の位置(0-indexed) arg_pos: usize, }, + /// 数値をとる二項演算子の型が一致しない MismatchedBinaryInteger { /// 期待される型 expected_int_ty: Monotype, @@ -66,6 +67,7 @@ pub enum InferenceError { /// 演算子 op: ast::BinaryOp, }, + /// 比較演算子の型が一致しない MismatchedBinaryCompare { /// 期待される型 compare_from_ty: Monotype, @@ -78,6 +80,7 @@ pub enum InferenceError { /// 演算子 op: ast::BinaryOp, }, + /// 単項演算子の型が一致しない MismatchedUnary { /// 期待される型 expected_ty: Monotype, diff --git a/crates/hir_ty/src/inference/type_scheme.rs b/crates/hir_ty/src/inference/type_scheme.rs index 90fc77ea..4433a594 100644 --- a/crates/hir_ty/src/inference/type_scheme.rs +++ b/crates/hir_ty/src/inference/type_scheme.rs @@ -6,9 +6,15 @@ use std::{ use super::{environment::Context, types::Monotype}; use crate::HirTyMasterDatabase; +/// Hindley-Milner型推論における型スキーマ +/// +/// 型変数を持つことができる型のテンプレートのようなものです。 +/// 現時点(2023-09-16)では型変数をサポートしていないため使用していません。 #[derive(Clone)] pub struct TypeScheme { + /// `ty`が持つ型変数の集合 pub variables: HashSet, + /// 型 pub ty: Monotype, } @@ -20,6 +26,10 @@ impl TypeScheme { } } + pub fn new_with_variables(ty: Monotype, variables: HashSet) -> TypeScheme { + TypeScheme { variables, ty } + } + #[allow(dead_code)] pub fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { self.ty diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index efe51e19..2e1eedfd 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use super::{error::InferenceError, types::Monotype, Signature}; +/// 型の統一を行うための構造体 #[derive(Default, Debug)] pub(crate) struct TypeUnifier { pub(crate) nodes: HashMap, @@ -140,11 +141,11 @@ fn build_unify_error_from_unify_purpose( } impl TypeUnifier { - pub fn new() -> Self { + pub(crate) fn new() -> Self { Default::default() } - pub fn find(&mut self, ty: &Monotype) -> Monotype { + pub(crate) fn find(&mut self, ty: &Monotype) -> Monotype { let node = self.nodes.get(ty); if let Some(node) = node { node.topmost_parent().value @@ -154,7 +155,12 @@ impl TypeUnifier { } } - pub fn unify(&mut self, a_expected: &Monotype, b_actual: &Monotype, purpose: &UnifyPurpose) { + pub(crate) fn unify( + &mut self, + a_expected: &Monotype, + b_actual: &Monotype, + purpose: &UnifyPurpose, + ) { let a_rep = self.find(a_expected); let b_rep = self.find(b_actual); diff --git a/crates/hir_ty/src/inference/types.rs b/crates/hir_ty/src/inference/types.rs index aa92d7e3..ef29582d 100644 --- a/crates/hir_ty/src/inference/types.rs +++ b/crates/hir_ty/src/inference/types.rs @@ -3,6 +3,7 @@ use std::collections::HashSet; use super::{environment::Context, type_scheme::TypeSubstitution, Signature}; use crate::HirTyMasterDatabase; +/// 単一の型 #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Monotype { Integer, @@ -17,13 +18,13 @@ pub enum Monotype { } impl Monotype { - pub fn gen_variable(cxt: &mut Context) -> Self { + pub(crate) fn gen_variable(cxt: &mut Context) -> Self { let monotype = Self::Variable(cxt.gen_counter); cxt.gen_counter += 1; monotype } - pub fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { + pub(crate) fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { match self { Monotype::Variable(id) => { let mut set = HashSet::new(); @@ -43,7 +44,7 @@ impl Monotype { } } - pub fn apply(&self, subst: &TypeSubstitution) -> Monotype { + pub(crate) fn apply(&self, subst: &TypeSubstitution) -> Monotype { match self { Monotype::Integer | Monotype::Bool From 0281a431ccc8186d011dea62d90b4ed781ce926f Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 19:04:02 +0900 Subject: [PATCH 25/35] wip --- crates/hir_ty/src/inference/environment.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index ba230937..66a3900d 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -113,10 +113,10 @@ impl<'a> InferBody<'a> { fn infer_stmt(&mut self, stmt: &hir::Stmt) { match stmt { - hir::Stmt::VariableDef { name, value } => { + hir::Stmt::VariableDef { name: _, value } => { let ty = self.infer_expr(*value); let ty_scheme = TypeScheme::new(ty); - self.mut_current_scope().insert(*name, ty_scheme); + self.mut_current_scope().insert(*value, ty_scheme); } hir::Stmt::ExprStmt { expr, @@ -134,8 +134,8 @@ impl<'a> InferBody<'a> { let param = param.data(self.hir_file.db(self.db)); self.infer_type(¶m.ty) } - hir::Symbol::Local { name, expr: _ } => { - let ty_scheme = self.current_scope().get(name).cloned(); + hir::Symbol::Local { name: _, expr } => { + let ty_scheme = self.current_scope().get(expr).cloned(); if let Some(ty_scheme) = ty_scheme { ty_scheme.instantiate(&mut self.cxt) } else { @@ -430,7 +430,7 @@ pub struct InferenceBodyResult { /// Hindley-Milner型システムにおける型環境 #[derive(Default)] pub struct Environment { - bindings: HashMap, + bindings: HashMap, } #[derive(Default)] @@ -456,19 +456,19 @@ impl Environment { } fn with(&self) -> Environment { - let mut copy = HashMap::::new(); + let mut copy = HashMap::new(); // FIXME: clone かつサイズが不定なので遅いかも。 copy.extend(self.bindings.clone()); Environment { bindings: copy } } - fn get(&self, name: &hir::Name) -> Option<&TypeScheme> { - self.bindings.get(name) + fn get(&self, expr: &hir::ExprId) -> Option<&TypeScheme> { + self.bindings.get(expr) } - fn insert(&mut self, name: hir::Name, ty_scheme: TypeScheme) { - self.bindings.insert(name, ty_scheme); + fn insert(&mut self, expr: hir::ExprId, ty_scheme: TypeScheme) { + self.bindings.insert(expr, ty_scheme); } #[allow(dead_code)] From e33ecdc7e65eb1c2eae922e57b8aad7b116af8bb Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 19:30:27 +0900 Subject: [PATCH 26/35] wip --- crates/hir_ty/src/inference/environment.rs | 24 +++++++++++++++++++--- crates/hir_ty/src/inference/type_scheme.rs | 15 ++++++++------ crates/hir_ty/src/inference/types.rs | 14 +++++++------ 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 66a3900d..aa9aaba0 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -433,9 +433,27 @@ pub struct Environment { bindings: HashMap, } +/// 型変数のID +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct VariableId(u32); +impl std::fmt::Display for VariableId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "_{}", self.0) + } +} + #[derive(Default)] pub struct Context { - pub gen_counter: u32, + gen_counter: u32, +} +impl Context { + /// IDを生成します。IDは生成時に自動的にインクリメントされます。 + pub(crate) fn gen_id(&mut self) -> VariableId { + let id = self.gen_counter; + self.gen_counter += 1; + + VariableId(id) + } } impl Environment { @@ -446,8 +464,8 @@ impl Environment { } #[allow(dead_code)] - fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { - let mut union = HashSet::::new(); + fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { + let mut union = HashSet::new(); for type_scheme in self.bindings.values() { union.extend(type_scheme.free_variables(db)); } diff --git a/crates/hir_ty/src/inference/type_scheme.rs b/crates/hir_ty/src/inference/type_scheme.rs index 4433a594..f77cc63e 100644 --- a/crates/hir_ty/src/inference/type_scheme.rs +++ b/crates/hir_ty/src/inference/type_scheme.rs @@ -3,7 +3,10 @@ use std::{ iter::FromIterator, }; -use super::{environment::Context, types::Monotype}; +use super::{ + environment::{Context, VariableId}, + types::Monotype, +}; use crate::HirTyMasterDatabase; /// Hindley-Milner型推論における型スキーマ @@ -13,7 +16,7 @@ use crate::HirTyMasterDatabase; #[derive(Clone)] pub struct TypeScheme { /// `ty`が持つ型変数の集合 - pub variables: HashSet, + pub variables: HashSet, /// 型 pub ty: Monotype, } @@ -26,12 +29,12 @@ impl TypeScheme { } } - pub fn new_with_variables(ty: Monotype, variables: HashSet) -> TypeScheme { + pub fn new_with_variables(ty: Monotype, variables: HashSet) -> TypeScheme { TypeScheme { variables, ty } } #[allow(dead_code)] - pub fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { + pub fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { self.ty .free_variables(db) .into_iter() @@ -56,11 +59,11 @@ impl TypeScheme { #[derive(Default)] pub struct TypeSubstitution { - pub replacements: HashMap, + pub replacements: HashMap, } impl TypeSubstitution { - pub fn lookup(&self, id: u32) -> Option { + pub fn lookup(&self, id: VariableId) -> Option { self.replacements.get(&id).cloned() } } diff --git a/crates/hir_ty/src/inference/types.rs b/crates/hir_ty/src/inference/types.rs index ef29582d..daf2f321 100644 --- a/crates/hir_ty/src/inference/types.rs +++ b/crates/hir_ty/src/inference/types.rs @@ -1,6 +1,10 @@ use std::collections::HashSet; -use super::{environment::Context, type_scheme::TypeSubstitution, Signature}; +use super::{ + environment::{Context, VariableId}, + type_scheme::TypeSubstitution, + Signature, +}; use crate::HirTyMasterDatabase; /// 単一の型 @@ -11,7 +15,7 @@ pub enum Monotype { Unit, Char, String, - Variable(u32), + Variable(VariableId), Function(Signature), Never, Unknown, @@ -19,12 +23,10 @@ pub enum Monotype { impl Monotype { pub(crate) fn gen_variable(cxt: &mut Context) -> Self { - let monotype = Self::Variable(cxt.gen_counter); - cxt.gen_counter += 1; - monotype + Monotype::Variable(cxt.gen_id()) } - pub(crate) fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { + pub(crate) fn free_variables(&self, db: &dyn HirTyMasterDatabase) -> HashSet { match self { Monotype::Variable(id) => { let mut set = HashSet::new(); From 14f1387b681df69adda6213673bf58da4fb82e54 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 19:34:54 +0900 Subject: [PATCH 27/35] wip --- crates/hir_ty/src/inference/environment.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index aa9aaba0..8e38c719 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -438,7 +438,7 @@ pub struct Environment { pub struct VariableId(u32); impl std::fmt::Display for VariableId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "_{}", self.0) + write!(f, "{}", self.0) } } From 518da8f1437568d07a43c5c1c370dd09aa97b670 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 19:48:05 +0900 Subject: [PATCH 28/35] wip --- crates/hir_ty/src/inference/environment.rs | 17 +++++++++++++---- crates/hir_ty/src/lib.rs | 7 +++++-- crates/mir/src/body.rs | 15 +++++++++++---- 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 8e38c719..872575ef 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -416,15 +416,24 @@ impl<'a> InferBody<'a> { /// Pod全体の型推論結果 #[derive(Debug)] pub struct InferenceResult { - pub signature_by_function: HashMap, - pub inference_body_result_by_function: HashMap, + pub(crate) signature_by_function: HashMap, + pub(crate) inference_body_result_by_function: HashMap, } /// 関数内の型推論結果 #[derive(Debug)] pub struct InferenceBodyResult { - pub type_by_expr: HashMap, - pub errors: Vec, + pub(crate) type_by_expr: HashMap, + pub(crate) errors: Vec, +} +impl InferenceBodyResult { + pub fn type_by_expr(&self, expr: hir::ExprId) -> Option<&Monotype> { + self.type_by_expr.get(&expr) + } + + pub fn errors(&self) -> &Vec { + &self.errors + } } /// Hindley-Milner型システムにおける型環境 diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index a851fb1a..756c8a46 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -45,8 +45,11 @@ pub struct TyLowerResult { } impl TyLowerResult { /// 指定した関数の型を取得します。 - pub fn signature_by_function(&self, function_id: hir::Function) -> Signature { - self.inference_result.signature_by_function[&function_id] + pub fn signature_by_function(&self, function_id: hir::Function) -> Option { + self.inference_result + .signature_by_function + .get(&function_id) + .copied() } /// 指定した関数の型推論結果を取得します。 diff --git a/crates/mir/src/body.rs b/crates/mir/src/body.rs index 2da9f863..6e1443ea 100644 --- a/crates/mir/src/body.rs +++ b/crates/mir/src/body.rs @@ -39,7 +39,7 @@ impl<'a> FunctionLower<'a> { function: hir::Function, ) -> Self { let mut locals = Arena::new(); - let signature = hir_ty_result.signature_by_function(function); + let signature = hir_ty_result.signature_by_function(function).unwrap(); let return_local = locals.alloc(Local { ty: signature.return_type(db), idx: 0, @@ -88,7 +88,10 @@ impl<'a> FunctionLower<'a> { } fn alloc_local(&mut self, expr: hir::ExprId) -> Idx { - let ty = &self.get_inference_by_function(self.function).type_by_expr[&expr]; + let ty = self + .get_inference_by_function(self.function) + .type_by_expr(expr) + .unwrap(); let local_idx = self.alloc_local_by_ty(ty.clone()); self.local_by_hir.insert(expr, local_idx); @@ -458,7 +461,8 @@ impl<'a> FunctionLower<'a> { hir::Item::Function(function) => { let function_id = self.function_by_hir_function[&function]; - let signature = self.hir_ty_result.signature_by_function(function); + let signature = + self.hir_ty_result.signature_by_function(function).unwrap(); let called_local = self.alloc_local_by_ty(signature.return_type(self.db)); let dest_place = Place::Local(called_local); @@ -519,7 +523,10 @@ impl<'a> FunctionLower<'a> { } pub(crate) fn lower(mut self) -> Body { - let signature = self.hir_ty_result.signature_by_function(self.function); + let signature = self + .hir_ty_result + .signature_by_function(self.function) + .unwrap(); for (param, param_ty) in self .function .params(self.db) From d41ec07fcda6d413ba42aece7cea628f89045348 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 20:18:56 +0900 Subject: [PATCH 29/35] wip --- crates/hir_ty/src/checker.rs | 4 +- crates/hir_ty/src/inference/environment.rs | 21 ++++-- crates/hir_ty/src/inference/error.rs | 14 ++++ crates/hir_ty/src/inference/type_unifier.rs | 7 ++ crates/hir_ty/src/lib.rs | 77 +++++++++++++++++++++ 5 files changed, 114 insertions(+), 9 deletions(-) diff --git a/crates/hir_ty/src/checker.rs b/crates/hir_ty/src/checker.rs index 294d22fd..7a05f486 100644 --- a/crates/hir_ty/src/checker.rs +++ b/crates/hir_ty/src/checker.rs @@ -96,8 +96,8 @@ impl<'a> FunctionTypeChecker<'a> { } fn check_expr(&mut self, expr: hir::ExprId) { - let ty = self.current_inference().type_by_expr[&expr].clone(); - if ty == Monotype::Unknown { + let Some(ty) = self.current_inference().type_by_expr(expr) else { return }; + if ty == &Monotype::Unknown { self.errors.push(TypeCheckError::UnresolvedType { expr }); } diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 872575ef..0880783c 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -189,8 +189,8 @@ impl<'a> InferBody<'a> { callee, args: call_args, } => { - let ty = self.infer_symbol(callee); - match ty { + let callee_ty = self.infer_symbol(callee); + match callee_ty { Monotype::Integer | Monotype::Bool | Monotype::Unit @@ -199,7 +199,10 @@ impl<'a> InferBody<'a> { | Monotype::Never | Monotype::Unknown | Monotype::Variable(_) => { - // TODO: 関数ではないものを呼び出そうとしているエラーを追加 + self.unifier.add_error(InferenceError::NotCallable { + found_callee_ty: callee_ty, + found_callee_symbol: callee.clone(), + }); Monotype::Unknown } Monotype::Function(signature) => { @@ -209,8 +212,11 @@ impl<'a> InferBody<'a> { .collect::>(); if call_args_ty.len() != signature.params(self.db).len() { - // TODO: 引数の数が異なるエラーを追加 - Monotype::Unknown + self.unifier + .add_error(InferenceError::MismatchedCallArgCount { + expected_callee_arg_count: signature.params(self.db).len(), + found_arg_count: call_args_ty.len(), + }); } else { for (((arg_pos, call_arg), call_arg_ty), signature_arg_ty) in call_args .iter() @@ -228,9 +234,10 @@ impl<'a> InferBody<'a> { }, ); } - - signature.return_type(self.db) } + + // 引数の数が異なったとしても、関数の戻り値は返す。 + signature.return_type(self.db) } } } diff --git a/crates/hir_ty/src/inference/error.rs b/crates/hir_ty/src/inference/error.rs index edf132d9..f57401f6 100644 --- a/crates/hir_ty/src/inference/error.rs +++ b/crates/hir_ty/src/inference/error.rs @@ -104,4 +104,18 @@ pub enum InferenceError { /// 実際の式 found_return_expr: Option, }, + /// 呼び出しの引数の数が一致しない + MismatchedCallArgCount { + /// 期待される引数の数 + expected_callee_arg_count: usize, + /// 実際の引数の数 + found_arg_count: usize, + }, + /// 呼び出そうとしている対象が関数ではない + NotCallable { + /// 呼び出し対象の型 + found_callee_ty: Monotype, + /// 呼び出し対象のシンボル + found_callee_symbol: hir::Symbol, + }, } diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index 2e1eedfd..c8b734a5 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -155,6 +155,13 @@ impl TypeUnifier { } } + /// エラーを追加します。 + /// + /// 引数の数の違いや呼び出し対象の型が異なる等のUnify外を収集するために使用されます。 + pub(crate) fn add_error(&mut self, error: InferenceError) { + self.errors.push(error); + } + pub(crate) fn unify( &mut self, a_expected: &Monotype, diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 756c8a46..9be78ce9 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -285,6 +285,26 @@ mod tests { )); } } + InferenceError::MismatchedCallArgCount { + expected_callee_arg_count, + found_arg_count, + } => { + msg.push_str(&format!( + "error MismatchedCallArgCount: expected_arg_count: {}, found_arg_count: {}", + expected_callee_arg_count, + found_arg_count, + )); + } + InferenceError::NotCallable { + found_callee_ty, + found_callee_symbol, + } => { + msg.push_str(&format!( + "error NotCallable: found_callee_ty: {}, found_callee_symbol: {}", + self.debug_monotype(found_callee_ty), + self.debug_symbol(found_callee_symbol), + )); + } } msg.push('\n'); } @@ -1876,4 +1896,61 @@ mod tests { "#]], ); } + + #[test] + fn call_callable_expr() { + check_in_root_file( + r#" + fn main() { + let a = 10; + a(); + } + "#, + expect![[r#" + //- /main.nail + fn entry:main() -> () { + let a = 10; //: int + a(); //: + } + + --- + error NotCallable: found_callee_ty: int, found_callee_symbol: a + --- + error Type is unknown: expr: a() + "#]], + ); + } + + #[test] + fn mismatched_arg_len() { + check_pod_start_with_root_file( + r#" + //- /main.nail + fn main() { + callee(10); + callee(10, "a"); + callee(10, "a", 30); + } + fn callee(a: int, b: string) -> int { + 10 + } + "#, + expect![[r#" + //- /main.nail + fn entry:main() -> () { + fn:callee(10); //: int + fn:callee(10, "a"); //: int + fn:callee(10, "a", 30); //: int + } + fn callee(a: int, b: string) -> int { + expr:10 //: int + } + + --- + error MismatchedCallArgCount: expected_arg_count: 2, found_arg_count: 1 + error MismatchedCallArgCount: expected_arg_count: 2, found_arg_count: 3 + --- + "#]], + ); + } } From 7ede097b74a382599a4b9f9cfd53613fa7bad94f Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 21:07:15 +0900 Subject: [PATCH 30/35] wip --- crates/hir_ty/src/inference/environment.rs | 56 ++++++----- crates/hir_ty/src/inference/error.rs | 5 + crates/hir_ty/src/lib.rs | 104 ++++++++++++++++++++- 3 files changed, 137 insertions(+), 28 deletions(-) diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 0880783c..ef7393f0 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -143,37 +143,43 @@ impl<'a> InferBody<'a> { } } hir::Symbol::Missing { path } => { - let item = self.pods.resolution_map.item_by_symbol(path).unwrap(); - match item { - hir::ResolutionStatus::Unresolved | hir::ResolutionStatus::Error => { - // 解決できないエラーを追加 - Monotype::Unknown - } - hir::ResolutionStatus::Resolved { path: _, item } => { - match item { - hir::Item::Function(function) => { - let signature = self.signature_by_function.get(&function); - if let Some(signature) = signature { - Monotype::Function(*signature) - } else { - unreachable!("Function signature should be resolved.") - } - } - hir::Item::Module(_) => { - // モジュールを型推論使用としているエラーを追加 - Monotype::Unknown - } - hir::Item::UseItem(_) => { - // 使用宣言を型推論使用としているエラーを追加 - Monotype::Unknown - } + let resolution_status = self.pods.resolution_map.item_by_symbol(path).unwrap(); + match self.resolve_resolution_status(resolution_status) { + Some(item) => match item { + hir::Item::Function(function) => { + let signature = self.signature_by_function.get(&function).unwrap(); + Monotype::Function(*signature) } - } + hir::Item::Module(module) => { + self.unifier.add_error(InferenceError::ModuleAsExpr { + found_module: module, + }); + Monotype::Unknown + } + hir::Item::UseItem(_) => unreachable!("UseItem should be resolved."), + }, + None => Monotype::Unknown, } } } } + fn resolve_resolution_status( + &self, + resolution_status: hir::ResolutionStatus, + ) -> Option { + match resolution_status { + hir::ResolutionStatus::Unresolved | hir::ResolutionStatus::Error => None, + hir::ResolutionStatus::Resolved { path: _, item } => match item { + hir::Item::Function(_) | hir::Item::Module(_) => Some(item), + hir::Item::UseItem(use_item) => { + let resolution_status = self.pods.resolution_map.item_by_use_item(&use_item)?; + self.resolve_resolution_status(resolution_status) + } + }, + } + } + fn infer_expr(&mut self, expr_id: hir::ExprId) -> Monotype { let expr = expr_id.lookup(self.hir_file.db(self.db)); let ty = match expr { diff --git a/crates/hir_ty/src/inference/error.rs b/crates/hir_ty/src/inference/error.rs index f57401f6..4a92f2ba 100644 --- a/crates/hir_ty/src/inference/error.rs +++ b/crates/hir_ty/src/inference/error.rs @@ -118,4 +118,9 @@ pub enum InferenceError { /// 呼び出し対象のシンボル found_callee_symbol: hir::Symbol, }, + /// モジュールが式として型推論されようとしている + ModuleAsExpr { + /// 実際のモジュール + found_module: hir::Module, + }, } diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 9be78ce9..aaa9c998 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -305,6 +305,12 @@ mod tests { self.debug_symbol(found_callee_symbol), )); } + InferenceError::ModuleAsExpr { found_module } => { + msg.push_str(&format!( + "error ModuleAsExpr: found_module: {}", + found_module.name(self.db).text(self.db) + )); + } } msg.push('\n'); } @@ -460,7 +466,11 @@ mod tests { let path_name = self.debug_path(use_item.path(self.db)); let item_name = use_item.name(self.db).text(self.db); - format!("{path_name}::{item_name}") + if path_name.is_empty() { + format!("use {item_name};\n") + } else { + format!("use {path_name}::{item_name};\n") + } } fn debug_item(&self, hir_file: hir::HirFile, item: hir::Item, nesting: usize) -> String { @@ -761,8 +771,17 @@ mod tests { hir::Item::Module(_) => { format!("mod:{path}") } - hir::Item::UseItem(_) => { - unreachable!() + hir::Item::UseItem(use_item) => { + let item = self + .pods + .resolution_map + .item_by_use_item(&use_item) + .unwrap(); + format!( + "{}::{}", + self.debug_resolution_status(item), + use_item.name(self.db).text(self.db) + ) } } } @@ -1953,4 +1972,83 @@ mod tests { "#]], ); } + + #[test] + fn symbol_to_module() { + check_pod_start_with_root_file( + r#" + //- /main.nail + fn main() { + aaa; + aaa::bbb; + } + + mod aaa { + mod bbb { + } + } + "#, + expect![[r#" + //- /main.nail + fn entry:main() -> () { + mod:aaa; //: + mod:aaa::bbb; //: + } + mod aaa { + mod bbb { + } + } + + --- + error ModuleAsExpr: found_module: aaa + error ModuleAsExpr: found_module: bbb + --- + error Type is unknown: expr: mod:aaa + error Type is unknown: expr: mod:aaa::bbb + "#]], + ); + } + + /// unimplements hir + #[test] + fn symbol_to_use_item() { + check_pod_start_with_root_file( + r#" + //- /main.nail + mod aaa; + use aaa:bbb; + fn main() -> int { + bbb() + } + + //- /aaa.nail + mod aaa { + fn bbb() -> int { + 10 + } + } + "#, + expect![[r#" + //- /main.nail + mod aaa; + use aaa; + fn entry:main() -> int { + expr:() //: + } + + //- /aaa.nail + mod aaa { + fn bbb() -> int { + expr:10 //: int + } + } + + --- + error NotCallable: found_callee_ty: , found_callee_symbol: + error MismatchedReturnType: expected_ty: int, found_ty: , found_expr: () + --- + error Type is unknown: expr: () + "#]], + ); + } } From 7617ff3f682c59f1b6ce80b3e2cb9c5e382e24f3 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 21:44:26 +0900 Subject: [PATCH 31/35] wip --- crates/hir_ty/src/lib.rs | 98 ++++++++++++++++++++-------------------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index aaa9c998..1fa251ca 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -152,7 +152,7 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedTypes: expected_ty: {}, found_ty: {}, expected_expr: {}, found_expr: {}", + "error MismatchedTypes: expected_ty: {}, found_ty: {}, expected_expr: `{}`, found_expr: `{}`", self.debug_monotype(expected_ty), self.debug_monotype(found_ty), self.debug_simplify_expr(hir_file, *expected_expr), @@ -166,7 +166,7 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedTypeIfCondition: expected_ty: {}, found_ty: {}, found_expr: {}", + "error MismatchedTypeIfCondition: expected_ty: {}, found_ty: {}, found_expr: `{}`", self.debug_monotype(expected_condition_bool_ty), self.debug_monotype(found_condition_ty), self.debug_simplify_expr(hir_file, *found_condition_expr), @@ -180,7 +180,7 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedTypeElseBranch: then_branch_ty: {}, else_branch_ty: {}, then_branch: {}, else_branch: {}", + "error MismatchedTypeElseBranch: then_branch_ty: {}, else_branch_ty: {}, then_branch: `{}`, else_branch: `{}`", self.debug_monotype(then_branch_ty), self.debug_monotype(else_branch_ty), self.debug_simplify_expr(hir_file, *then_branch), @@ -194,7 +194,7 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedTypeOnlyIfBranch: then_branch_ty: {}, else_branch_ty: {}, then_branch: {}", + "error MismatchedTypeOnlyIfBranch: then_branch_ty: {}, else_branch_ty: {}, then_branch: `{}`", self.debug_monotype(then_branch_ty), self.debug_monotype(else_branch_unit_ty), self.debug_simplify_expr(hir_file, *then_branch), @@ -209,7 +209,7 @@ mod tests { } => { msg.push_str( &format!( - "error MismaatchedSignature: expected_ty: {}, found_ty: {}, found_expr: {}, signature: {}, arg_pos: {}", + "error MismaatchedSignature: expected_ty: {}, found_ty: {}, found_expr: `{}`, signature: {}, arg_pos: {}", self.debug_monotype(expected_ty), self.debug_monotype(found_ty), self.debug_simplify_expr(hir_file, *found_expr), @@ -225,7 +225,7 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedBinaryInteger: op: {}, expected_ty: {}, found_ty: {}, found_expr: {}", + "error MismatchedBinaryInteger: op: {}, expected_ty: {}, found_ty: {}, found_expr: `{}`", self.debug_binary_op(op), self.debug_monotype(expected_int_ty), self.debug_monotype(found_ty), @@ -241,7 +241,7 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedBinaryCompare: op: {}, expected_ty: {}, found_ty: {}, expected_expr: {}, found_expr: {}", + "error MismatchedBinaryCompare: op: {}, expected_ty: {}, found_ty: {}, expected_expr: `{}`, found_expr: `{}`", self.debug_binary_op(op), self.debug_monotype(compare_from_ty), self.debug_monotype(compare_to_ty), @@ -257,7 +257,7 @@ mod tests { } => { msg.push_str( &format!( - "error MismatchedUnary: op: {}, expected_ty: {}, found_ty: {}, found_expr: {}", + "error MismatchedUnary: op: {}, expected_ty: {}, found_ty: {}, found_expr: `{}`", self.debug_unary_op(op), self.debug_monotype(expected_ty), self.debug_monotype(found_ty), @@ -272,7 +272,7 @@ mod tests { if let Some(found_return_expr) = found_return_expr { msg.push_str( &format!( - "error MismatchedReturnType: expected_ty: {}, found_ty: {}, found_expr: {}", + "error MismatchedReturnType: expected_ty: {}, found_ty: {}, found_expr: `{}`", self.debug_monotype(&expected_signature.return_type(self.db)), self.debug_monotype(found_ty), self.debug_simplify_expr(hir_file, *found_return_expr), @@ -1039,24 +1039,24 @@ mod tests { } --- - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "aaa" - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "bbb" - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "aaa" - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: char, found_expr: 'a' - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: char, found_expr: 'a' - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: char, found_expr: 'a' - error MismatchedBinaryCompare: op: <, expected_ty: int, found_ty: char, expected_expr: 10, found_expr: 'a' - error MismatchedBinaryCompare: op: >, expected_ty: int, found_ty: char, expected_expr: 10, found_expr: 'a' - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: bool, found_expr: true - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: bool, found_expr: true - error MismatchedBinaryInteger: op: -, expected_ty: int, found_ty: bool, found_expr: true - error MismatchedBinaryInteger: op: -, expected_ty: int, found_ty: bool, found_expr: true - error MismatchedBinaryInteger: op: *, expected_ty: int, found_ty: bool, found_expr: true - error MismatchedBinaryInteger: op: *, expected_ty: int, found_ty: bool, found_expr: true - error MismatchedBinaryInteger: op: /, expected_ty: int, found_ty: bool, found_expr: true - error MismatchedBinaryInteger: op: /, expected_ty: int, found_ty: bool, found_expr: true - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: bool, found_expr: true - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "aaa" + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: `"aaa"` + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: `"bbb"` + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: `"aaa"` + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: char, found_expr: `'a'` + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: char, found_expr: `'a'` + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: char, found_expr: `'a'` + error MismatchedBinaryCompare: op: <, expected_ty: int, found_ty: char, expected_expr: `10`, found_expr: `'a'` + error MismatchedBinaryCompare: op: >, expected_ty: int, found_ty: char, expected_expr: `10`, found_expr: `'a'` + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: bool, found_expr: `true` + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: bool, found_expr: `true` + error MismatchedBinaryInteger: op: -, expected_ty: int, found_ty: bool, found_expr: `true` + error MismatchedBinaryInteger: op: -, expected_ty: int, found_ty: bool, found_expr: `true` + error MismatchedBinaryInteger: op: *, expected_ty: int, found_ty: bool, found_expr: `true` + error MismatchedBinaryInteger: op: *, expected_ty: int, found_ty: bool, found_expr: `true` + error MismatchedBinaryInteger: op: /, expected_ty: int, found_ty: bool, found_expr: `true` + error MismatchedBinaryInteger: op: /, expected_ty: int, found_ty: bool, found_expr: `true` + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: bool, found_expr: `true` + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: `"aaa"` --- "#]], ); @@ -1074,8 +1074,8 @@ mod tests { } --- - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "aaa" - error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: "aaa" + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: `"aaa"` + error MismatchedBinaryInteger: op: +, expected_ty: int, found_ty: string, found_expr: `"aaa"` --- "#]], ); @@ -1111,12 +1111,12 @@ mod tests { } --- - error MismatchedUnary: op: -, expected_ty: int, found_ty: string, found_expr: "aaa" - error MismatchedUnary: op: -, expected_ty: int, found_ty: char, found_expr: 'a' - error MismatchedUnary: op: -, expected_ty: int, found_ty: bool, found_expr: true - error MismatchedUnary: op: !, expected_ty: bool, found_ty: int, found_expr: 10 - error MismatchedUnary: op: !, expected_ty: bool, found_ty: string, found_expr: "aaa" - error MismatchedUnary: op: !, expected_ty: bool, found_ty: char, found_expr: 'a' + error MismatchedUnary: op: -, expected_ty: int, found_ty: string, found_expr: `"aaa"` + error MismatchedUnary: op: -, expected_ty: int, found_ty: char, found_expr: `'a'` + error MismatchedUnary: op: -, expected_ty: int, found_ty: bool, found_expr: `true` + error MismatchedUnary: op: !, expected_ty: bool, found_ty: int, found_expr: `10` + error MismatchedUnary: op: !, expected_ty: bool, found_ty: string, found_expr: `"aaa"` + error MismatchedUnary: op: !, expected_ty: bool, found_ty: char, found_expr: `'a'` --- "#]], ) @@ -1261,7 +1261,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: 10 + error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: `10` --- "#]], ); @@ -1334,7 +1334,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: { tail:{ tail:10 } } + error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: `{ tail:{ tail:10 } }` --- "#]], ); @@ -1457,7 +1457,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: res + 30 + error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: `res + 30` --- "#]], ); @@ -1484,8 +1484,8 @@ mod tests { } --- - error MismaatchedSignature: expected_ty: bool, found_ty: string, found_expr: "aaa", signature: (bool, string) -> int, arg_pos: 0 - error MismaatchedSignature: expected_ty: string, found_ty: bool, found_expr: true, signature: (bool, string) -> int, arg_pos: 1 + error MismaatchedSignature: expected_ty: bool, found_ty: string, found_expr: `"aaa"`, signature: (bool, string) -> int, arg_pos: 0 + error MismaatchedSignature: expected_ty: string, found_ty: bool, found_expr: `true`, signature: (bool, string) -> int, arg_pos: 1 --- "#]], ); @@ -1538,7 +1538,7 @@ mod tests { } --- - error MismatchedTypeOnlyIfBranch: then_branch_ty: int, else_branch_ty: (), then_branch: { tail:10 } + error MismatchedTypeOnlyIfBranch: then_branch_ty: int, else_branch_ty: (), then_branch: `{ tail:10 }` --- "#]], ); @@ -1610,8 +1610,8 @@ mod tests { } --- - error MismatchedTypeElseBranch: then_branch_ty: int, else_branch_ty: string, then_branch: { tail:10 }, else_branch: { tail:"aaa" } - error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: if true { tail:10 } else { tail:"aaa" } + error MismatchedTypeElseBranch: then_branch_ty: int, else_branch_ty: string, then_branch: `{ tail:10 }`, else_branch: `{ tail:"aaa" }` + error MismatchedReturnType: expected_ty: (), found_ty: int, found_expr: `if true { tail:10 } else { tail:"aaa" }` --- "#]], ); @@ -1640,8 +1640,8 @@ mod tests { } --- - error MismatchedTypeIfCondition: expected_ty: bool, found_ty: int, found_expr: 10 - error MismatchedReturnType: expected_ty: (), found_ty: string, found_expr: if 10 { tail:"aaa" } else { tail:"aaa" } + error MismatchedTypeIfCondition: expected_ty: bool, found_ty: int, found_expr: `10` + error MismatchedReturnType: expected_ty: (), found_ty: string, found_expr: `if 10 { tail:"aaa" } else { tail:"aaa" }` --- "#]], ); @@ -1674,7 +1674,7 @@ mod tests { } --- - error MismatchedTypeElseBranch: then_branch_ty: (), else_branch_ty: bool, then_branch: { tail:none }, else_branch: { tail:true } + error MismatchedTypeElseBranch: then_branch_ty: (), else_branch_ty: bool, then_branch: `{ tail:none }`, else_branch: `{ tail:true }` --- "#]], ); @@ -1704,7 +1704,7 @@ mod tests { } --- - error MismatchedTypeElseBranch: then_branch_ty: bool, else_branch_ty: (), then_branch: { tail:true }, else_branch: { tail:none } + error MismatchedTypeElseBranch: then_branch_ty: bool, else_branch_ty: (), then_branch: `{ tail:true }`, else_branch: `{ tail:none }` --- "#]], ); @@ -1809,7 +1809,7 @@ mod tests { } --- - error MismatchedReturnType: expected_ty: int, found_ty: string, found_expr: "aaa" + error MismatchedReturnType: expected_ty: int, found_ty: string, found_expr: `"aaa"` --- "#]], ); @@ -2045,7 +2045,7 @@ mod tests { --- error NotCallable: found_callee_ty: , found_callee_symbol: - error MismatchedReturnType: expected_ty: int, found_ty: , found_expr: () + error MismatchedReturnType: expected_ty: int, found_ty: , found_expr: `()` --- error Type is unknown: expr: () "#]], From 5d3560767f643c5cdad1276c6d916b4c307f0d07 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sat, 16 Sep 2023 23:40:01 +0900 Subject: [PATCH 32/35] wip --- crates/codegen_llvm/src/body.rs | 2 +- crates/hir_ty/src/inference/environment.rs | 63 ++++++++++++++++----- crates/hir_ty/src/inference/type_unifier.rs | 4 +- crates/hir_ty/src/lib.rs | 31 ++++++++-- crates/mir/src/lib.rs | 7 ++- 5 files changed, 81 insertions(+), 26 deletions(-) diff --git a/crates/codegen_llvm/src/body.rs b/crates/codegen_llvm/src/body.rs index dc551962..cfdfcff1 100644 --- a/crates/codegen_llvm/src/body.rs +++ b/crates/codegen_llvm/src/body.rs @@ -46,7 +46,7 @@ impl<'a, 'ctx> BodyCodegen<'a, 'ctx> { hir_ty::Monotype::String => self.codegen.string_type().into(), hir_ty::Monotype::Bool => self.codegen.bool_type().into(), hir_ty::Monotype::Char => todo!(), - hir_ty::Monotype::Never => todo!(), + hir_ty::Monotype::Never => self.codegen.unit_type().into(), hir_ty::Monotype::Function(_) => todo!(), hir_ty::Monotype::Variable(_) => unreachable!(""), hir_ty::Monotype::Unknown => unreachable!(), diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index ef7393f0..78f33b5f 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -68,12 +68,17 @@ impl<'a> InferBody<'a> { pub(crate) fn infer_body(mut self) -> InferenceBodyResult { let hir::Expr::Block(body) = self.hir_file.function_body_by_function(self.db, self.function).unwrap() else { panic!("Should be Block.") }; + let mut has_never = false; for stmt in &body.stmts { - self.infer_stmt(stmt); + let is_never = self.infer_stmt(stmt); + if is_never { + has_never = true; + } } if let Some(tail) = &body.tail { let ty = self.infer_expr(*tail); + let ty = if has_never { Monotype::Never } else { ty }; self.unifier.unify( &self.signature.return_type(self.db), &ty, @@ -83,7 +88,11 @@ impl<'a> InferBody<'a> { }, ); } else { - let ty = Monotype::Unit; + let ty = if has_never { + Monotype::Never + } else { + Monotype::Unit + }; self.unifier.unify( &self.signature.return_type(self.db), &ty, @@ -111,21 +120,22 @@ impl<'a> InferBody<'a> { } } - fn infer_stmt(&mut self, stmt: &hir::Stmt) { - match stmt { + fn infer_stmt(&mut self, stmt: &hir::Stmt) -> bool { + let ty = match stmt { hir::Stmt::VariableDef { name: _, value } => { let ty = self.infer_expr(*value); - let ty_scheme = TypeScheme::new(ty); + let ty_scheme = TypeScheme::new(ty.clone()); self.mut_current_scope().insert(*value, ty_scheme); + ty } hir::Stmt::ExprStmt { expr, has_semicolon: _, - } => { - self.infer_expr(*expr); - } - hir::Stmt::Item { .. } => (), - } + } => self.infer_expr(*expr), + hir::Stmt::Item { .. } => return false, + }; + + ty == Monotype::Never } fn infer_symbol(&mut self, symbol: &hir::Symbol) -> Monotype { @@ -322,8 +332,12 @@ impl<'a> InferBody<'a> { hir::Expr::Block(block) => { self.entry_scope(); + let mut has_never = false; for stmt in &block.stmts { - self.infer_stmt(stmt); + let is_never = self.infer_stmt(stmt); + if is_never { + has_never = true; + } } let ty = if let Some(tail) = &block.tail { @@ -332,10 +346,17 @@ impl<'a> InferBody<'a> { // 最後の式がない場合は Unit として扱う Monotype::Unit }; + if ty == Monotype::Never { + has_never = true; + } self.exit_scope(); - ty + if has_never { + Monotype::Never + } else { + ty + } } hir::Expr::If { condition, @@ -352,16 +373,18 @@ impl<'a> InferBody<'a> { ); let then_ty = self.infer_expr(*then_branch); + let mut else_ty: Option = None; if let Some(else_branch) = else_branch { - let else_ty = self.infer_expr(*else_branch); + let ty = self.infer_expr(*else_branch); self.unifier.unify( &then_ty, - &else_ty, + &ty, &UnifyPurpose::IfThenElseBranch { found_then_branch_expr: *then_branch, found_else_branch_expr: *else_branch, }, ); + else_ty = Some(ty); } else { // elseブランチがない場合は Unit として扱う self.unifier.unify( @@ -373,7 +396,17 @@ impl<'a> InferBody<'a> { ); } - then_ty + if let Some(else_ty) = else_ty { + if then_ty == Monotype::Never { + else_ty + } else { + then_ty + } + } else if then_ty == Monotype::Never { + Monotype::Never + } else { + then_ty + } } hir::Expr::Return { value } => { if let Some(return_value) = value { diff --git a/crates/hir_ty/src/inference/type_unifier.rs b/crates/hir_ty/src/inference/type_unifier.rs index c8b734a5..e7f18639 100644 --- a/crates/hir_ty/src/inference/type_unifier.rs +++ b/crates/hir_ty/src/inference/type_unifier.rs @@ -176,6 +176,7 @@ impl TypeUnifier { } match (&a_rep, &b_rep) { + (Monotype::Never, _) | (_, Monotype::Never) => (), (Monotype::Function(_a_signature), Monotype::Function(_b_signature)) => { unreachable!(); // self.unify(&a_signature.return_type, &b_signature.return_type, purpose); @@ -191,9 +192,6 @@ impl TypeUnifier { } (Monotype::Variable(_), b_rep) => self.unify_var(b_rep, &a_rep), (a_rep, Monotype::Variable(_)) => self.unify_var(a_rep, &b_rep), - // 実際の型がNever型の場合は、期待する型がなんであれ到達しないので型チェックする必要がない - // 期待する型がNever型の場合は、、値が渡ってはいけないので型チェックするべき。途中でreturnした場合など改善の余地があると思われる。 - (_, Monotype::Never) => (), (_, _) => { self.errors .push(build_unify_error_from_unify_purpose(a_rep, b_rep, purpose)); diff --git a/crates/hir_ty/src/lib.rs b/crates/hir_ty/src/lib.rs index 1fa251ca..56a6bcf4 100644 --- a/crates/hir_ty/src/lib.rs +++ b/crates/hir_ty/src/lib.rs @@ -1669,12 +1669,36 @@ mod tests { return 10; //: ! } else { expr:true //: bool - }; //: () + }; //: bool expr:20 //: int } --- - error MismatchedTypeElseBranch: then_branch_ty: (), else_branch_ty: bool, then_branch: `{ tail:none }`, else_branch: `{ tail:true }` + --- + "#]], + ); + + check_in_root_file( + r#" + fn main() -> int { + if true { + return 10; + } else { + 20 + } + } + "#, + expect![[r#" + //- /main.nail + fn entry:main() -> int { + expr:if true { + return 10; //: ! + } else { + expr:20 //: int + } //: int + } + + --- --- "#]], ); @@ -1704,7 +1728,6 @@ mod tests { } --- - error MismatchedTypeElseBranch: then_branch_ty: bool, else_branch_ty: (), then_branch: `{ tail:true }`, else_branch: `{ tail:none }` --- "#]], ); @@ -1729,7 +1752,7 @@ mod tests { return 10; //: ! } else { return 20; //: ! - }; //: () + }; //: ! expr:30 //: int } diff --git a/crates/mir/src/lib.rs b/crates/mir/src/lib.rs index 053b69d6..360bd806 100755 --- a/crates/mir/src/lib.rs +++ b/crates/mir/src/lib.rs @@ -1328,7 +1328,7 @@ mod tests { expect![[r#" fn t_pod::main() -> bool { let _0: bool - let _1: bool + let _1: ! let _2: bool entry: { @@ -1357,7 +1357,7 @@ mod tests { expect![[r#" fn t_pod::main() -> int { let _0: int - let _1: () + let _1: ! entry: { _0 = const 10 @@ -1458,6 +1458,7 @@ mod tests { ); } + // TODO: _2は!ではなくintが正しい #[test] fn return_value_when_true_in_switch() { check_in_root_file( @@ -1474,7 +1475,7 @@ mod tests { fn t_pod::main() -> int { let _0: int let _1: bool - let _2: () + let _2: ! entry: { _1 = const true From 58a0b556eb54102156e4babb9539cbfcc6d53be7 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sun, 17 Sep 2023 00:14:09 +0900 Subject: [PATCH 33/35] fix --- crates/mir/src/body.rs | 7 ++++--- crates/mir/src/lib.rs | 7 +++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/crates/mir/src/body.rs b/crates/mir/src/body.rs index 6e1443ea..e1e363ee 100644 --- a/crates/mir/src/body.rs +++ b/crates/mir/src/body.rs @@ -292,8 +292,9 @@ impl<'a> FunctionLower<'a> { self.current_bb = Some(switch_bb.else_bb_idx); match else_branch { - Some(else_block) => { - let else_block = match else_block.lookup(self.hir_file.db(self.db)) { + Some(else_block_expr) => { + let else_block = match else_block_expr.lookup(self.hir_file.db(self.db)) + { hir::Expr::Block(block) => block, _ => unreachable!(), }; @@ -319,7 +320,7 @@ impl<'a> FunctionLower<'a> { None => { let idxes = self .alloc_dest_bb_and_result_local( - *then_branch, + *else_block_expr, ); dest_bb_and_result_local_idx = Some(idxes); diff --git a/crates/mir/src/lib.rs b/crates/mir/src/lib.rs index 360bd806..87fdf13d 100755 --- a/crates/mir/src/lib.rs +++ b/crates/mir/src/lib.rs @@ -1413,7 +1413,7 @@ mod tests { } #[test] - fn return_value_when_false_in_switch() { + fn return_value_in_else_branch() { check_in_root_file( r#" fn main() -> int { @@ -1458,9 +1458,8 @@ mod tests { ); } - // TODO: _2は!ではなくintが正しい #[test] - fn return_value_when_true_in_switch() { + fn return_value_in_then_branch() { check_in_root_file( r#" fn main() -> int { @@ -1475,7 +1474,7 @@ mod tests { fn t_pod::main() -> int { let _0: int let _1: bool - let _2: ! + let _2: int entry: { _1 = const true From 158401166ad121ba8a9abe8bc800f7c176681523 Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sun, 17 Sep 2023 00:16:01 +0900 Subject: [PATCH 34/35] wip --- crates/codegen_llvm/src/body.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/codegen_llvm/src/body.rs b/crates/codegen_llvm/src/body.rs index cfdfcff1..3f2087e0 100644 --- a/crates/codegen_llvm/src/body.rs +++ b/crates/codegen_llvm/src/body.rs @@ -46,7 +46,8 @@ impl<'a, 'ctx> BodyCodegen<'a, 'ctx> { hir_ty::Monotype::String => self.codegen.string_type().into(), hir_ty::Monotype::Bool => self.codegen.bool_type().into(), hir_ty::Monotype::Char => todo!(), - hir_ty::Monotype::Never => self.codegen.unit_type().into(), + // Never用にUnitを確保した方がいいかも? + hir_ty::Monotype::Never => continue, hir_ty::Monotype::Function(_) => todo!(), hir_ty::Monotype::Variable(_) => unreachable!(""), hir_ty::Monotype::Unknown => unreachable!(), From 3a83135b80d337597195d2a3cd7d24a7b312083d Mon Sep 17 00:00:00 2001 From: ktanaka101 Date: Sun, 17 Sep 2023 01:08:45 +0900 Subject: [PATCH 35/35] wip --- crates/hir_ty/src/inference.rs | 478 +++++++++++++++++++- crates/hir_ty/src/inference/environment.rs | 484 +-------------------- 2 files changed, 480 insertions(+), 482 deletions(-) diff --git a/crates/hir_ty/src/inference.rs b/crates/hir_ty/src/inference.rs index bf780da0..2c9fe27f 100644 --- a/crates/hir_ty/src/inference.rs +++ b/crates/hir_ty/src/inference.rs @@ -6,12 +6,15 @@ mod types; use std::collections::HashMap; -use environment::{Environment, InferBody}; -pub use environment::{InferenceBodyResult, InferenceResult, Signature}; +use environment::Environment; pub use error::InferenceError; pub use type_scheme::TypeScheme; pub use types::Monotype; +use self::{ + environment::Context, + type_unifier::{TypeUnifier, UnifyPurpose}, +}; use crate::HirTyMasterDatabase; /// Pod全体の型推論を行う @@ -67,3 +70,474 @@ fn lower_type(ty: &hir::Type) -> Monotype { hir::Type::Unknown => Monotype::Unknown, } } + +/// 関数シグネチャ +#[salsa::tracked] +pub struct Signature { + /// 関数のパラメータ + #[return_ref] + pub params: Vec, + /// 関数の戻り値 + pub return_type: Monotype, +} + +/// 関数内を型推論します。 +pub(crate) struct InferBody<'a> { + db: &'a dyn HirTyMasterDatabase, + pods: &'a hir::Pods, + hir_file: hir::HirFile, + function: hir::Function, + signature: Signature, + + unifier: TypeUnifier, + cxt: Context, + + /// 環境のスタック + /// + /// スコープを入れ子にするために使用しています。 + /// スコープに入る時にpushし、スコープから抜ける時はpopします。 + env_stack: Vec, + signature_by_function: &'a HashMap, + + /// 推論結果の式の型を記録するためのマップ + type_by_expr: HashMap, +} +impl<'a> InferBody<'a> { + pub(crate) fn new( + db: &'a dyn HirTyMasterDatabase, + pods: &'a hir::Pods, + hir_file: hir::HirFile, + function: hir::Function, + signature_by_function: &'a HashMap, + env: Environment, + ) -> Self { + InferBody { + db, + pods, + hir_file, + function, + signature: *signature_by_function.get(&function).unwrap(), + + unifier: TypeUnifier::new(), + cxt: Context::default(), + env_stack: vec![env], + signature_by_function, + type_by_expr: HashMap::new(), + } + } + + pub(crate) fn infer_body(mut self) -> InferenceBodyResult { + let hir::Expr::Block(body) = self.hir_file.function_body_by_function(self.db, self.function).unwrap() else { panic!("Should be Block.") }; + let mut has_never = false; + for stmt in &body.stmts { + let is_never = self.infer_stmt(stmt); + if is_never { + has_never = true; + } + } + + if let Some(tail) = &body.tail { + let ty = self.infer_expr(*tail); + let ty = if has_never { Monotype::Never } else { ty }; + self.unifier.unify( + &self.signature.return_type(self.db), + &ty, + &UnifyPurpose::ReturnValue { + expected_signature: self.signature, + found_return_expr: Some(*tail), + }, + ); + } else { + let ty = if has_never { + Monotype::Never + } else { + Monotype::Unit + }; + self.unifier.unify( + &self.signature.return_type(self.db), + &ty, + &UnifyPurpose::ReturnValue { + expected_signature: self.signature, + found_return_expr: None, + }, + ); + }; + + InferenceBodyResult { + type_by_expr: self.type_by_expr, + errors: self.unifier.errors, + } + } + + fn infer_type(&mut self, ty: &hir::Type) -> Monotype { + match ty { + hir::Type::Integer => Monotype::Integer, + hir::Type::String => Monotype::String, + hir::Type::Char => Monotype::Char, + hir::Type::Boolean => Monotype::Bool, + hir::Type::Unit => Monotype::Unit, + hir::Type::Unknown => Monotype::Unknown, + } + } + + fn infer_stmt(&mut self, stmt: &hir::Stmt) -> bool { + let ty = match stmt { + hir::Stmt::VariableDef { name: _, value } => { + let ty = self.infer_expr(*value); + let ty_scheme = TypeScheme::new(ty.clone()); + self.mut_current_scope().insert(*value, ty_scheme); + ty + } + hir::Stmt::ExprStmt { + expr, + has_semicolon: _, + } => self.infer_expr(*expr), + hir::Stmt::Item { .. } => return false, + }; + + ty == Monotype::Never + } + + fn infer_symbol(&mut self, symbol: &hir::Symbol) -> Monotype { + match symbol { + hir::Symbol::Param { name: _, param } => { + let param = param.data(self.hir_file.db(self.db)); + self.infer_type(¶m.ty) + } + hir::Symbol::Local { name: _, expr } => { + let ty_scheme = self.current_scope().get(expr).cloned(); + if let Some(ty_scheme) = ty_scheme { + ty_scheme.instantiate(&mut self.cxt) + } else { + panic!("Unbound variable {symbol:?}"); + } + } + hir::Symbol::Missing { path } => { + let resolution_status = self.pods.resolution_map.item_by_symbol(path).unwrap(); + match self.resolve_resolution_status(resolution_status) { + Some(item) => match item { + hir::Item::Function(function) => { + let signature = self.signature_by_function.get(&function).unwrap(); + Monotype::Function(*signature) + } + hir::Item::Module(module) => { + self.unifier.add_error(InferenceError::ModuleAsExpr { + found_module: module, + }); + Monotype::Unknown + } + hir::Item::UseItem(_) => unreachable!("UseItem should be resolved."), + }, + None => Monotype::Unknown, + } + } + } + } + + fn resolve_resolution_status( + &self, + resolution_status: hir::ResolutionStatus, + ) -> Option { + match resolution_status { + hir::ResolutionStatus::Unresolved | hir::ResolutionStatus::Error => None, + hir::ResolutionStatus::Resolved { path: _, item } => match item { + hir::Item::Function(_) | hir::Item::Module(_) => Some(item), + hir::Item::UseItem(use_item) => { + let resolution_status = self.pods.resolution_map.item_by_use_item(&use_item)?; + self.resolve_resolution_status(resolution_status) + } + }, + } + } + + fn infer_expr(&mut self, expr_id: hir::ExprId) -> Monotype { + let expr = expr_id.lookup(self.hir_file.db(self.db)); + let ty = match expr { + hir::Expr::Literal(literal) => match literal { + hir::Literal::Integer(_) => Monotype::Integer, + hir::Literal::String(_) => Monotype::String, + hir::Literal::Char(_) => Monotype::Char, + hir::Literal::Bool(_) => Monotype::Bool, + }, + hir::Expr::Missing => Monotype::Unknown, + hir::Expr::Symbol(symbol) => self.infer_symbol(symbol), + hir::Expr::Call { + callee, + args: call_args, + } => { + let callee_ty = self.infer_symbol(callee); + match callee_ty { + Monotype::Integer + | Monotype::Bool + | Monotype::Unit + | Monotype::Char + | Monotype::String + | Monotype::Never + | Monotype::Unknown + | Monotype::Variable(_) => { + self.unifier.add_error(InferenceError::NotCallable { + found_callee_ty: callee_ty, + found_callee_symbol: callee.clone(), + }); + Monotype::Unknown + } + Monotype::Function(signature) => { + let call_args_ty = call_args + .iter() + .map(|arg| self.infer_expr(*arg)) + .collect::>(); + + if call_args_ty.len() != signature.params(self.db).len() { + self.unifier + .add_error(InferenceError::MismatchedCallArgCount { + expected_callee_arg_count: signature.params(self.db).len(), + found_arg_count: call_args_ty.len(), + }); + } else { + for (((arg_pos, call_arg), call_arg_ty), signature_arg_ty) in call_args + .iter() + .enumerate() + .zip(call_args_ty) + .zip(signature.params(self.db)) + { + self.unifier.unify( + signature_arg_ty, + &call_arg_ty, + &UnifyPurpose::CallArg { + found_arg_expr: *call_arg, + callee_signature: signature, + arg_pos, + }, + ); + } + } + + // 引数の数が異なったとしても、関数の戻り値は返す。 + signature.return_type(self.db) + } + } + } + hir::Expr::Binary { op, lhs, rhs } => match op { + ast::BinaryOp::Add(_) + | ast::BinaryOp::Sub(_) + | ast::BinaryOp::Mul(_) + | ast::BinaryOp::Div(_) => { + let lhs_ty = self.infer_expr(*lhs); + let rhs_ty = self.infer_expr(*rhs); + self.unifier.unify( + &Monotype::Integer, + &lhs_ty, + &UnifyPurpose::BinaryInteger { + found_expr: *lhs, + op: op.clone(), + }, + ); + self.unifier.unify( + &Monotype::Integer, + &rhs_ty, + &UnifyPurpose::BinaryInteger { + found_expr: *rhs, + op: op.clone(), + }, + ); + + Monotype::Integer + } + ast::BinaryOp::Equal(_) + | ast::BinaryOp::GreaterThan(_) + | ast::BinaryOp::LessThan(_) => { + let lhs_ty = self.infer_expr(*lhs); + let rhs_ty = self.infer_expr(*rhs); + self.unifier.unify( + &lhs_ty, + &rhs_ty, + &UnifyPurpose::BinaryCompare { + found_compare_from_expr: *lhs, + found_compare_to_expr: *rhs, + op: op.clone(), + }, + ); + + Monotype::Bool + } + }, + hir::Expr::Unary { op, expr } => match op { + ast::UnaryOp::Neg(_) => { + let expr_ty = self.infer_expr(*expr); + self.unifier.unify( + &Monotype::Integer, + &expr_ty, + &UnifyPurpose::Unary { + found_expr: *expr, + op: op.clone(), + }, + ); + + Monotype::Integer + } + ast::UnaryOp::Not(_) => { + let expr_ty = self.infer_expr(*expr); + self.unifier.unify( + &Monotype::Bool, + &expr_ty, + &UnifyPurpose::Unary { + found_expr: *expr, + op: op.clone(), + }, + ); + + Monotype::Bool + } + }, + hir::Expr::Block(block) => { + self.entry_scope(); + + let mut has_never = false; + for stmt in &block.stmts { + let is_never = self.infer_stmt(stmt); + if is_never { + has_never = true; + } + } + + let ty = if let Some(tail) = &block.tail { + self.infer_expr(*tail) + } else { + // 最後の式がない場合は Unit として扱う + Monotype::Unit + }; + if ty == Monotype::Never { + has_never = true; + } + + self.exit_scope(); + + if has_never { + Monotype::Never + } else { + ty + } + } + hir::Expr::If { + condition, + then_branch, + else_branch, + } => { + let condition_ty = self.infer_expr(*condition); + self.unifier.unify( + &Monotype::Bool, + &condition_ty, + &UnifyPurpose::IfCondition { + found_condition_expr: *condition, + }, + ); + + let then_ty = self.infer_expr(*then_branch); + let mut else_ty: Option = None; + if let Some(else_branch) = else_branch { + let ty = self.infer_expr(*else_branch); + self.unifier.unify( + &then_ty, + &ty, + &UnifyPurpose::IfThenElseBranch { + found_then_branch_expr: *then_branch, + found_else_branch_expr: *else_branch, + }, + ); + else_ty = Some(ty); + } else { + // elseブランチがない場合は Unit として扱う + self.unifier.unify( + &Monotype::Unit, + &then_ty, + &UnifyPurpose::IfThenOnlyBranch { + found_then_branch_expr: *then_branch, + }, + ); + } + + if let Some(else_ty) = else_ty { + if then_ty == Monotype::Never { + else_ty + } else { + then_ty + } + } else if then_ty == Monotype::Never { + Monotype::Never + } else { + then_ty + } + } + hir::Expr::Return { value } => { + if let Some(return_value) = value { + let ty = self.infer_expr(*return_value); + self.unifier.unify( + &self.signature.return_type(self.db), + &ty, + &UnifyPurpose::ReturnValue { + expected_signature: self.signature, + found_return_expr: Some(*return_value), + }, + ); + } else { + // 何も指定しない場合は Unit を返すものとして扱う + self.unifier.unify( + &self.signature.return_type(self.db), + &Monotype::Unit, + &UnifyPurpose::ReturnValue { + expected_signature: self.signature, + found_return_expr: None, + }, + ); + } + + // return自体の戻り値は Never として扱う + Monotype::Never + } + }; + + self.type_by_expr.insert(expr_id, ty.clone()); + + ty + } + + fn entry_scope(&mut self) { + let env = self.env_stack.last().unwrap().with(); + self.env_stack.push(env); + } + + fn exit_scope(&mut self) { + self.env_stack.pop(); + } + + fn mut_current_scope(&mut self) -> &mut Environment { + self.env_stack.last_mut().unwrap() + } + + fn current_scope(&self) -> &Environment { + self.env_stack.last().unwrap() + } +} + +/// Pod全体の型推論結果 +#[derive(Debug)] +pub struct InferenceResult { + pub(crate) signature_by_function: HashMap, + pub(crate) inference_body_result_by_function: HashMap, +} + +/// 関数内の型推論結果 +#[derive(Debug)] +pub struct InferenceBodyResult { + pub(crate) type_by_expr: HashMap, + pub(crate) errors: Vec, +} +impl InferenceBodyResult { + pub fn type_by_expr(&self, expr: hir::ExprId) -> Option<&Monotype> { + self.type_by_expr.get(&expr) + } + + pub fn errors(&self) -> &Vec { + &self.errors + } +} diff --git a/crates/hir_ty/src/inference/environment.rs b/crates/hir_ty/src/inference/environment.rs index 78f33b5f..20aa286c 100644 --- a/crates/hir_ty/src/inference/environment.rs +++ b/crates/hir_ty/src/inference/environment.rs @@ -3,485 +3,9 @@ use std::{ ops::Sub, }; -use super::{ - error::InferenceError, - type_scheme::TypeScheme, - type_unifier::{TypeUnifier, UnifyPurpose}, - types::Monotype, -}; +use super::{type_scheme::TypeScheme, types::Monotype}; use crate::HirTyMasterDatabase; -/// 関数シグネチャ -#[salsa::tracked] -pub struct Signature { - /// 関数のパラメータ - #[return_ref] - pub params: Vec, - /// 関数の戻り値 - pub return_type: Monotype, -} - -/// 関数内を型推論します。 -pub(crate) struct InferBody<'a> { - db: &'a dyn HirTyMasterDatabase, - pods: &'a hir::Pods, - hir_file: hir::HirFile, - function: hir::Function, - signature: Signature, - - unifier: TypeUnifier, - cxt: Context, - - /// 環境のスタック - /// - /// スコープを入れ子にするために使用しています。 - /// スコープに入る時にpushし、スコープから抜ける時はpopします。 - env_stack: Vec, - signature_by_function: &'a HashMap, - - /// 推論結果の式の型を記録するためのマップ - type_by_expr: HashMap, -} -impl<'a> InferBody<'a> { - pub(crate) fn new( - db: &'a dyn HirTyMasterDatabase, - pods: &'a hir::Pods, - hir_file: hir::HirFile, - function: hir::Function, - signature_by_function: &'a HashMap, - env: Environment, - ) -> Self { - InferBody { - db, - pods, - hir_file, - function, - signature: *signature_by_function.get(&function).unwrap(), - - unifier: TypeUnifier::new(), - cxt: Context::default(), - env_stack: vec![env], - signature_by_function, - type_by_expr: HashMap::new(), - } - } - - pub(crate) fn infer_body(mut self) -> InferenceBodyResult { - let hir::Expr::Block(body) = self.hir_file.function_body_by_function(self.db, self.function).unwrap() else { panic!("Should be Block.") }; - let mut has_never = false; - for stmt in &body.stmts { - let is_never = self.infer_stmt(stmt); - if is_never { - has_never = true; - } - } - - if let Some(tail) = &body.tail { - let ty = self.infer_expr(*tail); - let ty = if has_never { Monotype::Never } else { ty }; - self.unifier.unify( - &self.signature.return_type(self.db), - &ty, - &UnifyPurpose::ReturnValue { - expected_signature: self.signature, - found_return_expr: Some(*tail), - }, - ); - } else { - let ty = if has_never { - Monotype::Never - } else { - Monotype::Unit - }; - self.unifier.unify( - &self.signature.return_type(self.db), - &ty, - &UnifyPurpose::ReturnValue { - expected_signature: self.signature, - found_return_expr: None, - }, - ); - }; - - InferenceBodyResult { - type_by_expr: self.type_by_expr, - errors: self.unifier.errors, - } - } - - fn infer_type(&mut self, ty: &hir::Type) -> Monotype { - match ty { - hir::Type::Integer => Monotype::Integer, - hir::Type::String => Monotype::String, - hir::Type::Char => Monotype::Char, - hir::Type::Boolean => Monotype::Bool, - hir::Type::Unit => Monotype::Unit, - hir::Type::Unknown => Monotype::Unknown, - } - } - - fn infer_stmt(&mut self, stmt: &hir::Stmt) -> bool { - let ty = match stmt { - hir::Stmt::VariableDef { name: _, value } => { - let ty = self.infer_expr(*value); - let ty_scheme = TypeScheme::new(ty.clone()); - self.mut_current_scope().insert(*value, ty_scheme); - ty - } - hir::Stmt::ExprStmt { - expr, - has_semicolon: _, - } => self.infer_expr(*expr), - hir::Stmt::Item { .. } => return false, - }; - - ty == Monotype::Never - } - - fn infer_symbol(&mut self, symbol: &hir::Symbol) -> Monotype { - match symbol { - hir::Symbol::Param { name: _, param } => { - let param = param.data(self.hir_file.db(self.db)); - self.infer_type(¶m.ty) - } - hir::Symbol::Local { name: _, expr } => { - let ty_scheme = self.current_scope().get(expr).cloned(); - if let Some(ty_scheme) = ty_scheme { - ty_scheme.instantiate(&mut self.cxt) - } else { - panic!("Unbound variable {symbol:?}"); - } - } - hir::Symbol::Missing { path } => { - let resolution_status = self.pods.resolution_map.item_by_symbol(path).unwrap(); - match self.resolve_resolution_status(resolution_status) { - Some(item) => match item { - hir::Item::Function(function) => { - let signature = self.signature_by_function.get(&function).unwrap(); - Monotype::Function(*signature) - } - hir::Item::Module(module) => { - self.unifier.add_error(InferenceError::ModuleAsExpr { - found_module: module, - }); - Monotype::Unknown - } - hir::Item::UseItem(_) => unreachable!("UseItem should be resolved."), - }, - None => Monotype::Unknown, - } - } - } - } - - fn resolve_resolution_status( - &self, - resolution_status: hir::ResolutionStatus, - ) -> Option { - match resolution_status { - hir::ResolutionStatus::Unresolved | hir::ResolutionStatus::Error => None, - hir::ResolutionStatus::Resolved { path: _, item } => match item { - hir::Item::Function(_) | hir::Item::Module(_) => Some(item), - hir::Item::UseItem(use_item) => { - let resolution_status = self.pods.resolution_map.item_by_use_item(&use_item)?; - self.resolve_resolution_status(resolution_status) - } - }, - } - } - - fn infer_expr(&mut self, expr_id: hir::ExprId) -> Monotype { - let expr = expr_id.lookup(self.hir_file.db(self.db)); - let ty = match expr { - hir::Expr::Literal(literal) => match literal { - hir::Literal::Integer(_) => Monotype::Integer, - hir::Literal::String(_) => Monotype::String, - hir::Literal::Char(_) => Monotype::Char, - hir::Literal::Bool(_) => Monotype::Bool, - }, - hir::Expr::Missing => Monotype::Unknown, - hir::Expr::Symbol(symbol) => self.infer_symbol(symbol), - hir::Expr::Call { - callee, - args: call_args, - } => { - let callee_ty = self.infer_symbol(callee); - match callee_ty { - Monotype::Integer - | Monotype::Bool - | Monotype::Unit - | Monotype::Char - | Monotype::String - | Monotype::Never - | Monotype::Unknown - | Monotype::Variable(_) => { - self.unifier.add_error(InferenceError::NotCallable { - found_callee_ty: callee_ty, - found_callee_symbol: callee.clone(), - }); - Monotype::Unknown - } - Monotype::Function(signature) => { - let call_args_ty = call_args - .iter() - .map(|arg| self.infer_expr(*arg)) - .collect::>(); - - if call_args_ty.len() != signature.params(self.db).len() { - self.unifier - .add_error(InferenceError::MismatchedCallArgCount { - expected_callee_arg_count: signature.params(self.db).len(), - found_arg_count: call_args_ty.len(), - }); - } else { - for (((arg_pos, call_arg), call_arg_ty), signature_arg_ty) in call_args - .iter() - .enumerate() - .zip(call_args_ty) - .zip(signature.params(self.db)) - { - self.unifier.unify( - signature_arg_ty, - &call_arg_ty, - &UnifyPurpose::CallArg { - found_arg_expr: *call_arg, - callee_signature: signature, - arg_pos, - }, - ); - } - } - - // 引数の数が異なったとしても、関数の戻り値は返す。 - signature.return_type(self.db) - } - } - } - hir::Expr::Binary { op, lhs, rhs } => match op { - ast::BinaryOp::Add(_) - | ast::BinaryOp::Sub(_) - | ast::BinaryOp::Mul(_) - | ast::BinaryOp::Div(_) => { - let lhs_ty = self.infer_expr(*lhs); - let rhs_ty = self.infer_expr(*rhs); - self.unifier.unify( - &Monotype::Integer, - &lhs_ty, - &UnifyPurpose::BinaryInteger { - found_expr: *lhs, - op: op.clone(), - }, - ); - self.unifier.unify( - &Monotype::Integer, - &rhs_ty, - &UnifyPurpose::BinaryInteger { - found_expr: *rhs, - op: op.clone(), - }, - ); - - Monotype::Integer - } - ast::BinaryOp::Equal(_) - | ast::BinaryOp::GreaterThan(_) - | ast::BinaryOp::LessThan(_) => { - let lhs_ty = self.infer_expr(*lhs); - let rhs_ty = self.infer_expr(*rhs); - self.unifier.unify( - &lhs_ty, - &rhs_ty, - &UnifyPurpose::BinaryCompare { - found_compare_from_expr: *lhs, - found_compare_to_expr: *rhs, - op: op.clone(), - }, - ); - - Monotype::Bool - } - }, - hir::Expr::Unary { op, expr } => match op { - ast::UnaryOp::Neg(_) => { - let expr_ty = self.infer_expr(*expr); - self.unifier.unify( - &Monotype::Integer, - &expr_ty, - &UnifyPurpose::Unary { - found_expr: *expr, - op: op.clone(), - }, - ); - - Monotype::Integer - } - ast::UnaryOp::Not(_) => { - let expr_ty = self.infer_expr(*expr); - self.unifier.unify( - &Monotype::Bool, - &expr_ty, - &UnifyPurpose::Unary { - found_expr: *expr, - op: op.clone(), - }, - ); - - Monotype::Bool - } - }, - hir::Expr::Block(block) => { - self.entry_scope(); - - let mut has_never = false; - for stmt in &block.stmts { - let is_never = self.infer_stmt(stmt); - if is_never { - has_never = true; - } - } - - let ty = if let Some(tail) = &block.tail { - self.infer_expr(*tail) - } else { - // 最後の式がない場合は Unit として扱う - Monotype::Unit - }; - if ty == Monotype::Never { - has_never = true; - } - - self.exit_scope(); - - if has_never { - Monotype::Never - } else { - ty - } - } - hir::Expr::If { - condition, - then_branch, - else_branch, - } => { - let condition_ty = self.infer_expr(*condition); - self.unifier.unify( - &Monotype::Bool, - &condition_ty, - &UnifyPurpose::IfCondition { - found_condition_expr: *condition, - }, - ); - - let then_ty = self.infer_expr(*then_branch); - let mut else_ty: Option = None; - if let Some(else_branch) = else_branch { - let ty = self.infer_expr(*else_branch); - self.unifier.unify( - &then_ty, - &ty, - &UnifyPurpose::IfThenElseBranch { - found_then_branch_expr: *then_branch, - found_else_branch_expr: *else_branch, - }, - ); - else_ty = Some(ty); - } else { - // elseブランチがない場合は Unit として扱う - self.unifier.unify( - &Monotype::Unit, - &then_ty, - &UnifyPurpose::IfThenOnlyBranch { - found_then_branch_expr: *then_branch, - }, - ); - } - - if let Some(else_ty) = else_ty { - if then_ty == Monotype::Never { - else_ty - } else { - then_ty - } - } else if then_ty == Monotype::Never { - Monotype::Never - } else { - then_ty - } - } - hir::Expr::Return { value } => { - if let Some(return_value) = value { - let ty = self.infer_expr(*return_value); - self.unifier.unify( - &self.signature.return_type(self.db), - &ty, - &UnifyPurpose::ReturnValue { - expected_signature: self.signature, - found_return_expr: Some(*return_value), - }, - ); - } else { - // 何も指定しない場合は Unit を返すものとして扱う - self.unifier.unify( - &self.signature.return_type(self.db), - &Monotype::Unit, - &UnifyPurpose::ReturnValue { - expected_signature: self.signature, - found_return_expr: None, - }, - ); - } - - // return自体の戻り値は Never として扱う - Monotype::Never - } - }; - - self.type_by_expr.insert(expr_id, ty.clone()); - - ty - } - - fn entry_scope(&mut self) { - let env = self.env_stack.last().unwrap().with(); - self.env_stack.push(env); - } - - fn exit_scope(&mut self) { - self.env_stack.pop(); - } - - fn mut_current_scope(&mut self) -> &mut Environment { - self.env_stack.last_mut().unwrap() - } - - fn current_scope(&self) -> &Environment { - self.env_stack.last().unwrap() - } -} - -/// Pod全体の型推論結果 -#[derive(Debug)] -pub struct InferenceResult { - pub(crate) signature_by_function: HashMap, - pub(crate) inference_body_result_by_function: HashMap, -} - -/// 関数内の型推論結果 -#[derive(Debug)] -pub struct InferenceBodyResult { - pub(crate) type_by_expr: HashMap, - pub(crate) errors: Vec, -} -impl InferenceBodyResult { - pub fn type_by_expr(&self, expr: hir::ExprId) -> Option<&Monotype> { - self.type_by_expr.get(&expr) - } - - pub fn errors(&self) -> &Vec { - &self.errors - } -} - /// Hindley-Milner型システムにおける型環境 #[derive(Default)] pub struct Environment { @@ -528,7 +52,7 @@ impl Environment { union } - fn with(&self) -> Environment { + pub(crate) fn with(&self) -> Environment { let mut copy = HashMap::new(); // FIXME: clone かつサイズが不定なので遅いかも。 copy.extend(self.bindings.clone()); @@ -536,11 +60,11 @@ impl Environment { Environment { bindings: copy } } - fn get(&self, expr: &hir::ExprId) -> Option<&TypeScheme> { + pub(crate) fn get(&self, expr: &hir::ExprId) -> Option<&TypeScheme> { self.bindings.get(expr) } - fn insert(&mut self, expr: hir::ExprId, ty_scheme: TypeScheme) { + pub(crate) fn insert(&mut self, expr: hir::ExprId, ty_scheme: TypeScheme) { self.bindings.insert(expr, ty_scheme); }