In [5]:
#import dependancies 
import pickle
import pandas as pd
from rdkit import Chem
import torch
from os import path
import matplotlib.pyplot as plt

#import pkasolver
import pkasolver
from pkasolver.query import QueryModel
from pkasolver.ml_architecture import GINPairV1
from pkasolver.query import draw_pka_map 
from pkasolver.query import calculate_microstate_pka_values, draw_pka_reactions
from IPython.display import display
from IPython.display import HTML

#load trained model
base_path=path.dirname(pkasolver.__file__)

#import MolGpKa
#from .TestInstall.MolGpKa.src.predict_pka import *

### pKaSolver

In [9]:
def RunPkasolver(x):#takes input of smile object
    mol=Chem.MolFromSmiles(x)
    protonation_states = calculate_microstate_pka_values(mol) #performs internal calculations and stores as object
    sites=len(protonation_states) #get the number of ionization sites 
    
    lst=[]
    depSmi=[]
    proSmi=[]
    
    for j in range(len(protonation_states)):
        state=protonation_states[j]
        depSmi.append(Chem.MolToSmiles(state.deprotonated_mol))
        proSmi.append(Chem.MolToSmiles(state.protonated_mol))
        lst.append(round(state.pka,2)) #get pka values for all sites for a given molecule store in a list
    yield sites, lst,proSmi,depSmi

### MolGpKa

In [2]:
def RunMolGpKa(x):
    mol = Chem.MolFromSmiles(x)
    base_dict, acid_dict, m = predict(mol)
    atom_idx = list(base_dict.keys()) + list(acid_dict.keys())
    pkas = list(base_dict.values()) + list(acid_dict.values())
    sites=len(pkas)
    yield sites, pkas

### Ionization Fraction Plots

In [10]:
def GetMonoPlot(pka_list,minph,maxph,step,proSmi,depSmi):
    pka=pka_list[0]
    x=[]
    a0=[]
    a1=[]
    
    ph=minph
    
    while ph<=maxph:
        ph+=step
        x.append(ph)
    
        a0.append(round((1/(1+10**(ph-pka))),pka_dec))
        a1.append(1-(round((1/(1+10**(ph-pka))),pka_dec))) # or round((10**(ph-pka))/(1+10**(ph-pka)),pka_dec)
    
    plt.plot(x,a0,color='red',label='a0')
    plt.plot(x,a1, color='green',label='a1')
    plt.legend() 
    #make a list of smiles in order 
    smiles=[pro[0]]+dep
   #make a dictionary where the keys are the a_index (ex.0= a0, 1=a1, etc.) and values are the microspecies smiles strings
    microspecies=dict(list(enumerate(smiles)))

    return plt,microspecies
    

In [11]:
def GetDiPlot(pka_list,minph,maxph,step,proSmi,depSmi):
    pka1=pka_list[0]
    pka2=pka_list[1]
    
    x=[]
    a0=[]
    a1=[]
    a2=[]
    
    ph=minph
    
    while ph<=maxph:
        ph+=step
        x.append(ph)
    
        ka1=10**(ph-pka1)
        ka2=10**(ph-pka2)
        E=((1+ka1)+(ka1*ka2))
    
        a0.append(round(((1**2)/E),pka_dec))
        a1.append(round(((1*ka1)/E),pka_dec))
        a2.append(round(((ka1*ka2)/E),pka_dec))
    
    plt.plot(x,a0,color='red',label='a0')
    plt.plot(x,a1, color='green',label='a1')
    plt.plot(x,a2,color='blue',label='a2')
    plt.legend()
    
    #make a list of smiles in order 
    smiles=[pro[0]]+dep
    #make a dictionary where the keys are the a_index (ex.0= a0, 1=a1, etc.) and values are the microspecies smiles strings
    microspecies=dict(list(enumerate(smiles)))

    return plt,microspecies

In [12]:
def GetTriPlot(pka_list,minph,maxph,step,proSmi,depSmi):
    pka1=pka_list[0]
    pka2=pka_list[1]
    pka3=pka_list[2]
    
    x=[]
    a0=[]
    a1=[]
    a2=[]
    a3=[]
    
    ph=minph
    
    while ph<=maxph:
        ph+=step
        x.append(ph)
    
        ka1=10**(ph-pka1)
        ka2=10**(ph-pka2)
        ka3=10**(ph-pka3)
        D=((1+ka1)+(ka1*ka2)+(ka1*ka2*ka3))
    
        a0.append(round(((1**2)/D),pka_dec))
        a1.append(round(((1*ka1)/D),pka_dec))
        a2.append(round(((ka1*ka2)/D),pka_dec))
        a3.append(round(((ka1*ka2*ka3)/D),pka_dec))
    
    plt.plot(x,a0,color='red',label='a0')
    plt.plot(x,a1, color='green',label='a1')
    plt.plot(x,a2,color='blue',label='a2')
    plt.plot(x,a3,color='yellow',label='a3')

    #make a list of smiles in order 
    smiles=[pro[0]]+dep
    #make a dictionary where the keys are the a_index (ex.0= a0, 1=a1, etc.) and values are the microspecies smiles strings
    microspecies=dict(list(enumerate(smiles)))

    return plt,microspecies

In [13]:
def GetMultiPlot(pka_sites,pka_list,minph,maxph,step,proSmi,depSmi):
    df=pd.DataFrame()
    x=[]
    a0=[]
    ax=[]
    
    ph=minph
    
    while ph<=maxph:
        ph+=step
        x.append(ph)
        count=0
        D=1
        numTerms=[]
        for i in range(0,(pka_sites-1)):
            n1=10**(ph-pka_lst[i])
            n2=10**(ph-pka_lst[1+i])
            #if there is only one term, return D+n1 (should not happend)
            if pka_sites ==1:
                D+=n1
            else:
                while count < pka_sites:
                    #get the numerator term and save to list
                    N=n1
                    numTerms.append(N)
                    #caluclate denominator
                    D+=n1
                    #calculate next term
                    nth=n1 *n2
                    #update values
                    n1=nth
                    n2=10**(ph-pka_lst[i+2])
                    count+=1

        #calculate ionization fraction for each numTerm
        a0.append(round(((1**2)/D),pka_dec))
        a=[]
        for t in numTerms:
            a.append(round((t/D),pka_dec))
        ax.append(a)
       
    df['pH']=x
    df['ax']=ax
    df['a0']=a0
    
    
    #separate out ionization fractions into their own columns (based on ka)
    points=pd.DataFrame(df.ax.tolist()).add_prefix('a')
    
    #combine pka columns with rest of data
    data=pd.concat([df,points],axis=1)

    #make a list of smiles in order 
    smiles=[pro[0]]+dep
    #make a dictionary where the keys are the a_index (ex.0= a0, 1=a1, etc.) and values are the microspecies smiles strings
    microspecies=dict(list(enumerate(smiles)))

    
    #plot
    for i in data.columns[2:]:
        x=data['pH']
        y=data[i]
        plt.plot(x,y,lable=i)
        plt.legend()
    return plt,microspecies