# Imports

In [None]:
from glob import glob
from pathlib import Path
import shutil
from tqdm import tqdm
import os
import numpy as np
import plotly.express as px

from io import BytesIO
from zipfile import ZipFile
from subprocess import call, STDOUT
from urllib.request import urlopen
from collections import Counter


import requests
from matplotlib import cm, colors
from IPython.display import Image
import pandas as pd
from rdkit.Chem.PandasTools import AddMoleculeColumnToFrame
from Bio.PDB import PDBParser, PDBIO
import nglview as nv

from plipify.fingerprints import InteractionFingerprint, _LabeledValue
from plipify.visualization import (
    fingerprint_barplot, fingerprint_heatmap, fingerprint_table, 
    fingerprint_nglview, PymolVisualizer, nglview_color_side_chains_by_frequency,
    fingerprint_writepdb
)
from ipywidgets.embed import embed_minimal_html

from plipify import core

from html2image import Html2Image
import imgkit
import yaml

from importlib import reload

In [None]:
import pickle as pkl

## paths

In [None]:
pdb = Path("/Users/alexpayne/lilac-mount-point/asap-datasets/plipify_prepped_no_header/no_header_Mpro-x11001_0A_bound_chainA_wH.pdb")
residue_indices_path = "/Volumes/Rohirrim/local_test/plipify_result_melissa_prepped/residue_indices_367.pkl"

# Load Structures

## add custom structure object

In [None]:
reload(core)
pdb = Path("/Users/alexpayne/lilac-mount-point/asap-datasets/plipify_prepped_no_header/no_header_Mpro-x11001_0A_bound_chainA_wH.pdb")
structure = core.Structure.from_pdbfile(str(pdb), ligand_name="LIG")

In [None]:
interaction_dict = core.Structure.from_pdbfile(str(pdb), ligand_name="LIG")

In [None]:
interaction_dict

In [None]:
BACKBONE_ATOM_NAMES = ['N', 'H', 'CA', 'HA', 'C', 'O']

In [None]:
def is_backbone_interaction(interaction_dict):
    if interaction_dict.get('SIDECHAIN', False):
        return False
    else:
        return True

In [None]:
is_backbone_interaction(interaction_dict)

In [None]:
residue = structure.get_residue_by(144)

In [None]:
residue.

## load residue indices

In [None]:
with open(residue_indices_path, 'rb') as handle:
    loaded_residue_indices = pkl.load(handle)

## structure analysis

In [None]:
structure.binding_sites

In [None]:
def calculate_fingerprint_one_structure(structure, indices, interaction_types, labeled=False):
    """
    Calculate the interaction fingerprint for a single structure.

    Parameters
    ----------
    structure = structure object based on pdb file
    indices = list of dict
        each dict contains kwargs that match Structure.get_residue_by
        so it can return a Residue object. For example:
        {"seq_index": 1, "chain": "A"}
    """
    empty_counter = Counter()
    fp_length = len(indices) * len(interaction_types)
    fingerprint = []
    for index_kwargs in indices:
        residue = structure.get_residue_by(**index_kwargs)
        if residue:
            counter = residue.count_interactions()
        else:
            # FIXME: This is a bit hacky. Let's see if we can
            # come up with something more elegant.
            residue = ProteinResidue("GAP", 0, None)
            counter = empty_counter
        for interaction in interaction_types:
            if labeled:
                label = {"residue": residue, "type": interaction}
                n_interactions = _LabeledValue(counter[interaction], label=label)
            else:
                n_interactions = counter[interaction]
            fingerprint.append(n_interactions)
    assert len(fingerprint) == fp_length, "Expected length not matched"
    if not labeled:
        return np.asarray(fingerprint)
    return fingerprint

In [None]:
residue = structure.get_residue_by(144)
interaction_types = InteractionFingerprint().interaction_types
labeled = True
fingerprint = []
if residue:
    counter = residue.count_interactions()
    interaction_location = [interaction.interaction.get('SIDECHAIN', None) for interaction in residue.interactions]
    print(counter)
    print(Counter(interaction_location))
for interaction in interaction_types:
    if labeled:
        label = {"residue": residue, "type": interaction}
        n_interactions = _LabeledValue(counter[interaction], label=label)
    else:
        n_interactions = counter[interaction]
    fingerprint.append(n_interactions)

In [None]:
def count_interactions(residue):
#     interaction_types = [interaction.shorthand for interaction in residue.interactions]
    interaction_types = []
    bb_interaction_types = []
    sc_interaction_types = []
    for interaction in residue.interactions:
        interaction_types.append(interaction.shorthand)
        if interaction.interaction.get('SIDECHAIN', 'not_found') == 'not_found':
            raise NotImplementedError
        else:
            sidechain = interaction.interaction.get('SIDECHAIN', False)
        if sidechain:
            sc_interaction_types.append(interaction.shorthand)
        else:
            bb_interaction_types.append(interaction.shorthand)
    counter = Counter(zip(interaction_types, interaction_location))
    return counter

In [None]:
interaction_types = []
bb_interaction_types = []
sc_interaction_types = []
for interaction in residue.interactions:
    interaction_types.append(interaction.shorthand)
    if interaction.interaction.get('SIDECHAIN', 'not_found') == 'not_found':
        raise NotImplementedError
    else:
        sidechain = interaction.interaction.get('SIDECHAIN', False)
    if sidechain:
        sc_interaction_types.append(interaction.shorthand)
    else:
        bb_interaction_types.append(interaction.shorthand)
    

# Test different examples

In [None]:
def add_sidechain_boolean(structure, interaction_dict):
    if interaction_dict.get("SIDECHAIN", 'not_found') != "not_found":
        pass
    elif interaction_dict.get('PROT_IDX_LIST', 'not_found') != "not_found":
        for idx in interaction_dict.get('PROT_IDX_LIST').split(','):
            ## This should be True if any atoms are in the sidechain, so we can quit as soon as we find one
            if not structure._pdbcomplex.atoms[int(idx)].type in BACKBONE_ATOM_NAMES:
                interaction_dict['SIDECHAIN'] = True
                break
            else:
                interaction_dict['SIDECHAIN'] = False
    elif interaction_dict.get("PROTCARBONIDX", "not_found") != "not_found":
        idx = int(interaction_dict.get("PROTCARBONIDX"))
        interaction_dict['SIDECHAIN'] = structure._pdbcomplex.atoms[idx].type not in BACKBONE_ATOM_NAMES
    else:
        raise NotImplementedError
    return interaction_dict

## hbond-don and hbond-acc

In [None]:
reload(core)
pdb = Path("/Users/alexpayne/lilac-mount-point/asap-datasets/plipify_prepped_no_header/no_header_Mpro-x2659_0A_bound_chainA_wH.pdb")
structure = core.Structure.from_pdbfile(str(pdb), ligand_name="LIG")

In [None]:
binding_site = structure.binding_sites[0]
for int_type, interactions in binding_site.interactions.items():
    print("\n", int_type, "\n")
    for interaction in interactions:
        print(interaction.interaction)
        print(add_sidechain_boolean(structure, interaction.interaction))

In [None]:
structure._pdbcomplex.atoms[1347]

In [None]:
 binding_site.interactions

In [None]:
for atom in structure._pdbcomplex.atoms[2220].residue.atoms:
    print(atom.type, atom.idx, atom.residue.idx)

In [None]:
interaction_types

In [None]:
bb_interaction_types

In [None]:
sc_interaction_types

In [None]:
count_interactions(residue)

In [None]:
interaction.interaction.get('SIDECHAIN')

In [None]:
interaction.interaction

In [None]:
interaction = residue.interactions[0]

In [None]:
interaction.interaction

In [None]:
for interaction in residue.interactions:
    print(interaction.shorthand)

In [None]:
fingerprint

In [None]:
fp = calculate_fingerprint_one_structure(structure, loaded_residue_indices[0].values(), InteractionFingerprint().interaction_types, labeled=True)

In [None]:
fp