From 460ae20ccd6b3df2538f8c0cc65dd569784497fe Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Wed, 22 Apr 2026 21:13:26 -0400 Subject: [PATCH] Add tests for override-sized array pointer parameters Fixes #4629 * Validation and execution tests --- .../expression/call/user/ptr_params.spec.ts | 144 ++++++++++++++++-- .../validation/functions/restrictions.spec.ts | 35 +++++ 2 files changed, 163 insertions(+), 16 deletions(-) diff --git a/src/webgpu/shader/execution/expression/call/user/ptr_params.spec.ts b/src/webgpu/shader/execution/expression/call/user/ptr_params.spec.ts index 2071446c0644..ecab678ec42b 100644 --- a/src/webgpu/shader/execution/expression/call/user/ptr_params.spec.ts +++ b/src/webgpu/shader/execution/expression/call/user/ptr_params.spec.ts @@ -7,15 +7,34 @@ import { GPUTest } from '../../../../../gpu_test.js'; export const g = makeTestGroup(GPUTest); -function wgslTypeDecl(kind: 'vec4i' | 'array' | 'struct') { +function wgslTypeDecl( + kind: 'vec4i' | 'array' | 'override_array1' | 'override_array2' | 'override_array3' | 'struct' +) { switch (kind) { case 'vec4i': return ` alias T = vec4i; +alias RT = T; `; case 'array': return ` alias T = array; +alias RT = T; +`; + case 'override_array1': + return ` +alias T = array; +alias RT = array; +`; + case 'override_array2': + return ` +alias T = array; +alias RT = array; +`; + case 'override_array3': + return ` +alias T = array; +alias RT = array; `; case 'struct': return ` @@ -26,15 +45,21 @@ c : i32, d : u32, } alias T = S; +alias RT = T; `; } } -function valuesForType(kind: 'vec4i' | 'array' | 'struct') { +function valuesForType( + kind: 'vec4i' | 'array' | 'override_array1' | 'override_array2' | 'override_array3' | 'struct' +) { switch (kind) { case 'vec4i': return new Uint32Array([1, 2, 3, 4]); case 'array': + case 'override_array1': + case 'override_array2': + case 'override_array3': return new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); case 'struct': return new Uint32Array([1, 2, 3, 4]); @@ -46,13 +71,15 @@ function run( wgsl: string, inputUsage: 'uniform' | 'storage', input: Uint32Array | Float32Array, - expected: Uint32Array | Float32Array + expected: Uint32Array | Float32Array, + constants: Record = {} ) { const pipeline = t.device.createComputePipeline({ layout: 'auto', compute: { module: t.device.createShaderModule({ code: wgsl }), entryPoint: 'main', + constants, }, }); @@ -91,7 +118,24 @@ g.test('read_full_object') u .combine('address_space', ['function', 'private', 'workgroup', 'storage', 'uniform'] as const) .combine('call_indirection', [0, 1, 2] as const) - .combine('type', ['vec4i', 'array', 'struct'] as const) + .combine('type', [ + 'vec4i', + 'array', + 'override_array1', + 'override_array2', + 'override_array3', + 'struct', + ] as const) + .filter(t => { + switch (t.type) { + case 'override_array1': + case 'override_array2': + case 'override_array3': + return t.address_space === 'workgroup'; + default: + return true; + } + }) ) .fn(t => { switch (t.params.address_space) { @@ -101,6 +145,27 @@ g.test('read_full_object') t.skipIfLanguageFeatureNotSupported('unrestricted_pointer_parameters'); } + let wg_assign_input = 'W = input;'; + let output_assign = 'output = *p;'; + if (t.params.address_space === 'workgroup') { + switch (t.params.type) { + case 'override_array1': + case 'override_array2': + case 'override_array3': + wg_assign_input = ` +for (var i = 0u; i < 3; i++) { + W[i] = input[i]; +}`; + output_assign = ` +for (var i = 0u; i < 3; i++) { + output[i] = (*p)[i]; +}`; + break; + default: + break; + } + } + const main: string = { function: ` @compute @workgroup_size(1) @@ -121,7 +186,7 @@ fn main() { var W : T; @compute @workgroup_size(1) fn main() { - W = input; + ${wg_assign_input} f0(&W); } `, @@ -150,18 +215,21 @@ fn f${i}(p : ptr<${t.params.address_space}, T>) { const inputVar: string = t.params.address_space === 'uniform' - ? `@binding(0) @group(0) var input : T;` - : `@binding(0) @group(0) var input : T;`; + ? `@binding(0) @group(0) var input : RT;` + : `@binding(0) @group(0) var input : RT;`; const wgsl = ` +override over_no_default : u32; +override over_default = 1u; +override over_expr = over_default + over_no_default - 3u; ${wgslTypeDecl(t.params.type)} ${inputVar} -@binding(1) @group(0) var output : T; +@binding(1) @group(0) var output : RT; fn f${t.params.call_indirection}(p : ptr<${t.params.address_space}, T>) { - output = *p; + ${output_assign} } ${call_chain} @@ -171,7 +239,10 @@ ${main} const values = valuesForType(t.params.type); - run(t, wgsl, t.params.address_space === 'uniform' ? 'uniform' : 'storage', values, values); + run(t, wgsl, t.params.address_space === 'uniform' ? 'uniform' : 'storage', values, values, { + over_no_default: 3, + over_default: 3, + }); }); g.test('read_ptr_to_member') @@ -374,7 +445,24 @@ g.test('write_full_object') u .combine('address_space', ['function', 'private', 'workgroup', 'storage'] as const) .combine('call_indirection', [0, 1, 2] as const) - .combine('type', ['vec4i', 'array', 'struct'] as const) + .combine('type', [ + 'vec4i', + 'array', + 'override_array1', + 'override_array2', + 'override_array3', + 'struct', + ] as const) + .filter(t => { + switch (t.type) { + case 'override_array1': + case 'override_array2': + case 'override_array3': + return t.address_space === 'workgroup'; + default: + return true; + } + }) ) .fn(t => { switch (t.params.address_space) { @@ -383,6 +471,27 @@ g.test('write_full_object') t.skipIfLanguageFeatureNotSupported('unrestricted_pointer_parameters'); } + let wg_output_assign = 'output = W;'; + let assign_from_input = '*p = input;'; + if (t.params.address_space === 'workgroup') { + switch (t.params.type) { + case 'override_array1': + case 'override_array2': + case 'override_array3': + wg_output_assign = ` +for (var i = 0u; i < 3; i++) { + output[i] = W[i]; +}`; + assign_from_input = ` +for (var i = 0u; i < 3; i++) { + (*p)[i] = input[i]; +}`; + break; + default: + break; + } + } + const ptr = t.params.address_space === 'storage' ? `ptr` @@ -410,7 +519,7 @@ var W : T; @compute @workgroup_size(1) fn main() { f0(&W); - output = W; + ${wg_output_assign} } `, storage: ` @@ -431,13 +540,16 @@ fn f${i}(p : ${ptr}) { } const wgsl = ` +override over_no_default : u32; +override over_default = 1u; +override over_expr = over_default + over_no_default - 3u; ${wgslTypeDecl(t.params.type)} -@binding(0) @group(0) var input : T; -@binding(1) @group(0) var output : T; +@binding(0) @group(0) var input : RT; +@binding(1) @group(0) var output : RT; fn f${t.params.call_indirection}(p : ${ptr}) { - *p = input; + ${assign_from_input} } ${call_chain} @@ -447,7 +559,7 @@ ${main} const values = valuesForType(t.params.type); - run(t, wgsl, 'uniform', values, values); + run(t, wgsl, 'uniform', values, values, { over_no_default: 3, over_default: 3 }); }); g.test('write_ptr_to_member') diff --git a/src/webgpu/shader/validation/functions/restrictions.spec.ts b/src/webgpu/shader/validation/functions/restrictions.spec.ts index 1a16a8c8a60f..73c3a376765e 100644 --- a/src/webgpu/shader/validation/functions/restrictions.spec.ts +++ b/src/webgpu/shader/validation/functions/restrictions.spec.ts @@ -34,6 +34,10 @@ struct struct_with_array { a : array } +override override_no_default : u32; +override override_default = 4u; +override override_expr = override_default + 2; + `; const kVertexPosCases: Record = { @@ -278,6 +282,18 @@ const kFunctionParamTypeCases: Record = { name: `ptr,1>>`, valid: 'with_unrestricted_pointer_parameters', }, + ptrWorkgroupOverrideNoDefault: { + name: `ptr>`, + valid: 'with_unrestricted_pointer_parameters', + }, + ptrWorkgroupOverrideWithDefault: { + name: `ptr>`, + valid: 'with_unrestricted_pointer_parameters', + }, + ptrWorkgroupOverrideExpr: { + name: `ptr>`, + valid: 'with_unrestricted_pointer_parameters', + }, // Invalid pointers. invalid_ptr1: { name: `ptr`, valid: false }, // Can't spell handle address space @@ -488,6 +504,21 @@ const kFunctionParamValueCases: Record = { matches: ['ptr12'], needsUnrestrictedPointerParameters: true, }, + ptrWorkgroupOverrideNoDefault: { + value: `&wg_override_no_default`, + matches: ['ptrWorkgroupOverrideNoDefault'], + needsUnrestrictedPointerParameters: true, + }, + ptrWorkgroupOverrideWithDefault: { + value: `&wg_override_default`, + matches: ['ptrWorkgroupOverrideWithDefault'], + needsUnrestrictedPointerParameters: true, + }, + ptrWorkgroupOverrideExpr: { + value: `&wg_override_expr`, + matches: ['ptrWorkgroupOverrideExpr'], + needsUnrestrictedPointerParameters: true, + }, }; function parameterMatches(decl: string, matches: string[]): boolean { @@ -569,6 +600,10 @@ var g_array5 : array; var g_constructible : constructible; var g_struct_with_array : struct_with_array; +var wg_override_no_default : array; +var wg_override_default : array; +var wg_override_expr : array; + fn foo() { var f_u32 : u32; var f_i32 : i32;