From 1559845ee252577c5191c9ce43f3bf6f6ef98239 Mon Sep 17 00:00:00 2001 From: dak2 Date: Thu, 12 Mar 2026 23:14:08 +0900 Subject: [PATCH] Add complete multi-assignment type inference Support splat targets (*rest) as Array[T], nil-fill when LHS exceeds RHS, right-side targets after splat, and basic single-expression array decomposition. This resolves the v0.1.8 TODOs for multi-assignment handling except for method return value decomposition which requires graph lazy resolution. Co-Authored-By: Claude Opus 4.6 --- core/src/analyzer/assignments.rs | 397 +++++++++++++++++++++++++++++-- test/multi_assign_test.rb | 83 +++++++ 2 files changed, 455 insertions(+), 25 deletions(-) create mode 100644 test/multi_assign_test.rb diff --git a/core/src/analyzer/assignments.rs b/core/src/analyzer/assignments.rs index 5f82848..7449a85 100644 --- a/core/src/analyzer/assignments.rs +++ b/core/src/analyzer/assignments.rs @@ -1,17 +1,81 @@ //! Multiple Assignment Handlers - Processing Ruby multiple assignment //! -//! v0.1.8 scope: Only RHS as ArrayNode (multiple literal values) is supported. -//! TODO: Support RHS as single expression (array decomposition) -//! TODO: Support splat target (*rest) as Array type -//! TODO: Support RHS as method return value decomposition -//! TODO: When LHS is longer than RHS, register trailing targets as NilClass +//! Supports: ArrayNode RHS with 1:1 mapping, LHS > RHS nil fill, +//! splat targets (*rest) as Array type, and basic single-expression RHS decomposition. +//! TODO: Support RHS as method return value decomposition (requires graph lazy resolution) use crate::env::{GlobalEnv, LocalEnv}; use crate::graph::{ChangeSet, VertexId}; +use crate::types::Type; use super::bytes_to_name; use super::variables::install_local_var_write; +/// Install an RHS node and assign it to a named local variable. +/// Falls back to an untyped vertex when `install_node` returns `None`. +fn install_target( + genv: &mut GlobalEnv, + lenv: &mut LocalEnv, + changes: &mut ChangeSet, + source: &str, + var_name: String, + rhs_node: &ruby_prism::Node, +) -> VertexId { + if let Some(rv) = super::install::install_node(genv, lenv, changes, source, rhs_node) { + install_local_var_write(genv, lenv, changes, var_name, rv) + } else { + let vtx = genv.new_vertex(); + lenv.new_var(var_name, vtx); + vtx + } +} + +/// Assign `Type::Nil` to a named local variable. +fn install_nil_target( + genv: &mut GlobalEnv, + lenv: &mut LocalEnv, + changes: &mut ChangeSet, + var_name: String, +) -> VertexId { + let nil_src = genv.new_source(Type::Nil); + install_local_var_write(genv, lenv, changes, var_name, nil_src) +} + +/// Extract the splat target variable name from a `node.rest()` result. +fn splat_var_name(rest_node: &ruby_prism::Node) -> Option { + let splat = rest_node.as_splat_node()?; + let expr = splat.expression()?; + let target = expr.as_local_variable_target_node()?; + Some(bytes_to_name(target.name().as_slice())) +} + +/// Collect unique types from a slice of elements, returning element type for Array[T]. +/// Only nodes that resolve to a Source (fixed type) contribute; Vertex-type nodes +/// are excluded (known limitation — requires lazy resolution for method return values). +fn collect_element_type( + genv: &mut GlobalEnv, + lenv: &mut LocalEnv, + changes: &mut ChangeSet, + source: &str, + elements: &[ruby_prism::Node], +) -> Type { + let mut types: Vec = Vec::new(); + for elem in elements { + if let Some(vtx) = super::install::install_node(genv, lenv, changes, source, elem) { + if let Some(src) = genv.get_source(vtx) { + if !types.contains(&src.ty) { + types.push(src.ty.clone()); + } + } + } + } + if types.is_empty() { + Type::Bot + } else { + Type::union_of(types) + } +} + /// Process multiple assignment node (e.g., `a, b = 1, "hello"`) pub(crate) fn process_multi_write_node( genv: &mut GlobalEnv, @@ -24,27 +88,129 @@ pub(crate) fn process_multi_write_node( let mut last_vtx = None; if let Some(array_node) = value.as_array_node() { - for (target, rhs_elem) in node.lefts().iter().zip(array_node.elements().iter()) { + let lefts = node.lefts(); + let elements: Vec<_> = array_node.elements().iter().collect(); + let total = elements.len(); + let lefts_count = lefts.len(); + let rights = node.rights(); + let rights_count = rights.len(); + + // Phase 1: Left targets — assign from start, nil for missing RHS + for (i, target) in lefts.iter().enumerate() { if let Some(target_node) = target.as_local_variable_target_node() { let var_name = bytes_to_name(target_node.name().as_slice()); - let rhs_vtx = - super::install::install_node(genv, lenv, changes, source, &rhs_elem); - if let Some(rv) = rhs_vtx { - last_vtx = Some(install_local_var_write(genv, lenv, changes, var_name, rv)); + if i < total { + last_vtx = Some(install_target( + genv, + lenv, + changes, + source, + var_name, + &elements[i], + )); } else { - let var_vtx = genv.new_vertex(); - lenv.new_var(var_name, var_vtx); - last_vtx = Some(var_vtx); + last_vtx = Some(install_nil_target(genv, lenv, changes, var_name)); + } + } + } + + // Phase 2: Splat target (*rest) — collect middle elements (after lefts, before rights) into Array[T] + if let Some(rest_node) = node.rest() { + if let Some(var_name) = splat_var_name(&rest_node) { + let splat_start = lefts_count; + let splat_end = total.saturating_sub(rights_count); + let splat_elements = if splat_start < splat_end { + &elements[splat_start..splat_end] + } else { + &elements[0..0] + }; + let element_type = collect_element_type( + genv, + lenv, + changes, + source, + splat_elements, + ); + let array_src = genv.new_source(Type::array_of(element_type)); + last_vtx = Some(install_local_var_write( + genv, lenv, changes, var_name, array_src, + )); + } + } + + // Phase 3: Right targets — assigned from end of elements, nil if overlapping with lefts + for (i, target) in rights.iter().enumerate() { + if let Some(target_node) = target.as_local_variable_target_node() { + let var_name = bytes_to_name(target_node.name().as_slice()); + let signed_idx = total as isize - rights_count as isize + i as isize; + if signed_idx >= lefts_count as isize && (signed_idx as usize) < total { + last_vtx = Some(install_target( + genv, + lenv, + changes, + source, + var_name, + &elements[signed_idx as usize], + )); + } else { + last_vtx = Some(install_nil_target(genv, lenv, changes, var_name)); } } } } else { - for target in node.lefts().iter() { + // RHS is a single expression (not comma-separated) + let rhs_vtx = super::install::install_node(genv, lenv, changes, source, &value); + + let rhs_type = rhs_vtx + .and_then(|vtx| genv.get_source(vtx)) + .map(|src| src.ty.clone()); + + // If RHS is Array[T], each target gets T; otherwise first target gets RHS, rest get nil + let element_type = rhs_type + .as_ref() + .and_then(|ty| ty.type_args()) + .and_then(|args| args.first().cloned()); + + for (i, target) in node.lefts().iter().enumerate() { if let Some(target_node) = target.as_local_variable_target_node() { let var_name = bytes_to_name(target_node.name().as_slice()); - let var_vtx = genv.new_vertex(); - lenv.new_var(var_name, var_vtx); - last_vtx = Some(var_vtx); + if let Some(ref elem_ty) = element_type { + let src = genv.new_source(elem_ty.clone()); + last_vtx = Some(install_local_var_write(genv, lenv, changes, var_name, src)); + } else if i == 0 { + if let Some(rv) = rhs_vtx { + last_vtx = Some(install_local_var_write(genv, lenv, changes, var_name, rv)); + } else { + let vtx = genv.new_vertex(); + lenv.new_var(var_name, vtx); + last_vtx = Some(vtx); + } + } else if rhs_type.is_some() { + last_vtx = Some(install_nil_target(genv, lenv, changes, var_name)); + } else { + let vtx = genv.new_vertex(); + lenv.new_var(var_name, vtx); + last_vtx = Some(vtx); + } + } + } + + // Splat in single-expression RHS + if let Some(rest_node) = node.rest() { + if let Some(var_name) = splat_var_name(&rest_node) { + let elem_ty = element_type.unwrap_or(Type::Bot); + let array_src = genv.new_source(Type::array_of(elem_ty)); + last_vtx = Some(install_local_var_write( + genv, lenv, changes, var_name, array_src, + )); + } + } + + // Right targets in single-expression RHS → nil + for target in node.rights().iter() { + if let Some(target_node) = target.as_local_variable_target_node() { + let var_name = bytes_to_name(target_node.name().as_slice()); + last_vtx = Some(install_nil_target(genv, lenv, changes, var_name)); } } } @@ -129,15 +295,16 @@ x = a #[test] fn test_multi_write_lhs_longer_than_rhs() { let source = "a, b, c = 1, 2"; - let (_, lenv) = analyze(source); + let (genv, lenv) = analyze(source); - assert!(lenv.get_var("a").is_some(), "a should be registered"); - assert!(lenv.get_var("b").is_some(), "b should be registered"); - // KNOWN LIMITATION (v0.1.8): In Ruby, c receives nil, but zip skips it here - assert!( - lenv.get_var("c").is_none(), - "c should not be registered (zip skips)" - ); + let a_vtx = lenv.get_var("a").expect("a should be registered"); + assert_eq!(get_type_show(&genv, a_vtx), "Integer"); + + let b_vtx = lenv.get_var("b").expect("b should be registered"); + assert_eq!(get_type_show(&genv, b_vtx), "Integer"); + + let c_vtx = lenv.get_var("c").expect("c should be registered with nil"); + assert_eq!(get_type_show(&genv, c_vtx), "nil"); } #[test] @@ -149,4 +316,184 @@ x = a assert!(lenv.get_var("a").is_some(), "a should be registered"); assert!(lenv.get_var("b").is_some(), "b should be registered"); } + + #[test] + fn test_multi_write_splat_basic() { + let source = "first, *rest = 1, 2, 3"; + let (genv, lenv) = analyze(source); + + let first_vtx = lenv.get_var("first").expect("first should be registered"); + assert_eq!(get_type_show(&genv, first_vtx), "Integer"); + + let rest_vtx = lenv.get_var("rest").expect("rest should be registered"); + assert_eq!(get_type_show(&genv, rest_vtx), "Array[Integer]"); + } + + #[test] + fn test_multi_write_splat_mixed_types() { + let source = r#"first, *rest = 1, "hello", :sym"#; + let (genv, lenv) = analyze(source); + + let first_vtx = lenv.get_var("first").expect("first should be registered"); + assert_eq!(get_type_show(&genv, first_vtx), "Integer"); + + let rest_vtx = lenv.get_var("rest").expect("rest should be registered"); + let type_str = get_type_show(&genv, rest_vtx); + assert!( + type_str.contains("Array"), + "should be Array type: {}", + type_str + ); + assert!( + type_str.contains("String"), + "should contain String: {}", + type_str + ); + assert!( + type_str.contains("Symbol"), + "should contain Symbol: {}", + type_str + ); + } + + #[test] + fn test_multi_write_splat_empty() { + let source = "first, *rest = 1"; + let (genv, lenv) = analyze(source); + + let first_vtx = lenv.get_var("first").expect("first should be registered"); + assert_eq!(get_type_show(&genv, first_vtx), "Integer"); + + let rest_vtx = lenv.get_var("rest").expect("rest should be registered"); + assert_eq!(get_type_show(&genv, rest_vtx), "Array[untyped]"); + } + + #[test] + fn test_multi_write_splat_with_rights() { + let source = "first, *rest, last = 1, 2, 3, 4"; + let (genv, lenv) = analyze(source); + + let first_vtx = lenv.get_var("first").expect("first should be registered"); + assert_eq!(get_type_show(&genv, first_vtx), "Integer"); + + let rest_vtx = lenv.get_var("rest").expect("rest should be registered"); + assert_eq!(get_type_show(&genv, rest_vtx), "Array[Integer]"); + + let last_vtx = lenv.get_var("last").expect("last should be registered"); + assert_eq!(get_type_show(&genv, last_vtx), "Integer"); + } + + #[test] + fn test_multi_write_splat_only() { + let source = "*all = 1, 2, 3"; + let (genv, lenv) = analyze(source); + + let all_vtx = lenv.get_var("all").expect("all should be registered"); + assert_eq!(get_type_show(&genv, all_vtx), "Array[Integer]"); + } + + #[test] + fn test_multi_write_splat_rights_no_lefts() { + let source = "*rest, last = 1, 2, 3"; + let (genv, lenv) = analyze(source); + + let rest_vtx = lenv.get_var("rest").expect("rest should be registered"); + assert_eq!(get_type_show(&genv, rest_vtx), "Array[Integer]"); + + let last_vtx = lenv.get_var("last").expect("last should be registered"); + assert_eq!(get_type_show(&genv, last_vtx), "Integer"); + } + + #[test] + fn test_multi_write_array_literal_rhs() { + // Explicit array literal RHS is decomposed element-by-element (same as comma-separated form) + let source = r#"a, b = [1, "hi"]"#; + let (genv, lenv) = analyze(source); + + let a_vtx = lenv.get_var("a").expect("a should be registered"); + assert_eq!(get_type_show(&genv, a_vtx), "Integer"); + + let b_vtx = lenv.get_var("b").expect("b should be registered"); + assert_eq!(get_type_show(&genv, b_vtx), "String"); + } + + #[test] + fn test_multi_write_splat_lefts_exceed_rhs() { + // Edge case: more left targets than RHS elements with splat + let source = "a, b, c, *rest = 1, 2"; + let (genv, lenv) = analyze(source); + + let a_vtx = lenv.get_var("a").expect("a should be registered"); + assert_eq!(get_type_show(&genv, a_vtx), "Integer"); + + let b_vtx = lenv.get_var("b").expect("b should be registered"); + assert_eq!(get_type_show(&genv, b_vtx), "Integer"); + + let c_vtx = lenv.get_var("c").expect("c should be registered"); + assert_eq!(get_type_show(&genv, c_vtx), "nil"); + + let rest_vtx = lenv.get_var("rest").expect("rest should be registered"); + assert_eq!(get_type_show(&genv, rest_vtx), "Array[untyped]"); + } + + #[test] + fn test_multi_write_splat_with_rights_insufficient_rhs() { + // Edge case: lefts + rights > total elements, splat between them + let source = "a, *rest, z = 1"; + let (genv, lenv) = analyze(source); + + let a_vtx = lenv.get_var("a").expect("a should be registered"); + assert_eq!(get_type_show(&genv, a_vtx), "Integer"); + + let rest_vtx = lenv.get_var("rest").expect("rest should be registered"); + assert_eq!(get_type_show(&genv, rest_vtx), "Array[untyped]"); + + let z_vtx = lenv.get_var("z").expect("z should be registered"); + assert_eq!(get_type_show(&genv, z_vtx), "nil"); + } + + #[test] + fn test_multi_write_rights_exceed_rhs() { + // Edge case: more right targets than available elements + let source = r#"*rest, x, y, z = "a", 1"#; + let (genv, lenv) = analyze(source); + + let rest_vtx = lenv.get_var("rest").expect("rest should be registered"); + assert_eq!(get_type_show(&genv, rest_vtx), "Array[untyped]"); + + let x_vtx = lenv.get_var("x").expect("x should be registered"); + assert_eq!(get_type_show(&genv, x_vtx), "nil"); + + let y_vtx = lenv.get_var("y").expect("y should be registered"); + assert_eq!(get_type_show(&genv, y_vtx), "String"); + + let z_vtx = lenv.get_var("z").expect("z should be registered"); + assert_eq!(get_type_show(&genv, z_vtx), "Integer"); + } + + #[test] + fn test_multi_write_scalar_rhs() { + // Single non-array expression: first target gets value, rest get nil + let source = "a, b = 42"; + let (genv, lenv) = analyze(source); + + let a_vtx = lenv.get_var("a").expect("a should be registered"); + assert_eq!(get_type_show(&genv, a_vtx), "Integer"); + + let b_vtx = lenv.get_var("b").expect("b should be registered"); + assert_eq!(get_type_show(&genv, b_vtx), "nil"); + } + + #[test] + fn test_multi_write_rhs_longer_than_lhs() { + // Extra RHS elements are silently discarded + let source = "a, b = 1, 2, 3, 4"; + let (genv, lenv) = analyze(source); + + let a_vtx = lenv.get_var("a").expect("a should be registered"); + assert_eq!(get_type_show(&genv, a_vtx), "Integer"); + + let b_vtx = lenv.get_var("b").expect("b should be registered"); + assert_eq!(get_type_show(&genv, b_vtx), "Integer"); + } } diff --git a/test/multi_assign_test.rb b/test/multi_assign_test.rb new file mode 100644 index 0000000..a86c578 --- /dev/null +++ b/test/multi_assign_test.rb @@ -0,0 +1,83 @@ +# frozen_string_literal: true + +require 'test_helper' + +class MultiAssignTest < Minitest::Test + include CLITestHelper + + # ============================================ + # Type Inference + # ============================================ + + def test_basic_multi_assign + types = infer('a, b = 1, "hello"') + assert_equal "Integer", types["a"] + assert_equal "String", types["b"] + end + + def test_lhs_longer_than_rhs + types = infer("a, b, c = 1, 2") + assert_equal "Integer", types["a"] + assert_equal "Integer", types["b"] + assert_equal "nil", types["c"] + end + + def test_splat_basic + types = infer("first, *rest = 1, 2, 3") + assert_equal "Integer", types["first"] + assert_equal "Array[Integer]", types["rest"] + end + + def test_splat_with_rights + types = infer("first, *rest, last = 1, 2, 3, 4") + assert_equal "Integer", types["first"] + assert_equal "Array[Integer]", types["rest"] + assert_equal "Integer", types["last"] + end + + def test_splat_rights_no_lefts + types = infer("*rest, last = 1, 2, 3") + assert_equal "Array[Integer]", types["rest"] + assert_equal "Integer", types["last"] + end + + def test_splat_empty + types = infer("first, *rest = 1") + assert_equal "Integer", types["first"] + assert_equal "Array[untyped]", types["rest"] + end + + def test_splat_lefts_exceed_rhs + types = infer("a, b, c, *rest = 1, 2") + assert_equal "Integer", types["a"] + assert_equal "Integer", types["b"] + assert_equal "nil", types["c"] + assert_equal "Array[untyped]", types["rest"] + end + + def test_splat_with_rights_insufficient_rhs + types = infer("a, *rest, z = 1") + assert_equal "Integer", types["a"] + assert_equal "Array[untyped]", types["rest"] + assert_equal "nil", types["z"] + end + + def test_scalar_rhs + types = infer("a, b = 42") + assert_equal "Integer", types["a"] + assert_equal "nil", types["b"] + end + + # ============================================ + # Error Detection + # ============================================ + + def test_multi_assign_type_error + source = <<~RUBY + a, b = 1, 2 + a.upcase + RUBY + + assert_check_error(source, method_name: 'upcase', receiver_type: 'Integer') + end +end