@@ -289,222 +289,6 @@ func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting(%lhs : tenso
289289
290290// -----
291291
292- // ragged_dot mode 1: [b,m,k], [g,b,k,n], [g] -> [b,m,n]
293- func.func @ragged_dot_non_contracting (%lhs : tensor <2 x11 x5 xf32 >, %rhs : tensor <3 x2 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <2 x11 x7 xf32 > {
294- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
295- ragged_dot_dimension_numbers = #chlo.ragged_dot <
296- lhs_batching_dimensions = [0 ],
297- rhs_batching_dimensions = [1 ],
298- lhs_contracting_dimensions = [2 ],
299- rhs_contracting_dimensions = [2 ],
300- lhs_ragged_dimensions = [1 ],
301- rhs_group_dimensions = [0 ]
302- >,
303- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
304- } : (tensor <2 x11 x5 xf32 >, tensor <3 x2 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <2 x11 x7 xf32 >
305- func.return %0 : tensor <2 x11 x7 xf32 >
306- }
307-
308- // -----
309-
310- // ragged_dot mode 2: [m,k], [k,n], [g] -> [g,m,n]
311- func.func @ragged_dot_contracting (%lhs : tensor <2 x11 x5 xf32 >, %rhs : tensor <2 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <3 x2 x11 x7 xf32 > {
312- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
313- ragged_dot_dimension_numbers = #chlo.ragged_dot <
314- lhs_batching_dimensions = [0 ],
315- rhs_batching_dimensions = [0 ],
316- lhs_contracting_dimensions = [2 ],
317- rhs_contracting_dimensions = [1 ],
318- lhs_ragged_dimensions = [2 ],
319- rhs_group_dimensions = []
320- >,
321- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
322- } : (tensor <2 x11 x5 xf32 >, tensor <2 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <3 x2 x11 x7 xf32 >
323- func.return %0 : tensor <3 x2 x11 x7 xf32 >
324- }
325-
326- // -----
327-
328- // ragged_dot mode 3: [b,m,k], [b,k,n], [g] -> [b,m,n]
329- func.func @ragged_dot_batch (%lhs : tensor <3 x11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <3 x11 x7 xf32 > {
330- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
331- ragged_dot_dimension_numbers = #chlo.ragged_dot <
332- lhs_batching_dimensions = [0 ],
333- rhs_batching_dimensions = [0 ],
334- lhs_contracting_dimensions = [2 ],
335- rhs_contracting_dimensions = [1 ],
336- lhs_ragged_dimensions = [0 ],
337- rhs_group_dimensions = []
338- >,
339- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
340- } : (tensor <3 x11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <3 x11 x7 xf32 >
341- func.return %0 : tensor <3 x11 x7 xf32 >
342- }
343-
344- // -----
345-
346- func.func @ragged_dot_incompatible_contracting_dims (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <3 x2 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <11 x7 xf32 > {
347- // @expected-error@+1 {{contracting dimension sizes must match}}
348- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
349- ragged_dot_dimension_numbers = #chlo.ragged_dot <
350- lhs_batching_dimensions = [],
351- rhs_batching_dimensions = [],
352- lhs_contracting_dimensions = [1 ],
353- rhs_contracting_dimensions = [1 ],
354- lhs_ragged_dimensions = [0 ],
355- rhs_group_dimensions = [0 ]
356- >,
357- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
358- } : (tensor <11 x5 xf32 >, tensor <3 x2 x7 xf32 >, tensor <3 xi64 >) -> tensor <11 x7 xf32 >
359- func.return %0 : tensor <11 x7 xf32 >
360- }
361-
362- // -----
363-
364- func.func @ragged_dot_group_sizes_incorrect_rank (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <3 x2 xi64 >) -> tensor <11 x7 xf32 > {
365- // @expected-error@+1 {{expected rank of group_sizes of ragged dot to be 1, got 2}}
366- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
367- ragged_dot_dimension_numbers = #chlo.ragged_dot <
368- lhs_batching_dimensions = [],
369- rhs_batching_dimensions = [],
370- lhs_contracting_dimensions = [1 ],
371- rhs_contracting_dimensions = [1 ],
372- lhs_ragged_dimensions = [0 ],
373- rhs_group_dimensions = [0 ]
374- >,
375- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
376- } : (tensor <11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <3 x2 xi64 >) -> tensor <11 x7 xf32 >
377- func.return %0 : tensor <11 x7 xf32 >
378- }
379-
380- // -----
381-
382- func.func @ragged_dot_group_sizes_incorrect_shape (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <2 xi64 >) -> tensor <11 x7 xf32 > {
383- // @expected-error@+1 {{group_sizes is expected to have shape=[3], got [2]}}
384- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
385- ragged_dot_dimension_numbers = #chlo.ragged_dot <
386- lhs_batching_dimensions = [],
387- rhs_batching_dimensions = [],
388- lhs_contracting_dimensions = [1 ],
389- rhs_contracting_dimensions = [1 ],
390- lhs_ragged_dimensions = [0 ],
391- rhs_group_dimensions = [0 ]
392- >,
393- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
394- } : (tensor <11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <2 xi64 >) -> tensor <11 x7 xf32 >
395- func.return %0 : tensor <11 x7 xf32 >
396- }
397-
398- // -----
399-
400- func.func @ragged_dot_incorrect_number_of_lhs_ragged_dimensions (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <11 x7 xf32 > {
401- // @expected-error@+1 {{There must be exactly one ragged dimension in the lhs}}
402- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
403- ragged_dot_dimension_numbers = #chlo.ragged_dot <
404- lhs_batching_dimensions = [],
405- rhs_batching_dimensions = [],
406- lhs_contracting_dimensions = [1 ],
407- rhs_contracting_dimensions = [1 ],
408- lhs_ragged_dimensions = [0 , 1 ],
409- rhs_group_dimensions = [0 ]
410- >,
411- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
412- } : (tensor <11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <11 x7 xf32 >
413- func.return %0 : tensor <11 x7 xf32 >
414- }
415-
416- // -----
417-
418- func.func @ragged_dot_rhs_group_dim_is_batch (%lhs : tensor <3 x11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <3 x11 x7 xf32 > {
419- // @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_batching_dimensions: 0}}
420- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
421- ragged_dot_dimension_numbers = #chlo.ragged_dot <
422- lhs_batching_dimensions = [0 ],
423- rhs_batching_dimensions = [0 ],
424- lhs_contracting_dimensions = [2 ],
425- rhs_contracting_dimensions = [1 ],
426- lhs_ragged_dimensions = [1 ],
427- rhs_group_dimensions = [0 ]
428- >,
429- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
430- } : (tensor <3 x11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <3 x11 x7 xf32 >
431- func.return %0 : tensor <3 x11 x7 xf32 >
432- }
433-
434- // -----
435-
436- func.func @ragged_dot_rhs_group_dim_is_contracting (%lhs : tensor <11 x3 xf32 >, %rhs : tensor <3 x3 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <11 x7 xf32 > {
437- // @expected-error@+1 {{has duplicated dimension from rhs_group_dimensions and rhs_contracting_dimensions: 1}}
438- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
439- ragged_dot_dimension_numbers = #chlo.ragged_dot <
440- lhs_batching_dimensions = [],
441- rhs_batching_dimensions = [],
442- lhs_contracting_dimensions = [1 ],
443- rhs_contracting_dimensions = [1 ],
444- lhs_ragged_dimensions = [0 ],
445- rhs_group_dimensions = [1 ]
446- >,
447- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
448- } : (tensor <11 x3 xf32 >, tensor <3 x3 x7 xf32 >, tensor <3 xi64 >) -> tensor <11 x7 xf32 >
449- func.return %0 : tensor <11 x7 xf32 >
450- }
451-
452- // -----
453-
454- func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_batch (%lhs : tensor <2 x11 x5 xf32 >, %rhs : tensor <3 x2 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <2 x11 x7 xf32 > {
455- // @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}}
456- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
457- ragged_dot_dimension_numbers = #chlo.ragged_dot <
458- lhs_batching_dimensions = [0 ],
459- rhs_batching_dimensions = [1 ],
460- lhs_contracting_dimensions = [2 ],
461- rhs_contracting_dimensions = [2 ],
462- lhs_ragged_dimensions = [0 ],
463- rhs_group_dimensions = [0 ]
464- >,
465- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
466- } : (tensor <2 x11 x5 xf32 >, tensor <3 x2 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <2 x11 x7 xf32 >
467- func.return %0 : tensor <2 x11 x7 xf32 >
468- }
469-
470- // -----
471-
472- func.func @ragged_dot_nonzero_rhs_group_dims_for_ragged_contracting (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <3 x5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <11 x7 xf32 > {
473- // @expected-error@+1 {{There must be zero group dimensions in the rhs when the ragged dimension is batch or contracting}}
474- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
475- ragged_dot_dimension_numbers = #chlo.ragged_dot <
476- lhs_batching_dimensions = [],
477- rhs_batching_dimensions = [],
478- lhs_contracting_dimensions = [1 ],
479- rhs_contracting_dimensions = [1 ],
480- lhs_ragged_dimensions = [1 ],
481- rhs_group_dimensions = [0 ]
482- >,
483- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
484- } : (tensor <11 x5 xf32 >, tensor <3 x5 x7 xf32 >, tensor <3 xi64 >) -> tensor <11 x7 xf32 >
485- func.return %0 : tensor <11 x7 xf32 >
486- }
487-
488- // -----
489-
490- func.func @ragged_dot_zero_rhs_group_dims_for_ragged_noncontracting (%lhs : tensor <11 x5 xf32 >, %rhs : tensor <5 x7 xf32 >, %group_sizes : tensor <3 xi64 >) -> tensor <11 x7 xf32 > {
491- // @expected-error@+1 {{There must be exactly one group dimension in the rhs when the lhs ragged dimension is non-contracting}}
492- %0 = " chlo.ragged_dot" (%lhs , %rhs , %group_sizes ) {
493- ragged_dot_dimension_numbers = #chlo.ragged_dot <
494- lhs_batching_dimensions = [],
495- rhs_batching_dimensions = [],
496- lhs_contracting_dimensions = [1 ],
497- rhs_contracting_dimensions = [0 ],
498- lhs_ragged_dimensions = [0 ],
499- rhs_group_dimensions = []
500- >,
501- precision_config = [#chlo <precision DEFAULT >, #chlo <precision DEFAULT >]
502- } : (tensor <11 x5 xf32 >, tensor <5 x7 xf32 >, tensor <3 xi64 >) -> tensor <11 x7 xf32 >
503- func.return %0 : tensor <11 x7 xf32 >
504- }
505-
506- // -----
507-
508292func.func @top_k (%arg0 : tensor <f32 >) {
509293 // expected-error @+2 {{failed to infer returned types}}
510294 // @expected-error @+1{{operand's rank must be at least 1}}
0 commit comments