@@ -401,6 +401,24 @@ def test_none_shape_bool(self, xp: ModuleType):
401401        a  =  a [a ]
402402        xp_assert_equal (isclose (a , b ), xp .asarray ([True , False ]))
403403
404+     @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" ) 
405+     @pytest .mark .skip_xp_backend (Backend .TORCH , reason = "Array API 2024.12 support" ) 
406+     def  test_python_scalar (self , xp : ModuleType ):
407+         a  =  xp .asarray ([0.0 , 0.1 ], dtype = xp .float32 )
408+         xp_assert_equal (isclose (a , 0.0 ), xp .asarray ([True , False ]))
409+         xp_assert_equal (isclose (0.0 , a ), xp .asarray ([True , False ]))
410+ 
411+         a  =  xp .asarray ([0 , 1 ], dtype = xp .int16 )
412+         xp_assert_equal (isclose (a , 0 ), xp .asarray ([True , False ]))
413+         xp_assert_equal (isclose (0 , a ), xp .asarray ([True , False ]))
414+ 
415+         xp_assert_equal (isclose (0 , 0 , xp = xp ), xp .asarray (True ))
416+         xp_assert_equal (isclose (0 , 1 , xp = xp ), xp .asarray (False ))
417+ 
418+     def  test_all_python_scalars (self ):
419+         with  pytest .raises (TypeError , match = "Unrecognized" ):
420+             isclose (0 , 0 )
421+ 
404422    def  test_xp (self , xp : ModuleType ):
405423        a  =  xp .asarray ([0.0 , 0.0 ])
406424        b  =  xp .asarray ([1e-9 , 1e-4 ])
@@ -413,30 +431,22 @@ def test_basic(self, xp: ModuleType):
413431        # Using 0-dimensional array 
414432        a  =  xp .asarray (1 )
415433        b  =  xp .asarray ([[1 , 2 ], [3 , 4 ]])
416-         k  =  xp .asarray ([[1 , 2 ], [3 , 4 ]])
417-         xp_assert_equal (kron (a , b ), k )
418-         a  =  xp .asarray ([[1 , 2 ], [3 , 4 ]])
419-         b  =  xp .asarray (1 )
420-         xp_assert_equal (kron (a , b ), k )
434+         xp_assert_equal (kron (a , b ), b )
435+         xp_assert_equal (kron (b , a ), b )
421436
422437        # Using 1-dimensional array 
423438        a  =  xp .asarray ([3 ])
424439        b  =  xp .asarray ([[1 , 2 ], [3 , 4 ]])
425440        k  =  xp .asarray ([[3 , 6 ], [9 , 12 ]])
426441        xp_assert_equal (kron (a , b ), k )
427-         a  =  xp .asarray ([[1 , 2 ], [3 , 4 ]])
428-         b  =  xp .asarray ([3 ])
429-         xp_assert_equal (kron (a , b ), k )
442+         xp_assert_equal (kron (b , a ), k )
430443
431444        # Using 3-dimensional array 
432445        a  =  xp .asarray ([[[1 ]], [[2 ]]])
433446        b  =  xp .asarray ([[1 , 2 ], [3 , 4 ]])
434447        k  =  xp .asarray ([[[1 , 2 ], [3 , 4 ]], [[2 , 4 ], [6 , 8 ]]])
435448        xp_assert_equal (kron (a , b ), k )
436-         a  =  xp .asarray ([[1 , 2 ], [3 , 4 ]])
437-         b  =  xp .asarray ([[[1 ]], [[2 ]]])
438-         k  =  xp .asarray ([[[1 , 2 ], [3 , 4 ]], [[2 , 4 ], [6 , 8 ]]])
439-         xp_assert_equal (kron (a , b ), k )
449+         xp_assert_equal (kron (b , a ), k )
440450
441451    def  test_kron_smoke (self , xp : ModuleType ):
442452        a  =  xp .ones ((3 , 3 ))
@@ -474,6 +484,18 @@ def test_kron_shape(
474484        k  =  kron (a , b )
475485        assert  k .shape  ==  expected_shape 
476486
487+     def  test_python_scalar (self , xp : ModuleType ):
488+         a  =  1 
489+         # Test no dtype promotion to xp.asarray(a); use b.dtype 
490+         b  =  xp .asarray ([[1 , 2 ], [3 , 4 ]], dtype = xp .int16 )
491+         xp_assert_equal (kron (a , b ), b )
492+         xp_assert_equal (kron (b , a ), b )
493+         xp_assert_equal (kron (1 , 1 , xp = xp ), xp .asarray (1 ))
494+ 
495+     def  test_all_python_scalars (self ):
496+         with  pytest .raises (TypeError , match = "Unrecognized" ):
497+             kron (1 , 1 )
498+ 
477499    def  test_device (self , xp : ModuleType , device : Device ):
478500        x1  =  xp .asarray ([1 , 2 , 3 ], device = device )
479501        x2  =  xp .asarray ([4 , 5 ], device = device )
@@ -601,6 +623,28 @@ def test_shapes(
601623        actual  =  setdiff1d (x1 , x2 , assume_unique = assume_unique )
602624        xp_assert_equal (actual , xp .empty ((0 ,)))
603625
626+     @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" ) 
627+     @pytest .mark .parametrize ("assume_unique" , [True , False ]) 
628+     def  test_python_scalar (self , xp : ModuleType , assume_unique : bool ):
629+         # Test no dtype promotion to xp.asarray(x2); use x1.dtype 
630+         x1  =  xp .asarray ([3 , 1 , 2 ], dtype = xp .int16 )
631+         x2  =  3 
632+         actual  =  setdiff1d (x1 , x2 , assume_unique = assume_unique )
633+         xp_assert_equal (actual , xp .asarray ([1 , 2 ], dtype = xp .int16 ))
634+ 
635+         actual  =  setdiff1d (x2 , x1 , assume_unique = assume_unique )
636+         xp_assert_equal (actual , xp .asarray ([], dtype = xp .int16 ))
637+ 
638+         xp_assert_equal (
639+             setdiff1d (0 , 0 , assume_unique = assume_unique , xp = xp ),
640+             xp .asarray ([0 ])[:0 ],  # Default int dtype for backend 
641+         )
642+ 
643+     @pytest .mark .parametrize ("assume_unique" , [True , False ]) 
644+     def  test_all_python_scalars (self , assume_unique : bool ):
645+         with  pytest .raises (TypeError , match = "Unrecognized" ):
646+             setdiff1d (0 , 0 , assume_unique = assume_unique )
647+ 
604648    def  test_device (self , xp : ModuleType , device : Device ):
605649        x1  =  xp .asarray ([3 , 8 , 20 ], device = device )
606650        x2  =  xp .asarray ([2 , 3 , 4 ], device = device )
0 commit comments