From 8b695609c67ad2009bef96ac7c38382408176c9e Mon Sep 17 00:00:00 2001 From: Nico Lehmann Date: Fri, 7 Jun 2024 11:31:05 -0700 Subject: [PATCH] Support unification of Aggregate and variable --- crates/flux-middle/src/rty/expr.rs | 15 ++++++++++--- crates/flux-refineck/src/constraint_gen.rs | 21 ++++++++++++++----- crates/flux-refineck/src/type_env/place_ty.rs | 2 +- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/crates/flux-middle/src/rty/expr.rs b/crates/flux-middle/src/rty/expr.rs index 018e22337..5d794b119 100644 --- a/crates/flux-middle/src/rty/expr.rs +++ b/crates/flux-middle/src/rty/expr.rs @@ -178,10 +178,19 @@ pub enum ExprKind { #[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)] pub enum AggregateKind { - Tuple, + Tuple(usize), Adt(DefId), } +impl AggregateKind { + pub fn to_proj(self, field: u32) -> FieldProj { + match self { + AggregateKind::Tuple(arity) => FieldProj::Tuple { arity, field }, + AggregateKind::Adt(def_id) => FieldProj::Adt { def_id, field }, + } + } +} + #[derive(Clone, Copy, PartialEq, Eq, Hash, TyEncodable, TyDecodable, Debug)] pub enum FieldProj { Tuple { arity: usize, field: u32 }, @@ -426,7 +435,7 @@ impl Expr { } pub fn tuple(flds: List) -> Expr { - Expr::aggregate(AggregateKind::Tuple, flds) + Expr::aggregate(AggregateKind::Tuple(flds.len()), flds) } pub fn adt(def_id: DefId, flds: List) -> Expr { @@ -945,7 +954,7 @@ mod pretty { w!("({:?}).{:?}", e, ^proj.field_idx()) } } - ExprKind::Aggregate(AggregateKind::Tuple, flds) => { + ExprKind::Aggregate(AggregateKind::Tuple(_), flds) => { if let [e] = &flds[..] { w!("({:?},)", e) } else { diff --git a/crates/flux-refineck/src/constraint_gen.rs b/crates/flux-refineck/src/constraint_gen.rs index a1df6d7ce..0c220a872 100644 --- a/crates/flux-refineck/src/constraint_gen.rs +++ b/crates/flux-refineck/src/constraint_gen.rs @@ -8,9 +8,9 @@ use flux_middle::{ self, evars::{EVarCxId, EVarSol}, fold::TypeFoldable, - AliasTy, BaseTy, BinOp, Binder, Constraint, CoroutineObligPredicate, ESpan, EVarGen, - EarlyBinder, Expr, ExprKind, FnOutput, GenericArg, HoleKind, InferMode, Lambda, Mutability, - Path, PolyFnSig, PolyVariant, PtrKind, Ref, Sort, Ty, TyKind, Var, + AliasTy, BaseTy, Binder, Constraint, CoroutineObligPredicate, ESpan, EVarGen, EarlyBinder, + Expr, ExprKind, FnOutput, GenericArg, HoleKind, InferMode, Lambda, Mutability, Path, + PolyFnSig, PolyVariant, PtrKind, Ref, Sort, Ty, TyKind, Var, }, rustc::mir::{BasicBlock, Place}, }; @@ -688,12 +688,23 @@ impl<'a, 'genv, 'tcx> InferCtxt<'a, 'genv, 'tcx> { match (e1.kind(), e2.kind()) { (ExprKind::Aggregate(kind1, flds1), ExprKind::Aggregate(kind2, flds2)) => { debug_assert_eq!(kind1, kind2); - debug_assert_eq!(flds1.len(), flds2.len()); - for (e1, e2) in iter::zip(flds1, flds2) { self.idx_eq(rcx, e1, e2); } } + (_, ExprKind::Aggregate(kind2, flds2)) => { + for (f, e2) in flds2.iter().enumerate() { + let e1 = e1.proj_and_reduce(kind2.to_proj(f as u32)); + self.idx_eq(rcx, &e1, e2); + } + } + (ExprKind::Aggregate(kind1, flds1), _) => { + self.unify_exprs(e1, e2); + for (f, e1) in flds1.iter().enumerate() { + let e2 = e2.proj_and_reduce(kind1.to_proj(f as u32)); + self.idx_eq(rcx, e1, &e2); + } + } (ExprKind::Abs(p1), ExprKind::Abs(p2)) => { self.abs_eq(rcx, p1, p2); } diff --git a/crates/flux-refineck/src/type_env/place_ty.rs b/crates/flux-refineck/src/type_env/place_ty.rs index a07c1415a..9e75f3060 100644 --- a/crates/flux-refineck/src/type_env/place_ty.rs +++ b/crates/flux-refineck/src/type_env/place_ty.rs @@ -1,4 +1,4 @@ -use std::{clone::Clone, fmt, iter, ops::ControlFlow}; +use std::{clone::Clone, fmt, ops::ControlFlow}; use flux_common::{iter::IterExt, tracked_span_bug}; use flux_middle::{