This notebook is designed to format the blast results into a format for benchmarking alignment comparisons in `pairwise-domain-benchmark.ipynb`

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce
from Bio import SeqIO
%matplotlib inline

In [None]:
# Careful the file paths in this notebook may have to change according to the
# local file paths 

data_dir = '../data/domains'
dom_file = f'{data_dir}/swissprot-pfam-domains.csv'
results_dir = '../results/alignments'
seqs = SeqIO.parse('../data/combined.fasta', format='fasta')
seqdict = {s.id: str(s.seq) for s in seqs}
pairs = pd.read_table(f'{data_dir}/domain_pairs.txt', header=None, sep='\s+')
df = pd.read_csv(dom_file, header=None, skiprows=1)
df.columns = ['protein', 'domain', 'source',
              'domain_id', 'start', 'end']
df['length'] = df.apply(lambda x: x['end'] - x['start'], axis=1)
domdict = dict(list(df.groupby('protein')))


# Uncomment below to read in a different file
#blast = f'{results_dir}/blast_domain_alignments.txt'
blast = f'{results_dir}/blast/swissprot_alignments.txt'

blast_df = pd.read_table(blast, header=None)

In [None]:
blast_df.columns = [
    'cur.id', 'hit.id', 'i', 
    'qs', 'qe', 'he', 'hs', 
    'query_s', 'hit_s', 'aln_s', 
    'bitscore', 'evalue'
]

In [None]:
blast_df = blast_df.dropna()

In [None]:
# Helper functions to parse blast output
def select_f(x):
    if x[2] != ' ':
        return x
    else:
        return None
    
def parse_hit(q_start: int, q_end:int, h_start:int, h_end:int,
              qseq: str, hseq: str, mseq: str):
    q_coords = np.cumsum(np.array(list(qseq)) != '-')
    h_coords = np.cumsum(np.array(list(hseq)) != '-')    
    agg = list(zip(list(qseq), list(hseq), list(mseq), 
                   list(q_coords), list(h_coords)))    
    matches = list(map(select_f, agg))
    matches = list(filter(lambda x: x is not None, matches))
    edges = list(map(lambda x: (x[3], x[4]), matches))
    edges = pd.DataFrame(edges, columns=['source', 'target'])
    return edges

def interval_f(x, y):
    intv = np.arange(x, y)
    return set(intv)


def blast_hits(name, group):
    prot_x, prot_y, i = name    
    #res = ground_truthing(prot_x, prot_y)
    #pfamx, pfamy, total_x, total_y = res
    
    xy = list(g.apply(
        lambda x: parse_hit(x['qs'], x['qe'], x['hs'], x['he'], 
                            x['query_s'], x['hit_s'], x['aln_s']), 
        axis=1
    ).values)

    xy = pd.concat(xy, axis=0)
    return prot_x, prot_y, xy


def domain_table(seq, dom):
    """
    Parameters
    ----------
    seq : str
        Sequence
    dom : pd.DataFrame
        Domain table

    Returns
    -------
    pd.DataFrame
        Per residue results, specifying if a residue
        belongs to a specific domain. Columns
        correspond to domains.
    """
    pos = np.arange(len(str(seq)))
    dpos = []
    for d in dom.domain.values:
        row = dom.loc[dom.domain == d]
        s, e = row['start'].values[0], row['end'].values[0]
        dpos.append(list(map(lambda x: is_interval(s, e, x), pos)))
    dpos = pd.DataFrame(dpos, index=dom.domain.values)

    # drop duplicates
    dpos = dpos.loc[~dpos.index.duplicated(keep='first')]
    return dpos.T

def domain_score(edges, seq1, dom1, seq2, dom2):
    """
    Parameters
    ----------
    seq : str
        Sequence 1
    dom : pd.DataFrame
        Domain table 1
    seq : str
        Sequence 2
    dom : pd.DataFrame
        Domain table 2

    Returns
    -------
    res : pd.DataFrame
        True positive and false positive results
    """
    df1 = domain_table(seq1, dom1)
    df2 = domain_table(seq2, dom2)
    res1 = pd.merge(edges, df1, left_on='source', right_index=True)
    res2 = pd.merge(edges, df2, left_on='target', right_index=True)
    resdf = pd.merge(res1, res2, left_on=['source', 'target'], right_on=['source', 'target'])
    cols = list(set(dom1.domain.values) & set(dom2.domain.values))

    tps, fps = [], []
    for col in cols:
        colx = col + '_x'
        coly = col + '_y'
        tp = np.sum(np.logical_and(resdf[colx].values, resdf[coly].values))
        fp = np.sum(np.logical_xor(resdf[colx].values, resdf[coly].values))
        tps.append(tp)
        fps.append(fp)

    res = pd.DataFrame({'tp': tps, 'fp': fps}, index=cols)
    return res

def is_interval(start, end, x):
    if x > start and x < end:
        return True
    return False



Group by ids

In [None]:
blast_groups = blast_df.groupby(['cur.id','hit.id','i'])

In [None]:
res = list(map(lambda x: blast_hits(*x), blast_groups))

In [None]:
def score_group(group):
    prot_x, prot_y, edges = group
    dom_x = domdict[prot_x]
    dom_y = domdict[prot_y]
    sx = seqdict[prot_x]
    sy = seqdict[prot_y]
    res = domain_score(edges, sx, dom_x, sy, dom_y)
    tp = res.tp.sum()
    fp = res.fp.sum()
    return prot_x, prot_y, tp, fp

In [None]:
stats = list(map(score_group, res))

In [None]:
blast_stats

In [None]:
blast_stats = pd.DataFrame(stats)

In [None]:
blast_stats.to_csv('../results/alignments/blast/blast_alignment_scores.txt', sep='\t', header=None, index=None)

In [None]:
def max_score(prot_x, prot_y):
    dfx = domdict[prot_x].groupby('domain').max()
    dfy = domdict[prot_y].groupby('domain').max()   
    common_doms = set(dfx.index) & set(dfy.index)
    dfx = dfx.loc[common_doms]
    dfy = dfy.loc[common_doms]
    l = np.minimum(dfx['length'], dfy['length'])
    return np.sum(l)

In [None]:
blast_stats.apply(lambda x: max_score(x[0], x[1]), axis=1)

In [None]:
r= res[0]

In [None]:
r

In [None]:
qseq.values[0]

In [None]:
int(x['qs'])

In [None]:
blast_groups

In [None]:
r, g = list(blast_groups)[2]

In [None]:
g

In [None]:
blast_df.i.value_counts()

In [None]:
x['query_s'].values[0]

In [None]:
x['hit_s'].values[0]

In [None]:
x['aln_s'].values[0]