# Imports

In [None]:
import numpy as np
import pandas as pd
from asapdiscovery.docking import analysis as a
from importlib import reload
reload(a)
from asapdiscovery.data.schema_v2.ligand import Ligand
from asapdiscovery.data.schema_v2.complex import Complex
from asapdiscovery.data.selectors.mcs_selector import MCSSelector
from tqdm import tqdm
from pathlib import Path
import json
from itertools import combinations # only need to do the combinations once since the Tanimoto is symmetric

## Load the data

In [None]:
data_path = Path("../../data/20240202_fragalysis_p_series_schema")

In [None]:
complexes = [Complex.from_dict(json.load(open(p))) for p in data_path.glob("*.json")]

In [None]:
mols = [c.ligand for c in complexes]

In [None]:
pairs = list(combinations(mols, 2))

In [None]:
len(pairs)

# Calculate N to N Tanimoto

In [None]:
for mol1 in mols:
    tc_list = []
    for mol2 in tqdm(mols):
        tc = a.calculate_tanimoto_oe(mol1.to_oemol(), mol2.to_oemol(), "combo")
        tc_list.append(tc)
        

In [None]:
tc = [a.calculate_tanimoto_oe(mol1.to_oemol(), mol2.to_oemol(), "combo") for mol1, mol2 in tqdm(permuts)]

In [None]:
def calculate_n_to_n_tanimoto(mols: list[Ligand], compute_type: a.TanimotoType = a.TanimotoType.COMBO):
    from asapdiscovery.data.openeye import oeshape
    refmols = [mol.to_oemol() for mol in mols]
    querymols = [mol.to_oemol() for mol in mols]
    
    tc_list = []
    for refmol in refmols:
        # Prepare reference molecule for calculation
        # With default options this will remove any explicit hydrogens present
        prep = oeshape.OEOverlapPrep()
        prep.Prep(refmol)
    
        # Get appropriate function to calculate exact shape
        shapeFunc = oeshape.OEOverlapFunc()
        shapeFunc.SetupRef(refmol)
    
        for fitmol in querymols:
            res = oeshape.OEOverlapResults()
            prep.Prep(fitmol)
            shapeFunc.Overlap(fitmol, res)
            
            if compute_type == a.TanimotoType.SHAPE:
                tc_list.append(res.GetShapeTanimoto())
            elif compute_type == a.TanimotoType.COLOR:
                tc_list.append(res.GetColorTanimoto())
            elif compute_type == a.TanimotoType.COMBO:
                tc_list.append(res.GetTanimotoCombo())
    return tc_list

In [None]:
tc_list = calculate_n_to_n_tanimoto(mols)

In [None]:
matrix = np.array(tc_list).reshape(len(mols), len(mols))

In [None]:
mol_array = np.array(mols)

In [None]:
top10_dict = {}
for i, row in enumerate(matrix):
    idx = np.argsort(row)[-11:]
    top10_dict[mol_array[i]] = {'mols': mol_array[idx][-11:-1],
                                'tcs': row[idx][-11:-1]}
    

In [None]:
top10_dict[mol_array[0]]

In [None]:
import plotly.express as px

In [None]:
px.imshow(matrix)

In [None]:
px.imshow(2 - matrix)

# convert to df and save

In [None]:
df = pd.DataFrame({"Mol1": [m.compound_name for m in mols for _ in mols], 
                       "Mol2": [m.compound_name for _ in mols for m in mols], 
                       "Tanimoto": matrix.reshape(-1),
                   })

In [None]:
df.to_csv("tanimoto_combo.csv", index=False)