Skip to content

Commit

Permalink
wt[80] is redundant.
Browse files Browse the repository at this point in the history
  • Loading branch information
trcrsired committed Nov 6, 2022
1 parent 147d7c3 commit 7a50adf
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 116 deletions.
196 changes: 192 additions & 4 deletions include/fast_io_core_impl/simd/generic_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -1048,18 +1048,206 @@ inline constexpr simd_vector<T,N> operator>>(simd_vector<T,N> const& a,simd_vect
});
}

template<typename T,std::size_t N,std::integral I>
inline constexpr simd_vector<T,N> operator<<(simd_vector<T,N> const& a,I i) noexcept
template<typename T,std::size_t N>
inline constexpr simd_vector<T,N> operator<<(simd_vector<T,N> const& a,unsigned i) noexcept
{
#if defined(__x86_64__) || defined(_M_X64)
#if __cpp_if_consteval >= 202106L
if !consteval
#else
if (!__builtin_is_constant_evaluated())
#endif
{
using vec_type = simd_vector<T,N>;
if constexpr(2<=sizeof(T)&&sizeof(T)<=8)
{
if constexpr(sizeof(vec_type)==16)
{
__m128i amm = __builtin_bit_cast(__m128i,a);
if constexpr(sizeof(T)==2)
{
amm = _mm_slli_epi16(amm,i);
}
else if constexpr(sizeof(T)==4)
{
amm = _mm_slli_epi32(amm,i);
}
else if constexpr(sizeof(T)==8)
{
amm = _mm_slli_epi64(amm,i);
}
return __builtin_bit_cast(vec_type,amm);
}
else if constexpr(sizeof(vec_type)==32)
{
if constexpr(::fast_io::details::cpu_flags::avx2_supported)
{
__m256i amm = __builtin_bit_cast(__m256i,a);
if constexpr(sizeof(T)==2)
{
amm = _mm256_slli_epi16(amm,i);
}
else if constexpr(sizeof(T)==4)
{
amm = _mm256_slli_epi32(amm,i);
}
else if constexpr(sizeof(T)==8)
{
amm = _mm256_slli_epi64(amm,i);
}
return __builtin_bit_cast(vec_type,amm);
}
else if constexpr(::fast_io::details::cpu_flags::avx_supported)
{
__m256i amm = __builtin_bit_cast(__m256i,a);
__m128i alow = _mm256_castsi256_si128(amm);
__m128i ahigh = _mm256_extractf128_si256(amm,1);
if constexpr(sizeof(T)==2)
{
alow = _mm_slli_epi16(alow,i);
ahigh = _mm_slli_epi16(ahigh,i);
}
else if constexpr(sizeof(T)==4)
{
alow = _mm_slli_epi32(alow,i);
ahigh = _mm_slli_epi32(ahigh,i);
}
else if constexpr(sizeof(T)==8)
{
alow = _mm_slli_epi64(alow,i);
ahigh = _mm_slli_epi64(ahigh,i);
}
__m256i res = _mm256_castsi128_si256(alow);
res = _mm256_insertf128_si256(res,ahigh,1);
return __builtin_bit_cast(vec_type,res);
}
}
else if constexpr(sizeof(vec_type)==64)
{
if constexpr(::fast_io::details::cpu_flags::avx512bw_supported)
{
__m512i amm = __builtin_bit_cast(__m512i,a);
if constexpr(sizeof(T)==2)
{
amm = _mm512_slli_epi16(amm,i);
}
else if constexpr(sizeof(T)==4)
{
amm = _mm512_slli_epi32(amm,i);
}
else if constexpr(sizeof(T)==8)
{
amm = _mm512_slli_epi64(amm,i);
}
return __builtin_bit_cast(vec_type,amm);
}
}
}
}
#endif
return ::fast_io::details::generic_simd_self_create_op_impl(a,[i](T va)
{
return va<<i;
});
}

template<typename T,std::size_t N,std::integral I>
inline constexpr simd_vector<T,N> operator>>(simd_vector<T,N> const& a,I i) noexcept
template<typename T,std::size_t N>
inline constexpr simd_vector<T,N> operator>>(simd_vector<T,N> const& a,unsigned i) noexcept
{
#if defined(__x86_64__) || defined(_M_X64)
#if __cpp_if_consteval >= 202106L
if !consteval
#else
if (!__builtin_is_constant_evaluated())
#endif
{
using vec_type = simd_vector<T,N>;
if constexpr(2<=sizeof(T)&&sizeof(T)<=8)
{
if constexpr(sizeof(vec_type)==16)
{
__m128i amm = __builtin_bit_cast(__m128i,a);
if constexpr(sizeof(T)==2)
{
amm = _mm_srli_epi16(amm,i);
}
else if constexpr(sizeof(T)==4)
{
amm = _mm_srli_epi32(amm,i);
}
else if constexpr(sizeof(T)==8)
{
amm = _mm_srli_epi64(amm,i);
}
return __builtin_bit_cast(vec_type,amm);
}
else if constexpr(sizeof(vec_type)==32)
{
if constexpr(::fast_io::details::cpu_flags::avx2_supported)
{
__m256i amm = __builtin_bit_cast(__m256i,a);
if constexpr(sizeof(T)==2)
{
amm = _mm256_srli_epi16(amm,i);
}
else if constexpr(sizeof(T)==4)
{
amm = _mm256_srli_epi32(amm,i);
}
else if constexpr(sizeof(T)==8)
{
amm = _mm256_srli_epi64(amm,i);
}
return __builtin_bit_cast(vec_type,amm);
}
else if constexpr(::fast_io::details::cpu_flags::avx_supported)
{
__m256i amm = __builtin_bit_cast(__m256i,a);
__m128i alow = _mm256_castsi256_si128(amm);
__m128i ahigh = _mm256_extractf128_si256(amm,1);
if constexpr(sizeof(T)==2)
{
alow = _mm_srli_epi16(alow,i);
ahigh = _mm_srli_epi16(ahigh,i);
}
else if constexpr(sizeof(T)==4)
{
alow = _mm_srli_epi32(alow,i);
ahigh = _mm_srli_epi32(ahigh,i);
}
else if constexpr(sizeof(T)==8)
{
alow = _mm_srli_epi64(alow,i);
ahigh = _mm_srli_epi64(ahigh,i);
}
__m256i res = _mm256_castsi128_si256(alow);
res = _mm256_insertf128_si256(res,ahigh,1);
return __builtin_bit_cast(vec_type,res);
}
}
else if constexpr(sizeof(vec_type)==64)
{
if constexpr(::fast_io::details::cpu_flags::avx512bw_supported)
{
__m512i amm = __builtin_bit_cast(__m512i,a);
if constexpr(sizeof(T)==2)
{
amm = _mm512_srli_epi16(amm,i);
}
else if constexpr(sizeof(T)==4)
{
amm = _mm512_srli_epi32(amm,i);
}
else if constexpr(sizeof(T)==8)
{
amm = _mm512_srli_epi64(amm,i);
}
return __builtin_bit_cast(vec_type,amm);
}
}
}
}
#endif
return ::fast_io::details::generic_simd_self_create_op_impl(a,[i](T va)
{
return va>>i;
Expand Down
117 changes: 58 additions & 59 deletions include/fast_io_crypto/hash/sha512_simd16.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ inline void sha512_simd16_byte_swap_message_2rounds(::fast_io::intrinsics::simd_
}
s1.store(w+round);
s0.wrap_add_assign(s1);
s0.store(wt+round);
s0.store(wt);
}

#if __has_cpp_attribute(__gnu__::__always_inline__)
Expand Down Expand Up @@ -52,7 +52,7 @@ inline void sha512_simd16_compute_message_2rounds(
s1.store(w+round);
s0.load(::fast_io::details::sha512::K512+round);
s0.wrap_add_assign(s1);
s0.store(wt+round);
s0.store(wt);
}

#if __has_cpp_attribute(__gnu__::__flatten__)
Expand All @@ -65,8 +65,8 @@ inline void sha512_runtime_routine(std::uint_least64_t* __restrict state,std::by

simd_vector<std::uint_least64_t,2> simd;

std::uint_least64_t wt[80];
std::uint_least64_t w[80];
std::uint_least64_t wt0[2],wt1[2];
std::uint_least64_t a{state[0]};
std::uint_least64_t b{state[1]};
std::uint_least64_t c{state[2]};
Expand All @@ -78,73 +78,72 @@ inline void sha512_runtime_routine(std::uint_least64_t* __restrict state,std::by

for(;blocks_start!=blocks_last;blocks_start+=128)
{
sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt,0);
sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt,2);
sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt0,0);
sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt1,2);
std::uint_least64_t bpc{b^c};
sha512_scalar_round(wt[0],a,b,d,e,f,g,h,bpc);
sha512_scalar_round(wt[1],h,a,c,d,e,f,g,bpc);
sha512_scalar_round(wt0[0],a,b,d,e,f,g,h,bpc);
sha512_scalar_round(wt0[1],h,a,c,d,e,f,g,bpc);

sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt,4);
sha512_scalar_round(wt[2],g,h,b,c,d,e,f,bpc);
sha512_scalar_round(wt[3],f,g,a,b,c,d,e,bpc);
sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt0,4);
sha512_scalar_round(wt1[0],g,h,b,c,d,e,f,bpc);
sha512_scalar_round(wt1[1],f,g,a,b,c,d,e,bpc);

sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt,6);
sha512_scalar_round(wt[4],e,f,h,a,b,c,d,bpc);
sha512_scalar_round(wt[5],d,e,g,h,a,b,c,bpc);
sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt1,6);
sha512_scalar_round(wt0[0],e,f,h,a,b,c,d,bpc);
sha512_scalar_round(wt0[1],d,e,g,h,a,b,c,bpc);

sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt,8);
sha512_scalar_round(wt[6],c,d,f,g,h,a,b,bpc);
sha512_scalar_round(wt[7],b,c,e,f,g,h,a,bpc);
sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt0,8);
sha512_scalar_round(wt1[0],c,d,f,g,h,a,b,bpc);
sha512_scalar_round(wt1[1],b,c,e,f,g,h,a,bpc);

sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt,10);
sha512_scalar_round(wt[8],a,b,d,e,f,g,h,bpc);
sha512_scalar_round(wt[9],h,a,c,d,e,f,g,bpc);
sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt1,10);
sha512_scalar_round(wt0[0],a,b,d,e,f,g,h,bpc);
sha512_scalar_round(wt0[1],h,a,c,d,e,f,g,bpc);

sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt,12);
sha512_scalar_round(wt[10],g,h,b,c,d,e,f,bpc);
sha512_scalar_round(wt[11],f,g,a,b,c,d,e,bpc);
sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt0,12);
sha512_scalar_round(wt1[0],g,h,b,c,d,e,f,bpc);
sha512_scalar_round(wt1[1],f,g,a,b,c,d,e,bpc);

sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt,14);
sha512_scalar_round(wt[12],e,f,h,a,b,c,d,bpc);
sha512_scalar_round(wt[13],d,e,g,h,a,b,c,bpc);
sha512_simd16_byte_swap_message_2rounds(simd,blocks_start,w,wt1,14);
sha512_scalar_round(wt0[0],e,f,h,a,b,c,d,bpc);
sha512_scalar_round(wt0[1],d,e,g,h,a,b,c,bpc);

for(std::uint_fast8_t i{14};i!=78;i+=16)
{
std::uint_least64_t const* const p{wt+i};
sha512_simd16_compute_message_2rounds(simd,w,wt,i+2);
sha512_scalar_round(*p,c,d,f,g,h,a,b,bpc);
sha512_scalar_round(p[1],b,c,e,f,g,h,a,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt,i+4);
sha512_scalar_round(p[2],a,b,d,e,f,g,h,bpc);
sha512_scalar_round(p[3],h,a,c,d,e,f,g,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt,i+6);
sha512_scalar_round(p[4],g,h,b,c,d,e,f,bpc);
sha512_scalar_round(p[5],f,g,a,b,c,d,e,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt,i+8);
sha512_scalar_round(p[6],e,f,h,a,b,c,d,bpc);
sha512_scalar_round(p[7],d,e,g,h,a,b,c,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt,i+10);
sha512_scalar_round(p[8],c,d,f,g,h,a,b,bpc);
sha512_scalar_round(p[9],b,c,e,f,g,h,a,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt,i+12);
sha512_scalar_round(p[10],a,b,d,e,f,g,h,bpc);
sha512_scalar_round(p[11],h,a,c,d,e,f,g,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt,i+14);
sha512_scalar_round(p[12],g,h,b,c,d,e,f,bpc);
sha512_scalar_round(p[13],f,g,a,b,c,d,e,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt,i+16);
sha512_scalar_round(p[14],e,f,h,a,b,c,d,bpc);
sha512_scalar_round(p[15],d,e,g,h,a,b,c,bpc);
sha512_simd16_compute_message_2rounds(simd,w,wt0,i+2);
sha512_scalar_round(wt1[0],c,d,f,g,h,a,b,bpc);
sha512_scalar_round(wt1[1],b,c,e,f,g,h,a,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt1,i+4);
sha512_scalar_round(wt0[0],a,b,d,e,f,g,h,bpc);
sha512_scalar_round(wt0[1],h,a,c,d,e,f,g,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt0,i+6);
sha512_scalar_round(wt1[0],g,h,b,c,d,e,f,bpc);
sha512_scalar_round(wt1[1],f,g,a,b,c,d,e,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt1,i+8);
sha512_scalar_round(wt0[0],e,f,h,a,b,c,d,bpc);
sha512_scalar_round(wt0[1],d,e,g,h,a,b,c,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt0,i+10);
sha512_scalar_round(wt1[0],c,d,f,g,h,a,b,bpc);
sha512_scalar_round(wt1[1],b,c,e,f,g,h,a,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt1,i+12);
sha512_scalar_round(wt0[0],a,b,d,e,f,g,h,bpc);
sha512_scalar_round(wt0[1],h,a,c,d,e,f,g,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt0,i+14);
sha512_scalar_round(wt1[0],g,h,b,c,d,e,f,bpc);
sha512_scalar_round(wt1[1],f,g,a,b,c,d,e,bpc);

sha512_simd16_compute_message_2rounds(simd,w,wt1,i+16);
sha512_scalar_round(wt0[0],e,f,h,a,b,c,d,bpc);
sha512_scalar_round(wt0[1],d,e,g,h,a,b,c,bpc);
}
sha512_scalar_round(wt[78],c,d,f,g,h,a,b,bpc);
sha512_scalar_round(wt[79],b,c,e,f,g,h,a,bpc);
sha512_scalar_round(wt1[0],c,d,f,g,h,a,b,bpc);
sha512_scalar_round(wt1[1],b,c,e,f,g,h,a,bpc);
a=(*state+=a);
b=(state[1]+=b);
c=(state[2]+=c);
Expand Down
Loading

0 comments on commit 7a50adf

Please sign in to comment.