In [1]:
import os
import numpy as np
import pandas as pd
import umap
import time
from Bio import SeqIO
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.patches import Patch
import mpl_stylesheet
mpl_stylesheet.banskt_presentation(fontfamily = 'mono', fontsize = 20, colors = 'banskt', dpi = 300)

target_uniprots = ["P37840", "P04637", "P02686", "P07305", "O00488", "Q9NYB9", "P06401", "Q16186", "S6B291", "P23441"]


In [1]:
# import bio_embeddings
from bio_embeddings.embed import ProtTransT5XLU50Embedder, ESM1bEmbedder # ProtTransBertBFDEmbedder #, 
from Bio import SeqIO
import numpy as np

halft5_dir = "./models/half_prottrans_t5_xl_u50/"

embedding_data = {}
embedding_data['halft5'] = { 'dir': halft5_dir }
sel_embedding = 'halft5'

  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


In [2]:
embedder = ProtTransT5XLU50Embedder(model_directory=embedding_data[sel_embedding]['dir'], half_model=True)

You are using the legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This means that tokens that come after special tokens will not be properly handled. We recommend you to read the related pull request available at https://github.com/huggingface/transformers/pull/24565


In [14]:
### Load sequences and annotations for disprot

def get_sequences(fastadir=None, fastafile=None):
    sequences = []
    if fastadir is None and fastafile is None:
        print("No fasta dir or file")
        raise
    if fastadir is not None and fastafile is not None:
        print("Choose one, fasta dir or multi fasta filr")
        raise
    # check for a directory with individual fasta files
    # or a multi fasta file
    if fastadir is not None:
        fastafiles = os.listdir(fastadir)
        for f in fastafiles:
            counter = 0
            for record in SeqIO.parse(os.path.join(fastadir, f), "fasta"):
                sequences.append(record)
                counter += 1
                if counter > 1:
                    print("More than one fasta record?", f)
                    raise
    elif fastafile is not None:
        for record in SeqIO.parse(fastafile, "fasta"):
            sequences.append(record)
    return sequences

# Disprot
# embeddir="/data/franco/disorder_flavours/testset/" #"/data/franco/datasets/prot_embedding_weights/disprot/halft5/"
fasta_dir = "./disprot/fasta/" #"/data/franco/datasets/disprot/fasta/"
fasta_files = [f"{u}.fasta" for u in target_uniprots]
counter = 0

# find annotation for each sequence
annotfile = "./disprot/DisProt_release_2022_06_reformat_annot.fasta" #/data/franco/datasets/disprot/
annotdir  = None
annots    = get_sequences(fastadir=annotdir, fastafile=annotfile)
sequences = get_sequences(fastadir=fasta_dir, fastafile=None)

In [41]:
### Load more detailed annotation for disprot

import json
disprot_json_file = "./disprot/DisProt_release_2022_06_with_ambiguous_evidences.json"
with open(disprot_json_file) as infmt:
    json_dict = json.load(infmt)

json_dict['data'][1]['regions'][0]
print(json_dict['data'][1]["acc"])

P49913


In [49]:
### Subset target test proteins
disprot_datadict = dict()
for i in range(len(json_dict['data'])):
    if json_dict['data'][i]["acc"] in target_uniprots:
        disprot_datadict[json_dict['data'][i]["acc"]] = json_dict['data'][i]

In [20]:
annot_dict = dict()
for record in annots:
    if "|" in record.name:
        name = record.name.split("|")[1].strip()
    else:
        name = record.name.split()[0].strip()
        if name == "":
            print("Name is empty",record.name)
    annot_dict[name] = str(record.seq)

In [31]:
### Make sure seq and annotations length match, and embed test proteins

msequences  = list()
mannots     = list()
uniprots    = list()
embeddings  = list()
counter = 0
for s in sequences:
    if "|" in s.name:
        uniprot_id = s.name.split("|")[1].strip()
    else:
        uniprot_id = s.name.split()[0].strip()
    if uniprot_id in target_uniprots:
        print(uniprot_id)
        aa_sequence = str(s.seq).upper()
        ## get the embedding
        emb = embedder.embed(aa_sequence)
        if len(aa_sequence) == len(annot_dict[uniprot_id]):
            embeddings.append(emb)
            uniprots.append(uniprot_id)
            msequences.append(aa_sequence)
            mannots.append(annot_dict[uniprot_id])
            counter += 1
        else:
            print("Embedding length and annot do not match")
            print(len(annot_dict[uniprot_id]), uniprot_id, len(aa_sequence))
print(f"Loaded {counter} proteins")

S6B291
P04637
P23441
Q16186
P06401
O00488
P02686
P37840
Q9NYB9
P07305
Loaded 10 proteins


In [52]:
for e in embeddings:
    print(e.shape)

(466, 1024)
(393, 1024)
(372, 1024)
(407, 1024)
(933, 1024)
(134, 1024)
(304, 1024)
(140, 1024)
(513, 1024)
(194, 1024)


In [36]:
### Get disorder content
def get_contents(annotations):
    disorder_contents = list()
    for da in annotations:
        contents = [ x != "-" for x in da]
        DC = np.sum(contents) / len(contents)
        disorder_contents.append(DC)
    return disorder_contents

dc = get_contents(mannots)

In [37]:
dc

[0.48497854077253216,
 0.3816793893129771,
 0.41935483870967744,
 0.48402948402948404,
 0.6334405144694534,
 0.8656716417910447,
 0.5625,
 1.0,
 0.6978557504873294,
 0.5773195876288659]

In [77]:
### Load monomer test proteins

# Monomers
fasta_file = "./disprot/monomers.fasta"
annotfile = "./disprot/monomers_annot.fasta"
annotdir  = None
mono_rannots    = get_sequences(fastadir=None, fastafile=annotfile)
mono_rsequences = get_sequences(fastadir=None, fastafile=fasta_file)

In [84]:
mono_seqs = list()
mono_annots = list()
mono_embeds = list()
mono_ids = list()

for i in range(len(mono_rsequences)):
    aa_seq = str(mono_rsequences[i].seq).upper()
    annot_seq = str(mono_rannots[i].seq)
    if len(aa_seq) == len(annot_seq):
        mono_ids.append(mono_rsequences[i].id)
        mono_seqs.append(aa_seq)
        mono_annots.append(annot_seq)
        emb = embedder.embed(aa_seq)
        mono_embeds.append(emb)
        if i>9:
            break

In [85]:
mono_ids

['1AE9A',
 '1AH7A',
 '1AHOA',
 '1AOCA',
 '1AOLA',
 '1AQZA',
 '1ATGA',
 '1ATZA',
 '1AYOB',
 '1AZOA',
 '1B9WA']

In [82]:
for e in mono_embeds:
    print(e.shape)

(179, 1024)
(245, 1024)
(64, 1024)
(175, 1024)
(228, 1024)
(149, 1024)
(231, 1024)
(189, 1024)
(130, 1024)
(232, 1024)
(95, 1024)


In [108]:
## Get PDBs and calculate Secondary structure for each

import requests
## Info about DSSP output
# https://pdb-redo.eu/dssp/about


for pdbidchain in mono_ids:
    pdbid = pdbidchain[:4]
    chain = pdbidchain[4]
    print(pdbid, chain)
    pdbfile = os.path.join("disprot","pdbs",pdbid+".pdb")
    res = requests.get(f"https://files.rcsb.org/download/{pdbid}.pdb")
    with open(pdbfile, 'w') as outfmt:
        outfmt.write(res.content.decode())

    dsspfile = os.path.join("disprot","pdbs",pdbid+".dssp")
    os.system(f"mkdssp --output-format dssp {pdbfile} {dsspfile}")

1AE9 A
1AH7 A
1AHO A
1AOC A
1AOL A
1AQZ A
1ATG A
1ATZ A
1AYO B
1AZO A
1B9W A


In [360]:
def offset_loop(this_annot, counter, offset):
    dlen = 1
    while this_annot[counter+offset+1] == "D":
        offset += 1
        dlen += 1
    ss_list = ["-" for r in range(dlen)]
    return offset, ss_list

def parseDSSP(lines, targ_chain, this_seq, this_annot, debug=False):
    flagstart = False
    ss_seq = list()
    aa_seq = list()
    offset = 0
    targ_chain = "A"
    counter = 0
    for line in lines:
        if flagstart:
            if line == "":
                break
            resnum = line[:5]
            pdbresnum = line[5:10]
            chain = line[11]
            resname = line[13]
            ss = line[16]
            end = line[14]
            if debug:
                print(f"--> counter:{counter}, offset:{offset}",resnum, pdbresnum, chain, resname, ss, "||", line[1:17])
            if end == "*":
                break
            if resname != "!":
                if chain == targ_chain:
                    if resname == this_seq[counter+offset]:
                        if ss == " ":
                            ss = "-"
                        ss_seq.append(ss)
                        aa_seq.append(resname)
                        if debug:
                            print(f"---> {resname} === {this_seq[counter+offset]}")
                    else:
                        if this_annot[counter+offset] == "D":
                            if debug:
                                print(f" {resname} != {this_seq[counter+offset]} -->OFFSET LOOP: counter:{counter}, offset:{offset}")
                            #ss_seq.append("-")
                            if this_annot[counter+offset+1] == "D" and len(this_annot) >= (counter+offset+1):
                                offset, ss_list = offset_loop(this_annot, counter, offset)
                                if debug:
                                    print(f"END OFFSET LOOP: counter:{counter}, offset:{offset}")
                                for s in ss_list:
                                    ss_seq.append(s)
                                    aa_seq.append("-")
                                if counter == 0:
                                    offset += 1
                                    if resname == this_seq[counter+offset]:
                                        if ss == " ":
                                            ss = "-"
                                        ss_seq.append(ss)
                                        aa_seq.append(resname)
                        else:
                            ss = "-"
                            ss_seq.append(ss)
                            aa_seq.append(resname)
                            if debug:
                                print(f"MISMATCH?? {resname} ?? {this_seq[counter+offset]}")
                else:
                    print("PANIC")
                    raise
            else:
                if debug:
                    print(f"resname == '!' --> counter:{counter}, offset:{offset} '-' {this_seq[counter+offset]}")
                #ss_seq.append("-")
                if this_annot[counter+offset+1] == "D" and len(this_annot) >= (counter+offset+1):
                    offset, ss_list = offset_loop(this_annot, counter, offset)
                    for s in ss_list:
                        ss_seq.append(s)
                        aa_seq.append("-")
                else:
                    print("PANIC")
                    raise           
            counter += 1
        if line.strip().startswith("#"):
            flagstart = True
            
    diff = len(ss_seq) - len(this_seq)
    if diff > 0:
        print("DSSP is longer than seq?")
    if diff < 0:
        print(f"DSSP is missing some residues? missing:{diff}")
        for i in range(np.abs(diff)):
            ss_seq.append("-")
    return ss_seq, aa_seq

In [363]:
## Parse secondary structure
# make sure sequences matches
mono_dssps = list()
for i, pdbidchain in enumerate(mono_ids):
    pdbid = pdbidchain[:4]
    dsspfile = os.path.join("disprot","pdbs",pdbid+".dssp")
    with open(dsspfile) as infmt:
        lines = infmt.readlines()
    print(pdbid, pdbidchain[4])
    mono_ss_seq, _aa_seq = parseDSSP(lines, pdbidchain[4], mono_seqs[i], mono_annots[i])
    mono_dssps.append("".join(mono_ss_seq))

1AE9 A
1AH7 A
1AHO A
1AOC A
1AOL A
DSSP is missing some residues? missing:-1
1AQZ A
1ATG A
1ATZ A
DSSP is missing some residues? missing:-2
1AYO B
1AZO A
1B9W A
DSSP is missing some residues? missing:-4


In [365]:
for i in range(11):
    #print(mono_seqs[i])
    print(f"### {mono_ids[i]}")
    #print(mono_annots[i])
    #print(mono_dssps[i])
    #print(mono_seqs[i])
    print("seqlen:", len(mono_seqs[i]),"dssp_len:", len(mono_dssps[i]))

### 1AE9A
seqlen: 179 dssp_len: 179
### 1AH7A
seqlen: 245 dssp_len: 245
### 1AHOA
seqlen: 64 dssp_len: 64
### 1AOCA
seqlen: 175 dssp_len: 175
### 1AOLA
seqlen: 228 dssp_len: 228
### 1AQZA
seqlen: 149 dssp_len: 149
### 1ATGA
seqlen: 231 dssp_len: 231
### 1ATZA
seqlen: 189 dssp_len: 189
### 1AYOB
seqlen: 130 dssp_len: 130
### 1AZOA
seqlen: 232 dssp_len: 232
### 1B9WA
seqlen: 95 dssp_len: 95


In [367]:
i = 9
pdbidchain = mono_ids[i]
pdbid = pdbidchain[:4]
print(pdbidchain)
dsspfile = os.path.join("disprot","pdbs",pdbid+".dssp")
with open(dsspfile) as infmt:
    lines = infmt.readlines()
ss_seq, _aa_seq = parseDSSP(lines, pdbidchain[4], mono_seqs[i], mono_annots[i], debug=True)

#>1AOLA 228
#QVYNITWEVTNGDRETVWAISGNHPLWTWWPVLTPDLCMLALSGPPHWGLEYQAPYSSPPGPPCCSGSSGSSAGCSRDCDEPLTSLTPRCNTAWNRLKLDQVTHKSSEGFYVCPGSHRPREAKSCGGPDSFYCASWGCETTGRVYWKPSSSWDYITVDNNLTTSQAVQVCKDNKWCNPLAIQFTNAGKQVTSWTTGHYWGLRLYVSGRDPGLTFGIRLRYQNLGPRVP
print(f"### {mono_ids[i]}")
print(mono_annots[i])
print("".join(mono_dssps[i]))
print("seqlen:", len(mono_seqs[i]),"dssp_len:", len("".join(mono_dssps[i])))

1AZOA
--> counter:0, offset:0     1     4 A P   ||    1    4 A P   
 P != G -->OFFSET LOOP: counter:0, offset:0
END OFFSET LOOP: counter:0, offset:5
--> counter:1, offset:6     2     5 A R   ||    2    5 A R   
---> R === R
--> counter:2, offset:6     3     6 A P   ||    3    6 A P   
---> P === P
--> counter:3, offset:6     4     7 A L   ||    4    7 A L   
---> L === L
--> counter:4, offset:6     5     8 A L S ||    5    8 A L  S
---> L === L
--> counter:5, offset:6     6     9 A S S ||    6    9 A S  S
---> S === S
--> counter:6, offset:6     7    10 A P   ||    7   10 A P   
---> P === P
--> counter:7, offset:6     8    11 A P   ||    8   11 A P   
---> P === P
--> counter:8, offset:6     9    12 A E S ||    9   12 A E  S
---> E === E
--> counter:9, offset:6    10    13 A T S ||   10   13 A T  S
---> T === T
--> counter:10, offset:6    11    14 A E H ||   11   14 A E  H
---> E === E
--> counter:11, offset:6    12    15 A E H ||   12   15 A E  H
---> E === E
--> counter:12, offset:6

In [352]:
import json
ssannots_list = list()
for i in range(len(mono_ids)):
    annots = list()
    pdbidchain = mono_ids[i]
    pdbid = pdbidchain[:4]
    chain = pdbidchain[4]
    print(pdbid, chain)
    res = requests.get(f"https://data.rcsb.org/rest/v1/core/polymer_entity_instance/{pdbid.upper()}/{chain.upper()}")
    datadict = json.loads(res.content.decode())
    
    for e in datadict["rcsb_polymer_instance_feature"]:
        if 'provenance_source' in e:
            if e['provenance_source'] == "PROMOTIF":
                ss_type  = e['type']
                for ee in e['feature_positions']:
                    ss_begin = ee["beg_seq_id"]
                    ss_end   = ee["end_seq_id"]
                    print(ss_type, ss_begin, ss_end)
                    annots.append([ss_type, ss_begin, ss_end])
            if e['provenance_source'] == "PDB":
                if e["type"] == "UNOBSERVED_RESIDUE_XYZ":
                    for ee in e['feature_positions']:
                        miss_begin = ee["beg_seq_id"]
                        miss_end   = ee["end_seq_id"]
                        print("MISSING", miss_begin, miss_end)
                        annots.append(["MISSING", miss_begin, miss_end])
    ssannots_list.append(annots)

1AE9 A
SHEET 48 49
SHEET 52 56
SHEET 63 67
SHEET 173 176
SHEET 70 72
SHEET 77 79
SHEET 167 168
SHEET 173 176
HELIX_P 6 18
HELIX_P 22 33
HELIX_P 37 42
HELIX_P 79 89
HELIX_P 106 120
HELIX_P 133 145
HELIX_P 148 155
UNASSIGNED_SEC_STRUCT 1 5
UNASSIGNED_SEC_STRUCT 19 21
UNASSIGNED_SEC_STRUCT 34 36
UNASSIGNED_SEC_STRUCT 43 47
UNASSIGNED_SEC_STRUCT 50 51
UNASSIGNED_SEC_STRUCT 57 62
UNASSIGNED_SEC_STRUCT 68 69
UNASSIGNED_SEC_STRUCT 73 76
UNASSIGNED_SEC_STRUCT 90 105
UNASSIGNED_SEC_STRUCT 121 132
UNASSIGNED_SEC_STRUCT 146 147
UNASSIGNED_SEC_STRUCT 156 166
UNASSIGNED_SEC_STRUCT 169 172
UNASSIGNED_SEC_STRUCT 177 179
MISSING 159 165
1AH7 A
HELIX_P 14 27
HELIX_P 35 52
HELIX_P 86 103
HELIX_P 106 122
HELIX_P 126 130
HELIX_P 141 148
HELIX_P 173 186
HELIX_P 193 204
HELIX_P 206 242
UNASSIGNED_SEC_STRUCT 1 13
UNASSIGNED_SEC_STRUCT 28 34
UNASSIGNED_SEC_STRUCT 53 85
UNASSIGNED_SEC_STRUCT 104 105
UNASSIGNED_SEC_STRUCT 123 125
UNASSIGNED_SEC_STRUCT 131 140
UNASSIGNED_SEC_STRUCT 149 172
UNASSIGNED_SEC_STRUCT 