In [1]:
import os
import warnings
from Bio import SeqIO
from utils.dataloader import *
from utils.pesudo_MSA import *
from transformers import ESMForMaskedLM, ESMTokenizer, pipeline
from tqdm.notebook import tqdm as tqdm

def warn(*args, **kwargs):
    pass

warnings.warn = warn

### Dedup multiple fasta files

In [2]:
fasta0 = "HIS7_data/pseudo_MSA/HIS7_pseudo_MSA_editDist30_0.fasta"
fasta1 = "HIS7_data/pseudo_MSA/HIS7_pseudo_MSA_editDist30_1.fasta"
fasta2 = "HIS7_data/pseudo_MSA/HIS7_pseudo_MSA_editDist30_2.fasta"
fasta3 = "HIS7_data/pseudo_MSA/HIS7_pseudo_MSA_editDist30_3.fasta"

records_list = [list(SeqIO.parse(fasta_name, "fasta")) for fasta_name 
           in [fasta0, fasta1, fasta2, fasta3]]
seqs = []; records = []
for records_i in records_list:
    records = records + [record for record in records_i]
    seqs = seqs + [str(record.seq) for record in records_i]
print("Total number of pseudo MSA sequences before dedup:", len(seqs))
# set to avoid duplicate seqs
seq_set = set()
for seq in seqs:
    seq_set.add(seq)
print("Number of nondup sequences:", len(seq_set))
output_records = []
for record in records:
    seq = record.seq
    if seq in seq_set:
        seq_set.remove(seq)
        output_records.append(record)
output_fasta = "HIS7_data/pseudo_MSA/HIS7_editDist30_dedup.fasta"
with open(output_fasta, 'w') as output_handle:
    SeqIO.write(output_records, output_handle, 'fasta')

Total number of pseudo MSA sequences before dedup: 117391
Number of nondup sequences: 114771


### Subsampling from pseudo_MSA

In [5]:
fasta_name = "HIS7_data/pseudo_MSA/HIS7_editDist30_dedup.fasta"
records = list(SeqIO.parse(fasta_name, "fasta"))
print("Begin subsampling, total sequences number:", len(records))
#===============================================================#
# parameters for the subsampled fasta file
args_mutant_per_editDist = 2000; args_editDist = 30
#===============================================================#
mutants_by_editDist = []; scores_by_editDist = []
# seperate mutants by their edit distances
for edit_dist_i in range(1, args_editDist+1):
    mutants = []; scores = []
    for record in records:
        score = float(record.description.split(" | ")[1].split("=")[1])
        edit_distance = int(record.description.split(" | ")[0].split("_")[1])
        if edit_distance == edit_dist_i:
            mutants.append(record)
            scores.append(score)
    mutants_by_editDist.append(mutants)
    scores_by_editDist.append(scores)
 
# outputing records base on score    
output_records = [] 
for edit_dist_i in range(args_editDist):
    mutants = mutants_by_editDist[edit_dist_i]
    scores = scores_by_editDist[edit_dist_i]
    n_mutants = len(mutants)
    # prob to pick a mutant is its weighted score
    p = scores/np.sum(scores)
    # subsampling without replacement
    if len(mutants) < args_mutant_per_editDist:
        select_n_mutants = len(mutants)
    else:
        select_n_mutants = args_mutant_per_editDist
    weighted_choices = np.random.choice(
        len(mutants), select_n_mutants, p = p, replace=False) 
    output_records += [mutants[i] for i in weighted_choices]
#===============================================================#
output_fasta = "HIS7_data/pseudo_MSA/HIS7_editDist30_subsampled.fasta"
with open(output_fasta, 'w') as output_handle:
    SeqIO.write(output_records, output_handle, 'fasta')

Begin subsampling, total sequences number: 114771
