Skip to content

Commit

Permalink
roundkey扩展的常量时间运行,避免条件判断 #3
Browse files Browse the repository at this point in the history
  • Loading branch information
emmansun committed Jun 12, 2023
1 parent 3e163bd commit 5343658
Show file tree
Hide file tree
Showing 3 changed files with 300 additions and 279 deletions.
180 changes: 95 additions & 85 deletions _asm/bs_amd64_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -1577,78 +1577,83 @@ func xorRoundKey128() {
x3 := Mem{Base: Load(Param("x3"), GP64())}
out := Mem{Base: Load(Param("out"), GP64())}

tmp1, tmp2, tmp3, tmp := XMM(), XMM(), XMM(), XMM()
ret := XMM()
one := XMM()
PCMPEQB(one, one)

y := GP32()

count := GP64()
XORQ(count, count)
MOVQ(x, tmp2)

Comment("Handle first byte")
MOVL(U32(0x01000000), y)
MOVQ(y, tmp1)
VMOVDQU(tmp1, tmp3)
Label("rk_loop_1")
MOVOU(x1.Idx(count, 1), ret)
PXOR(x2.Idx(count, 1), ret)
PXOR(x3.Idx(count, 1), ret)
TESTL(x, y)
JZ(LabelRef("rk_loop_1_c"))
PXOR(one, ret)
Label("rk_loop_1_c")
MOVOU(ret, out.Idx(count, 1))
ROLL(U8(1), y)
VMOVDQU(x1.Idx(count, 1), ret)
VPXOR(x2.Idx(count, 1), ret, ret)
VPXOR(x3.Idx(count, 1), ret, ret)
VPAND(tmp1, tmp2, tmp)
VPCMPEQD(tmp1, tmp, tmp)
VPBROADCASTD(tmp, one)
VPXOR(one, ret, ret)
VMOVDQU(ret, out.Idx(count, 1))
VPSLLD(Imm(1), tmp1, tmp1)
ADDQ(U8(16), count)
CMPQ(count, U32(128))
JL(LabelRef("rk_loop_1"))

Comment("Handle second byte")
MOVL(U32(0x00010000), y)
VPSRLD(Imm(8), tmp3, tmp1)
Label("rk_loop_2")
MOVOU(x1.Idx(count, 1), ret)
PXOR(x2.Idx(count, 1), ret)
PXOR(x3.Idx(count, 1), ret)
TESTL(x, y)
JZ(LabelRef("rk_loop_2_c"))
PXOR(one, ret)
Label("rk_loop_2_c")
MOVOU(ret, out.Idx(count, 1))
ROLL(U8(1), y)
VMOVDQU(x1.Idx(count, 1), ret)
VPXOR(x2.Idx(count, 1), ret, ret)
VPXOR(x3.Idx(count, 1), ret, ret)
VPAND(tmp1, tmp2, tmp)
VPCMPEQD(tmp1, tmp, tmp)
VPBROADCASTD(tmp, one)
VPXOR(one, ret, ret)
VMOVDQU(ret, out.Idx(count, 1))
VPSLLD(Imm(1), tmp1, tmp1)
ADDQ(U8(16), count)
CMPQ(count, U32(256))
JL(LabelRef("rk_loop_2"))

Comment("Handle third byte")
MOVL(U32(0x00000100), y)
VPSRLD(Imm(16), tmp3, tmp1)
Label("rk_loop_3")
MOVOU(x1.Idx(count, 1), ret)
PXOR(x2.Idx(count, 1), ret)
PXOR(x3.Idx(count, 1), ret)
TESTL(x, y)
JZ(LabelRef("rk_loop_3_c"))
PXOR(one, ret)
Label("rk_loop_3_c")
MOVOU(ret, out.Idx(count, 1))
ROLL(U8(1), y)
VMOVDQU(x1.Idx(count, 1), ret)
VPXOR(x2.Idx(count, 1), ret, ret)
VPXOR(x3.Idx(count, 1), ret, ret)
VPAND(tmp1, tmp2, tmp)
VPCMPEQD(tmp1, tmp, tmp)
VPBROADCASTD(tmp, one)
VPXOR(one, ret, ret)
VMOVDQU(ret, out.Idx(count, 1))
VPSLLD(Imm(1), tmp1, tmp1)
ADDQ(U8(16), count)
CMPQ(count, U32(384))
JL(LabelRef("rk_loop_3"))

Comment("Handle last byte")
MOVL(U32(0x00000001), y)
VPSRLD(Imm(24), tmp3, tmp1)
Label("rk_loop_4")
MOVOU(x1.Idx(count, 1), ret)
PXOR(x2.Idx(count, 1), ret)
PXOR(x3.Idx(count, 1), ret)
TESTL(x, y)
JZ(LabelRef("rk_loop_4_c"))
PXOR(one, ret)
Label("rk_loop_4_c")
MOVOU(ret, out.Idx(count, 1))
ROLL(U8(1), y)
VMOVDQU(x1.Idx(count, 1), ret)
VPXOR(x2.Idx(count, 1), ret, ret)
VPXOR(x3.Idx(count, 1), ret, ret)
VPAND(tmp1, tmp2, tmp)
VPCMPEQD(tmp1, tmp, tmp)
VPBROADCASTD(tmp, one)
VPXOR(one, ret, ret)
VMOVDQU(ret, out.Idx(count, 1))
VPSLLD(Imm(1), tmp1, tmp1)
ADDQ(U8(16), count)
CMPQ(count, U32(512))
JL(LabelRef("rk_loop_4"))

VZEROUPPER()
RET()
}

Expand Down Expand Up @@ -3211,78 +3216,83 @@ func xorRoundKey256avx2() {
x3 := Mem{Base: Load(Param("x3"), GP64())}
out := Mem{Base: Load(Param("out"), GP64())}

tmp1, tmp2, tmp3, tmp := XMM(), XMM(), XMM(), XMM()
ret := YMM()
one := YMM()
VPCMPEQB(one, one, one)

y := GP32()

count := GP64()
XORQ(count, count)
MOVQ(x, tmp2)

Comment("Handle first byte")
MOVL(U32(0x01000000), y)
MOVQ(y, tmp1)
VMOVDQU(tmp1, tmp3)
Label("rk_loop_1")
VMOVDQU(x1.Idx(count, 1), ret)
VPXOR(x2.Idx(count, 1), ret, ret)
VPXOR(x3.Idx(count, 1), ret, ret)
TESTL(x, y)
JZ(LabelRef("rk_loop_1_c"))
VPAND(tmp1, tmp2, tmp)
VPCMPEQD(tmp1, tmp, tmp)
VPBROADCASTD(tmp, one)
VPXOR(one, ret, ret)
Label("rk_loop_1_c")
VMOVDQU(ret, out.Idx(count, 1))
ROLL(U8(1), y)
VPSLLD(Imm(1), tmp1, tmp1)
ADDQ(U8(32), count)
CMPQ(count, U32(256))
JL(LabelRef("rk_loop_1"))

Comment("Handle second byte")
MOVL(U32(0x00010000), y)
VPSRLD(Imm(8), tmp3, tmp1)
Label("rk_loop_2")
VMOVDQU(x1.Idx(count, 1), ret)
VPXOR(x2.Idx(count, 1), ret, ret)
VPXOR(x3.Idx(count, 1), ret, ret)
TESTL(x, y)
JZ(LabelRef("rk_loop_2_c"))
VPAND(tmp1, tmp2, tmp)
VPCMPEQD(tmp1, tmp, tmp)
VPBROADCASTD(tmp, one)
VPXOR(one, ret, ret)
Label("rk_loop_2_c")
VMOVDQU(ret, out.Idx(count, 1))
ROLL(U8(1), y)
VPSLLD(Imm(1), tmp1, tmp1)
ADDQ(U8(32), count)
CMPQ(count, U32(512))
JL(LabelRef("rk_loop_2"))

Comment("Handle third byte")
MOVL(U32(0x00000100), y)
VPSRLD(Imm(16), tmp3, tmp1)
Label("rk_loop_3")
VMOVDQU(x1.Idx(count, 1), ret)
VPXOR(x2.Idx(count, 1), ret, ret)
VPXOR(x3.Idx(count, 1), ret, ret)
TESTL(x, y)
JZ(LabelRef("rk_loop_3_c"))
VPAND(tmp1, tmp2, tmp)
VPCMPEQD(tmp1, tmp, tmp)
VPBROADCASTD(tmp, one)
VPXOR(one, ret, ret)
Label("rk_loop_3_c")
VMOVDQU(ret, out.Idx(count, 1))
ROLL(U8(1), y)
VPSLLD(Imm(1), tmp1, tmp1)
ADDQ(U8(32), count)
CMPQ(count, U32(768))
JL(LabelRef("rk_loop_3"))

Comment("Handle last byte")
MOVL(U32(0x00000001), y)
VPSRLD(Imm(24), tmp3, tmp1)
Label("rk_loop_4")
VMOVDQU(x1.Idx(count, 1), ret)
VPXOR(x2.Idx(count, 1), ret, ret)
VPXOR(x3.Idx(count, 1), ret, ret)
TESTL(x, y)
JZ(LabelRef("rk_loop_4_c"))
VPAND(tmp1, tmp2, tmp)
VPCMPEQD(tmp1, tmp, tmp)
VPBROADCASTD(tmp, one)
VPXOR(one, ret, ret)
Label("rk_loop_4_c")
VMOVDQU(ret, out.Idx(count, 1))
ROLL(U8(1), y)
VPSLLD(Imm(1), tmp1, tmp1)
ADDQ(U8(32), count)
CMPQ(count, U32(1024))
JL(LabelRef("rk_loop_4"))

VZEROUPPER()
RET()
}

Expand Down Expand Up @@ -4105,75 +4115,75 @@ func xorRoundKey64() {
x3 := Mem{Base: Load(Param("x3"), GP64())}
out := Mem{Base: Load(Param("out"), GP64())}

ret := GP64()
ret, nret := GP64(), GP64()

y := GP32()

count := GP64()
XORQ(count, count)
Comment("Handle first byte")
MOVL(U32(0x01000000), y)
Label("rk_loop_1")
Label("rk_b1")
MOVQ(x1.Idx(count, 1), ret)
XORQ(x2.Idx(count, 1), ret)
XORQ(x3.Idx(count, 1), ret)
MOVQ(ret, nret)
NOTQ(nret)
TESTL(x, y)
JZ(LabelRef("rk_loop_1_c"))
NOTQ(ret)
Label("rk_loop_1_c")
MOVQ(ret, out.Idx(count, 1))
CMOVQEQ(ret, nret)
MOVQ(nret, out.Idx(count, 1))
ROLL(U8(1), y)
ADDQ(U8(8), count)
CMPQ(count, U32(64))
JL(LabelRef("rk_loop_1"))
JL(LabelRef("rk_b1"))

Comment("Handle second byte")
MOVL(U32(0x00010000), y)
Label("rk_loop_2")
Label("rk_b2")
MOVQ(x1.Idx(count, 1), ret)
XORQ(x2.Idx(count, 1), ret)
XORQ(x3.Idx(count, 1), ret)
MOVQ(ret, nret)
NOTQ(nret)
TESTL(x, y)
JZ(LabelRef("rk_loop_2_c"))
NOTQ(ret)
Label("rk_loop_2_c")
MOVQ(ret, out.Idx(count, 1))
CMOVQEQ(ret, nret)
MOVQ(nret, out.Idx(count, 1))
ROLL(U8(1), y)
ADDQ(U8(8), count)
CMPQ(count, U32(128))
JL(LabelRef("rk_loop_2"))
JL(LabelRef("rk_b2"))

Comment("Handle third byte")
MOVL(U32(0x00000100), y)
Label("rk_loop_3")
Label("rk_b3")
MOVQ(x1.Idx(count, 1), ret)
XORQ(x2.Idx(count, 1), ret)
XORQ(x3.Idx(count, 1), ret)
MOVQ(ret, nret)
NOTQ(nret)
TESTL(x, y)
JZ(LabelRef("rk_loop_3_c"))
NOTQ(ret)
Label("rk_loop_3_c")
MOVQ(ret, out.Idx(count, 1))
CMOVQEQ(ret, nret)
MOVQ(nret, out.Idx(count, 1))
ROLL(U8(1), y)
ADDQ(U8(8), count)
CMPQ(count, U32(192))
JL(LabelRef("rk_loop_3"))
JL(LabelRef("rk_b3"))

Comment("Handle last byte")
MOVL(U32(0x00000001), y)
Label("rk_loop_4")
Label("rk_b4")
MOVQ(x1.Idx(count, 1), ret)
XORQ(x2.Idx(count, 1), ret)
XORQ(x3.Idx(count, 1), ret)
MOVQ(ret, nret)
NOTQ(nret)
TESTL(x, y)
JZ(LabelRef("rk_loop_4_c"))
NOTQ(ret)
Label("rk_loop_4_c")
MOVQ(ret, out.Idx(count, 1))
CMOVQEQ(ret, nret)
MOVQ(nret, out.Idx(count, 1))
ROLL(U8(1), y)
ADDQ(U8(8), count)
CMPQ(count, U32(256))
JL(LabelRef("rk_loop_4"))
JL(LabelRef("rk_b4"))

RET()
}
Expand Down
18 changes: 17 additions & 1 deletion bs128_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,33 @@ func BenchmarkL128(b *testing.B) {
}
}

func BenchmarkXorRK(b *testing.B) {
func TestXorRK128(t *testing.T) {
b0 := newUint32x128(0xe0e7eef5)
b1 := newUint32x128(0xc0c7ced5)
b2 := newUint32x128(0xa0a7aeb5)
rk := make([]byte, 32*BS128.bytes())
k := uint32(0xa3b1bac6)
BS128.xorRK(k, rk, b0, b1, b2)
expected := newUint32x128(k ^ 0xe0e7eef5 ^ 0xc0c7ced5 ^ 0xa0a7aeb5)
if !bytes.Equal(expected, rk) {
t.Fatalf("unexpected xorRK result %x, %x", rk, expected)
}
}

func BenchmarkXorRK128(b *testing.B) {
b0 := make([]byte, 32*BS128.bytes())
b1 := make([]byte, 32*BS128.bytes())
b2 := make([]byte, 32*BS128.bytes())
rk := make([]byte, 32*BS128.bytes())
k := uint32(0xa3b1bac6)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
BS128.xorRK(k, rk, b0, b1, b2)
}
}


func BenchmarkXor32(b *testing.B) {
b0 := make([]byte, 32*BS128.bytes())
b1 := make([]byte, 32*BS128.bytes())
Expand Down
Loading

0 comments on commit 5343658

Please sign in to comment.