In [1]:
import os
import numpy as np
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import pandas as pd
from Bio import SeqIO
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import cm
from matplotlib.patches import Patch
import mpl_stylesheet
import re
import gc
mpl_stylesheet.banskt_presentation(fontfamily = 'mono', fontsize = 20, colors = 'banskt', dpi = 300)

In [2]:
# Load the data
import json
import collections

datadict = collections.defaultdict(dict)
for record in SeqIO.parse("disprot_OK_fullset.fasta", "fasta"):
    uniprot_id = record.id
    seq = str(record.seq)
    datadict[uniprot_id]['seq'] = seq

for record in SeqIO.parse("disprot_OK_fullset_annotations.fasta", "fasta"):
    uniprot_id = record.id
    seq = str(record.seq)
    datadict[uniprot_id]['disorder'] = seq

print(f"Loaded {len(datadict.keys())} proteins")

Loaded 2145 proteins


In [3]:
AA2num = dict(zip(sorted(list(set([i for i in datadict["P19793"]['seq']]))),np.arange(20)))
AA2num['</s>'] = 20
AA2num['X'] = 21
num2AA = dict(zip(np.arange(20),sorted(list(set([i for i in datadict["P19793"]['seq']])))))
num2AA[20] = "</s>"
num2AA[21] = "X"
AAsize = len(list(AA2num.keys()))

uniprots = list(datadict.keys())
delcounter = 0
for unip in uniprots:
    lossfile = f"loss/{unip}.json"
    if os.path.exists(lossfile):
        #print(unip)
        with open(lossfile) as instrm:
            lossdict = json.load(instrm)
        datadict[unip]["aamask_1"] = lossdict[unip]["aamask_1"]
    else:
        del(datadict[unip])
        delcounter += 1

print(f"Loaded {len(datadict.keys())} proteins")

Loaded 2145 proteins


In [4]:
## Obtains aa counts for order and disorder regions
order_aa_counts    = np.zeros(AAsize) 
disorder_aa_counts = np.zeros(AAsize)
for unip in datadict:
    seq = datadict[unip]['seq']
    diso_ix = [True if e != "-" else False for e in datadict[unip]['disorder'] ]
    for pos, aa in enumerate(seq):
        if aa not in AA2num:
            aa = "X"
        if diso_ix[pos]:
            disorder_aa_counts[AA2num[aa]] += 1
        else:
            order_aa_counts[AA2num[aa]] += 1

total_aa_counts = order_aa_counts + disorder_aa_counts
total_aa = np.sum(total_aa_counts)
total_ordered_aa = np.sum(order_aa_counts)
total_disordered_aa = np.sum(disorder_aa_counts)
labels = [num2AA[i] for i in range(len(order_aa_counts))]

print(total_aa)
print(total_ordered_aa)
print(total_disordered_aa)

827664.0
639673.0
187991.0


In [5]:
match_counts         = np.zeros(AAsize)     # nº of matches
mismatch_counts_aa   = np.zeros(AAsize)     # nº of total mismatches, could be more than 1 mismatch per masked aa
mismatch_in_disorder = np.zeros(AAsize)
mismatch_in_order    = np.zeros(AAsize)
match_in_disorder = np.zeros(AAsize)
match_in_order    = np.zeros(AAsize)

mismatch_mask_counts = np.zeros(AAsize)     # nº of mismatches per masked aa, one count per masked embedding
mismatch_matrix             = np.zeros((AAsize,AAsize))
mismatch_matrix_in_disorder = np.zeros((AAsize,AAsize))
mismatch_matrix_in_order    = np.zeros((AAsize,AAsize))
mismatch_matrix_in_disorder_onsite = np.zeros((AAsize,AAsize))
mismatch_matrix_in_order_onsite    = np.zeros((AAsize,AAsize))
mismatch_posdiff = list()        # distance of the masked aa to the mismatched position (can separate long and short)
mismatch_posdiff_detail = list() # uniprot, masked aa pos, mismatch pos, prot length
n_pos_mismatch = collections.defaultdict(int) # nº of times a position was mismatched (on different masked aa)
n_mismatch_pos = collections.defaultdict(int) # nº of mismatches encountered for a single masked aa
loss_match    = list()
loss_mismatch = list()
loss_mismatch_in_order = list()
loss_mismatch_in_disorder = list()
onsm_offsmm = list() #on-site match but off-site mismatch
for unip in datadict:
    L = len(datadict[unip]['seq'])
    this_seq = datadict[unip]['seq']
    diso_ix = [True if e != "-" else False for e in datadict[unip]['disorder'] ]
    loss = np.array(datadict[unip]['aamask_1']['loss'])
    matches = datadict[unip]['aamask_1']['match']
    seq_pos = np.arange(len(matches))
    #print(list(zip(seq_pos, matches)))
    for pos,e in enumerate(matches):
        if e == True:
            match_counts[AA2num[this_seq[pos]]] += 1
            loss_match.append([loss[pos], L])
            if diso_ix[pos]: # in disordered region
                match_in_disorder[AA2num[this_seq[pos]]] += 1
            else:
                match_in_order[AA2num[this_seq[pos]]] += 1
        else:
            mismatch_mask_counts += 1
            loss_mismatch.append([loss[pos], L])
            n_mismatch_pos[f"{unip}_{pos}"] += len(e)
            this_pos_found = False
            for mismatch_aa in e:
                diff = mismatch_aa[0] - pos
                mismatch_posdiff.append(diff)
                mismatch_posdiff_detail.append([unip,pos,mismatch_aa[0],L])
                mismatch_counts_aa[AA2num[mismatch_aa[2]]] += 1
                n_pos_mismatch[f"{unip}_{mismatch_aa[0]}"] += 1
                mismatch_matrix[AA2num[mismatch_aa[2]],AA2num[mismatch_aa[1]]] += 1
                if diso_ix[mismatch_aa[0]]:
                    mismatch_matrix_in_disorder[AA2num[mismatch_aa[2]],AA2num[mismatch_aa[1]]] += 1
                    mismatch_in_disorder[AA2num[mismatch_aa[2]]] += 1
                    if pos == mismatch_aa[0]:
                        mismatch_matrix_in_disorder_onsite[AA2num[mismatch_aa[2]],AA2num[mismatch_aa[1]]] += 1
                        loss_mismatch_in_disorder.append(loss[pos]* L)
                        this_pos_found = True
                else:
                    mismatch_matrix_in_order[AA2num[mismatch_aa[2]],AA2num[mismatch_aa[1]]] += 1
                    mismatch_in_order[AA2num[mismatch_aa[2]]] += 1
                    if pos == mismatch_aa[0]:
                        loss_mismatch_in_order.append(loss[pos]* L)
                        mismatch_matrix_in_order_onsite[AA2num[mismatch_aa[2]],AA2num[mismatch_aa[1]]] += 1
                        this_pos_found = True
            if not this_pos_found:
                ## on-site match but off-site mismatch!!
                match_counts[AA2num[this_seq[pos]]] += 1
                if diso_ix[pos]:
                    match_in_disorder[AA2num[this_seq[pos]]] += 1
                else:
                    match_in_order[AA2num[this_seq[pos]]] += 1
                onsm_offsmm.append([unip,pos])

In [6]:
onsm_offsmm

[['P03347', 28],
 ['P03347', 74],
 ['P03347', 92],
 ['P03347', 122],
 ['P03347', 136],
 ['P03347', 157],
 ['P03347', 213],
 ['P03347', 278],
 ['P03347', 348],
 ['P03347', 371],
 ['P03347', 373],
 ['P03347', 378],
 ['P03347', 382],
 ['P03347', 387],
 ['P03347', 399],
 ['P03347', 412],
 ['P03347', 439],
 ['P03347', 458],
 ['P03347', 469],
 ['P03347', 475],
 ['P03347', 477],
 ['P03347', 483],
 ['P03347', 497],
 ['P03347', 509],
 ['P03347', 510],
 ['P19793', 0],
 ['P19793', 6],
 ['P19793', 8],
 ['P19793', 11],
 ['P19793', 16],
 ['P19793', 18],
 ['P19793', 21],
 ['P19793', 32],
 ['P19793', 39],
 ['P19793', 42],
 ['P19793', 43],
 ['P19793', 50],
 ['P19793', 86],
 ['P19793', 102],
 ['P19793', 133],
 ['P19793', 232],
 ['P19793', 248],
 ['P19793', 257],
 ['P19793', 270],
 ['P19793', 291],
 ['P19793', 402],
 ['P19793', 435],
 ['P19793', 447],
 ['P19793', 456],
 ['P19793', 457],
 ['P19793', 458],
 ['P19793', 459],
 ['P19793', 460],
 ['P48439', 0],
 ['P48439', 3],
 ['P48439', 5],
 ['P48439', 8],
 

In [7]:
onsm_offsmm_disordered_match = np.zeros(AAsize)
onsm_offsmm_ordered_match = np.zeros(AAsize)

for e in onsm_offsmm:
    unip = e[0]
    pos  = e[1]
    L = len(datadict[unip]['seq'])
    this_seq = datadict[unip]['seq']
    is_disordered = datadict[unip]['disorder'][pos] != "-"
    if is_disordered:
        onsm_offsmm_disordered_match[AA2num[this_seq[pos]]] += 1
    else:
        onsm_offsmm_ordered_match[AA2num[this_seq[pos]]] += 1

In [13]:
# print(f"Matches in order: {np.sum(match_in_order)/np.sum(match_counts)}")
# print(f"Matches in disorder: {np.sum(match_in_disorder)/np.sum(match_counts)}")
## Above is wrong since it doesn't take into account the imbalance of order/disorder positions in the dataset
print(f"Matches in order: {np.sum(match_in_order)/total_ordered_aa}")
print(f"Matches in disorder: {np.sum(match_in_disorder)/total_disordered_aa}")
print(f"Ordered positions: {total_ordered_aa/total_aa}, Disordered positions: {total_disordered_aa/total_aa}")
print(f"on-site match in order but off-site mismatch : {np.sum(onsm_offsmm_ordered_match)/np.sum(match_counts)}")
print(f"on-site match in disorder but off-site mismatch: {np.sum(onsm_offsmm_disordered_match)/np.sum(match_counts)}")

Matches in order: 0.6733143340425498
Matches in disorder: 0.5367278220765888
Ordered positions: 0.7728655589707901, Disordered positions: 0.22713444102920993
on-site match in order but off-site mismatch : 0.19537585519967043
on-site match in disorder but off-site mismatch: 0.058184615905538174


In [14]:
assert (np.sum(match_in_order) + np.sum(match_in_disorder) ) == np.sum(match_counts)

In [15]:
print(np.sum(match_counts))
print(np.sum(match_in_order))
print(np.sum(match_in_disorder))
print(len(onsm_offsmm))

531601.0
430701.0
100900.0
134793


In [None]:
## Paint matches and mismatches in the residue-level UMAP