Skip to content

Commit

Permalink
[naga] Introduce HandleVec.
Browse files Browse the repository at this point in the history
Introduce a new type, `HandleVec<T, U>`, which is basically a
`Vec<U>`, except that it's indexed by values of type `Handle<T>`. This
gives us a more strictly typed way to build tables of data parallel to
some other `Arena`.

Change `naga::back::pipeline_constants` to use `HandleVec` instead of
`Vec`. This removes many calls to `Handle::index`, and makes the types
more informative.

In `naga::back::spv`, change `Writer` and `BlockContext` to use
`HandleVec` instead of `Vec` for various handle-indexed tables.
  • Loading branch information
jimblandy authored and teoxoy committed Jun 21, 2024
1 parent d6c4d5c commit 9b5035c
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 52 deletions.
98 changes: 98 additions & 0 deletions naga/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -784,3 +784,101 @@ where
arbitrary::size_hint::and(depth_hint, (0, None))
}
}

/// A [`Vec`] indexed by [`Handle`]s.
///
/// A `HandleVec<T, U>` is a [`Vec<U>`] indexed by values of type `Handle<T>`,
/// rather than `usize`.
///
/// Rather than a `push` method, `HandleVec` has an [`insert`] method, analogous
/// to [`HashMap::insert`], that requires you to provide the handle at which the
/// new value should appear. However, since `HandleVec` only supports insertion
/// at the end, the given handle's index must be equal to the the `HandleVec`'s
/// current length; otherwise, the insertion will panic.
///
/// [`insert`]: HandleVec::insert
/// [`HashMap::insert`]: std::collections::HashMap::insert
pub(crate) struct HandleVec<T, U> {
inner: Vec<U>,
as_keys: PhantomData<T>,
}

impl<T, U> Default for HandleVec<T, U> {
fn default() -> Self {
Self {
inner: vec![],
as_keys: PhantomData,
}
}
}

#[allow(dead_code)]
impl<T, U> HandleVec<T, U> {
pub(crate) const fn new() -> Self {
Self {
inner: vec![],
as_keys: PhantomData,
}
}

pub(crate) fn with_capacity(capacity: usize) -> Self {
Self {
inner: Vec::with_capacity(capacity),
as_keys: PhantomData,
}
}

pub(crate) fn len(&self) -> usize {
self.inner.len()
}

/// Insert a mapping from `handle` to `value`.
///
/// Unlike a [`HashMap`], a `HandleVec` can only have new entries inserted at
/// the end, like [`Vec::push`]. So the index of `handle` must equal
/// [`self.len()`].
///
/// [`HashMap]: std::collections::HashMap
/// [`self.len()`]: HandleVec::len
pub(crate) fn insert(&mut self, handle: Handle<T>, value: U) {
assert_eq!(handle.index(), self.inner.len());
self.inner.push(value);
}

pub(crate) fn get(&self, handle: Handle<T>) -> Option<&U> {
self.inner.get(handle.index())
}

pub(crate) fn clear(&mut self) {
self.inner.clear()
}

pub(crate) fn resize(&mut self, len: usize, fill: U)
where
U: Clone,
{
self.inner.resize(len, fill);
}

pub(crate) fn iter(&self) -> impl Iterator<Item = &U> {
self.inner.iter()
}

pub(crate) fn iter_mut(&mut self) -> impl Iterator<Item = &mut U> {
self.inner.iter_mut()
}
}

impl<T, U> ops::Index<Handle<T>> for HandleVec<T, U> {
type Output = U;

fn index(&self, handle: Handle<T>) -> &Self::Output {
&self.inner[handle.index()]
}
}

impl<T, U> ops::IndexMut<Handle<T>> for HandleVec<T, U> {
fn index_mut(&mut self, handle: Handle<T>) -> &mut Self::Output {
&mut self.inner[handle.index()]
}
}
50 changes: 24 additions & 26 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::PipelineConstants;
use crate::{
arena::HandleVec,
proc::{ConstantEvaluator, ConstantEvaluatorError, Emitter},
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
Arena, Block, Constant, Expression, Function, Handle, Literal, Module, Override, Range, Scalar,
Expand Down Expand Up @@ -49,11 +50,11 @@ pub fn process_overrides<'a>(

// A map from override handles to the handles of the constants
// we've replaced them with.
let mut override_map = Vec::with_capacity(module.overrides.len());
let mut override_map = HandleVec::with_capacity(module.overrides.len());

// A map from `module`'s original global expression handles to
// handles in the new, simplified global expression arena.
let mut adjusted_global_expressions = Vec::with_capacity(module.global_expressions.len());
let mut adjusted_global_expressions = HandleVec::with_capacity(module.global_expressions.len());

// The set of constants whose initializer handles we've already
// updated to refer to the newly built global expression arena.
Expand Down Expand Up @@ -105,7 +106,7 @@ pub fn process_overrides<'a>(
for (old_h, expr, span) in module.global_expressions.drain() {
let mut expr = match expr {
Expression::Override(h) => {
let c_h = if let Some(new_h) = override_map.get(h.index()) {
let c_h = if let Some(new_h) = override_map.get(h) {
*new_h
} else {
let mut new_h = None;
Expand All @@ -131,7 +132,7 @@ pub fn process_overrides<'a>(
Expression::Constant(c_h) => {
if adjusted_constant_initializers.insert(c_h) {
let init = &mut module.constants[c_h].init;
*init = adjusted_global_expressions[init.index()];
*init = adjusted_global_expressions[*init];
}
expr
}
Expand All @@ -144,8 +145,7 @@ pub fn process_overrides<'a>(
);
adjust_expr(&adjusted_global_expressions, &mut expr);
let h = evaluator.try_eval_and_append(expr, span)?;
debug_assert_eq!(old_h.index(), adjusted_global_expressions.len());
adjusted_global_expressions.push(h);
adjusted_global_expressions.insert(old_h, h);
}

// Finish processing any overrides we didn't visit in the loop above.
Expand All @@ -169,12 +169,12 @@ pub fn process_overrides<'a>(
.iter_mut()
.filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
{
c.init = adjusted_global_expressions[c.init.index()];
c.init = adjusted_global_expressions[c.init];
}

for (_, v) in module.global_variables.iter_mut() {
if let Some(ref mut init) = v.init {
*init = adjusted_global_expressions[init.index()];
*init = adjusted_global_expressions[*init];
}
}

Expand Down Expand Up @@ -206,8 +206,8 @@ fn process_override(
(old_h, override_, span): (Handle<Override>, Override, Span),
pipeline_constants: &PipelineConstants,
module: &mut Module,
override_map: &mut Vec<Handle<Constant>>,
adjusted_global_expressions: &[Handle<Expression>],
override_map: &mut HandleVec<Override, Handle<Constant>>,
adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
global_expression_kind_tracker: &mut crate::proc::ExpressionKindTracker,
) -> Result<Handle<Constant>, PipelineConstantError> {
Expand All @@ -234,7 +234,7 @@ fn process_override(
global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
expr
} else if let Some(init) = override_.init {
adjusted_global_expressions[init.index()]
adjusted_global_expressions[init]
} else {
return Err(PipelineConstantError::MissingValue(key.to_string()));
};
Expand All @@ -246,8 +246,7 @@ fn process_override(
init,
};
let h = module.constants.append(constant, span);
debug_assert_eq!(old_h.index(), override_map.len());
override_map.push(h);
override_map.insert(old_h, h);
adjusted_constant_initializers.insert(h);
Ok(h)
}
Expand All @@ -259,16 +258,16 @@ fn process_override(
/// Replace any expressions whose values are now known with their fully
/// evaluated form.
///
/// If `h` is a `Handle<Override>`, then `override_map[h.index()]` is the
/// If `h` is a `Handle<Override>`, then `override_map[h]` is the
/// `Handle<Constant>` for the override's final value.
fn process_function(
module: &mut Module,
override_map: &[Handle<Constant>],
override_map: &HandleVec<Override, Handle<Constant>>,
function: &mut Function,
) -> Result<(), ConstantEvaluatorError> {
// A map from original local expression handles to
// handles in the new, local expression arena.
let mut adjusted_local_expressions = Vec::with_capacity(function.expressions.len());
let mut adjusted_local_expressions = HandleVec::with_capacity(function.expressions.len());

let mut local_expression_kind_tracker = crate::proc::ExpressionKindTracker::new();

Expand All @@ -294,12 +293,11 @@ fn process_function(

for (old_h, mut expr, span) in expressions.drain() {
if let Expression::Override(h) = expr {
expr = Expression::Constant(override_map[h.index()]);
expr = Expression::Constant(override_map[h]);
}
adjust_expr(&adjusted_local_expressions, &mut expr);
let h = evaluator.try_eval_and_append(expr, span)?;
debug_assert_eq!(old_h.index(), adjusted_local_expressions.len());
adjusted_local_expressions.push(h);
adjusted_local_expressions.insert(old_h, h);
}

adjust_block(&adjusted_local_expressions, &mut function.body);
Expand All @@ -309,7 +307,7 @@ fn process_function(
// Update local expression initializers.
for (_, local) in function.local_variables.iter_mut() {
if let &mut Some(ref mut init) = &mut local.init {
*init = adjusted_local_expressions[init.index()];
*init = adjusted_local_expressions[*init];
}
}

Expand All @@ -319,17 +317,17 @@ fn process_function(
for (expr_h, name) in named_expressions {
function
.named_expressions
.insert(adjusted_local_expressions[expr_h.index()], name);
.insert(adjusted_local_expressions[expr_h], name);
}

Ok(())
}

/// Replace every expression handle in `expr` with its counterpart
/// given by `new_pos`.
fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
fn adjust_expr(new_pos: &HandleVec<Expression, Handle<Expression>>, expr: &mut Expression) {
let adjust = |expr: &mut Handle<Expression>| {
*expr = new_pos[expr.index()];
*expr = new_pos[*expr];
};
match *expr {
Expression::Compose {
Expand Down Expand Up @@ -532,17 +530,17 @@ fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {

/// Replace every expression handle in `block` with its counterpart
/// given by `new_pos`.
fn adjust_block(new_pos: &[Handle<Expression>], block: &mut Block) {
fn adjust_block(new_pos: &HandleVec<Expression, Handle<Expression>>, block: &mut Block) {
for stmt in block.iter_mut() {
adjust_stmt(new_pos, stmt);
}
}

/// Replace every expression handle in `stmt` with its counterpart
/// given by `new_pos`.
fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
fn adjust_stmt(new_pos: &HandleVec<Expression, Handle<Expression>>, stmt: &mut Statement) {
let adjust = |expr: &mut Handle<Expression>| {
*expr = new_pos[expr.index()];
*expr = new_pos[*expr];
};
match *stmt {
Statement::Emit(ref mut range) => {
Expand Down
8 changes: 4 additions & 4 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ impl<'w> BlockContext<'w> {

// The chain rule: if this `Access...`'s `base` operand was
// previously omitted, then omit this one, too.
_ => self.cached.ids[expr_handle.index()] == 0,
_ => self.cached.ids[expr_handle] == 0,
}
}

Expand All @@ -237,7 +237,7 @@ impl<'w> BlockContext<'w> {
crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
crate::Expression::Constant(handle) => {
let init = self.ir_module.constants[handle].init;
self.writer.constant_ids[init.index()]
self.writer.constant_ids[init]
}
crate::Expression::Override(_) => return Err(Error::Override),
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
Expand Down Expand Up @@ -430,7 +430,7 @@ impl<'w> BlockContext<'w> {
}
}
crate::Expression::GlobalVariable(handle) => {
self.writer.global_variables[handle.index()].access_id
self.writer.global_variables[handle].access_id
}
crate::Expression::Swizzle {
size,
Expand Down Expand Up @@ -1830,7 +1830,7 @@ impl<'w> BlockContext<'w> {
base
}
crate::Expression::GlobalVariable(handle) => {
let gv = &self.writer.global_variables[handle.index()];
let gv = &self.writer.global_variables[handle];
break gv.access_id;
}
crate::Expression::LocalVariable(variable) => {
Expand Down
4 changes: 2 additions & 2 deletions naga/src/back/spv/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ impl<'w> BlockContext<'w> {
pub(super) fn get_handle_id(&mut self, expr_handle: Handle<crate::Expression>) -> Word {
let id = match self.ir_function.expressions[expr_handle] {
crate::Expression::GlobalVariable(handle) => {
self.writer.global_variables[handle.index()].handle_id
self.writer.global_variables[handle].handle_id
}
crate::Expression::FunctionArgument(i) => {
self.function.parameters[i as usize].handle_id
Expand Down Expand Up @@ -974,7 +974,7 @@ impl<'w> BlockContext<'w> {
};

if let Some(offset_const) = offset {
let offset_id = self.writer.constant_ids[offset_const.index()];
let offset_id = self.writer.constant_ids[offset_const];
main_instruction.add_operand(offset_id);
}

Expand Down
2 changes: 1 addition & 1 deletion naga/src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl<'w> BlockContext<'w> {
_ => return Err(Error::Validation("array length expression case-4")),
};

let gvar = self.writer.global_variables[global_handle.index()].clone();
let gvar = self.writer.global_variables[global_handle].clone();
let global = &self.ir_module.global_variables[global_handle];
let (last_member_index, gvar_id) = match opt_last_member_index {
Some(index) => (index, gvar.access_id),
Expand Down
12 changes: 6 additions & 6 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mod writer;

pub use spirv::Capability;

use crate::arena::Handle;
use crate::arena::{Handle, HandleVec};
use crate::proc::{BoundsCheckPolicies, TypeResolution};

use spirv::Word;
Expand Down Expand Up @@ -420,7 +420,7 @@ enum Dimension {
/// [emit]: index.html#expression-evaluation-time-and-scope
#[derive(Default)]
struct CachedExpressions {
ids: Vec<Word>,
ids: HandleVec<crate::Expression, Word>,
}
impl CachedExpressions {
fn reset(&mut self, length: usize) {
Expand All @@ -431,7 +431,7 @@ impl CachedExpressions {
impl ops::Index<Handle<crate::Expression>> for CachedExpressions {
type Output = Word;
fn index(&self, h: Handle<crate::Expression>) -> &Word {
let id = &self.ids[h.index()];
let id = &self.ids[h];
if *id == 0 {
unreachable!("Expression {:?} is not cached!", h);
}
Expand All @@ -440,7 +440,7 @@ impl ops::Index<Handle<crate::Expression>> for CachedExpressions {
}
impl ops::IndexMut<Handle<crate::Expression>> for CachedExpressions {
fn index_mut(&mut self, h: Handle<crate::Expression>) -> &mut Word {
let id = &mut self.ids[h.index()];
let id = &mut self.ids[h];
if *id != 0 {
unreachable!("Expression {:?} is already cached!", h);
}
Expand Down Expand Up @@ -662,9 +662,9 @@ pub struct Writer {
lookup_function: crate::FastHashMap<Handle<crate::Function>, Word>,
lookup_function_type: crate::FastHashMap<LookupFunctionType, Word>,
/// Indexed by const-expression handle indexes
constant_ids: Vec<Word>,
constant_ids: HandleVec<crate::Expression, Word>,
cached_constants: crate::FastHashMap<CachedConstant, Word>,
global_variables: Vec<GlobalVariable>,
global_variables: HandleVec<crate::GlobalVariable, GlobalVariable>,
binding_map: BindingMap,

// Cached expressions are only meaningful within a BlockContext, but we
Expand Down
Loading

0 comments on commit 9b5035c

Please sign in to comment.