In [1]:
import os, pickle, tempfile
import pandas as pd

# Data

In [2]:
extra_set_path = "../../training_data/7.Extra_set"

In [3]:
with open(f"{extra_set_path}/features.pkl", "rb") as f:
    news = pickle.load(f)

len(news), news

(9,
 {'7gqu':     Residues                                                          \
           pdb label_entity_id label_asym_id label_seq_id auth_asym_id   
  0       7gqu               1             A           12            A   
  1       7gqu               1             A           13            A   
  2       7gqu               1             A           14            A   
  3       7gqu               1             A           15            A   
  4       7gqu               1             A           16            A   
  ..       ...             ...           ...          ...          ...   
  414     7gqu               1             A          426            A   
  415     7gqu               1             A          427            A   
  416     7gqu               1             A          428            A   
  417     7gqu               1             A          429            A   
  418     7gqu               1             A          430            A   
  
                       

In [4]:
with open(f"{extra_set_path}/news_sites.pkl", "rb") as f:
    news_sites = {k: v for k,v in pickle.load(f).items() if k in news}

len(news_sites), news_sites

(9,
 {'7gqu': [{'mod':      label_comp_id label_asym_id label_entity_id label_seq_id  \
    3420           X1L             D               4            .   
    
         pdbx_PDB_ins_code auth_seq_id auth_comp_id auth_asym_id  \
    3420                 ?        1002          X1L            A   
    
         pdbx_PDB_model_num pdbx_label_index pdbx_sifts_xref_db_name  \
    3420                  1             1002                       ?   
    
         pdbx_sifts_xref_db_acc pdbx_sifts_xref_db_num pdbx_sifts_xref_db_res  
    3420                      ?                      ?                      ?  ,
    'site':    label_comp_id label_asym_id label_entity_id label_seq_id pdbx_PDB_ins_code  \
    0            VAL             A               1           55                 ?   
    1            MET             A               1           56                 ?   
    2            ALA             A               1           57                 ?   
    3            THR             A     

In [5]:
assert all(len(sites) == 1 for sites in news_sites.values()), "Not all apos have a single annotated site"

In [6]:
models = pd.read_pickle("models.pkl")

models

{'model5': {'results': {'7gqu': {'pocket11': {'prob': 0.0006499344599433243,
     'pred': 0,
     'label': 0,
     'max_overlap': 0.16666666666666666,
     'pocket_in_site': 0.16666666666666666,
     'site_in_pocket': 0.06896551724137931},
    'pocket15': {'prob': 1.6904914446058683e-05,
     'pred': 0,
     'label': 0,
     'max_overlap': 0.0,
     'pocket_in_site': 0.0,
     'site_in_pocket': 0.0},
    'pocket7': {'prob': 1.0997085553299257e-07,
     'pred': 0,
     'label': 0,
     'max_overlap': 0.0,
     'pocket_in_site': 0.0,
     'site_in_pocket': 0.0},
    'pocket13': {'prob': 0.0004982067039236426,
     'pred': 0,
     'label': 0,
     'max_overlap': 0.0,
     'pocket_in_site': 0.0,
     'site_in_pocket': 0.0},
    'pocket14': {'prob': 2.0316012523835525e-05,
     'pred': 0,
     'label': 0,
     'max_overlap': 0.0,
     'pocket_in_site': 0.0,
     'site_in_pocket': 0.0},
    'pocket1': {'prob': 0.932378888130188,
     'pred': 1,
     'label': 1,
     'max_overlap': 0.89655172

In [7]:
colors = {
    "orange": "#D55E00".lower(),
    "green": "#009E73".lower(),
    "blue": "#0072B2".lower()
}

# Labelling

In [8]:
# Percentage of residues of "one" in "other"
get_overlap = lambda one, other: (
    len( one.merge(other) ) / len(one)
)

get_overlaps = lambda pdb, pocketd: {
    name: get_overlap(*one_in_other) 
        for site in news_sites[pdb] 
            for name, one_in_other in (
                ("pocket_in_site", (pocketd["residues"], site["site"])),
                ("site_in_pocket", (site["site"], pocketd["residues"])),
            )
}

def get_label(overlaps, site_in_pocket=None, pocket_in_site=None):
    assert not (site_in_pocket==None and pocket_in_site==None)
    
    if site_in_pocket is None:
        return int( overlaps["pocket_in_site"] >= pocket_in_site )
    if pocket_in_site is None:
        return int( overlaps["site_in_pocket"] >= site_in_pocket )
    return int( overlaps["site_in_pocket"] >= site_in_pocket or overlaps["pocket_in_site"] >= pocket_in_site )

In [9]:
def label_results(resultsd, site_in_pocket=0.65, pocket_in_site=None, prob_key=None):
    return pd.DataFrame((
        {
            "pdb": pdb,
            "pocket": pocket,
            **{"prob": pocketd[prob_key] for prob_key in (prob_key,) if prob_key is not None},
            "pred": pocketd["pred"],
            "label": get_label(overlaps, site_in_pocket, pocket_in_site),
            "max_overlap": max(overlaps.values()),
            **overlaps,
        }
        for pdb, pockets in resultsd.items()
        for pocket, pocketd in pockets.items()
        for overlaps in (get_overlaps(pdb, pocketd),)
    )).sort_values("max_overlap", ascending=False)

In [10]:
def label_our_results(resultsd):#, site_in_pocket=0.65, pocket_in_site=None):
    return pd.DataFrame((
        {
            "pdb": pdb,
            "pocket": pocket,
            "prob": pocketd["prob"],
            "pred": pocketd["pred"],
            "pred_top1": int( pocketd["prob"] == pdb_maxprob ),
            "label": pocketd["label"],
            "max_overlap": max(overlaps.values()),
            **overlaps,
        }
        for pdb, pockets in resultsd.items()
        for pdb_maxprob in (max(pktd["prob"] for pktd in pockets.values()),)
        for pocket, pocketd in pockets.items()
        for overlaps in ({k: pocketd[k] for k in ["pocket_in_site", "site_in_pocket"]},)
    )).sort_values("max_overlap", ascending=False)

In [11]:
for model, modeld in models.items():
    if model != "model5":
        models[model]["labelled"] = label_results(
            modeld["results"], 
            **modeld["labelling"], 
            prob_key=modeld["prob_key"]
        )
    else:
        models[model]["labelled"] = label_our_results(
            modeld["results"], 
            # **modeld["labelling"]
        )

# Pocket functions

In [12]:
import sys

sys.path.append("../../training_data")

In [13]:
from utils.utils import Cif, CifFileWriter
from utils.pocket_utils import Pocket

In [14]:
from biotite.structure.io.pdb import PDBFile

def get_pdb_atoms(f):
    atom_array = PDBFile.read(f).get_structure()
    return pd.DataFrame({
            "auth_asym_id": atom_array.chain_id,
            "auth_seq_id": atom_array.res_id,
            "auth_comp_id": atom_array.res_name,
            "auth_atom_id": atom_array.atom_name,
            "type_symbol": atom_array.element,
            "Cartn_x": atom_array.coord[0][:, 0],
            "Cartn_y": atom_array.coord[0][:, 1],
            "Cartn_z": atom_array.coord[0][:, 2],
            "pdbx_PDB_ins_code": (ic or '?' for ic in atom_array.ins_code)
        }, dtype=str)

In [15]:
def get_our_pocket(pdb, pocket, color):
    pocketn = pocket.replace('pocket', '')
    pocket_atoms = (
        Cif(pdb, f"{extra_set_path}/pockets/{pdb}/{pdb}_out/{pdb}_out.cif", name=f"{pdb}_out")
        .atoms
        .query(f"label_comp_id == 'STP' and label_seq_id == '{pocketn}'")
    )
    pocket_atoms["label_entity_id"] = '99'

    return {
        # "number": int(pocketn), 
        "atoms": pocket_atoms, 
        "representation": [{
            'entity_id': '99', "auth_asym_id": pocket_atoms.auth_asym_id.unique().item(), 'auth_residue_number': int(pocketn), 'representation': 'molecular-surface', 'representationColor': colors[color]
        },]
    }#"color": colors[color]}

models["model5"]["pocketf"] = get_our_pocket

In [16]:
def get_allositepro_pocket(pdb, pocket, color):
    resultsf = next(f for f in os.listdir(f"AllositePro/{pdb}") if f.endswith("_download"))
    # pockets are 0-indexed but residue numbers start at 1
    # also pocket0 can be residue ID 2, so all pockets auth_seq_id will be sorted and then the relevant residue id taken with the 0-index pocket number
    pockets_atoms = (
        get_pdb_atoms(f"AllositePro/{pdb}/{resultsf}/{resultsf.replace('_download', '')}.pdb")
        .query(f"auth_comp_id == 'STP'")
    )
    pocketn = sorted(pockets_atoms.auth_seq_id.unique())[ int(pocket.replace('pocket', '')) ]
    pocket_atoms = pockets_atoms.query(f"auth_seq_id == '{pocketn}'")
    
    pocket_atoms["label_entity_id"] = '99'
    return {
        # "number": int(pocketn), 
        "atoms": pocket_atoms, 
        "representation": [{
            'entity_id': '99', "auth_asym_id": pocket_atoms.auth_asym_id.unique().item(), 'auth_residue_number': int(pocketn), 'representation': 'molecular-surface', 'representationColor': colors[color]
        },]
        # "color": colors[color]
    }

models["allositepro"]["pocketf"] = get_allositepro_pocket

In [17]:
from functools import partial

In [18]:
def get_passer_pocket(pdb, pocket, color, model):
    pocketn = pocket.replace('pocket', '')
    pocket_atoms = (
        get_pdb_atoms(f"PASSer/{model}/{pdb}/{pdb}_out.pdb")
        .query(f"auth_comp_id == 'STP' and auth_seq_id == '{pocketn}'")
    )
    
    pocket_atoms["label_entity_id"] = '99'
    return {
        # "number": int(pocketn), 
        "atoms": pocket_atoms, 
        "representation": [{
            'entity_id': '99', "auth_asym_id": pocket_atoms.auth_asym_id.unique().item(), 'auth_residue_number': int(pocketn), 'representation': 'molecular-surface', 'representationColor': colors[color]
        },]
        # "color": colors[color]
    }

models["passer_ensemble"]["pocketf"] = partial(get_passer_pocket, model="ensemble")
models["passer_automl"]["pocketf"] = partial(get_passer_pocket, model="automl")
models["passer_rank"]["pocketf"] = partial(get_passer_pocket, model="rank")

In [19]:
def get_allo_pocket(pdb, pocket, color):
    resultsf = next(f for f in os.listdir(f"ALLO/{pdb}") if f.startswith(f"{pdb}pdb") and f.endswith("_desc.txt")).replace("_desc.txt", "")
    pocket_atoms = (
        get_pdb_atoms(f"ALLO/{pdb}/pockets/{resultsf}_{pocket}_res.pdb")
    )
    
    pocket_atoms["label_entity_id"] = '99'
    pocket_atoms["auth_asym_id"] = 'ZZZ'
    return {
        # "number": int(pocketn), 
        "atoms": pocket_atoms.query("auth_atom_id not in ['CA', 'C', 'O', 'N', 'CB']"), 
        "representation": [
            {
                'auth_asym_id': "ZZZ", 
                "auth_residue_number": int(res),
                'representation': 'molecular-surface',
                'representationColor': colors[color]
            }
            for res in pocket_atoms["auth_seq_id"].unique()
        ]
        # "color": colors[color]
    }

models["allo"]["pocketf"] = get_allo_pocket

# View functions

In [20]:
from ipymolstar import PDBeMolstar

In [21]:
# ass_fields_list = ["_pdbx_struct_assembly", "_pdbx_struct_assembly_gen", "_pdbx_struct_oper_list"]

def view_pockets(pdb, pockets:list):
    cif = Cif(pdb, f"{extra_set_path}/origcifs/{pdb}_updated.cif.gz")

    # minimal_elements = lambda element="label_asym_id": site["site"][element].unique().tolist() + site["mod"][element].unique().tolist()

    site = news_sites[pdb][0]
    # atoms = cif.atoms.query(f"label_asym_id in {minimal_elements('label_asym_id')}")

    # Fake entity data
    entities = pd.concat((
        pd.DataFrame(cif.cif.data["_entity"], dtype=str),#.query(f"id in {minimal_elements('label_entity_id')}"),
        pd.DataFrame([{"id": "99", "type": "branched", "pdbx_description": "pockets"}]) # Fake the pockets as carbohydrates to manage their representation
    )).fillna(".")

    

    columns = list( set.intersection( *map(set, (pocket_atoms["atoms"].columns for pocket_atoms in pockets)) ) )
    atoms = pd.concat((
        cif.atoms[columns],
        *(pocket_atoms["atoms"][columns] for pocket_atoms in pockets)
    ))

    with tempfile.NamedTemporaryFile("w+", suffix=".cif") as f:
        writer = CifFileWriter(f.name)
        writer.write({cif.entry_id.upper(): {
            "_entity": entities.to_dict(orient="list"),
            "_atom_site": atoms.to_dict(orient="list"),
            # **{k: cif.cif.data[k] for k in ass_fields_list}
        }})
        combined = Cif(pdb, filename=f.name)
        combined.cif.data # to cache it while 'f' exists
        
    v = PDBeMolstar(
        custom_data = {
                'data': combined.cif.text,
                'format': 'cif',
                'binary': False,
            },
        sequence_panel = True,
        assembly_id='', # str(ass_ids[pdb]),
        
        hide_polymer = True,
        hide_heteroatoms = True,
        hide_carbs = True,
        hide_water = True,
        hide_non_standard = True,
        
        color_data = {
            "data": [
                # Protein
                *(
                    {"auth_asym_id": asym_id, 'representation': 'cartoon', 'representationColor': '#AEAEAE'} #, 'focus': True
                    for asym_id in site["site"].auth_asym_id.unique().tolist()
                ),
                # Site
                *(
                    {'auth_asym_id': r["auth_asym_id"], 'auth_residue_number': int(r["auth_seq_id"]) if r["auth_seq_id"] != "." else "", 'representationColor': colors["green"], 'focus': True}
                    for i, r in site["site"].iterrows()
                ),
                # Modulator
                *(
                    {"entity_id": entity_id, 'representation': "spacefill", }
                    for entity_id in entities.query("type == 'non-polymer'").id.unique()
                    if entity_id in site["mod"].label_entity_id.unique()
                ),         
                # Other ligands
                *(
                    {"entity_id": entity_id, 'representation': "ball-and-stick", }
                    for entity_id in entities.query("type == 'non-polymer'").id.unique()
                    if entity_id not in site["mod"].label_entity_id.unique()
                ),                
                # Pockets
                *(rep for pocket in pockets for rep in pocket["representation"]),
                # *(
                #     {'entity_id': '99', "auth_asym_id": pocket["atoms"].auth_asym_id.unique().item(), 'auth_residue_number': pocket["number"], 'representation': 'molecular-surface', 'representationColor': pocket["color"]}
                #     for pocket in pockets
                # ),
            ],
            "nonSelectedColor": None,
            "keepColors": False,
            "keepRepresentations": False,
        },
    )
    
    return v

In [22]:
def view_top(pdb, model, top=None):
    pocketf = models[model]["pocketf"]
    prob_key = models[model]["prob_key"]
    results = models[model]["results"][pdb]
    labelled = models[model]["labelled"].query(f"pdb == '{pdb}'").sort_values("prob", ascending=False)
    pos_pockets = labelled[labelled["label"] == 1]

    if top is None:
        top = labelled["label"].sum() or 1
    top_pockets = tuple(pocket for i, pocket in tuple(labelled.iterrows())[:top])

    for pocket in top_pockets:
        print(pocket["pocket"], {k: v for k, v in results[pocket["pocket"]].items() if k != "residues"}, "label:", pocket["label"])

    # return view_pockets(
    #     pdb,
    #     tuple(
    #         pocketf(pdb, pocket["pocket"], "green" if pocket["label"] == 1 else "blue")
    #         for pocket in top_pockets
    #     )
    # )
    return view_pockets(
        pdb,
        tuple(
            [
                pocketf(pdb, pocket["pocket"], "green" if pocket["label"] == 1 else "blue")
                for pocket in top_pockets
            ] + [
                pocketf(pdb, pocket["pocket"], "orange")
                for i, pocket in pos_pockets.iterrows()
                if pocket["pocket"] not in (p["pocket"] for p in top_pockets)
            ]
        )
    )

# Viz

- Components settings: Ignore light
- Examine ligands and make selection+component or hide all
- Duplicate spacefill (modulator), add stick-and-ball and hide spacefill. Carbon color uniform: ~30,30,30
- Adjust molecular surface Probe radius and Opacity (0.4)

In [23]:
news.keys()

dict_keys(['7gqu', '7yg5', '8aq6', '8f4s', '8jp0', '8qni', '8uk6', '8v81', '9dnm'])

In [29]:
modell = ["model5", "allositepro", "passer_ensemble", "allo"]#"allositepro",

In [49]:
i = 0

In [52]:
curr_model = modell[i]; print(curr_model)
i += 1
v = view_top('9dnm', curr_model)#, top=20)
v

passer_ensemble
20 {'pred': 1, 'prob/score': 57.40799307823181} label: 0


PDBeMolstar(bg_color='#F7F7F7', color_data={'data': [{'auth_asym_id': 'A', 'representation': 'cartoon', 'repre…

In [43]:
# v = view_top('8uk6', "allo")#, top=20)
# v