Skip to content

Commit

Permalink
Don't crash when bit-shifting uints (#643)
Browse files Browse the repository at this point in the history
* add stuff to generic_const test

* don't die on ExposeProvidence casts

* support other binops
  • Loading branch information
ranjitjhala committed Jun 29, 2024
1 parent a44d9fc commit bcf99cd
Show file tree
Hide file tree
Showing 16 changed files with 277 additions and 33 deletions.
17 changes: 15 additions & 2 deletions crates/flux-middle/src/rty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use rustc_data_structures::unord::UnordMap;
use rustc_hir::def_id::DefId;
use rustc_index::{newtype_index, IndexSlice};
use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable};
use rustc_middle::ty::{ParamConst, TyCtxt};
use rustc_middle::ty::{ParamConst, TyCtxt, ValTree};
pub use rustc_middle::{
mir::Mutability,
ty::{AdtFlags, ClosureKind, FloatTy, IntTy, OutlivesPredicate, ParamTy, ScalarInt, UintTy},
Expand Down Expand Up @@ -1047,7 +1047,20 @@ impl GenericArg {
ty::GenericArg::from(ctor.as_ref().skip_binder().to_rustc(tcx))
}
GenericArg::Lifetime(re) => ty::GenericArg::from(re.to_rustc(tcx)),
GenericArg::Const(_) => todo!(),
GenericArg::Const(c) => {
let ty = &c.ty;
let ty = ty.to_rustc(tcx);
let kind = match c.kind {
ConstKind::Param(param_const) => {
rustc_middle::ty::ConstKind::Param(param_const)
}
ConstKind::Value(scalar_int) => {
rustc_middle::ty::ConstKind::Value(ValTree::Leaf(scalar_int))
}
};
let c = rustc_middle::ty::Const::new(tcx, kind, ty);
ty::GenericArg::from(c)
}
}
}
}
Expand Down
11 changes: 9 additions & 2 deletions crates/flux-middle/src/rty/projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use rustc_trait_selection::traits::SelectionContext;

use super::{
fold::{FallibleTypeFolder, TypeFoldable, TypeSuperFoldable},
AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Expr, ExprKind, GenericArg,
ProjectionPredicate, RefineArgs, Region, SubsetTy, Ty, TyKind,
AliasKind, AliasReft, AliasTy, BaseTy, Binder, Clause, ClauseKind, Const, Expr, ExprKind,
GenericArg, ProjectionPredicate, RefineArgs, Region, SubsetTy, Ty, TyKind,
};
use crate::{
global_env::GlobalEnv,
Expand Down Expand Up @@ -297,6 +297,7 @@ impl TVarSubst {
(GenericArg::Base(a), GenericArg::Base(b)) => {
self.btys(a.as_bty_skipping_binder(), b.as_bty_skipping_binder());
}
(GenericArg::Const(a), GenericArg::Const(b)) => self.consts(a, b),
_ => {}
}
}
Expand Down Expand Up @@ -354,6 +355,12 @@ impl TVarSubst {
}
}

fn consts(&mut self, a: &Const, b: &Const) {
if let super::ConstKind::Param(param_const) = a.kind {
self.insert_generic_arg(param_const.index, GenericArg::Const(b.clone()));
}
}

fn insert_generic_arg(&mut self, idx: u32, arg: GenericArg) {
if self.args[idx as usize].replace(arg).is_some() {
bug!("duplicate insert");
Expand Down
10 changes: 7 additions & 3 deletions crates/flux-middle/src/rustc/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,12 @@ impl<'sess, 'tcx> LoweringCtxt<'_, 'sess, 'tcx> {
let ty = lower_ty(self.tcx, *ty)?;
Ok(Rvalue::Cast(kind, op, ty))
}
rustc_mir::Rvalue::Repeat(_, _)
| rustc_mir::Rvalue::ThreadLocalRef(_)
rustc_mir::Rvalue::Repeat(op, c) => {
let op = self.lower_operand(op)?;
let c = lower_const(self.tcx, *c)?;
Ok(Rvalue::Repeat(op, c))
}
rustc_mir::Rvalue::ThreadLocalRef(_)
| rustc_mir::Rvalue::AddressOf(_, _)
| rustc_mir::Rvalue::NullaryOp(_, _)
| rustc_mir::Rvalue::CopyForDeref(_)
Expand Down Expand Up @@ -531,14 +535,14 @@ impl<'sess, 'tcx> LoweringCtxt<'_, 'sess, 'tcx> {
rustc_mir::BinOp::Rem => Ok(BinOp::Rem),
rustc_mir::BinOp::BitAnd => Ok(BinOp::BitAnd),
rustc_mir::BinOp::BitOr => Ok(BinOp::BitOr),
rustc_mir::BinOp::BitXor => Ok(BinOp::BitXor),
rustc_mir::BinOp::Shl => Ok(BinOp::Shl),
rustc_mir::BinOp::Shr => Ok(BinOp::Shr),
rustc_mir::BinOp::AddUnchecked
| rustc_mir::BinOp::SubUnchecked
| rustc_mir::BinOp::MulUnchecked
| rustc_mir::BinOp::ShlUnchecked
| rustc_mir::BinOp::ShrUnchecked
| rustc_mir::BinOp::BitXor
| rustc_mir::BinOp::Cmp
| rustc_mir::BinOp::Offset => {
Err(UnsupportedReason::new(format!("unsupported binary op `{bin_op:?}`")))
Expand Down
5 changes: 4 additions & 1 deletion crates/flux-middle/src/rustc/mir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub use rustc_middle::{
use rustc_span::{Span, Symbol};
pub use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};

use super::ty::{GenericArg, GenericArgs, Region, Ty, TyKind};
use super::ty::{Const, GenericArg, GenericArgs, Region, Ty, TyKind};
use crate::{
global_env::GlobalEnv, intern::List, pretty::def_id_to_string, queries::QueryResult,
rustc::ty::region_to_string,
Expand Down Expand Up @@ -192,6 +192,7 @@ pub enum Rvalue {
Discriminant(Place),
Len(Place),
Cast(CastKind, Operand, Ty),
Repeat(Operand, Const),
}

pub enum BorrowKind {
Expand Down Expand Up @@ -239,6 +240,7 @@ pub enum BinOp {
Rem,
BitAnd,
BitOr,
BitXor,
Shl,
Shr,
}
Expand Down Expand Up @@ -686,6 +688,7 @@ impl fmt::Debug for Rvalue {
}
Rvalue::Len(place) => write!(f, "Len({place:?})"),
Rvalue::Cast(kind, op, ty) => write!(f, "{op:?} as {ty:?} [{kind:?}]"),
Rvalue::Repeat(op, c) => write!(f, "[{op:?}; {c:?}]"),
}
}
}
Expand Down
72 changes: 71 additions & 1 deletion crates/flux-middle/src/rustc/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use itertools::Itertools;
use rustc_hir::def_id::DefId;
use rustc_index::{IndexSlice, IndexVec};
use rustc_macros::{TyDecodable, TyEncodable};
use rustc_middle::ty::{AdtFlags, ParamConst, TyCtxt};
use rustc_middle::ty::{self as rustc_ty, AdtFlags, ParamConst, TyCtxt};
pub use rustc_middle::{
mir::Mutability,
ty::{
Expand Down Expand Up @@ -209,6 +209,22 @@ pub struct Const {
pub ty: Ty,
}

impl Const {
pub fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_ty::Const<'tcx> {
let ty = &self.ty;
let ty = ty.to_rustc(tcx);
let kind = match self.kind {
ConstKind::Param(param_const) => {
let param_const = ParamConst { name: param_const.name, index: param_const.index };
rustc_ty::ConstKind::Param(param_const)
}
ConstKind::Value(scalar_int) => {
rustc_ty::ConstKind::Value(rustc_middle::ty::ValTree::Leaf(scalar_int))
}
};
rustc_ty::Const::new(tcx, kind, ty)
}
}
#[derive(Clone, PartialEq, Eq, Hash, TyEncodable, TyDecodable)]
pub enum ConstKind {
Param(ParamConst),
Expand Down Expand Up @@ -386,6 +402,15 @@ impl GenericArg {
bug!("expected `GenericArg::Const`, found {:?}", self)
}
}

fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::GenericArg<'tcx> {
use rustc_middle::ty;
match self {
GenericArg::Ty(ty) => ty::GenericArg::from(ty.to_rustc(tcx)),
GenericArg::Lifetime(re) => ty::GenericArg::from(re.to_rustc(tcx)),
GenericArg::Const(c) => ty::GenericArg::from(c.to_rustc(tcx)),
}
}
}

impl CoroutineArgs {
Expand Down Expand Up @@ -490,6 +515,10 @@ impl AdtDef {
assert!(self.is_struct() || self.is_union());
self.variant(FIRST_VARIANT)
}

pub fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::AdtDef<'tcx> {
tcx.adt_def(self.did())
}
}

impl AdtDefData {
Expand Down Expand Up @@ -642,6 +671,47 @@ impl Ty {
pub fn is_box(&self) -> bool {
matches!(self.kind(), TyKind::Adt(adt, ..) if adt.is_box())
}

pub fn to_rustc<'tcx>(&self, tcx: TyCtxt<'tcx>) -> rustc_middle::ty::Ty<'tcx> {
let kind = match self.kind() {
TyKind::Bool => rustc_ty::TyKind::Bool,
TyKind::Str => rustc_ty::TyKind::Str,
TyKind::Char => rustc_ty::TyKind::Char,
TyKind::Never => rustc_ty::TyKind::Never,
TyKind::Float(float_ty) => rustc_ty::TyKind::Float(*float_ty),
TyKind::Int(int_ty) => rustc_ty::TyKind::Int(*int_ty),
TyKind::Uint(uint_ty) => rustc_ty::TyKind::Uint(*uint_ty),
TyKind::Adt(adt_def, args) => {
let adt_def = adt_def.to_rustc(tcx);
let args = tcx.mk_args_from_iter(args.iter().map(|arg| arg.to_rustc(tcx)));
rustc_ty::TyKind::Adt(adt_def, args)
}
TyKind::Array(ty, len) => {
let ty = ty.to_rustc(tcx);
let len = len.to_rustc(tcx);
rustc_ty::TyKind::Array(ty, len)
}
TyKind::Param(pty) => {
let pty = rustc_ty::ParamTy::new(pty.index, pty.name);
rustc_ty::TyKind::Param(pty)
}
TyKind::Ref(re, ty, mutbl) => {
rustc_ty::TyKind::Ref(re.to_rustc(tcx), ty.to_rustc(tcx), *mutbl)
}
TyKind::Tuple(tys) => {
let ts = tys.iter().map(|ty| ty.to_rustc(tcx)).collect_vec();
rustc_ty::TyKind::Tuple(tcx.mk_type_list(&ts))
}
TyKind::Slice(ty) => rustc_ty::TyKind::Slice(ty.to_rustc(tcx)),
TyKind::RawPtr(ty, mutbl) => rustc_ty::TyKind::RawPtr(ty.to_rustc(tcx), *mutbl),
TyKind::FnPtr(_) => todo!(),
TyKind::Closure(_, _) => todo!(),
TyKind::Coroutine(_, _) => todo!(),
TyKind::CoroutineWitness(_, _) => todo!(),
TyKind::Alias(_, _) => todo!(),
};
rustc_ty::Ty::new(tcx, kind)
}
}

impl_internable!(TyS, AdtDefData);
Expand Down
56 changes: 42 additions & 14 deletions crates/flux-refineck/src/checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ fn const_params(
) -> Result<List<(ParamConst, Sort)>> {
let generics = genv.generics_of(def_id).with_span(span)?;
let mut res = vec![];
for generic_param in &generics.params {
for i in 0..generics.count() {
let generic_param = generics.param_at(i, genv).with_span(span)?;
if let GenericParamDefKind::Const { .. } = generic_param.kind
&& let Some(local_def_id) = generic_param.def_id.as_local()
&& let Some(sort) = genv.sort_of_generic_param(local_def_id).with_span(span)?
Expand All @@ -216,8 +217,7 @@ impl<'ck, 'genv, 'tcx> Checker<'ck, 'genv, 'tcx, RefineMode> {
let fn_sig = genv.fn_sig(def_id).with_span(span)?;

let mut kvars = fixpoint_encoding::KVarStore::new();
let const_params = const_params(genv, def_id, span)?;
let mut refine_tree = RefineTree::new(const_params);
let mut refine_tree = RefineTree::new(const_params(genv, def_id, span)?);
let bb_envs = bb_env_shapes.into_bb_envs(&mut kvars);

dbg::refine_mode_span!(genv.tcx(), def_id, bb_envs).in_scope(|| {
Expand Down Expand Up @@ -833,6 +833,10 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
let from = self.check_operand(rcx, env, stmt_span, op)?;
self.check_cast(*kind, &from, to)
}
Rvalue::Repeat(operand, c) => {
let ty = self.check_operand(rcx, env, stmt_span, operand)?;
Ok(Ty::array(ty, c.clone()))
}
}
}

Expand Down Expand Up @@ -907,23 +911,42 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
| mir::BinOp::Div
| mir::BinOp::BitAnd
| mir::BinOp::BitOr
| mir::BinOp::BitXor
| mir::BinOp::Shl
| mir::BinOp::Shr
| mir::BinOp::Rem => Ok(Ty::float(*float_ty1)),
}
}
(TyKind::Indexed(bty1, idx1), TyKind::Indexed(bty2, idx2)) => {
let sig = sigs::get_bin_op_sig(bin_op, bty1, bty2, self.check_overflow());
let (e1, e2) = (idx1.clone(), idx2.clone());
if let sigs::Pre::Some(reason, constr) = &sig.pre {
self.constr_gen(rcx, source_span).check_pred(
rcx,
&constr([e1.clone(), e2.clone()]),
*reason,
);
match sigs::get_bin_op_sig(bin_op, bty1, bty2, self.check_overflow()) {
Some(sig) => {
let (e1, e2) = (idx1.clone(), idx2.clone());
if let sigs::Pre::Some(reason, constr) = &sig.pre {
self.constr_gen(rcx, source_span).check_pred(
rcx,
&constr([e1.clone(), e2.clone()]),
*reason,
);
}

Ok(sig.out.to_ty([e1, e2]))
}
None => {
match bin_op {
mir::BinOp::Eq
| mir::BinOp::Ne
| mir::BinOp::Gt
| mir::BinOp::Ge
| mir::BinOp::Lt
| mir::BinOp::Le => Ok(Ty::bool()),
_ => {
tracked_span_bug!(
"No sig for binop : `{bin_op:?}` with `{ty1:?}` and `{ty2:?}`"
)
}
}
}
}

Ok(sig.out.to_ty([e1, e2]))
}
_ => tracked_span_bug!("incompatible types: `{ty1:?}` `{ty2:?}`"),
}
Expand All @@ -941,7 +964,12 @@ impl<'ck, 'genv, 'tcx, M: Mode> Checker<'ck, 'genv, 'tcx, M> {
match ty.kind() {
Float!(float_ty) => Ok(Ty::float(*float_ty)),
TyKind::Indexed(bty, idx) => {
let sig = sigs::get_un_op_sig(un_op, bty, self.check_overflow());
let sig = if let Some(sig) = sigs::get_un_op_sig(un_op, bty, self.check_overflow())
{
sig
} else {
tracked_span_bug!("No sig for unop : `{un_op:?}` with `{ty:?}`")
};
let e = idx.clone();
if let sigs::Pre::Some(reason, constr) = &sig.pre {
self.constr_gen(rcx, source_span).check_pred(
Expand Down
5 changes: 5 additions & 0 deletions crates/flux-refineck/src/constraint_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,11 @@ impl<'a, 'genv, 'tcx> InferCtxt<'a, 'genv, 'tcx> {
(ctor1.to_ty(), ctor2.to_ty())
}
(GenericArg::Lifetime(_), GenericArg::Lifetime(_)) => return Ok(()),
(GenericArg::Const(c1), GenericArg::Const(c2)) => {
debug_assert_eq!(c1, c2);
return Ok(());
}

_ => tracked_span_bug!("incompatible generic args: `{arg1:?}` `{arg2:?}"),
};
match variance {
Expand Down
3 changes: 3 additions & 0 deletions crates/flux-refineck/src/ghost_statements/fold_unfold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@ impl<'a, 'genv, 'tcx, M: Mode> FoldUnfoldAnalysis<'a, 'genv, 'tcx, M> {
M::projection(self, env, discr, ProjKind::Other)?;
self.discriminants.insert(place.clone(), discr.clone());
}
Rvalue::Repeat(op, _) => {
self.operand(op, env)?;
}
}
M::projection(self, env, place, ProjKind::Other)?;
}
Expand Down
9 changes: 7 additions & 2 deletions crates/flux-refineck/src/refine_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,12 +904,17 @@ mod pretty {
define_scoped!(cx, f);
write!(
f,
"[{}]",
"[bindings = {}, reftgenerics = {}]",
self.bindings
.iter_enumerated()
.format_with(", ", |(name, sort), f| {
f(&format_args_cx!("{:?}: {:?}", ^name, sort))
})
}),
self.reftgenerics
.iter()
.format_with(", ", |(param_const, sort), f| {
f(&format_args_cx!("{:?}: {:?}", ^param_const, sort))
}),
)
}
}
Expand Down
Loading

0 comments on commit bcf99cd

Please sign in to comment.