diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 8e040b9734..66667cb1a6 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2161,19 +2161,17 @@ impl<'a, W: Write> Writer<'a, W> { match expressions[expr] { Expression::Literal(literal) => { match literal { - crate::Literal::I32(value) => write!(self.out, "{}", value)?, + // Floats are written using `Debug` instead of `Display` because it always appends the + // decimal part even it's zero which is needed for a valid glsl float constant + crate::Literal::F64(value) => write!(self.out, "{:?}LF", value)?, + crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, // Unsigned integers need a `u` at the end // // While `core` doesn't necessarily need it, it's allowed and since `es` needs it we // always write it as the extra branch wouldn't have any benefit in readability crate::Literal::U32(value) => write!(self.out, "{}u", value)?, - // Floats are written using `Debug` instead of `Display` because it always appends the - // decimal part even it's zero which is needed for a valid glsl float constant - crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, + crate::Literal::I32(value) => write!(self.out, "{}", value)?, crate::Literal::Bool(value) => write!(self.out, "{}", value)?, - crate::Literal::F64(_) => { - return Err(Error::Custom("f64 literal not supported".to_string())); - } } } Expression::Constant(handle) => { @@ -2230,7 +2228,7 @@ impl<'a, W: Write> Writer<'a, W> { | Expression::Constant(_) | Expression::New(_) | Expression::Compose { .. } => { - self.write_possibly_const_expr(expr, &ctx.expressions, |writer, expr| { + self.write_possibly_const_expr(expr, ctx.expressions, |writer, expr| { writer.write_expr(expr, ctx) })?; } diff --git a/src/back/hlsl/help.rs b/src/back/hlsl/help.rs index c1c8edb208..7ad4631315 100644 --- a/src/back/hlsl/help.rs +++ b/src/back/hlsl/help.rs @@ -159,7 +159,6 @@ impl<'a, W: Write> super::Writer<'a, W> { /// pub(super) fn write_wrapped_array_length_function( &mut self, - module: &crate::Module, wal: WrappedArrayLength, ) -> BackendResult { use crate::back::INDENT; @@ -789,19 +788,16 @@ impl<'a, W: Write> super::Writer<'a, W> { expressions: &crate::Arena, ) -> BackendResult { for (handle, _) in expressions.iter() { - match expressions[handle] { - crate::Expression::Compose { ty, .. } => { - match module.types[ty].inner { - crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => { - let constructor = WrappedConstructor { ty }; - if self.wrapped.constructors.insert(constructor) { - self.write_wrapped_constructor_function(module, constructor)?; - } + if let crate::Expression::Compose { ty, .. } = expressions[handle] { + match module.types[ty].inner { + crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => { + let constructor = WrappedConstructor { ty }; + if self.wrapped.constructors.insert(constructor) { + self.write_wrapped_constructor_function(module, constructor)?; } - _ => {} - }; - } - _ => {} + } + _ => {} + }; } } Ok(()) @@ -813,7 +809,7 @@ impl<'a, W: Write> super::Writer<'a, W> { module: &crate::Module, func_ctx: &FunctionCtx, ) -> BackendResult { - self.write_wrapped_compose_functions(module, &func_ctx.expressions)?; + self.write_wrapped_compose_functions(module, func_ctx.expressions)?; for (handle, _) in func_ctx.expressions.iter() { match func_ctx.expressions[handle] { @@ -838,7 +834,7 @@ impl<'a, W: Write> super::Writer<'a, W> { }; if self.wrapped.array_lengths.insert(wal) { - self.write_wrapped_array_length_function(module, wal)?; + self.write_wrapped_array_length_function(wal)?; } } crate::Expression::ImageQuery { image, query } => { diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index f58f2e3820..b2f2a5f0da 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1965,26 +1965,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { match expressions[expr] { Expression::Literal(literal) => match literal { - crate::Literal::I32(value) => write!(self.out, "{}", value)?, - crate::Literal::U32(value) => write!(self.out, "{}u", value)?, // Floats are written using `Debug` instead of `Display` because it always appends the // decimal part even it's zero + crate::Literal::F64(value) => write!(self.out, "{value:?}L")?, crate::Literal::F32(value) => write!(self.out, "{value:?}")?, - // crate::Literal::F32(value) => { - // if value.is_infinite() { - // let sign = if value.is_sign_negative() { "-" } else { "" }; - // write!(self.out, "{}1.#INF", sign)?; - // } else if value.is_nan() { - // write!(self.out, "(0.0/0.0)")?; - // } else { - // let suffix = if value.fract() == 0.0 { ".0" } else { "" }; - // write!(self.out, "{}{}", value, suffix)?; - // } - // } + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => write!(self.out, "{}", value)?, crate::Literal::Bool(value) => write!(self.out, "{}", value)?, - crate::Literal::F64(_) => { - return Err(Error::Unimplemented("f64 literal".to_string())); - } }, Expression::Constant(handle) => { let constant = &module.constants[handle]; @@ -2077,7 +2064,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_possibly_const_expression( module, expr, - &func_ctx.expressions, + func_ctx.expressions, |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 16f4b58a99..a6563c4414 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1206,7 +1206,9 @@ impl Writer { { match expressions[expr_handle] { crate::Expression::Literal(literal) => match literal { - crate::Literal::F64(value) => todo!(), + crate::Literal::F64(_) => { + return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64)) + } crate::Literal::F32(value) => { if value.is_infinite() { let sign = if value.is_sign_negative() { "-" } else { "" }; @@ -3222,7 +3224,7 @@ impl Writer { /// Writes all named constants fn write_global_constants(&mut self, module: &crate::Module) -> BackendResult { - let mut constants = module.constants.iter().filter(|&(_, c)| c.name.is_some()); + let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some()); for (handle, constant) in constants { let ty_name = TypeContext { diff --git a/src/back/spv/index.rs b/src/back/spv/index.rs index 95eb6d7d36..913e4d895f 100644 --- a/src/back/spv/index.rs +++ b/src/back/spv/index.rs @@ -179,7 +179,7 @@ impl<'w> BlockContext<'w> { if let Ok(known_index) = self .ir_module .to_ctx() - .to_array_length(index, Some(&self.ir_function.expressions)) + .eval_expr_to_u32(index, Some(&self.ir_function.expressions)) { // Both the index and length are known at compile time. // @@ -241,7 +241,7 @@ impl<'w> BlockContext<'w> { if let Ok(known_index) = self .ir_module .to_ctx() - .to_array_length(index, Some(&self.ir_function.expressions)) + .eval_expr_to_u32(index, Some(&self.ir_function.expressions)) { // Both the index and length are known at compile time. // diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index 3cd04489d6..35524726d4 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -290,11 +290,11 @@ enum LocalType { Sampler, PointerToBindingArray { base: Handle, - size: u64, + size: u32, }, BindingArray { base: Handle, - size: u64, + size: u32, }, AccelerationStructure, RayQuery, diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 2d18dea015..8b9b4cf083 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -936,9 +936,7 @@ impl Writer { } LocalType::BindingArray { base, size } => { let inner_ty = self.get_type_id(LookupType::Handle(base)); - // TODO: add Literal::U64 - let scalar_id = - self.get_constant_scalar(crate::Literal::U32(size.try_into().unwrap())); + let scalar_id = self.get_constant_scalar(crate::Literal::U32(size)); Instruction::type_array(id, inner_ty, scalar_id) } LocalType::PointerToBindingArray { base, size } => { @@ -1635,7 +1633,7 @@ impl Writer { substitute_inner_type_lookup = Some(LookupType::Local(LocalType::PointerToBindingArray { base, - size: remapped_binding_array_size as u64, + size: remapped_binding_array_size, })) } } else { diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 5d986bbc29..e8fb8caf7e 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -1089,15 +1089,15 @@ impl Writer { match expressions[expr] { Expression::Literal(literal) => { match literal { - crate::Literal::I32(value) => write!(self.out, "{}", value)?, - crate::Literal::U32(value) => write!(self.out, "{}u", value)?, // Floats are written using `Debug` instead of `Display` because it always appends the // decimal part even it's zero - crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, - crate::Literal::Bool(value) => write!(self.out, "{}", value)?, crate::Literal::F64(_) => { - return Err(Error::Unimplemented("f64 literal".to_string())); + return Err(Error::Custom("unsupported f64 literal".to_string())); } + crate::Literal::F32(value) => write!(self.out, "{:?}", value)?, + crate::Literal::U32(value) => write!(self.out, "{}u", value)?, + crate::Literal::I32(value) => write!(self.out, "{}", value)?, + crate::Literal::Bool(value) => write!(self.out, "{}", value)?, } } Expression::Constant(handle) => { @@ -1168,7 +1168,7 @@ impl Writer { self.write_possibly_const_expression( module, expr, - &func_ctx.expressions, + func_ctx.expressions, |writer, expr| writer.write_expr(module, expr, func_ctx), )?; } diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index fa97e73736..7c531f61c1 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -527,7 +527,7 @@ impl Context { index: frontend .module .to_ctx() - .to_array_length(const_expr, None) + .eval_expr_to_u32(const_expr, None) .ok()?, }, meta, diff --git a/src/front/glsl/offset.rs b/src/front/glsl/offset.rs index 2bc23b015f..21481372b5 100644 --- a/src/front/glsl/offset.rs +++ b/src/front/glsl/offset.rs @@ -16,7 +16,7 @@ use super::{ error::{Error, ErrorKind}, Span, }; -use crate::{proc::Alignment, Arena, Constant, Handle, Type, TypeInner, UniqueArena}; +use crate::{proc::Alignment, Handle, Type, TypeInner, UniqueArena}; /// Struct with information needed for defining a struct member. /// @@ -43,7 +43,6 @@ pub fn calculate_offset( meta: Span, layout: StructLayout, types: &mut UniqueArena, - constants: &Arena, errors: &mut Vec, ) -> TypeAlignSpan { // When using the std430 storage layout, shader storage blocks will be laid out in buffer storage @@ -68,7 +67,7 @@ pub fn calculate_offset( // to rules (1), (2), and (3), and rounded up to the base alignment of a vec4. // TODO: Matrices array TypeInner::Array { base, size, .. } => { - let info = calculate_offset(base, meta, layout, types, constants, errors); + let info = calculate_offset(base, meta, layout, types, errors); let name = types[ty].name.clone(); @@ -133,7 +132,7 @@ pub fn calculate_offset( let name = types[ty].name.clone(); for member in members.iter_mut() { - let info = calculate_offset(member.ty, meta, layout, types, constants, errors); + let info = calculate_offset(member.ty, meta, layout, types, errors); let member_alignment = info.align; span = member_alignment.round_up(span); diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index 538242a3c1..5e419cf7b7 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -9,7 +9,7 @@ use super::{ variables::{GlobalOrConstant, VarDeclaration}, Frontend, Result, }; -use crate::{arena::Handle, proc::ArrayLengthError, Block, Expression, Span, Type}; +use crate::{arena::Handle, proc::U32EvalError, Block, Expression, Span, Type}; use pp_rs::token::{PreprocessorError, Token as PPToken, TokenValue as PPTokenValue}; use std::iter::Peekable; @@ -193,15 +193,15 @@ impl<'source> ParsingContext<'source> { fn parse_uint_constant(&mut self, frontend: &mut Frontend) -> Result<(u32, Span)> { let (const_expr, meta) = self.parse_constant_expression(frontend)?; - let res = frontend.module.to_ctx().to_array_length(const_expr, None); + let res = frontend.module.to_ctx().eval_expr_to_u32(const_expr, None); let int = match res { Ok(value) => Ok(value), - Err(ArrayLengthError::NotPositive) => Err(Error { + Err(U32EvalError::Negative) => Err(Error { kind: ErrorKind::SemanticError("int constant overflows".into()), meta, }), - Err(ArrayLengthError::Invalid) => Err(Error { + Err(U32EvalError::Invalid) => Err(Error { kind: ErrorKind::SemanticError("Expected a uint constant".into()), meta, }), diff --git a/src/front/glsl/parser/declarations.rs b/src/front/glsl/parser/declarations.rs index 5cf3d58c0b..1062630180 100644 --- a/src/front/glsl/parser/declarations.rs +++ b/src/front/glsl/parser/declarations.rs @@ -642,7 +642,6 @@ impl<'source> ParsingContext<'source> { meta, layout, &mut frontend.module.types, - &frontend.module.constants, &mut frontend.errors, ); diff --git a/src/front/glsl/parser/expressions.rs b/src/front/glsl/parser/expressions.rs index ad599d5b01..7e47b2eea7 100644 --- a/src/front/glsl/parser/expressions.rs +++ b/src/front/glsl/parser/expressions.rs @@ -140,8 +140,7 @@ impl<'source> ParsingContext<'source> { let size = u32::try_from(args.len()) .ok() - .map(NonZeroU32::new) - .flatten() + .and_then(NonZeroU32::new) .ok_or(Error { kind: ErrorKind::SemanticError( "There must be at least one argument".into(), diff --git a/src/front/glsl/types.rs b/src/front/glsl/types.rs index 0d62eb29d1..513bc22754 100644 --- a/src/front/glsl/types.rs +++ b/src/front/glsl/types.rs @@ -1,7 +1,5 @@ use super::{ - constants::{ConstantSolver, ExprType}, - context::Context, - Error, ErrorKind, Frontend, Result, Span, + constants::ConstantSolver, context::Context, Error, ErrorKind, Frontend, Result, Span, }; use crate::{ proc::ResolveContext, Bytes, Expression, Handle, ImageClass, ImageDimension, ScalarKind, Type, diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 52d7b72835..0f1a31902b 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -1168,7 +1168,6 @@ impl> Frontend { selections: &[spirv::Word], type_arena: &UniqueArena, expressions: &mut Arena, - constants: &Arena, span: crate::Span, ) -> Result, Error> { let selection = match selections.first() { @@ -1229,7 +1228,6 @@ impl> Frontend { &selections[1..], type_arena, expressions, - constants, span, )?; @@ -1488,7 +1486,7 @@ impl> Frontend { let index_maybe = match *index_expr_data { crate::Expression::Constant(const_handle) => Some( ctx.global_ctx() - .to_array_length( + .eval_expr_to_u32( ctx.const_arena[const_handle].init.unwrap(), None, ) @@ -1755,7 +1753,6 @@ impl> Frontend { let root_type_lookup = self.lookup_type.lookup(root_lexp.type_id)?; let index_lexp = self.lookup_expression.lookup(index_id)?; let index_handle = get_expr_handle!(index_id, index_lexp); - let index_type = self.lookup_type.lookup(index_lexp.type_id)?.handle; let num_components = match ctx.type_arena[root_type_lookup.handle].inner { crate::TypeInner::Vector { size, .. } => size as u32, @@ -1887,7 +1884,6 @@ impl> Frontend { &selections, ctx.type_arena, ctx.expressions, - ctx.const_arena, span, )?; @@ -4358,8 +4354,7 @@ impl> Frontend { let length_const = self.lookup_constant.lookup(length_id)?; let size = resolve_constant(module.to_ctx(), length_const.handle) - .map(|size| NonZeroU32::new(size)) - .flatten() + .and_then(NonZeroU32::new) .ok_or(Error::InvalidArraySize(length_const.handle))?; let decor = self.future_decor.remove(&id).unwrap_or_default(); @@ -5077,7 +5072,6 @@ impl> Frontend { match null::generate_default_built_in( Some(built_in), ty, - &module.types, &mut module.const_expressions, span, ) { @@ -5108,7 +5102,6 @@ impl> Frontend { let handle = null::generate_default_built_in( built_in, member_ty, - &module.types, &mut module.const_expressions, span, )?; diff --git a/src/front/spv/null.rs b/src/front/spv/null.rs index 9d9224419f..477d2d9958 100644 --- a/src/front/spv/null.rs +++ b/src/front/spv/null.rs @@ -1,11 +1,10 @@ use super::Error; -use crate::arena::{Arena, Handle, UniqueArena}; +use crate::arena::{Arena, Handle}; /// Create a default value for an output built-in. pub fn generate_default_built_in( built_in: Option, ty: Handle, - type_arena: &UniqueArena, const_expressions: &mut Arena, span: crate::Span, ) -> Result, Error> { diff --git a/src/front/wgsl/error.rs b/src/front/wgsl/error.rs index 724c3f9cd3..38353148fa 100644 --- a/src/front/wgsl/error.rs +++ b/src/front/wgsl/error.rs @@ -100,8 +100,6 @@ pub enum ExpectedToken<'a> { Identifier, Number, Integer, - /// A compile-time constant expression. - Constant, /// Expected: constant, parenthesized expression, identifier PrimaryExpression, /// Expected: assignment, increment/decrement expression @@ -141,6 +139,7 @@ pub enum InvalidAssignmentType { pub enum Error<'a> { Unexpected(Span, ExpectedToken<'a>), UnexpectedComponents(Span), + UnexpectedOperationInConstContext(Span), BadNumber(Span, NumberError), /// A negative signed integer literal where both signed and unsigned, /// but only non-negative literals are allowed. @@ -227,7 +226,6 @@ pub enum Error<'a> { /// the same identifier as `ident`, above. path: Vec<(Span, Span)>, }, - ConstExprUnsupported(Span), InvalidSwitchValue { uint: bool, span: Span, @@ -272,7 +270,6 @@ impl<'a> Error<'a> { ExpectedToken::Identifier => "identifier".to_string(), ExpectedToken::Number => "32-bit signed integer literal".to_string(), ExpectedToken::Integer => "unsigned/signed integer literal".to_string(), - ExpectedToken::Constant => "compile-time constant".to_string(), ExpectedToken::PrimaryExpression => "expression".to_string(), ExpectedToken::Assignment => "assignment or increment/decrement".to_string(), ExpectedToken::SwitchItem => "switch item ('case' or 'default') or a closing curly bracket to signify the end of the switch statement ('}')".to_string(), @@ -296,6 +293,11 @@ impl<'a> Error<'a> { labels: vec![(bad_span, "unexpected components".into())], notes: vec![], }, + Error::UnexpectedOperationInConstContext(span) => ParseError { + message: "this operation is not supported in a const context".to_string(), + labels: vec![(span, "operation not supported here".into())], + notes: vec![], + }, Error::BadNumber(bad_span, ref err) => ParseError { message: format!("{}: `{}`", err, &source[bad_span],), labels: vec![(bad_span, err.to_string().into())], @@ -624,14 +626,6 @@ impl<'a> Error<'a> { .collect(), notes: vec![], }, - Error::ConstExprUnsupported(span) => ParseError { - message: "this constant expression is not supported".to_string(), - labels: vec![(span, "expression is not supported".into())], - notes: vec![ - "this should be fixed in a future version of Naga".into(), - "https://github.com/gfx-rs/naga/issues/1829".into(), - ], - }, Error::InvalidSwitchValue { uint, span } => ParseError { message: "invalid switch value".to_string(), labels: vec![( diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index acbd28677c..9c26b661d8 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -544,7 +544,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { /// array's length. /// /// [`Type`]: crate::Type - /// [`ctx.module`]: GlobalContext::module + /// [`ctx.module`]: ExpressionContext::module /// [`Array`]: crate::TypeInner::Array /// [`Constant`]: crate::Constant fn constructor<'out>( diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index ffd0a7b17e..50771b22c3 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -284,10 +284,13 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { } } - fn runtime_expression_ctx(&mut self) -> RuntimeExpressionContext<'_, '_> { + fn runtime_expression_ctx( + &mut self, + span: Span, + ) -> Result, Error<'source>> { match self.expr_type { - ExpressionContextType::Runtime(ref mut ctx) => ctx.reborrow(), - ExpressionContextType::Constant => panic!(), + ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx.reborrow()), + ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)), } } @@ -299,10 +302,10 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { let len = self .module .to_ctx() - .to_array_length(const_expr, None) + .eval_expr_to_u32(const_expr, None) .map_err(|err| match err { - crate::proc::ArrayLengthError::Invalid => Error::ExpectedArraySize(span), - crate::proc::ArrayLengthError::NotPositive => Error::NonPositiveArrayLength(span), + crate::proc::U32EvalError::Invalid => Error::ExpectedArraySize(span), + crate::proc::U32EvalError::Negative => Error::NonPositiveArrayLength(span), })?; NonZeroU32::new(len).ok_or(Error::NonPositiveArrayLength(span)) } @@ -317,14 +320,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { let index = self .module .to_ctx() - .to_array_length(expr, Some(ctx.naga_expressions)) + .eval_expr_to_u32(expr, Some(ctx.naga_expressions)) .map_err(|_| Error::InvalidGatherComponent(span))?; crate::SwizzleComponent::XYZW .get(index as usize) .copied() .ok_or(Error::InvalidGatherComponent(span)) } - ExpressionContextType::Constant => panic!(), + ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)), } } @@ -1232,7 +1235,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let right = ectx.interrupt_emitter(crate::Expression::Literal(literal), Span::UNDEFINED); - let rctx = ectx.runtime_expression_ctx(); + let rctx = ectx.runtime_expression_ctx(stmt.span)?; let left = rctx.naga_expressions.append( crate::Expression::Load { pointer: reference.handle, @@ -1296,7 +1299,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(TypedExpression::non_reference(handle)); } ast::Expression::Ident(ast::IdentExpr::Local(local)) => { - let rctx = ctx.runtime_expression_ctx(); + let rctx = ctx.runtime_expression_ctx(span)?; return Ok(rctx.local_table[&local]); } ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => { @@ -1591,7 +1594,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .collect::, _>>()?; let has_result = ctx.module.functions[function].result.is_some(); - let rctx = ctx.runtime_expression_ctx(); + let rctx = ctx.runtime_expression_ctx(span)?; // we need to always do this before a fn call since all arguments need to be emitted before the fn call rctx.block .extend(rctx.emitter.finish(rctx.naga_expressions)); @@ -1694,7 +1697,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let value = self.expression(args.next()?, ctx.reborrow())?; args.finish()?; - let rctx = ctx.runtime_expression_ctx(); + let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .extend(rctx.emitter.finish(rctx.naga_expressions)); rctx.emitter.start(rctx.naga_expressions); @@ -1794,7 +1797,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }; let result = ctx.interrupt_emitter(expression, span); - let rctx = ctx.runtime_expression_ctx(); + let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::Atomic { pointer, @@ -1811,7 +1814,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { "storageBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; - let rctx = ctx.runtime_expression_ctx(); + let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(crate::Statement::Barrier(crate::Barrier::STORAGE), span); return Ok(None); @@ -1819,7 +1822,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { "workgroupBarrier" => { ctx.prepare_args(arguments, 0, span).finish()?; - let rctx = ctx.runtime_expression_ctx(); + let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); return Ok(None); @@ -1842,7 +1845,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; - let rctx = ctx.runtime_expression_ctx(); + let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .extend(rctx.emitter.finish(rctx.naga_expressions)); rctx.emitter.start(rctx.naga_expressions); @@ -1948,7 +1951,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { descriptor, }; - let rctx = ctx.runtime_expression_ctx(); + let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .extend(rctx.emitter.finish(rctx.naga_expressions)); rctx.emitter.start(rctx.naga_expressions); @@ -1964,7 +1967,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let result = ctx .interrupt_emitter(crate::Expression::RayQueryProceedResult, span); let fun = crate::RayQueryFunction::Proceed { result }; - let rctx = ctx.runtime_expression_ctx(); + let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(crate::Statement::RayQuery { query, fun }, span); return Ok(Some(result)); @@ -2050,7 +2053,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }, span, ); - let rctx = ctx.runtime_expression_ctx(); + let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::Atomic { pointer, diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index e6bdbdc1d4..9964ee5eac 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -98,7 +98,7 @@ impl crate::TypeInner { /// For example `vec3`. /// /// Note: The names of a `TypeInner::Struct` is not known. Therefore this method will simply return "struct" for them. - fn to_wgsl<'a, 'g>(&'a self, global_ctx: crate::GlobalCtx<'g>) -> String { + fn to_wgsl(&self, global_ctx: crate::GlobalCtx) -> String { use crate::TypeInner as Ti; // fn get_array_size( diff --git a/src/front/wgsl/parse/ast.rs b/src/front/wgsl/parse/ast.rs index 2a56ac6f80..9b7670e4a2 100644 --- a/src/front/wgsl/parse/ast.rs +++ b/src/front/wgsl/parse/ast.rs @@ -107,7 +107,7 @@ pub struct EntryPoint { } #[cfg(doc)] -use crate::front::wgsl::lower::{ExpressionContext, StatementContext}; +use crate::front::wgsl::lower::{RuntimeExpressionContext, StatementContext}; #[derive(Debug)] pub struct Function<'a> { @@ -132,14 +132,14 @@ pub struct Function<'a> { /// During lowering, [`LocalDecl`] statements add entries to a per-function /// table that maps `Handle` values to their Naga representations, /// accessed via [`StatementContext::local_table`] and - /// [`ExpressionContext::local_table`]. This table is then consulted when + /// [`RuntimeExpressionContext::local_table`]. This table is then consulted when /// lowering subsequent [`Ident`] expressions. /// /// [`LocalDecl`]: StatementKind::LocalDecl /// [`arguments`]: Function::arguments /// [`Ident`]: Expression::Ident /// [`StatementContext::local_table`]: StatementContext::local_table - /// [`ExpressionContext::local_table`]: ExpressionContext::local_table + /// [`RuntimeExpressionContext::local_table`]: RuntimeExpressionContext::local_table pub locals: Arena, pub body: Block<'a>, diff --git a/src/lib.rs b/src/lib.rs index 12e0f481f8..48e1a915fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -223,14 +223,11 @@ pub mod proc; mod span; pub mod valid; -use std::num::NonZeroU32; - pub use crate::arena::{Arena, Handle, Range, UniqueArena}; pub use crate::span::{SourceLocation, Span, SpanContext, WithSpan}; #[cfg(feature = "arbitrary")] use arbitrary::Arbitrary; -use proc::TypeResolution; #[cfg(feature = "deserialize")] use serde::Deserialize; #[cfg(feature = "serialize")] @@ -412,7 +409,7 @@ pub enum ScalarKind { #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ArraySize { - Constant(NonZeroU32), + Constant(std::num::NonZeroU32), // OverrideConstant(Handle), // /// The array size is constant. // /// @@ -787,52 +784,6 @@ pub enum TypeInner { BindingArray { base: Handle, size: ArraySize }, } -impl TypeInner { - pub fn scalar_bool() -> Self { - Self::Scalar { - kind: ScalarKind::Bool, - width: BOOL_WIDTH, - } - } - - pub fn components(&self) -> Option { - Some(match *self { - Self::Vector { size, .. } => size as u32, - Self::Matrix { columns, .. } => columns as u32, - Self::Array { - size: crate::ArraySize::Constant(len), - .. - } => len.get(), - Self::Struct { ref members, .. } => members.len() as u32, - _ => return None, - }) - } - - pub fn component_type(&self, index: usize) -> Option { - Some(match *self { - Self::Vector { kind, width, .. } => { - TypeResolution::Value(TypeInner::Scalar { kind, width }) - } - Self::Matrix { - columns, - rows, - width, - } => TypeResolution::Value(TypeInner::Vector { - size: rows, - kind: ScalarKind::Float, - width, - }), - Self::Array { - base, - size: crate::ArraySize::Constant(_), - .. - } => TypeResolution::Handle(base), - Self::Struct { ref members, .. } => TypeResolution::Handle(members[index].ty), - _ => return None, - }) - } -} - #[derive(Debug, Clone, Copy, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] @@ -845,90 +796,6 @@ pub enum Literal { Bool(bool), } -impl PartialEq for Literal { - fn eq(&self, other: &Self) -> bool { - match (*self, *other) { - (Self::F64(a), Self::F64(b)) => a.to_bits() == b.to_bits(), - (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(), - (Self::U32(a), Self::U32(b)) => a == b, - (Self::I32(a), Self::I32(b)) => a == b, - (Self::Bool(a), Self::Bool(b)) => a == b, - _ => false, - } - } -} -impl Eq for Literal {} -impl std::hash::Hash for Literal { - fn hash(&self, hasher: &mut H) { - match *self { - Self::F64(v) => { - hasher.write_u8(0); - v.to_bits().hash(hasher); - } - Self::F32(v) => { - hasher.write_u8(1); - v.to_bits().hash(hasher); - } - Self::U32(v) => { - hasher.write_u8(2); - v.hash(hasher); - } - Self::I32(v) => { - hasher.write_u8(3); - v.hash(hasher); - } - Self::Bool(v) => { - hasher.write_u8(4); - v.hash(hasher); - } - } - } -} - -impl Literal { - pub const fn new(value: u8, kind: ScalarKind, width: Bytes) -> Option { - match (value, kind, width) { - (value, ScalarKind::Float, 8) => Some(Self::F64(value as _)), - (value, ScalarKind::Float, 4) => Some(Self::F32(value as _)), - (value, ScalarKind::Uint, 4) => Some(Self::U32(value as _)), - (value, ScalarKind::Sint, 4) => Some(Self::I32(value as _)), - (1, ScalarKind::Bool, 4) => Some(Self::Bool(true)), - (0, ScalarKind::Bool, 4) => Some(Self::Bool(false)), - _ => None, - } - } - - pub const fn zero(kind: ScalarKind, width: Bytes) -> Option { - Self::new(0, kind, width) - } - - pub const fn one(kind: ScalarKind, width: Bytes) -> Option { - Self::new(1, kind, width) - } - - pub const fn width(&self) -> Bytes { - match *self { - Self::F64(_) => 8, - Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, - Self::Bool(_) => 1, - } - } - pub const fn scalar_kind(&self) -> ScalarKind { - match *self { - Self::F64(_) | Self::F32(_) => ScalarKind::Float, - Self::U32(_) => ScalarKind::Uint, - Self::I32(_) => ScalarKind::Sint, - Self::Bool(_) => ScalarKind::Bool, - } - } - pub const fn ty_inner(&self) -> TypeInner { - TypeInner::Scalar { - kind: self.scalar_kind(), - width: self.width(), - } - } -} - // TODO: Rename to Override? #[derive(Debug, PartialEq)] #[cfg_attr(feature = "clone", derive(Clone))] @@ -2044,29 +1911,3 @@ pub struct GlobalCtx<'a> { pub constants: &'a Arena, pub const_expressions: &'a Arena, } - -impl GlobalCtx<'_> { - const fn reborrow(&self) -> GlobalCtx<'_> { - GlobalCtx { - types: self.types, - constants: self.constants, - const_expressions: self.const_expressions, - } - } -} - -impl Module { - fn to_ctx(&self) -> GlobalCtx<'_> { - self.into() - } -} - -impl<'a> From<&'a Module> for GlobalCtx<'a> { - fn from(module: &'a Module) -> Self { - Self { - types: &module.types, - constants: &module.constants, - const_expressions: &module.const_expressions, - } - } -} diff --git a/src/proc/index.rs b/src/proc/index.rs index 37764650a6..6bc19f8887 100644 --- a/src/proc/index.rs +++ b/src/proc/index.rs @@ -327,26 +327,14 @@ pub fn access_needs_check( } impl GuardedIndex { - /// Make A `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible. - /// - /// If the expression is a [`Constant`] whose value is a non-specialized, scalar - /// integer constant that can be converted to a `u32`, do so and return a - /// `GuardedIndex::Known`. Otherwise, return the `GuardedIndex::Expression` - /// unchanged. - /// - /// Return values that are already `Known` unchanged. - /// - /// [`Constant`]: crate::Expression::Constant + /// Make a `GuardedIndex::Known` from a `GuardedIndex::Expression` if possible. fn try_resolve_to_constant(&mut self, function: &crate::Function, module: &crate::Module) { if let GuardedIndex::Expression(expr) = *self { - // TODO: check if this is right - if let crate::Expression::Constant(handle) = function.expressions[expr] { - if let Ok(value) = module - .to_ctx() - .to_array_length(module.constants[handle].init.unwrap(), None) - { - *self = GuardedIndex::Known(value); - } + if let Ok(value) = module + .to_ctx() + .eval_expr_to_u32(expr, Some(&function.expressions)) + { + *self = GuardedIndex::Known(value); } } } @@ -418,9 +406,9 @@ pub enum IndexableLength { } impl crate::ArraySize { - pub fn to_indexable_length( + pub const fn to_indexable_length( self, - module: &crate::Module, + _module: &crate::Module, ) -> Result { Ok(match self { Self::Constant(length) => IndexableLength::Known(length.get()), diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 54faf6fa7a..4d39987768 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -14,8 +14,6 @@ pub use namer::{EntryPointIndex, NameKey, Namer}; pub use terminator::ensure_block_returns; pub use typifier::{ResolveContext, ResolveError, TypeResolution}; -use crate::{ScalarKind, TypeInner}; - impl From for super::ScalarKind { fn from(format: super::StorageFormat) -> Self { use super::{ScalarKind as Sk, StorageFormat as Sf}; @@ -71,6 +69,90 @@ impl super::ScalarKind { } } +impl PartialEq for crate::Literal { + fn eq(&self, other: &Self) -> bool { + match (*self, *other) { + (Self::F64(a), Self::F64(b)) => a.to_bits() == b.to_bits(), + (Self::F32(a), Self::F32(b)) => a.to_bits() == b.to_bits(), + (Self::U32(a), Self::U32(b)) => a == b, + (Self::I32(a), Self::I32(b)) => a == b, + (Self::Bool(a), Self::Bool(b)) => a == b, + _ => false, + } + } +} +impl Eq for crate::Literal {} +impl std::hash::Hash for crate::Literal { + fn hash(&self, hasher: &mut H) { + match *self { + Self::F64(v) => { + hasher.write_u8(0); + v.to_bits().hash(hasher); + } + Self::F32(v) => { + hasher.write_u8(1); + v.to_bits().hash(hasher); + } + Self::U32(v) => { + hasher.write_u8(2); + v.hash(hasher); + } + Self::I32(v) => { + hasher.write_u8(3); + v.hash(hasher); + } + Self::Bool(v) => { + hasher.write_u8(4); + v.hash(hasher); + } + } + } +} + +impl crate::Literal { + pub const fn new(value: u8, kind: crate::ScalarKind, width: crate::Bytes) -> Option { + match (value, kind, width) { + (value, crate::ScalarKind::Float, 8) => Some(Self::F64(value as _)), + (value, crate::ScalarKind::Float, 4) => Some(Self::F32(value as _)), + (value, crate::ScalarKind::Uint, 4) => Some(Self::U32(value as _)), + (value, crate::ScalarKind::Sint, 4) => Some(Self::I32(value as _)), + (1, crate::ScalarKind::Bool, 4) => Some(Self::Bool(true)), + (0, crate::ScalarKind::Bool, 4) => Some(Self::Bool(false)), + _ => None, + } + } + + pub const fn zero(kind: crate::ScalarKind, width: crate::Bytes) -> Option { + Self::new(0, kind, width) + } + + pub const fn one(kind: crate::ScalarKind, width: crate::Bytes) -> Option { + Self::new(1, kind, width) + } + + pub const fn width(&self) -> crate::Bytes { + match *self { + Self::F64(_) => 8, + Self::F32(_) | Self::U32(_) | Self::I32(_) => 4, + Self::Bool(_) => 1, + } + } + pub const fn scalar_kind(&self) -> crate::ScalarKind { + match *self { + Self::F64(_) | Self::F32(_) => crate::ScalarKind::Float, + Self::U32(_) => crate::ScalarKind::Uint, + Self::I32(_) => crate::ScalarKind::Sint, + Self::Bool(_) => crate::ScalarKind::Bool, + } + } + pub const fn ty_inner(&self) -> crate::TypeInner { + crate::TypeInner::Scalar { + kind: self.scalar_kind(), + width: self.width(), + } + } +} + pub const POINTER_SPAN: u32 = 4; impl super::TypeInner { @@ -103,7 +185,7 @@ impl super::TypeInner { } /// Get the size of this type. - pub fn size(&self, gctx: crate::GlobalCtx) -> u32 { + pub fn size(&self, _gctx: crate::GlobalCtx) -> u32 { match *self { Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32, Self::Vector { @@ -204,6 +286,46 @@ impl super::TypeInner { _ => false, } } + + pub const fn scalar_bool() -> Self { + Self::Scalar { + kind: crate::ScalarKind::Bool, + width: crate::BOOL_WIDTH, + } + } + + pub fn components(&self) -> Option { + Some(match *self { + Self::Vector { size, .. } => size as u32, + Self::Matrix { columns, .. } => columns as u32, + Self::Array { + size: crate::ArraySize::Constant(len), + .. + } => len.get(), + Self::Struct { ref members, .. } => members.len() as u32, + _ => return None, + }) + } + + pub fn component_type(&self, index: usize) -> Option { + Some(match *self { + Self::Vector { kind, width, .. } => { + TypeResolution::Value(crate::TypeInner::Scalar { kind, width }) + } + Self::Matrix { rows, width, .. } => TypeResolution::Value(crate::TypeInner::Vector { + size: rows, + kind: crate::ScalarKind::Float, + width, + }), + Self::Array { + base, + size: crate::ArraySize::Constant(_), + .. + } => TypeResolution::Handle(base), + Self::Struct { ref members, .. } => TypeResolution::Handle(members[index].ty), + _ => return None, + }) + } } impl super::AddressSpace { @@ -380,70 +502,6 @@ impl crate::SampleLevel { } } -// struct Evaluator {} - -// impl Evaluator { -// fn eval() -> Result<(), ()> {} -// } - -#[derive(Debug)] -pub(super) enum ArrayLengthError { - Invalid, - NotPositive, -} - -impl crate::GlobalCtx<'_> { - /// Interpret this constant as an array length, and return it as a `u32`. - /// - /// Ignore any specialization available for this constant; return its - /// unspecialized value. - /// - /// If the constant has an inappropriate kind (non-scalar or non-integer) or - /// value (negative, out of range for u32), return `None`. This usually - /// indicates an error, but only the caller has enough information to report - /// the error helpfully: in back ends, it's a validation error, but in front - /// ends, it may indicate ill-formed input (for example, a SPIR-V - /// `OpArrayType` referring to an inappropriate `OpConstant`). So we return - /// `Option` and let the caller sort things out. - pub(crate) fn to_array_length( - &self, - handle: crate::Handle, - arena: Option<&crate::Arena>, - ) -> Result { - fn get( - gctx: crate::GlobalCtx, - handle: crate::Handle, - arena: &crate::Arena, - ) -> Result { - match arena[handle] { - crate::Expression::Literal(crate::Literal::U32(value)) => Ok(value), - crate::Expression::Literal(crate::Literal::I32(value)) => { - value.try_into().map_err(|_| ArrayLengthError::NotPositive) - } - crate::Expression::New(ty) - if matches!( - gctx.types[ty].inner, - TypeInner::Scalar { - kind: ScalarKind::Sint | ScalarKind::Uint, - width: _ - } - ) => - { - Ok(0) - } - _ => Err(ArrayLengthError::Invalid), - } - } - let arena = arena.unwrap_or(&self.const_expressions); - match arena[handle] { - crate::Expression::Constant(c) => { - get(self.reborrow(), self.constants[c].init.unwrap(), arena) - } - _ => get(self.reborrow(), handle, arena), - } - } -} - impl crate::Binding { pub const fn to_built_in(&self) -> Option { match *self { @@ -490,6 +548,74 @@ impl super::ImageClass { } } +impl crate::Module { + pub const fn to_ctx(&self) -> crate::GlobalCtx<'_> { + crate::GlobalCtx { + types: &self.types, + constants: &self.constants, + const_expressions: &self.const_expressions, + } + } +} + +#[derive(Debug)] +pub(super) enum U32EvalError { + Invalid, + Negative, +} + +impl crate::GlobalCtx<'_> { + pub const fn reborrow(&self) -> crate::GlobalCtx<'_> { + crate::GlobalCtx { + types: self.types, + constants: self.constants, + const_expressions: self.const_expressions, + } + } + + /// Try to evaluate this expression and return it as a `u32`. + /// + /// I the expression doesn't reside in `const_expressions`, + /// provide the `arena` where it lives. + pub(crate) fn eval_expr_to_u32( + &self, + handle: crate::Handle, + arena: Option<&crate::Arena>, + ) -> Result { + fn get( + gctx: crate::GlobalCtx, + handle: crate::Handle, + arena: &crate::Arena, + ) -> Result { + match arena[handle] { + crate::Expression::Literal(crate::Literal::U32(value)) => Ok(value), + crate::Expression::Literal(crate::Literal::I32(value)) => { + value.try_into().map_err(|_| U32EvalError::Negative) + } + crate::Expression::New(ty) + if matches!( + gctx.types[ty].inner, + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + width: _ + } + ) => + { + Ok(0) + } + _ => Err(U32EvalError::Invalid), + } + } + let arena = arena.unwrap_or(self.const_expressions); + match arena[handle] { + crate::Expression::Constant(c) => { + get(self.reborrow(), self.constants[c].init.unwrap(), arena) + } + _ => get(self.reborrow(), handle, arena), + } + } +} + #[test] fn test_matrix_size() { let module = crate::Module::default(); diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 1df3d559f4..da4f0da5ea 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -11,7 +11,7 @@ use crate::arena::UniqueArena; use crate::{ arena::Handle, - proc::{IndexableLengthError, ResolveError}, + proc::{IndexableLengthError, ResolveError, U32EvalError}, }; #[derive(Clone, Debug, thiserror::Error)] @@ -25,6 +25,8 @@ pub enum ExpressionError { InvalidBaseType(Handle), #[error("Accessing with index {0:?} can't be done")] InvalidIndexType(Handle), + #[error("Accessing {0:?} via a negative index is invalid")] + NegativeIndex(Handle), #[error("Accessing index {1} is out of {0:?} bounds")] IndexOutOfBounds(Handle, u32), #[error("The expression {0:?} may only be indexed by a constant")] @@ -94,7 +96,7 @@ pub enum ExpressionError { has_ref: bool, }, #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")] - InvalidSampleOffset(crate::ImageDimension, Handle), + InvalidSampleOffset(crate::ImageDimension, Handle), #[error("Depth reference {0:?} is not a scalar float")] InvalidDepthReference(Handle), #[error("Depth sample level can only be Auto or Zero")] @@ -201,32 +203,26 @@ impl super::Validator { // If we know both the length and the index, we can do the // bounds check now. - // if let crate::proc::IndexableLength::Known(known_length) = - // base_type.indexable_length(module)? - // { - // if let E::Constant(k) = function.expressions[index] { - // if let crate::Constant { - // // We must treat specializable constants as unknown. - // specialization: crate::Specialization::None, - // // Non-scalar indices should have been caught above. - // inner: crate::ConstantInner::Scalar { value, .. }, - // .. - // } = module.constants[k] - // { - // match value { - // crate::ScalarValue::Uint(u) if u >= known_length as u64 => { - // return Err(ExpressionError::IndexOutOfBounds(base, value)); - // } - // crate::ScalarValue::Sint(s) - // if s < 0 || s >= known_length as i64 => - // { - // return Err(ExpressionError::IndexOutOfBounds(base, value)); - // } - // _ => (), - // } - // } - // } - // } + if let crate::proc::IndexableLength::Known(known_length) = + base_type.indexable_length(module)? + { + { + match module + .to_ctx() + .eval_expr_to_u32(index, Some(&function.expressions)) + { + Ok(value) => { + if value >= known_length { + return Err(ExpressionError::IndexOutOfBounds(base, value)); + } + } + Err(U32EvalError::Negative) => { + return Err(ExpressionError::NegativeIndex(base)) + } + Err(U32EvalError::Invalid) => {} + } + } + } ShaderStages::all() } @@ -406,29 +402,25 @@ impl super::Validator { } // check constant offset - // TODO - // if let Some(const_handle) = offset { - // let good = match module.constants[const_handle].inner { - // crate::ConstantInner::Scalar { - // width: _, - // value: crate::ScalarValue::Sint(_), - // } => num_components == 1, - // crate::ConstantInner::Scalar { .. } => false, - // crate::ConstantInner::Composite { ty, .. } => { - // match module.types[ty].inner { - // Ti::Vector { - // size, - // kind: Sk::Sint, - // .. - // } => size as u32 == num_components, - // _ => false, - // } - // } - // }; - // if !good { - // return Err(ExpressionError::InvalidSampleOffset(dim, const_handle)); - // } - // } + if let Some(const_expr) = offset { + let good = match module.const_expressions[const_expr] { + crate::Expression::Literal(crate::Literal::I32(_)) => num_components == 1, + crate::Expression::New(ty) | crate::Expression::Compose { ty, .. } => { + match module.types[ty].inner { + Ti::Vector { + size, + kind: Sk::Sint, + .. + } => size as u32 == num_components, + _ => false, + } + } + _ => false, + }; + if !good { + return Err(ExpressionError::InvalidSampleOffset(dim, const_expr)); + } + } // check depth reference type if let Some(expr) = depth_ref { diff --git a/src/valid/function.rs b/src/valid/function.rs index ea25d6bdc8..302268e171 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -894,8 +894,8 @@ impl super::Validator { fn validate_local_var( &self, var: &crate::LocalVariable, - types: &UniqueArena, - constants: &Arena, + _types: &UniqueArena, + _constants: &Arena, ) -> Result<(), LocalVariableError> { log::debug!("var {:?}", var); let type_info = self @@ -909,24 +909,8 @@ impl super::Validator { return Err(LocalVariableError::InvalidType(var.ty)); } - // if let Some(const_handle) = var.init { - // match constants[const_handle].inner { - // crate::ConstantInner::Scalar { width, ref value } => { - // let ty_inner = crate::TypeInner::Scalar { - // width, - // kind: value.scalar_kind(), - // }; - // if types[var.ty].inner != ty_inner { - // return Err(LocalVariableError::InitializerType); - // } - // } - // crate::ConstantInner::Composite { ty, components: _ } => { - // if ty != var.ty { - // return Err(LocalVariableError::InitializerType); - // } - // } - // } - // } + // TODO(teoxoy): check that var.ty and var.init types match + Ok(()) } diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 2c5703b558..a009af9dbd 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -45,10 +45,7 @@ impl super::Validator { // NOTE: Types being first is important. All other forms of validation depend on this. for (this_handle, ty) in types.iter() { - let &crate::Type { - ref name, - ref inner, - } = ty; + let &crate::Type { name: _, ref inner } = ty; let validate_array_size = |size| -> Result, ValidationError> { match size { @@ -103,7 +100,7 @@ impl super::Validator { let validate_const_expr = |handle| Self::validate_constant_expression_handle(handle, const_expressions); - for (this_handle, constant) in constants.iter() { + for (_handle, constant) in constants.iter() { let &crate::Constant { name: _, specialization: _, @@ -237,6 +234,7 @@ impl super::Validator { handle.check_valid_for(functions).map(|_| ()) } + #[allow(clippy::too_many_arguments)] fn validate_expression_handles( (handle, expression): (Handle, &crate::Expression), constants: &Arena, @@ -290,7 +288,7 @@ impl super::Validator { crate::Expression::ImageSample { image, sampler, - gather, + gather: _, coordinate, array_index, offset, diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 77884d6ff3..365a9893d1 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -174,12 +174,8 @@ pub struct Validator { pub enum ConstantError { #[error("The type doesn't match the constant")] InvalidType, - #[error("The component handle {0:?} can not be resolved")] - UnresolvedComponent(Handle), - #[error("The array size handle {0:?} can not be resolved")] - UnresolvedSize(Handle), - #[error(transparent)] - Compose(#[from] ComposeError), + #[error("The type is not constructible")] + NonConstructibleType, } #[derive(Clone, Debug, thiserror::Error)] @@ -301,41 +297,16 @@ impl Validator { &self, handle: Handle, constants: &Arena, - types: &UniqueArena, + _types: &UniqueArena, ) -> Result<(), ConstantError> { let con = &constants[handle]; - // match con.inner { - // crate::ConstantInner::Scalar { width, ref value } => { - // if self.check_width(value.scalar_kind(), width).is_err() { - // return Err(ConstantError::InvalidType); - // } - // } - // crate::ConstantInner::Composite { ty, ref components } => { - // match types[ty].inner { - // crate::TypeInner::Array { - // size: crate::ArraySize::Constant(size_handle), - // .. - // } if handle <= size_handle => { - // return Err(ConstantError::UnresolvedSize(size_handle)); - // } - // _ => {} - // } - // if let Some(&comp) = components.iter().find(|&&comp| handle <= comp) { - // return Err(ConstantError::UnresolvedComponent(comp)); - // } - // compose::validate_compose( - // ty, - // constants, - // types, - // components - // .iter() - // .map(|&component| constants[component].inner.resolve_type()), - // )?; - // } - // } let type_info = &self.types[con.ty.index()]; - // TODO: if !type_info.flags.contains(TypeFlags::) { return Err() } + if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { + return Err(ConstantError::NonConstructibleType); + } + + // TODO(teoxoy): check that var.ty and var.init types match Ok(()) } @@ -356,21 +327,6 @@ impl Validator { ValidationError::from(e).with_span_handle(handle, &module.types) })?; - #[cfg(feature = "validate")] - if self.flags.contains(ValidationFlags::CONSTANTS) { - for (handle, constant) in module.constants.iter() { - self.validate_constant(handle, &module.constants, &module.types) - .map_err(|source| { - ValidationError::Constant { - handle, - name: constant.name.clone().unwrap_or_default(), - source, - } - .with_span_handle(handle, &module.constants) - })? - } - } - let mut mod_info = ModuleInfo { type_flags: Vec::with_capacity(module.types.len()), functions: Vec::with_capacity(module.functions.len()), @@ -392,6 +348,35 @@ impl Validator { self.types[handle.index()] = ty_info; } + #[cfg(feature = "validate")] + if self.flags.contains(ValidationFlags::CONSTANTS) { + for (handle, constant) in module.constants.iter() { + self.validate_constant(handle, &module.constants, &module.types) + .map_err(|source| { + ValidationError::Constant { + handle, + name: constant.name.clone().unwrap_or_default(), + source, + } + .with_span_handle(handle, &module.constants) + })? + } + + // TODO(teoxoy): validate const_expressions + + // for (handle, constant) in module.const_expressions.iter() { + // self.validate_const_expr(handle, module.to_ctx()) + // .map_err(|source| { + // ValidationError::Constant { + // handle, + // name: constant.name.clone().unwrap_or_default(), + // source, + // } + // .with_span_handle(handle, &module.constants) + // })? + // } + } + #[cfg(feature = "validate")] for (var_handle, var) in module.global_variables.iter() { self.validate_global_var(var, &module.types) diff --git a/src/valid/type.rs b/src/valid/type.rs index 806d2f9e88..067b3acbb1 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -419,7 +419,7 @@ impl super::Validator { }; let type_info_mask = match size { - crate::ArraySize::Constant(len) => { + crate::ArraySize::Constant(_) => { TypeFlags::DATA | TypeFlags::SIZED | TypeFlags::COPY