Skip to content

Commit

Permalink
[wasm] Add PackedSimd bit shift intrinsics (#83896)
Browse files Browse the repository at this point in the history
* Add bit shifts to PackedSimd class

* [wasm] Add PackedSimd bit shift intrinsics

Example of code generated for this method:

    void Snippet(Vector128<int> vector)
    {
        var v2 = PackedSimd.ShiftLeft (vector, 3);
        var v3 = PackedSimd.ShiftRightArithmetic (vector, 3);
        var v4 = PackedSimd.ShiftRightLogical (vector, 3);

        Print (vector, v2, v3, v4);
    }

Relevant parts:
    (func Wasm_Browser_Bench_Sample_Sample_VectorTask_Add_Snippet_System_Runtime_Intrinsics_Vector128_1_int(param $0 i32, $1 i32, $2 i32))
    ...
     local.tee $3
     i32.const 3
     i32x4.shr.u    [SIMD]
     v128.store offset:96 align:4    [SIMD]
     local.get $0
     local.get $3
     i32.const 3
     i32x4.shr.s    [SIMD]
     v128.store offset:80 align:4    [SIMD]
     local.get $0
     local.get $3
     i32.const 3
     i32x4.shl    [SIMD]
     v128.store offset:64 align:4    [SIMD]
    ...

* Fix build
  • Loading branch information
radekdoulik committed Mar 27, 2023
1 parent 8d5f520 commit beb708f
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,41 @@ public abstract class PackedSimd
public static Vector128<nint> Negate(Vector128<nint> value) { throw new PlatformNotSupportedException(); }
public static Vector128<nuint> Negate(Vector128<nuint> value) { throw new PlatformNotSupportedException(); }

// Bit shifts

public static Vector128<sbyte> ShiftLeft(Vector128<sbyte> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<byte> ShiftLeft(Vector128<byte> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<short> ShiftLeft(Vector128<short> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<ushort> ShiftLeft(Vector128<ushort> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<int> ShiftLeft(Vector128<int> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<uint> ShiftLeft(Vector128<uint> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<long> ShiftLeft(Vector128<long> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<ulong> ShiftLeft(Vector128<ulong> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<nint> ShiftLeft(Vector128<nint> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<nuint> ShiftLeft(Vector128<nuint> value, int count) { throw new PlatformNotSupportedException(); }

public static Vector128<sbyte> ShiftRightArithmetic(Vector128<sbyte> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<byte> ShiftRightArithmetic(Vector128<byte> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<short> ShiftRightArithmetic(Vector128<short> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<ushort> ShiftRightArithmetic(Vector128<ushort> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<int> ShiftRightArithmetic(Vector128<int> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<uint> ShiftRightArithmetic(Vector128<uint> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<long> ShiftRightArithmetic(Vector128<long> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<ulong> ShiftRightArithmetic(Vector128<ulong> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<nint> ShiftRightArithmetic(Vector128<nint> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<nuint> ShiftRightArithmetic(Vector128<nuint> value, int count) { throw new PlatformNotSupportedException(); }

public static Vector128<sbyte> ShiftRightLogical(Vector128<sbyte> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<byte> ShiftRightLogical(Vector128<byte> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<short> ShiftRightLogical(Vector128<short> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<ushort> ShiftRightLogical(Vector128<ushort> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<int> ShiftRightLogical(Vector128<int> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<uint> ShiftRightLogical(Vector128<uint> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<long> ShiftRightLogical(Vector128<long> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<ulong> ShiftRightLogical(Vector128<ulong> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<nint> ShiftRightLogical(Vector128<nint> value, int count) { throw new PlatformNotSupportedException(); }
public static Vector128<nuint> ShiftRightLogical(Vector128<nuint> value, int count) { throw new PlatformNotSupportedException(); }

public static Vector128<sbyte> And(Vector128<sbyte> left, Vector128<sbyte> right) { throw new PlatformNotSupportedException(); }
public static Vector128<byte> And(Vector128<byte> left, Vector128<byte> right) { throw new PlatformNotSupportedException(); }
public static Vector128<short> And(Vector128<short> left, Vector128<short> right) { throw new PlatformNotSupportedException(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,161 @@ public abstract class PackedSimd
[Intrinsic]
public static Vector128<nuint> Negate(Vector128<nuint> value) => Negate(value);

// Bit shifts

/// <summary>
/// i8x16.shl
/// </summary>
[Intrinsic]
public static Vector128<sbyte> ShiftLeft(Vector128<sbyte> value, int count) => ShiftLeft(value, count);
/// <summary>
/// i8x16.shl
/// </summary>
[Intrinsic]
public static Vector128<byte> ShiftLeft(Vector128<byte> value, int count) => ShiftLeft(value, count);
/// <summary>
/// i16x8.shl
/// </summary>
[Intrinsic]
public static Vector128<short> ShiftLeft(Vector128<short> value, int count) => ShiftLeft(value, count);
/// <summary>
/// i16x8.shl
/// </summary>
[Intrinsic]
public static Vector128<ushort> ShiftLeft(Vector128<ushort> value, int count) => ShiftLeft(value, count);
/// <summary>
/// i32x4.shl
/// </summary>
[Intrinsic]
public static Vector128<int> ShiftLeft(Vector128<int> value, int count) => ShiftLeft(value, count);
/// <summary>
/// i32x4.shl
/// </summary>
[Intrinsic]
public static Vector128<uint> ShiftLeft(Vector128<uint> value, int count) => ShiftLeft(value, count);
/// <summary>
/// i64x2.shl
/// </summary>
[Intrinsic]
public static Vector128<long> ShiftLeft(Vector128<long> value, int count) => ShiftLeft(value, count);
/// <summary>
/// i64x2.shl
/// </summary>
[Intrinsic]
public static Vector128<ulong> ShiftLeft(Vector128<ulong> value, int count) => ShiftLeft(value, count);
/// <summary>
/// i32x4.shl
/// </summary>
[Intrinsic]
public static Vector128<nint> ShiftLeft(Vector128<nint> value, int count) => ShiftLeft(value, count);
/// <summary>
/// i32x4.shl
/// </summary>
[Intrinsic]
public static Vector128<nuint> ShiftLeft(Vector128<nuint> value, int count) => ShiftLeft(value, count);

/// <summary>
/// i8x16.shr_s
/// </summary>
[Intrinsic]
public static Vector128<sbyte> ShiftRightArithmetic(Vector128<sbyte> value, int count) => ShiftRightArithmetic(value, count);
/// <summary>
/// i8x16.shr_s
/// </summary>
[Intrinsic]
public static Vector128<byte> ShiftRightArithmetic(Vector128<byte> value, int count) => ShiftRightArithmetic(value, count);
/// <summary>
/// i16x8.shr_s
/// </summary>
[Intrinsic]
public static Vector128<short> ShiftRightArithmetic(Vector128<short> value, int count) => ShiftRightArithmetic(value, count);
/// <summary>
/// i16x8.shr_s
/// </summary>
[Intrinsic]
public static Vector128<ushort> ShiftRightArithmetic(Vector128<ushort> value, int count) => ShiftRightArithmetic(value, count);
/// <summary>
/// i32x4.shr_s
/// </summary>
[Intrinsic]
public static Vector128<int> ShiftRightArithmetic(Vector128<int> value, int count) => ShiftRightArithmetic(value, count);
/// <summary>
/// i32x4.shr_s
/// </summary>
[Intrinsic]
public static Vector128<uint> ShiftRightArithmetic(Vector128<uint> value, int count) => ShiftRightArithmetic(value, count);
/// <summary>
/// i64x2.shr_s
/// </summary>
[Intrinsic]
public static Vector128<long> ShiftRightArithmetic(Vector128<long> value, int count) => ShiftRightArithmetic(value, count);
/// <summary>
/// i64x2.shr_s
/// </summary>
[Intrinsic]
public static Vector128<ulong> ShiftRightArithmetic(Vector128<ulong> value, int count) => ShiftRightArithmetic(value, count);
/// <summary>
/// i32x4.shr_s
/// </summary>
[Intrinsic]
public static Vector128<nint> ShiftRightArithmetic(Vector128<nint> value, int count) => ShiftRightArithmetic(value, count);
/// <summary>
/// i32x4.shr_s
/// </summary>
[Intrinsic]
public static Vector128<nuint> ShiftRightArithmetic(Vector128<nuint> value, int count) => ShiftRightArithmetic(value, count);

/// <summary>
/// i8x16.shr_u
/// </summary>
[Intrinsic]
public static Vector128<sbyte> ShiftRightLogical(Vector128<sbyte> value, int count) => ShiftRightLogical(value, count);
/// <summary>
/// i8x16.shr_u
/// </summary>
[Intrinsic]
public static Vector128<byte> ShiftRightLogical(Vector128<byte> value, int count) => ShiftRightLogical(value, count);
/// <summary>
/// i16x8.shr_u
/// </summary>
[Intrinsic]
public static Vector128<short> ShiftRightLogical(Vector128<short> value, int count) => ShiftRightLogical(value, count);
/// <summary>
/// i16x8.shr_u
/// </summary>
[Intrinsic]
public static Vector128<ushort> ShiftRightLogical(Vector128<ushort> value, int count) => ShiftRightLogical(value, count);
/// <summary>
/// i32x4.shr_u
/// </summary>
[Intrinsic]
public static Vector128<int> ShiftRightLogical(Vector128<int> value, int count) => ShiftRightLogical(value, count);
/// <summary>
/// i32x4.shr_u
/// </summary>
[Intrinsic]
public static Vector128<uint> ShiftRightLogical(Vector128<uint> value, int count) => ShiftRightLogical(value, count);
/// <summary>
/// i64x2.shr_u
/// </summary>
[Intrinsic]
public static Vector128<long> ShiftRightLogical(Vector128<long> value, int count) => ShiftRightLogical(value, count);
/// <summary>
/// i64x2.shr_u
/// </summary>
[Intrinsic]
public static Vector128<ulong> ShiftRightLogical(Vector128<ulong> value, int count) => ShiftRightLogical(value, count);
/// <summary>
/// i32x4.shr_u
/// </summary>
[Intrinsic]
public static Vector128<nint> ShiftRightLogical(Vector128<nint> value, int count) => ShiftRightLogical(value, count);
/// <summary>
/// i32x4.shr_u
/// </summary>
[Intrinsic]
public static Vector128<nuint> ShiftRightLogical(Vector128<nuint> value, int count) => ShiftRightLogical(value, count);

// Bitwise operations

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5490,6 +5490,36 @@ public abstract partial class PackedSimd
public static Vector128<ulong> Negate(Vector128<ulong> value) { throw null; }
public static Vector128<nint> Negate(Vector128<nint> value) { throw null; }
public static Vector128<nuint> Negate(Vector128<nuint> value) { throw null; }
public static Vector128<sbyte> ShiftLeft(Vector128<sbyte> value, int count) { throw null; }
public static Vector128<byte> ShiftLeft(Vector128<byte> value, int count) { throw null; }
public static Vector128<short> ShiftLeft(Vector128<short> value, int count) { throw null; }
public static Vector128<ushort> ShiftLeft(Vector128<ushort> value, int count) { throw null; }
public static Vector128<int> ShiftLeft(Vector128<int> value, int count) { throw null; }
public static Vector128<uint> ShiftLeft(Vector128<uint> value, int count) { throw null; }
public static Vector128<long> ShiftLeft(Vector128<long> value, int count) { throw null; }
public static Vector128<ulong> ShiftLeft(Vector128<ulong> value, int count) { throw null; }
public static Vector128<nint> ShiftLeft(Vector128<nint> value, int count) { throw null; }
public static Vector128<nuint> ShiftLeft(Vector128<nuint> value, int count) { throw null; }
public static Vector128<sbyte> ShiftRightArithmetic(Vector128<sbyte> value, int count) { throw null; }
public static Vector128<byte> ShiftRightArithmetic(Vector128<byte> value, int count) { throw null; }
public static Vector128<short> ShiftRightArithmetic(Vector128<short> value, int count) { throw null; }
public static Vector128<ushort> ShiftRightArithmetic(Vector128<ushort> value, int count) { throw null; }
public static Vector128<int> ShiftRightArithmetic(Vector128<int> value, int count) { throw null; }
public static Vector128<uint> ShiftRightArithmetic(Vector128<uint> value, int count) { throw null; }
public static Vector128<long> ShiftRightArithmetic(Vector128<long> value, int count) { throw null; }
public static Vector128<ulong> ShiftRightArithmetic(Vector128<ulong> value, int count) { throw null; }
public static Vector128<nint> ShiftRightArithmetic(Vector128<nint> value, int count) { throw null; }
public static Vector128<nuint> ShiftRightArithmetic(Vector128<nuint> value, int count) { throw null; }
public static Vector128<sbyte> ShiftRightLogical(Vector128<sbyte> value, int count) { throw null; }
public static Vector128<byte> ShiftRightLogical(Vector128<byte> value, int count) { throw null; }
public static Vector128<short> ShiftRightLogical(Vector128<short> value, int count) { throw null; }
public static Vector128<ushort> ShiftRightLogical(Vector128<ushort> value, int count) { throw null; }
public static Vector128<int> ShiftRightLogical(Vector128<int> value, int count) { throw null; }
public static Vector128<uint> ShiftRightLogical(Vector128<uint> value, int count) { throw null; }
public static Vector128<long> ShiftRightLogical(Vector128<long> value, int count) { throw null; }
public static Vector128<ulong> ShiftRightLogical(Vector128<ulong> value, int count) { throw null; }
public static Vector128<nint> ShiftRightLogical(Vector128<nint> value, int count) { throw null; }
public static Vector128<nuint> ShiftRightLogical(Vector128<nuint> value, int count) { throw null; }
public static Vector128<sbyte> And(Vector128<sbyte> left, Vector128<sbyte> right) { throw null; }
public static Vector128<byte> And(Vector128<byte> left, Vector128<byte> right) { throw null; }
public static Vector128<short> And(Vector128<short> left, Vector128<short> right) { throw null; }
Expand Down
66 changes: 33 additions & 33 deletions src/mono/mono/mini/mini-llvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -7655,6 +7655,39 @@ MONO_RESTORE_WARNING
values [ins->dreg] = result;
break;
}
case OP_SHL:
case OP_SSHR:
case OP_SSRA:
case OP_USHR:
case OP_USRA: {
gboolean right = FALSE;
gboolean add = FALSE;
gboolean arith = FALSE;
switch (ins->opcode) {
case OP_USHR: right = TRUE; break;
case OP_USRA: right = TRUE; add = TRUE; break;
case OP_SSHR: arith = TRUE; break;
case OP_SSRA: arith = TRUE; add = TRUE; break;
}
LLVMValueRef shiftarg = lhs;
LLVMValueRef shift = rhs;
if (add) {
shiftarg = rhs;
shift = arg3;
}
shift = create_shift_vector (ctx, shiftarg, shift);
LLVMValueRef result = NULL;
if (right)
result = LLVMBuildLShr (builder, shiftarg, shift, "");
else if (arith)
result = LLVMBuildAShr (builder, shiftarg, shift, "");
else
result = LLVMBuildShl (builder, shiftarg, shift, "");
if (add)
result = LLVMBuildAdd (builder, lhs, result, "arm64_usra");
values [ins->dreg] = result;
break;
}
case OP_SSHLL:
case OP_SSHLL2:
case OP_USHLL:
Expand Down Expand Up @@ -10693,39 +10726,6 @@ MONO_RESTORE_WARNING
values [ins->dreg] = result;
break;
}
case OP_ARM64_SHL:
case OP_ARM64_SSHR:
case OP_ARM64_SSRA:
case OP_ARM64_USHR:
case OP_ARM64_USRA: {
gboolean right = FALSE;
gboolean add = FALSE;
gboolean arith = FALSE;
switch (ins->opcode) {
case OP_ARM64_USHR: right = TRUE; break;
case OP_ARM64_USRA: right = TRUE; add = TRUE; break;
case OP_ARM64_SSHR: arith = TRUE; break;
case OP_ARM64_SSRA: arith = TRUE; add = TRUE; break;
}
LLVMValueRef shiftarg = lhs;
LLVMValueRef shift = rhs;
if (add) {
shiftarg = rhs;
shift = arg3;
}
shift = create_shift_vector (ctx, shiftarg, shift);
LLVMValueRef result = NULL;
if (right)
result = LLVMBuildLShr (builder, shiftarg, shift, "");
else if (arith)
result = LLVMBuildAShr (builder, shiftarg, shift, "");
else
result = LLVMBuildShl (builder, shiftarg, shift, "");
if (add)
result = LLVMBuildAdd (builder, lhs, result, "arm64_usra");
values [ins->dreg] = result;
break;
}
case OP_ARM64_SHRN:
case OP_ARM64_SHRN2: {
LLVMValueRef shiftarg = lhs;
Expand Down
11 changes: 5 additions & 6 deletions src/mono/mono/mini/mini-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1616,12 +1616,6 @@ MINI_OP(OP_ARM64_URSHR, "arm64_urshr", XREG, XREG, IREG)
MINI_OP3(OP_ARM64_SRSRA, "arm64_srsra", XREG, XREG, XREG, IREG)
MINI_OP3(OP_ARM64_URSRA, "arm64_ursra", XREG, XREG, XREG, IREG)

MINI_OP(OP_ARM64_SHL, "arm64_shl", XREG, XREG, IREG)
MINI_OP(OP_ARM64_SSHR, "arm64_sshr", XREG, XREG, IREG)
MINI_OP(OP_ARM64_USHR, "arm64_ushr", XREG, XREG, IREG)
MINI_OP3(OP_ARM64_USRA, "arm64_usra", XREG, XREG, XREG, IREG)
MINI_OP3(OP_ARM64_SSRA, "arm64_ssra", XREG, XREG, XREG, IREG)

/* Narrowing arm64 shifts that aren't decomposed into urshl or srshl. */
MINI_OP(OP_ARM64_XNSHIFT_SCALAR, "arm64_xrshift_scalar", XREG, XREG, IREG)
MINI_OP(OP_ARM64_XNSHIFT, "arm64_xnshift", XREG, XREG, IREG)
Expand Down Expand Up @@ -1769,6 +1763,11 @@ MINI_OP(OP_USHLL, "unsigned_shift_left_long", XREG, XREG, IREG)
MINI_OP(OP_USHLL2, "unsigned_shift_left_long_2", XREG, XREG, IREG)
MINI_OP(OP_SSHLL, "signed_shift_left_long", XREG, XREG, IREG)
MINI_OP(OP_SSHLL2, "signed_shift_left_long_2", XREG, XREG, IREG)
MINI_OP(OP_SHL, "shl", XREG, XREG, IREG)
MINI_OP(OP_SSHR, "sshr", XREG, XREG, IREG)
MINI_OP(OP_USHR, "ushr", XREG, XREG, IREG)
MINI_OP3(OP_USRA, "usra", XREG, XREG, XREG, IREG)
MINI_OP3(OP_SSRA, "ssra", XREG, XREG, XREG, IREG)

#if defined(TARGET_WASM)
MINI_OP(OP_WASM_ONESCOMPLEMENT, "wasm_onescomplement", XREG, XREG, NONE)
Expand Down

0 comments on commit beb708f

Please sign in to comment.