Skip to content

Commit

Permalink
Merge branch 'sabercrombie/avx_rc_case' into 'master'
Browse files Browse the repository at this point in the history
Make AVX reverse_complement implementation preserve case

See merge request machine-learning/dorado!746
  • Loading branch information
StuartAbercrombie committed Dec 1, 2023
2 parents 4a4dd1c + 86e1e0d commit 1c2c6a9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
19 changes: 16 additions & 3 deletions dorado/utils/sequence_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ __attribute__((target("avx2"))) std::string reverse_complement_impl(const std::s
// 'C' & 0xf = 3
// 'T' & 0xf = 4
// 'G' & 0xf = 7
// The lowest 4 bits are the same for upper and lower case, so the lookup still works for
// lower case, but the results will always be upper case.
const __m256i kComplementTable =
_mm256_setr_epi8(0, 'T', 0, 'G', 'A', 0, 0, 'C', 0, 0, 0, 0, 0, 0, 0, 0, 0, 'T', 0, 'G',
'A', 0, 0, 'C', 0, 0, 0, 0, 0, 0, 0, 0);
Expand All @@ -67,6 +69,9 @@ __attribute__((target("avx2"))) std::string reverse_complement_impl(const std::s
_mm256_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15);

// Mask for upper / lower case bits: if set, the character is lower case.
const __m256i kCaseBitMask = _mm256_set1_epi8(0x20);

// Unroll to AVX register size. Unrolling further would probably help performance.
static constexpr size_t kUnroll = 32;

Expand All @@ -80,8 +85,13 @@ __attribute__((target("avx2"))) std::string reverse_complement_impl(const std::s
// Load template bases.
const __m256i template_bases =
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(template_ptr));
// Look up complement bases.
const __m256i complement_bases = _mm256_shuffle_epi8(kComplementTable, template_bases);
// Extract the bit that signifies upper / lower case.
const __m256i case_bits = _mm256_and_si256(template_bases, kCaseBitMask);
// Look up complement bases as upper case (where the case bit is not set).
const __m256i complement_bases_upper_case =
_mm256_shuffle_epi8(kComplementTable, template_bases);
// Reinstate bits signifying lower case.
const __m256i complement_bases = _mm256_or_si256(complement_bases_upper_case, case_bits);
// Reverse byte order within 16 byte AVX lanes.
const __m256i reversed_lanes = _mm256_shuffle_epi8(complement_bases, kByteReverseTable);
// We store reversed lanes in reverse order to reverse 32 bytes overall.
Expand All @@ -104,7 +114,10 @@ __attribute__((target("avx2"))) std::string reverse_complement_impl(const std::s
// Same steps as in the main loop, but char by char, so there's no
// reversal of byte ordering, and we load/store with scalar instructions.
const __m256i template_base = _mm256_insert_epi8(kZero, *template_ptr--, 0);
const __m256i complement_base = _mm256_shuffle_epi8(kComplementTable, template_base);
const __m256i case_bit = _mm256_and_si256(template_base, kCaseBitMask);
const __m256i complement_base_upper_case =
_mm256_shuffle_epi8(kComplementTable, template_base);
const __m256i complement_base = _mm256_or_si256(complement_base_upper_case, case_bit);
*complement_ptr++ = _mm256_extract_epi8(complement_base, 0);
}

Expand Down
11 changes: 9 additions & 2 deletions tests/SequenceUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,15 @@ TEST_CASE(TEST_GROUP "reverse_complement") {
std::string rev_comp(len, ' ');
for (int j = 0; j < len; ++j) {
const int base_index = std::rand() % 4;
temp.at(j) = bases.at(base_index);
rev_comp.at(len - 1 - j) = bases.at(3 - base_index);
char temp_base = bases.at(base_index);
char rev_comp_base = bases.at(3 - base_index);
// Randomly switch to lower case.
if (rand() & 1) {
temp_base = static_cast<char>(std::tolower(temp_base));
rev_comp_base = static_cast<char>(std::tolower(rev_comp_base));
}
temp.at(j) = temp_base;
rev_comp.at(len - 1 - j) = rev_comp_base;
}
CHECK(dorado::utils::reverse_complement(temp) == rev_comp);
}
Expand Down

0 comments on commit 1c2c6a9

Please sign in to comment.