# Process test dataset for mutation masking

In [None]:
import os
import sys

from tqdm.notebook import tqdm

import pandas as pd
import numpy as np

import abutils

from abutils.utils.codons import codon_lookup

## load and filter sequences

In [None]:
# replace with actual path to data
# this should be the AIRR-annotated test set
data_path = './data/lc-coherence_test-unique_annotated.csv'

In [None]:
seqs = abutils.io.read_csv(data_path)

In [None]:
pairs = abutils.pair.assign_pairs(seqs, id_key='pair_id')
len(pairs)

we just want pairs with mutations in both heavy and light chains (which can be masked):

In [None]:
mutated_pairs = [p for p in pairs if p.heavy['v_mutation_count_aa'] > 2 and p.light['v_mutation_count_aa'] > 2]
len(mutated_pairs)                                                      

we also want sequences without indels, since we can't really mask/predict those (since our vocab is just AAs and doesn't include a gap character):

In [None]:
mutated_pairs = [p for p in mutated_pairs if not any(['-' in s for s in [p.heavy['sequence_alignment'],
                                                                         p.light['sequence_alignment'],
                                                                         p.heavy['germline_alignment'],
                                                                         p.light['germline_alignment']]])]
len(mutated_pairs)

## build masked and reverted datasets

this builds datasets needed for both the paired and unpaired models

In [None]:
def translate(nt):
    aa = []
    for i in range(0, len(nt), 3):
        codon = nt[i:i+3]
        if len(codon) != 3:
            break
        aa.append(codon_lookup[codon])
    return ''.join(aa)

In [None]:
pair_ids = []
hseqs = []
lseqs = []
paired_seqs = []
hmasked = []
lmasked = []
hmask_lmutated = []
hmask_lreverted = []
lmask_hmutated = []
lmask_hreverted = []

for p in tqdm(mutated_pairs):
    # heavy chains
    hmask = ''
    hseq = translate(p.heavy['sequence_alignment'])
    hgerm = translate(p.heavy['germline_alignment'])
    for s, g in zip(hseq, hgerm):
        if s != g:
            hmask += '<mask>'
        else:
            hmask += s
    
    # light chains
    lmask = ''
    lseq = translate(p.light['sequence_alignment'])
    lgerm = translate(p.light['germline_alignment'])
    for s, g in zip(lseq, lgerm):
        if s != g:
            lmask += '<mask>'
        else:
            lmask += s

    # ids
    pair_ids.append(p.name)
    
    # for unpaired model
    hseqs.append(hseq)
    lseqs.append(lseq)
    hmasked.append(hmask)
    lmasked.append(lmask)
    
    # for paired models
    paired_seqs.append(hseq + '</s>' + lseq)
    hmask_lmutated.append(hmask + '</s>' + lseq)
    hmask_lreverted.append(hmask + '</s>' + lgerm)
    lmask_hmutated.append(hseq + '</s>' + lmask)
    lmask_hreverted.append(hgerm + '</s>' + lmask)

## save files

In [None]:
# pair ids
with open('./data/pair_ids.txt', 'w') as f:
    f.write('\n'.join(pair_ids))

In [None]:
# unpaired - masked
with open('./data/heavy-masked.txt', 'w') as f:
    f.write('\n'.join(hmasked))

with open('./data/light-masked.txt', 'w') as f:
    f.write('\n'.join(lmasked))

In [None]:
# unpaired - labels
with open('./data/heavy_labels.txt', 'w') as f:
    f.write('\n'.join(hseqs))

with open('./data/light_labels.txt', 'w') as f:
    f.write('\n'.join(lseqs))

In [None]:
# paired - masked heavy chain
with open('./data/heavy-masked_light-mutated.txt', 'w') as f:
    f.write('\n'.join(hmask_lmutated))

with open('./data/heavy-masked_light-reverted.txt', 'w') as f:
    f.write('\n'.join(hmask_lreverted))

In [None]:
# paired - masked light chain
with open('./data/light-masked_heavy-mutated.txt', 'w') as f:
    f.write('\n'.join(lmask_hmutated))

with open('./data/light-masked_heavy-reverted.txt', 'w') as f:
    f.write('\n'.join(lmask_hreverted))

In [None]:
# paired - labels
with open('./data/paired_labels.txt', 'w') as f:
    f.write('\n'.join(paired_seqs))