- take input sequences and make blast database
- create pairwise combos by index

Alternatively:
- take pairwise sequences
- make blast database with unique seqs

Now have blast database and pairwise sequences (need to assign indices)

- Run Blast via iterator?
- Generate feature set for model
- Send to model

In [1]:
import os
import shutil
import pandas as pd
import numpy as np
import tempfile

from typing import Collection

from itertools import combinations
import io

from Bio.Blast.Applications import NcbimakeblastdbCommandline, NcbiblastpCommandline
from Bio.Blast import NCBIXML
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO

# https://biopython.org/docs/1.75/api/Bio.pairwise2.html
from Bio import pairwise2
from Bio.Align import substitution_matrices



In [2]:
s50k = pd.read_csv('../learn2therm_sample_50k_exploration.csv')

In [3]:
s50k.head()

Unnamed: 0.1,Unnamed: 0,local_gap_compressed_percent_id,scaled_local_query_percent_id,scaled_local_symmetric_percent_id,query_align_len,query_align_cov,subject_align_len,subject_align_cov,bit_score,thermo_index,...,bit_score_16s,m_ogt,t_ogt,ogt_difference,m_protein_seq,t_protein_seq,m_protein_desc,t_protein_desc,m_protein_len,t_protein_len
0,0,0.287582,0.217822,0.215686,160,0.792079,152,0.737864,131,875,...,1153.0,27.5,50.0,22.5,MAESGTSRRADHLVPVPGPDAEPPAVADELLRAVGRGDEQAFGRLY...,MPSQITESERIELAERFERDALPLLDQLYSAALRMTRNPADAEDLV...,ECF RNA polymerase sigma factor SigK,sigma-70 family RNA polymerase sigma factor,206,202
1,1,0.319635,0.295359,0.297872,218,0.919831,226,0.969957,282,11324,...,1014.0,25.0,54.0,29.0,MARIALVDDDRNILTSVSMTLEAEGFEVETYNDGQSALDAFNKRMP...,MRVLLVEDDPNTSRSIEMMLTHANLNVYATDMGEEGIDLAKLYDYD...,response regulator transcription factor,response regulator transcription factor,233,237
2,2,0.279621,0.234127,0.218924,211,0.837302,210,0.731707,96,875,...,1138.0,28.0,50.0,22.0,MKDTVVFVTGAARGIGAHTARLAVARGARVALVGLEPHLLADLAAE...,MTPEQIFSGQTAIVTGGASGIGAATVEHIARRGGRVFSVDLSYDSP...,SDR family oxidoreductase,SDR family oxidoreductase,287,252
3,3,0.327273,0.200743,0.214712,166,0.6171,163,0.696581,175,875,...,1077.0,28.0,50.0,22.0,MTSGLWERVLDGVWVTIQLLVLSALLATAVSFVVGIARTHRLWIVR...,MAMSRRKRGQLARGIQYAILVIVVVVLALLADWGKIGKAFFDWEAA...,ectoine/hydroxyectoine ABC transporter permeas...,amino acid ABC transporter permease,234,269
4,4,0.33871,0.318182,0.287671,60,0.909091,71,0.8875,61,9827,...,991.0,30.0,50.0,20.0,MIISLRRGLRFIRFIVFFAALVYLFYHVLDLFNGWISPVDQYQMPT...,MKRMVWRTLKVFIIFIACTLLFYFGLRFMHLEYEQFHRYEPPEGPA...,YqzK family protein,YqzK family protein,80,66


In [28]:
s50k.sample(1).to_csv('s1.csv')

In [5]:
class FAFSA_paired:
    
    def __init__(self, seqs):
        
        self.seqs = seqs
        
    # Get it into fasta form
        
    def pair(self):
        
        return list(combinations(self.seqs, 2))
        
class FAFSA_single:
    
    def __init__(self, seqs):
        
        self.seqs = seqs

In [6]:
class BlastMetrics:
    """Handles computation of metrics for each alignment in a blast record.

    The HSP with the largest average sequence coverage is used for local metrics.
    
    Parameters
    ----------
    blast_record : the record containing all hits for a query
    """
    def __init__(self, seqs):
        self.seqs = seqs
        
        logger.debug(f"Query {self.qid} with {len(self.record.alignments)} alignments.")

    def id_hsp_best_cov(self, alignment):
        """Determine HSP with the most average coverage of both sequences.
        
        Returns
        -------
        Index of HSP with max average seq coverage
        Max average coverage
        """
        scores = []
        for hsp in alignment.hsps:
            scores.append(
                ((hsp.query_end +1 - hsp.query_start)/self.record.query_length + (hsp.sbjct_end +1 - hsp.sbjct_start)/alignment.length)/2)
        return np.argmax(scores), max(scores)

    def compute_metric(self, metric_name: str):
        """Compute the metric with specified name for each alignment"""
        if not hasattr(self, metric_name):
            raise ValueError(f"No metric found with name : {metric_name}")
        else:
            metric = getattr(self, metric_name)
        
        logger.debug(f"Computing metric `{metric_name}` for all alignments in query {self.qid}")

        outputs = []
        for alignment in self.record.alignments:
            hsp_id, _ = self.id_hsp_best_cov(alignment)
            hsp = alignment.hsps[hsp_id]
            outputs.append((self.qid, alignment.hit_id.split('|')[-1], metric(alignment, hsp)))
        return pd.DataFrame(data=outputs, columns=['query_id', 'subject_id', metric_name])

    @staticmethod
    def raw_gap_excluding_percent_id(n_matches, n_gaps, n_columns):
        """Percent matches in sequence, excluding gaps.
        
        Parameters
        ----------
        n_matches : int, number of matches in match columns
        n_gaps : number of gaps in match columns
        n_columns : total number of alignment match columns
        """
        return n_matches / (n_columns - n_gaps)

    @staticmethod
    def raw_gap_including_percent_id(n_matches, n_columns):
        """Percent matches in sequence, including gaps gaps.
        
        Parameters
        ----------
        n_matches : int, number of matches in match columns
        n_columns : total number of alignment match columns
        """
        return n_matches / (n_columns)

    @staticmethod
    def raw_gap_compressed_percent_id(n_matches, n_gaps, n_columns, n_compressed_gaps):
        """Percent matches in sequence, including but compressing gaps.
        
        Parameters
        ----------
        n_matches : int, number of matches in match columns
        n_gaps : number of gaps in match columns
        n_columns : total number of alignment match columns
        n_compressed_gaps : number of compressed gaps in match columns
        """
        return n_matches / (n_columns - n_gaps + n_compressed_gaps)

    def local_gap_compressed_percent_id(self, alignment, hsp):
        """Percent matches in match sequence, including but compressing gaps.
        
        The largest local HSP score is used
        """
        n_matches = hsp.identities
        n_gaps = hsp.gaps
        n_columns = len(hsp.query)
        n_compressed_gaps = len(re.findall('-+', hsp.query))+len(re.findall('-+', hsp.sbjct))
        return self.raw_gap_compressed_percent_id(n_matches, n_gaps, n_columns, n_compressed_gaps)

    def scaled_local_query_percent_id(self, alignment, hsp):
        """Percent matches in query sequence based on best HSP."""
        return hsp.identities/self.record.query_length

    def scaled_local_symmetric_percent_id(self, alignment, hsp):
        """Percent matches compared to average seq length of query and subject based on best HSP"""
        return 2*hsp.identities/(self.record.query_length + alignment.length)

    def local_E_value(self, alignment, hsp):
        """E value of HSP with most identities."""
        return hsp.expect

    def query_align_start(self, alignment, hsp):
        """Start index of alignment in query."""
        return hsp.query_start

    def query_align_end(self, alignment, hsp):
        """End index of alignment in query."""
        return hsp.query_end
    
    def subject_align_end(self, alignment, hsp):
        """End index of alignment in subject."""
        return hsp.sbjct_end

    def subject_align_start(self, alignment, hsp):
        """Start index of alignment in subject."""
        return hsp.sbjct_start

    def query_align_len(self, alignment, hsp):
        """Length of AA on query string taken up by alignment"""
        return int(hsp.query_end +1 - hsp.query_start)

    def query_align_cov(self, alignment, hsp):
        """Fraction of AA on query string taken up by alignment"""
        return (hsp.query_end +1 - hsp.query_start)/self.record.query_length
    
    def subject_align_len(self, alignment, hsp):
        """Length of AA on query string taken up by alignment"""
        return int(hsp.sbjct_end +1 - hsp.sbjct_start)

    def subject_align_cov(self, alignment, hsp):
        """Fraction of AA on query string taken up by alignment"""
        return (hsp.sbjct_end +1 - hsp.sbjct_start)/alignment.length
    
    def bit_score(self, alignment, hsp):
        return hsp.score
    
class BlastFiles:
    """Temporary files for use with BLAST CLI.
    
    Blast expects two input FASTA and produces an XML. The FASTA are redundant to CSV
    we already have. These files are created for the context and removed after completion.

    Parameters
    ----------
    query_iterator : iterator of (seq id, sequence)
        sequences to be used as query
    subject_iterator : iterator of (seq id, sequence)
        sequences to be used as the "database"

    Returns
    -------
    query_filename : str, name of fasta file with query sequences
    subject_filename : str, name of fasta file with subject sequences
    output_filename : str, name of output file for blast to save reuslts, will be deleted out of context
    """
    def __init__(self, query_iterator, subject_iterator, dbtype: str = 'prot', dev_sample_num: int = None):
        # we have to create the temporary fasta files
        logger.debug("Creating temporary files to deposit blast inputs and outputs.")
        os.makedirs('./tmp/', exist_ok=True)
        query_temp = tempfile.NamedTemporaryFile('w', delete=False, dir='./tmp/')
        logger.debug(f"query file: {query_temp.name}")
        if dev_sample_num is not None:
            logger.debug(f"Using only max {dev_sample_num} sequences from query and subject")
        self.qt = query_temp.name
        n = 0
        for id_, seq in query_iterator:
            if seq == 'None' or seq is None:
                continue
            query_temp.write(f">{id_}\n{seq}\n")
            n +=1
            if dev_sample_num is not None:
                if n >= dev_sample_num:
                    break
        query_temp.close()
        logger.debug(f"added {n} sequences to query file")

        # folder for subject DB after we make a fasta
        subject_folder = tempfile.mkdtemp(dir='./tmp/')
        self.st = subject_folder
        subject_fasta_file = subject_folder+'/subs.fasta'
        self.subject_fasta_file = subject_fasta_file
        n = 0
        file = open(subject_fasta_file, 'w')
        for id_, seq in subject_iterator:
            if seq == 'None' or seq is None:
                continue
            file.write(f">{id_}\n{seq}\n")
            n +=1
            if dev_sample_num is not None:
                if n >= dev_sample_num:
                    break
        file.close()
        logger.debug(f"added {n} sequences to db file")

        # create db for it
        NcbimakeblastdbCommandline(dbtype=dbtype, input_file=subject_fasta_file, parse_seqids=True)()
        logger.debug(f"created database")
        # create the output xml file
        out_temp = tempfile.NamedTemporaryFile('w', delete=False, dir='./tmp')
        self.ot = out_temp.name

    def __enter__(self):
        return self.qt, self.subject_fasta_file, self.ot

    def __exit__(self, type, value, traceback):
        logger.debug("Removing temporary files used by blast")
        os.remove(self.qt)
        shutil.rmtree(self.st)
        os.remove(self.ot)

In [7]:
p = FAFSA_input(test)

NameError: name 'FAFSA_input' is not defined

In [8]:
p.pair()

NameError: name 'p' is not defined

In [9]:
class BlastFiles:
    """Temporary files for use with BLAST CLI.
    
    Blast expects two input FASTA and produces an XML. The FASTA are redundant to CSV
    we already have. These files are created for the context and removed after completion.

    Parameters
    ----------
    query_iterator : iterator of (seq id, sequence)
        sequences to be used as query
    subject_iterator : iterator of (seq id, sequence)
        sequences to be used as the "database"

    Returns
    -------
    query_filename : str, name of fasta file with query sequences
    subject_filename : str, name of fasta file with subject sequences
    output_filename : str, name of output file for blast to save reuslts, will be deleted out of context
    """
    def __init__(self, query_iterator, subject_iterator, dbtype: str = 'prot', dev_sample_num: int = None):
        # we have to create the temporary fasta files
        logger.debug("Creating temporary files to deposit blast inputs and outputs.")
        os.makedirs('./tmp/', exist_ok=True)
        query_temp = tempfile.NamedTemporaryFile('w', delete=False, dir='./tmp/')
        logger.debug(f"query file: {query_temp.name}")
        if dev_sample_num is not None:
            logger.debug(f"Using only max {dev_sample_num} sequences from query and subject")
        self.qt = query_temp.name
        n = 0
        for id_, seq in query_iterator:
            if seq == 'None' or seq is None:
                continue
            query_temp.write(f">{id_}\n{seq}\n")
            n +=1
            if dev_sample_num is not None:
                if n >= dev_sample_num:
                    break
        query_temp.close()
        logger.debug(f"added {n} sequences to query file")

        # folder for subject DB after we make a fasta
        subject_folder = tempfile.mkdtemp(dir='./tmp/')
        self.st = subject_folder
        subject_fasta_file = subject_folder+'/subs.fasta'
        self.subject_fasta_file = subject_fasta_file
        n = 0
        file = open(subject_fasta_file, 'w')
        for id_, seq in subject_iterator:
            if seq == 'None' or seq is None:
                continue
            file.write(f">{id_}\n{seq}\n")
            n +=1
            if dev_sample_num is not None:
                if n >= dev_sample_num:
                    break
        file.close()
        logger.debug(f"added {n} sequences to db file")

        # create db for it
        NcbimakeblastdbCommandline(dbtype=dbtype, input_file=subject_fasta_file, parse_seqids=True)()
        logger.debug(f"created database")
        # create the output xml file
        out_temp = tempfile.NamedTemporaryFile('w', delete=False, dir='./tmp')
        self.ot = out_temp.name

    def __enter__(self):
        return self.qt, self.subject_fasta_file, self.ot

    def __exit__(self, type, value, traceback):
        logger.debug("Removing temporary files used by blast")
        os.remove(self.qt)
        shutil.rmtree(self.st)
        os.remove(self.ot)

In [10]:
def csv_id_seq_iterator(csv_filepath: str, seq_col: str, index_col: str=None, id_filter: Collection = None, chunksize: int = 512, max_seq_length: int=None, **kwargs):
    """Returns a one by one iterator of seq ids and sequences to avoid OOM.
    
    Parameters
    ----------
    csv_filepath : str
        path to file containing data
    seq_col : str
        name of column containing sequences
    index_col: str, default None
        which column name is associated with the index, otherwise the 0th column will be used
    id_filter : Collection, Optional
        If given, only return sequences with the provided indexes
    chunksize : int, default 512
        Number of sequences that will be stored in memory at once.
    max_seq_length : int, default None
        Maximum length of sequence to return
    **kwargs passed to pandas read csv 
    """
    # first get the row numbers to consider
    if id_filter is not None:
        if index_col is not None:
            # get column positions to figure out which full col to load into memory
            columns = pd.read_csv(csv_filepath, nrows=1, **kwargs).columns
            index_col_position = np.argwhere(columns==index_col)[0][0]
            # load only that column
            row_indexes = pd.read_csv(csv_filepath, usecols=[index_col_position], **kwargs)
            row_indexes = pd.Series(row_indexes.set_index(index_col, drop=True).index)
        else:
            row_indexes = pd.read_csv(csv_filepath, usecols=[0]).index # just take the first column becuase we only need the indexes
            row_indexes = pd.Series(index=row_indexes, data=row_indexes)
        row_indexes_to_keep_mask = row_indexes.isin(id_filter)
        skiprows = lambda row_num: False if row_num==0 else not row_indexes_to_keep_mask.loc[row_num-1]
        logger.debug(f"{row_indexes_to_keep_mask.sum()} viable sequences in in file to iterate")
        seq_index_iterator = iter(list(row_indexes[row_indexes_to_keep_mask].values))
    else:
        skiprows=None

    for i, df_chunk in enumerate(pd.read_csv(csv_filepath, chunksize=chunksize, skiprows=skiprows, dtype=str, **kwargs)):
        if index_col is not None:
            df_chunk = df_chunk.set_index(index_col, drop=True)
        chunk = df_chunk[seq_col]
        logger.debug(f'Iterating chunk {i} seq in {csv_filepath}')
        for id_, seq in chunk.items():
            # skip long sequences
            if max_seq_length and len(seq) > max_seq_length:
                continue
            # in the case that there were no id filters, the id in the chunk corresponds to the correct sequence id
            # but if there was a filter, many rows were skipped and the indexes got jumbled, so we have to recapitulate
            # the correct seq index
            if id_filter is not None:
                yield next(seq_index_iterator), seq
            else:
                yield id_, seq


In [39]:
iter1 = csv_id_seq_iterator('s1.csv', seq_col='m_protein_seq', index_col='meso_index')
iter2 = csv_id_seq_iterator('s1.csv', seq_col='t_protein_seq', index_col='thermo_index')

In [40]:
import logging
logger = logging.getLogger(__name__)

In [41]:
with BlastFiles(iter1, iter2) as (qname, sname, oname):
    NcbiblastpCommandline(query=qname, db=sname, outfmt=5, out=oname, word_size=9, evalue=10000, max_target_seqs=100000)()
    f = open(oname, 'r')
    blast_result_records = NCBIXML.parse(f)
    record1 = next(blast_result_records)
    record2 = next(blast_result_records)
    
    metrics = BlastMetrics(record2)
    print(record2.query_length)

    for m in [
        'local_gap_compressed_percent_id',
        'scaled_local_query_percent_id',
        'scaled_local_symmetric_percent_id',
        'local_E_value',
        'query_align_start',
        'query_align_end',
        'subject_align_end',
        'subject_align_start',
        'query_align_len',
        'query_align_cov',
        'subject_align_len',
        'subject_align_cov',
        'bit_score'
    ]: 
        print(metrics.compute_metric(m))

ApplicationError: Non-zero return code 1 from 'blastp -out /home/ryfran/ValidProt/notebooks/c0-c2_exploration_plotting_sampling/tmp/tmptfofx1ll -outfmt 5 -query /home/ryfran/ValidProt/notebooks/c0-c2_exploration_plotting_sampling/tmp/tmpa6vofs8b -db ./tmp/tmp54nrgctw/subs.fasta -evalue 10000 -word_size 9 -max_target_seqs 100000', message 'BLAST query/options error: Word-size must be less than 8 for a tblastn, blastp or blastx search'