# Imports

In [None]:
from glob import glob
from pathlib import Path
import shutil
from tqdm import tqdm
import os
import numpy as np
import plotly.express as px

from io import BytesIO
from zipfile import ZipFile
from subprocess import call, STDOUT
from urllib.request import urlopen


import requests
from matplotlib import cm, colors
from IPython.display import Image
import pandas as pd
from rdkit.Chem.PandasTools import AddMoleculeColumnToFrame
from Bio.PDB import PDBParser, PDBIO
import nglview as nv

from plipify.fingerprints import InteractionFingerprint
from plipify.visualization import (
    fingerprint_barplot, fingerprint_heatmap, fingerprint_table, 
    fingerprint_nglview, PymolVisualizer, nglview_color_side_chains_by_frequency,
    fingerprint_writepdb
)
from ipywidgets.embed import embed_minimal_html

from plipify.core import Structure

from html2image import Html2Image
import imgkit
import yaml

In [None]:
import pickle as pkl

## paths

In [None]:
pdbs = list(Path("/Users/alexpayne/lilac-mount-point/asap-datasets/plipify_prepped_no_header").glob("no_header*.pdb"))

In [None]:
output = Path("/Volumes/Rohirrim/local_test/plipify_result_melissa_prepped")
if not os.path.exists(output.resolve()):
    os.makedirs(output.resolve())

# Load structures

In [None]:
structures = []
for path in tqdm(pdbs, total=len(pdbs)):
    structure = Structure.from_pdbfile(str(path), ligand_name="LIG")

    ## Skip structures with multiple binding sites
    if len(structure.binding_sites) != 1:
        print(
            f"{path.name} contains {len(structure.binding_sites)} binding sites and we want exactly one.")
        continue
    structures.append(structure)
print(f"Loaded {len(structures)} structures.")

## save structures to pickle file

In [None]:
len(structures)

In [None]:
pkl

In [None]:
with open(output / "structures.pkl", 'wb') as handle:
    pkl.dump(structures, handle)

# Filter the structures

## get length and gap info

In [None]:
lengths = pd.DataFrame([((s.identifier), len(s.sequence()), s.sequence()) for s in structures],
                           columns=["identifier", "length", "sequence"])
lengths["gapcount"] = lengths.sequence.str.count('-')
print('Sequence length median and std: ', lengths.length.median(), lengths.length.std())

## remove based on length

In [None]:
lengths_filtered = lengths[(lengths.length - lengths.length.median()).abs() <= lengths.length.std()]

In [None]:
print(f"{len(lengths) - len(lengths_filtered)} sequences removed due to length")

### plot lengths

In [None]:
px.histogram(lengths.length)

In [None]:
px.histogram(lengths_filtered.length)

## filter sequences with gaps

In [None]:
lengths_no_gaps = lengths_filtered[lengths_filtered.gapcount == 0]

In [None]:
print(f"{len(lengths_no_gaps) - len(lengths_filtered)} additional sequences filtered due to gaps")

## filter the structure list

In [None]:
filtered_structures = [s for s in structures if s.identifier in set(lengths_no_gaps.identifier.tolist())]
print(len(pdbs), "->", len(structures), "->", len(filtered_structures), "=", len(pdbs) - len(filtered_structures),
      "structures filtered out")

In [None]:
structure_name_type_dict = {'Mpro-P': 0, 'Mpro-x': 0, 'Mpro-z': 0, 'other': 0}
print(structure_name_type_dict.keys())
for s in filtered_structures:
    name_code = s.identifier[0:6]
    identified = False
    for name_code in structure_name_type_dict.keys():
        if name_code in s.identifier:
            structure_name_type_dict[name_code] += 1
            identified = True
    if identified == False:
        structure_name_type_dict['other'] += 1
print(structure_name_type_dict)

# Calculate residue mapping

In [None]:
residue_indices = InteractionFingerprint.calculate_indices_mapping(filtered_structures)

## Save residue mapping

In [None]:
with open(output / "residue_indices_367.pkl", 'wb') as handle:
    pkl.dump(residue_indices, handle)

## load residue mapping

In [None]:
with open(output / "residue_indices_367.pkl", 'rb') as handle:
    loaded_residue_indices = pkl.load(handle)

## check that residue mapping is correct

In [None]:
if len(filtered_structures) != len(loaded_residue_indices):
    raise ValueError(
        f"Number of residue indices mappings ({len(residue_indices)}) "
        f"does not match number of structures ({len(structures)})"
    )

# Calculate Fingerprints

## calculate individual fingerprints

In [None]:
fingerprints = InteractionFingerprint().calculate_fingerprint(
        filtered_structures,
    residue_indices=loaded_residue_indices,
        labeled=True,
        as_dataframe=False,
        remove_non_interacting_residues=True,
        remove_empty_interaction_types=True,
        ensure_same_sequence=False
    )

## calculate dataframe

In [None]:
fp = InteractionFingerprint().calculate_fingerprint(
        filtered_structures,
    residue_indices=loaded_residue_indices,
        labeled=True,
        as_dataframe=True,
        remove_non_interacting_residues=True,
        remove_empty_interaction_types=True,
        ensure_same_sequence=False
    )

## save results

In [None]:
with open(output / "fingerprints.pkl", 'wb') as handle:
    pkl.dump(fingerprints, handle)

In [None]:
fp.to_csv(output / "plipify_results.csv")

# Get Specific Fingerprint Examples

## get interaction types

In [None]:
int_types = np.array(InteractionFingerprint().interaction_types)

In [None]:
int_types

In [None]:
np.where(int_types == 'pication')[0][0]

## zip fingerprints to get by position

In [None]:
positions = list(zip(*fingerprints))

In [None]:
np.shape(positions)

## write function to select specific structures:

In [None]:
def get_structures_from_interaction(resn, interaction_type, positions, int_types, only_return_interactions=True):
#     np.where(int_types == 'pication')[0][0]
    assert resn > 0, "resn is the residue number, starting with 1"
    ## this is kinda wacky but it works
    idx = (resn-1)*(10)+np.where(int_types == interaction_type)[0][0]
    structures = positions[idx]
    if only_return_interactions:
        filtered_structures = [structure for structure in structures if structure.value > 0]
        structures = filtered_structures
    return structures

## C145 hbond-don

In [None]:
structures_of_interest = get_structures_from_interaction(145, 'hbond-don', positions, int_types)

In [None]:
len(structures_of_interest)

In [None]:
fns = [s.label['residue'].structure.identifier for s in structures_of_interest]

### save list with yaml

In [None]:
with open(output/'C145_structure_examples.yml', 'w') as file:
    yaml.dump(fns, file)

## Get dictionary of examples

In [None]:
fp.to_dict()

In [None]:
examples = [(resid, int_type) for int_type, data in fp.to_dict().items() 
            for resid, value in data.items()
            if value > 0]

In [None]:
len(examples)

## iterate through examples

In [None]:
for resid, int_type in tqdm(examples):
    structures_of_interest = get_structures_from_interaction(resid, int_type, positions, int_types)
    fns = [s.label['residue'].structure.identifier for s in structures_of_interest]
    print(resid, int_type, len(fns))
    with open(output/f'{resid}_{int_type}_structure_examples.yml', 'w') as file:
        yaml.dump(fns, file)