Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 128 additions & 16 deletions src/webgpu/shader/execution/expression/call/user/ptr_params.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,34 @@ import { GPUTest } from '../../../../../gpu_test.js';

export const g = makeTestGroup(GPUTest);

function wgslTypeDecl(kind: 'vec4i' | 'array' | 'struct') {
function wgslTypeDecl(
kind: 'vec4i' | 'array' | 'override_array1' | 'override_array2' | 'override_array3' | 'struct'
) {
switch (kind) {
case 'vec4i':
return `
alias T = vec4i;
alias RT = T;
`;
case 'array':
return `
alias T = array<vec4f, 3>;
alias RT = T;
`;
case 'override_array1':
return `
alias T = array<vec4f, over_no_default>;
alias RT = array<vec4f, 3>;
`;
case 'override_array2':
return `
alias T = array<vec4f, over_default>;
alias RT = array<vec4f, 3>;
`;
case 'override_array3':
return `
alias T = array<vec4f, over_expr>;
alias RT = array<vec4f, 3>;
`;
case 'struct':
return `
Expand All @@ -26,15 +45,21 @@ c : i32,
d : u32,
}
alias T = S;
alias RT = T;
`;
}
}

function valuesForType(kind: 'vec4i' | 'array' | 'struct') {
function valuesForType(
kind: 'vec4i' | 'array' | 'override_array1' | 'override_array2' | 'override_array3' | 'struct'
) {
switch (kind) {
case 'vec4i':
return new Uint32Array([1, 2, 3, 4]);
case 'array':
case 'override_array1':
case 'override_array2':
case 'override_array3':
return new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
case 'struct':
return new Uint32Array([1, 2, 3, 4]);
Expand All @@ -46,13 +71,15 @@ function run(
wgsl: string,
inputUsage: 'uniform' | 'storage',
input: Uint32Array | Float32Array,
expected: Uint32Array | Float32Array
expected: Uint32Array | Float32Array,
constants: Record<string, number> = {}
) {
const pipeline = t.device.createComputePipeline({
layout: 'auto',
compute: {
module: t.device.createShaderModule({ code: wgsl }),
entryPoint: 'main',
constants,
},
});

Expand Down Expand Up @@ -91,7 +118,24 @@ g.test('read_full_object')
u
.combine('address_space', ['function', 'private', 'workgroup', 'storage', 'uniform'] as const)
.combine('call_indirection', [0, 1, 2] as const)
.combine('type', ['vec4i', 'array', 'struct'] as const)
.combine('type', [
'vec4i',
'array',
'override_array1',
'override_array2',
'override_array3',
'struct',
] as const)
.filter(t => {
switch (t.type) {
case 'override_array1':
case 'override_array2':
case 'override_array3':
return t.address_space === 'workgroup';
default:
return true;
}
})
)
.fn(t => {
switch (t.params.address_space) {
Expand All @@ -101,6 +145,27 @@ g.test('read_full_object')
t.skipIfLanguageFeatureNotSupported('unrestricted_pointer_parameters');
}

let wg_assign_input = 'W = input;';
let output_assign = 'output = *p;';
if (t.params.address_space === 'workgroup') {
switch (t.params.type) {
case 'override_array1':
case 'override_array2':
case 'override_array3':
wg_assign_input = `
for (var i = 0u; i < 3; i++) {
W[i] = input[i];
}`;
output_assign = `
for (var i = 0u; i < 3; i++) {
output[i] = (*p)[i];
}`;
break;
default:
break;
}
}

const main: string = {
function: `
@compute @workgroup_size(1)
Expand All @@ -121,7 +186,7 @@ fn main() {
var<workgroup> W : T;
@compute @workgroup_size(1)
fn main() {
W = input;
${wg_assign_input}
f0(&W);
}
`,
Expand Down Expand Up @@ -150,18 +215,21 @@ fn f${i}(p : ptr<${t.params.address_space}, T>) {

const inputVar: string =
t.params.address_space === 'uniform'
? `@binding(0) @group(0) var<uniform> input : T;`
: `@binding(0) @group(0) var<storage, read> input : T;`;
? `@binding(0) @group(0) var<uniform> input : RT;`
: `@binding(0) @group(0) var<storage, read> input : RT;`;

const wgsl = `
override over_no_default : u32;
override over_default = 1u;
override over_expr = over_default + over_no_default - 3u;
${wgslTypeDecl(t.params.type)}

${inputVar}

@binding(1) @group(0) var<storage, read_write> output : T;
@binding(1) @group(0) var<storage, read_write> output : RT;

fn f${t.params.call_indirection}(p : ptr<${t.params.address_space}, T>) {
output = *p;
${output_assign}
}

${call_chain}
Expand All @@ -171,7 +239,10 @@ ${main}

const values = valuesForType(t.params.type);

run(t, wgsl, t.params.address_space === 'uniform' ? 'uniform' : 'storage', values, values);
run(t, wgsl, t.params.address_space === 'uniform' ? 'uniform' : 'storage', values, values, {
over_no_default: 3,
over_default: 3,
});
});

g.test('read_ptr_to_member')
Expand Down Expand Up @@ -374,7 +445,24 @@ g.test('write_full_object')
u
.combine('address_space', ['function', 'private', 'workgroup', 'storage'] as const)
.combine('call_indirection', [0, 1, 2] as const)
.combine('type', ['vec4i', 'array', 'struct'] as const)
.combine('type', [
'vec4i',
'array',
'override_array1',
'override_array2',
'override_array3',
'struct',
] as const)
.filter(t => {
switch (t.type) {
case 'override_array1':
case 'override_array2':
case 'override_array3':
return t.address_space === 'workgroup';
default:
return true;
}
})
)
.fn(t => {
switch (t.params.address_space) {
Expand All @@ -383,6 +471,27 @@ g.test('write_full_object')
t.skipIfLanguageFeatureNotSupported('unrestricted_pointer_parameters');
}

let wg_output_assign = 'output = W;';
let assign_from_input = '*p = input;';
if (t.params.address_space === 'workgroup') {
switch (t.params.type) {
case 'override_array1':
case 'override_array2':
case 'override_array3':
wg_output_assign = `
for (var i = 0u; i < 3; i++) {
output[i] = W[i];
}`;
assign_from_input = `
for (var i = 0u; i < 3; i++) {
(*p)[i] = input[i];
}`;
break;
default:
break;
}
}

const ptr =
t.params.address_space === 'storage'
? `ptr<storage, T, read_write>`
Expand Down Expand Up @@ -410,7 +519,7 @@ var<workgroup> W : T;
@compute @workgroup_size(1)
fn main() {
f0(&W);
output = W;
${wg_output_assign}
}
`,
storage: `
Expand All @@ -431,13 +540,16 @@ fn f${i}(p : ${ptr}) {
}

const wgsl = `
override over_no_default : u32;
override over_default = 1u;
override over_expr = over_default + over_no_default - 3u;
${wgslTypeDecl(t.params.type)}

@binding(0) @group(0) var<uniform> input : T;
@binding(1) @group(0) var<storage, read_write> output : T;
@binding(0) @group(0) var<uniform> input : RT;
@binding(1) @group(0) var<storage, read_write> output : RT;

fn f${t.params.call_indirection}(p : ${ptr}) {
*p = input;
${assign_from_input}
}

${call_chain}
Expand All @@ -447,7 +559,7 @@ ${main}

const values = valuesForType(t.params.type);

run(t, wgsl, 'uniform', values, values);
run(t, wgsl, 'uniform', values, values, { over_no_default: 3, over_default: 3 });
});

g.test('write_ptr_to_member')
Expand Down
35 changes: 35 additions & 0 deletions src/webgpu/shader/validation/functions/restrictions.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ struct struct_with_array {
a : array<constructible, 4>
}

override override_no_default : u32;
override override_default = 4u;
override override_expr = override_default + 2;

`;

const kVertexPosCases: Record<string, VertexPosCase> = {
Expand Down Expand Up @@ -278,6 +282,18 @@ const kFunctionParamTypeCases: Record<string, ParamTypeCase> = {
name: `ptr<workgroup, array<atomic<u32>,1>>`,
valid: 'with_unrestricted_pointer_parameters',
},
ptrWorkgroupOverrideNoDefault: {
name: `ptr<workgroup, array<u32, override_no_default>>`,
valid: 'with_unrestricted_pointer_parameters',
},
ptrWorkgroupOverrideWithDefault: {
name: `ptr<workgroup, array<f32, override_default>>`,
valid: 'with_unrestricted_pointer_parameters',
},
ptrWorkgroupOverrideExpr: {
name: `ptr<workgroup, array<vec4f, override_expr>>`,
valid: 'with_unrestricted_pointer_parameters',
},

// Invalid pointers.
invalid_ptr1: { name: `ptr<handle, u32>`, valid: false }, // Can't spell handle address space
Expand Down Expand Up @@ -488,6 +504,21 @@ const kFunctionParamValueCases: Record<string, ParamValueCase> = {
matches: ['ptr12'],
needsUnrestrictedPointerParameters: true,
},
ptrWorkgroupOverrideNoDefault: {
value: `&wg_override_no_default`,
matches: ['ptrWorkgroupOverrideNoDefault'],
needsUnrestrictedPointerParameters: true,
},
ptrWorkgroupOverrideWithDefault: {
value: `&wg_override_default`,
matches: ['ptrWorkgroupOverrideWithDefault'],
needsUnrestrictedPointerParameters: true,
},
ptrWorkgroupOverrideExpr: {
value: `&wg_override_expr`,
matches: ['ptrWorkgroupOverrideExpr'],
needsUnrestrictedPointerParameters: true,
},
};

function parameterMatches(decl: string, matches: string[]): boolean {
Expand Down Expand Up @@ -569,6 +600,10 @@ var<private> g_array5 : array<bool, 4>;
var<private> g_constructible : constructible;
var<private> g_struct_with_array : struct_with_array;

var<workgroup> wg_override_no_default : array<u32, override_no_default>;
var<workgroup> wg_override_default : array<f32, override_default>;
var<workgroup> wg_override_expr : array<vec4f, override_expr>;

fn foo() {
var f_u32 : u32;
var f_i32 : i32;
Expand Down
Loading