In [None]:
import json
import requests
import pandas as pd
import numpy as np
import mdtraj as md
import nglview as nv
import os

In [None]:
# AlphaFind endpoint 
api='https://api.stage.alphafind-ted.dyn.cloud.e-infra.cz/search'

# TED endpoint
tedapi='https://ted.cathdb.info/api/v1' 
tedfiles=tedapi+'/files'

# AlphaFold DB endpoint
af='https://alphafold.ebi.ac.uk'
affiles=af + '/files'

In [None]:
tmp=os.environ.get('TMPDIR','/tmp')
tmp

In [None]:
# limit the AlphaFind search
domain_limit=20

In [None]:
# our sample query
query='AF-A0A7L0KP91-F1-model_v4'

In [None]:
# retrieve domains from TED database (just the residue intervals, for the time being)
def get_domains(qry):
    up = qry.split('-')[1]
    j = requests.get(f'{tedapi}/uniprot/summary/{up}').json()
    return {
        d['ted_id'].split('_')[-1] : 
        [ c.split('-') for c in d['chopping'].split('_') ] 
        for d in j['data']
    }

In [None]:
# domains for the query structure
qchops = get_domains(query)

In [None]:
qdomains = len(qchops)
qdomains

In [None]:
# query AlphaFind
# The server works asynchronously, run this cell repeatedly until all results are returned, e.g. getting list of {qdomains} numbers close to {domain_limit}, not all or mostly zeros
bag = [ 
    requests.get(f'{api}?query={query}_TED{domain:02d}&limit={domain_limit}&superposition=True').json()['results']
    for domain in range(1,qdomains+1)
]
[ len(b) for b in bag ]

In [None]:
# Have you really run the previsous cell several times?
# Getting less than {qdomains * domain_limit} is fine if you know what's going on; relax the condition then

assert sum([ len(b) for b in bag ]) ==  qdomains * domain_limit

In [None]:
# arrange results into a dict with target structures as key, listing matching domain pairs for each
tdom = {}
for i,qd in enumerate(bag):
    for td in qd:
        s = td['object_id'].split('_')
        target = '_'.join(s[:-1])
        domain=s[-1]
        if not target in tdom: tdom[target] = {}
        tdom[target][domain] = { 'query_domain' : f'TED{i+1:02d}', 
                                 'tm_score' : td['tm_score'],
                                 'rmsd' : td['rmsd'],
                                 'translation_vector' : np.array(td['translation_vector']),
                                 'rotation_matrix': np.array(td['rotation_matrix'])
                               }
    
#tdom

In [None]:
# count total target domains for TM score calculation
tchops = { t : get_domains(t) for t in tdom.keys() }
tdomains = { k : len(v) for k,v in tchops.items() }
#tdomains

In [None]:
# target side TM score (reflecting common vs. target size)
bag_tm_t = { t : 1./tdomains[t] * sum([
                1./(1.+v['rmsd']**2)
                for d,v in doms.items()
            ])
           for t,doms in tdom.items()
         }
#bag_tm

In [None]:
# query side TM score (commmon vs. query size)
bag_tm_q = { t : 1./qdomains * sum([
                1./(1.+v['rmsd']**2)
                for d,v in doms.items()
            ])
           for t,doms in tdom.items()
         }

In [None]:
# sort and pretty print results
result = pd.DataFrame({
    'target' : tdom.keys(),
    'target tm score' : [ bag_tm_t[t] for t in tdom.keys()],
    'query tm score' : [ bag_tm_q[t] for t in tdom.keys()],
    'target domains #' : [ tdomains[t] for t in tdom.keys()],
    'common domains #' : [ len(tdom[t]) for t in tdom.keys()],
    'domain pairs' : [ ','.join([ f'{v["query_domain"]}-{k}' for k,v in tdom[t].items()]) for t in tdom.keys()]
}).set_index('target').sort_values('query tm score',ascending=False)
result

In [None]:
# From the table above, pick up the target structure and the domain you want to align the query to

target = 'AF-A0A850XH80-F1-model_v4','TED06'
tdom[target[0]][target[1]]

In [None]:
# retrieve whole query structure from AlphaFold DB
qpdb = f'{tmp}/{query}.pdb'
r = requests.get(f'{affiles}/{query}.pdb')
with open(qpdb,'wb') as p:
    p.write(r.content)

In [None]:
# the chosen target as well
tpdb = f'{tmp}/{target[0]}.pdb'
r = requests.get(f'{affiles}/{target[0]}.pdb')
with open(tpdb,'wb') as p:
    p.write(r.content)

In [None]:
# alignment rotation+translation as returned by the server
rot = tdom[target[0]][target[1]]['rotation_matrix']
trans = tdom[target[0]][target[1]]['translation_vector']
rot,trans

In [None]:
# load the full query and target
qt=md.load_pdb(qpdb)
tt=md.load_pdb(tpdb)

In [None]:
# apply the alignment coordinate transformation
# XXX: mdtraj works with nm, we have A
tt.xyz = tt.xyz@rot - trans/10.

In [None]:
# Enjoy

v = nv.NGLWidget()
qc = v.add_component(qt)
tc = v.add_component(tt)
qc.clear()
qc.add_representation('ribbon',color='#ffe0e0')
qd = tdom[target[0]][target[1]]['query_domain']
c = qchops[qd]
qc.add_representation('ribbon',color='#ff0000',selection=', '.join([ '-'.join(c1) for c1 in c ]))
tc.clear()
tc.add_representation('ribbon',color='#e0ffe0')
#tc.add_representation('ribbon',color='#ffff00',selection='-'.join(tchops[target[0]][target[1]]))
c=tchops[target[0]][target[1]]
tc.add_representation('ribbon',color='#00ff00',selection=', '.join([ '-'.join(c1) for c1 in c ]))
v