In [None]:
from Bio.PDB import *
from Bio import PDB
from torch import nn
import numpy as np
from ase import Atoms, Atom
import dask.dataframe as dd
from ordered_set import OrderedSet
import os
from collections import Counter
from collections import defaultdict
from itertools import chain, product
import sys

local_folder="/Users/jessihoernschemeyer/pKaSchNet"
pkPDB_CSV = f"{local_folder}/pkas.csv"
def read_database(path):
    """csv --> dask df"""
    #make the dask data frame from the PYPKA csv
    dk=dd.read_csv(path, delimiter=';', na_filter=False, dtype={'idcode':'category', 
                                                                    'residue_number':'uint8',
                                                                    'pk': 'float32',
                                                                    'residue_name':'category',
                                                                    'chain': 'category',
                                                                    'residue_name': 'category'
                                                                    })
                                                            
    dk=dk.rename(columns={'idcode': 'PDB ID', 'residue_number': 'Res ID', 'residue_name': 'Res Name', 'residue_number': 'Res ID', 'pk': 'pKa', 'chain' : 'Chain'}) #rename columns to match df from pkad 
    dk=dk.sort_values(['PDB ID', 'Res ID'], ascending=[True, True]) 
    dk=dk.compute() 
    dff = dk.reset_index() 

    return dff

def check_atoms_protein(structure, struc_atoms): 
    """internal function. checks every atom in the entire protein for metals, undesirables"""
    pdb_residues=[]
    for atom in struc_atoms: 
        resname, atomid=atom.get_parent().get_resname(), atom.get_full_id()[2:]
        element=atomid[2][0]

        if element in ["MG", "MN", "FE", "CO", "NI", "CU", "ZN"]:
            return 0#,0#print(f"{element} present, pdb skipped")
        
        else:
            #atomid=atom.get_full_id() #('', 0, 'B', (' ', 177, ' '), ('OH', ' '))
            if atomid[1][0] not in [' ']:
                if element == 'S': #check 4 hetero sulfur, exclude.
                    print(f"{atomid}, hetero sulfur. pdb skipped ")
                    return 0#,0
                
                if element in ['CA', 'CL', 'K', 'NA']: #other salt
                    for res in structure.get_residues():
                        if resname in ["GLU", "HIS", "ASP", "ARG", "TYR", "CYS", "LYS"]: #if the other salt is part of the residue (<3Ã¥ from geometric center), delete atom from residue
                            if np.linalg.norm(res.center_of_mass(geometric=True) - atom.get_coord()) < 3:
                                atom.get_parent().detach_child(atom.get_id()) #print(f"salt {atom} deleted, {d} from {res}")
    
    return structure#, set(pdb_residues) #('', 0, 'B', ('W', 371, ' '), ('O', ' '))

def atoms_to_structure(cutout, filename): 
    """Internal function (or not), cutout --> save to harddrive
    input: cutout: list of biopython atom objects (NOT ASE)"""
    chain_dict = {}

    structure = Structure.Structure(filename)
    model = Model.Model(0)
    structure.add(model)

    for atom in cutout:
        res = atom.get_parent()
        res_id, resname, chain_id = res.get_id(), res.get_resname(), res.get_full_id()[2]

        #make acidic GLH and ASH straight here. so change their name before saving 
        if resname == "GLU":
            resname="GLH"
            
        if resname=="ASP":
            resname="ASH"

        if resname=="HIS":
            resname="HIP"

        
            
        
        if chain_id not in chain_dict:
            chain = Chain.Chain(chain_id) #make new chain
            chain_dict[chain_id] = chain
            model.add(chain) #add it

        else:
            chain = chain_dict[chain_id]

        if res_id in [res.get_id() for res in chain.get_residues()]:
            residue = [res for res in chain.get_residues() if res.get_id() == res_id][0] 
        else:
            residue = Residue.Residue(res_id, resname, '') #make new res
            chain.add(residue)

        residue.add(atom)
    # save the pdb
    io = PDBIO()
    io.set_structure(structure)
    io.save(f"cuts/{filename}.pdb")

def generate_cutout_around_protonatable_site(residue, distance_cutoff, ns, counter, resname):
    """Residue wise resolurion. ns is neighbor search set up for the entire protein, residue is the single data point / 1 of several residues in a pdb & in pypka.
    input is one residue. output is the cutout around its titratable site, both of which can be plural e.g. his, mb asp and glu.
    residue (biopython Residue object): a single protonable residue """
    protonatable_sites = {"G":("OE1","OE2"), "A":("OD1","OD2"), "C":"SG", "L":"NZ", "H":("NE2", "ND1"), "T":"OH"}
    cuts = []
    if resname==0:
        #first atom is N and NTR
        #atoms=residue.
        center = residue['N'].get_coord()
        cut = ns.search(center, distance_cutoff, "A")
        cuts.append((counter, center, 'NTR', cut)) #counter is id!
        return cuts
    
    elif resname==1: #CTR
        try:
            center = residue['OXT'].get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, 'OX', cut))
        except:
            center = residue['C'].get_coord()
        cut = ns.search(center, distance_cutoff, "A")
        cuts.append((counter, center, 'X', cut)) #counter is id!
        return cuts
 
    else:
        if resname=="G": 
            sites=protonatable_sites[resname]
            atom1,atom2=residue[sites[0]],residue[sites[1]]
            if atom1.is_disordered(): 
                center, resname = atom1.get_coord(), resname + "D"
            elif atom2.is_disordered():
                center, resname = atom2.get_coord(), resname + "D"
            else:
                center=(atom1.get_coord() + atom2.get_coord()) / 2.0
            cut = ns.search(center, distance_cutoff, "A") #put ns search i n below? todo
            cuts.append((counter, center, resname, cut)) #counter is id!
            return cuts
        if resname=="A": 
            sites=protonatable_sites[resname]
            atom1,atom2=residue[sites[0]],residue[sites[1]]
            if atom1.is_disordered(): 
                center, resname = atom1.get_coord(), resname + "D"
                print(1)
            elif atom2.is_disordered():
                center, resname = atom2.get_coord(), resname + "D"
                print(2)
            else:
                center=(atom1.get_coord() + atom2.get_coord()) / 2.0
                #print(3)
            
            cut = ns.search(center, distance_cutoff, "A")
            #print("cut", cut)
            cuts.append((counter, center, resname, cut)) #counter is id!
            return cuts
        if resname=="C": 
            site=residue[protonatable_sites[resname]]
            if site.is_disordered(): 
                resname = resname + "D"
            center =site.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, resname, cut)) #counter is id!
            return cuts
        if resname=="L": 
            site=residue[protonatable_sites[resname]]
            if site.is_disordered(): 
                resname = resname + "D"
            center =site.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, resname, cut)) #counter is id!
            return cuts
        if resname=="T":
            site=residue[protonatable_sites[resname]]
            if site.is_disordered(): 
                resname = resname + "D"
            center =site.get_coord()
            cut = ns.search(center, distance_cutoff, "A")
            cuts.append((counter, center, resname, cut)) #counter is id!
            return cuts
              
        if resname=="H":
            sites=protonatable_sites[resname]
            atom1,atom2=residue[sites[0]],residue[sites[1]]            
            if atom1.is_disordered(): 
                resname = resname + "D"
            if atom2.is_disordered():
                resname = resname + "D"
            center1,center2=atom1.get_coord(), atom2.get_coord()
            cut1 = ns.search(center1, distance_cutoff, "A")
            cut2= ns.search(center2, distance_cutoff, "A")
            cuts.append([(counter+.1, center1, resname, cut1),(counter+.2, center2, resname, cut2)])
            #cuts.append((counter+.2, center2, resname, cut2))


    return cuts #plural because of sites with multiple sites.

def merge_or_not_cutouts(cutouts_apdb, distance_cutoff): #TODO: reduce dtypes #PDB WISE!
    """
    in: all of the cutouts from the pdb. returns the merged or solo cutout for each input residue of cutouts_apdb. len in = len out"""
    #protein wise ..
    dp_ids,centers,cuts, Ds_lite, cutouts, resnames,redunant_merged_is, done_pairs =[],[],[],[], [],[],[],[]

    for site in cutouts_apdb:
        if type(site)==tuple:
            #print("a site", site[0])
            dp_ids.append(site[0]) #1,2,3f,4,5,6,7,8,9.1, 9.2....
            centers.append(site[1]) 
            resnames.append(site[2])
            cuts.append(site[3])
        else:
            print("a site", site[0][0], site[1][0])
            #site1,site2=site[0],site[1]
            dp_ids.append(site[0][0]) #1,2,3,4,5,6,7,8,9.1, 9.2....
            dp_ids.append(site[1][0])

            centers.append(site[0][1]) 
            centers.append(site[1][1]) 

            resnames.append(site[0][2])
            resnames.append(site[1][2]) 

            cuts.append(site[0][3])
            cuts.append(site[1][3])
    print(resnames)
    num_residues=len(centers)
    distances = np.zeros((num_residues, num_residues))
    
    for i in range(num_residues):
        for j in range(i + 1, num_residues):
            distance = np.linalg.norm(centers[i] - centers[j]).astype(np.float32)

            if distance < distance_cutoff:
                distances[i, j] = distance.astype(np.float32)
                distances[j, i] = distance.astype(np.float32)

    #Ds lite correctly gets the nonzero entries from column? row? i of distances.
    Ds_lite = [distances[i][distances[i] != 0] for i in range(num_residues)] #nonzero entries for easier searching
    #print(len(Ds_lite))
#residuewise...
    for i in range(len(Ds_lite)): #=len IDs 
        a_residues_distance_array=Ds_lite[i]
        if a_residues_distance_array.any(): #if not empty
            index=i
            closest_cutout_i = int(np.where((distances[int(index), :])==np.min(a_residues_distance_array))[0]) #int is unneccessary TODO
            pair_i = frozenset((index,dp_ids[closest_cutout_i]))#key #frozen set is immutable thus can be used as a dict key #also order doesnt matter, 2-1=1-2

            if not done_pairs: #if there are any yet merged
                pairid = resnames[i] + resnames[closest_cutout_i]
                
                cutout = (list(set(cuts[i] + cuts[closest_cutout_i])),resnames[i] + resnames[closest_cutout_i], (centers[i], centers[closest_cutout_i]))
                done_pairs.append(pair_i)
                redunant_merged_is.append(closest_cutout_i)
                
            else: #if there are already some generated
                if pair_i not in done_pairs: #if that mergedcut hasnt yet been made
                    cutout = (list(set(cuts[i] + cuts[closest_cutout_i])),resnames[i] + resnames[closest_cutout_i], (centers[i], centers[closest_cutout_i]))
                    done_pairs.append(pair_i)
                    redunant_merged_is.append(closest_cutout_i)

                else: #null
                    cutout = None

        else: #solo cutout
            cutout = (cuts[i], centers[i])

        cutouts.append(cutout)

    if len(cutouts) != len(dp_ids): #delete?
        return #this will make an exception if something went wrong
    
    return cutouts,redunant_merged_is #merged or solo

def get_cutout(dask_df, distance_cutoff): #"PARENT" FUNCTION
    """for each protein in dask_df (the entire PYPKA database), it iterates residue wise through the 121,294 proteins in PYPKA database and downloads
    the structure from RCSB with biopython. Then, it checks and skips the structure if metals & hetero sulfurs are present, and deletes non-sulfur
    salts from titratable residues.
    Then, for each structure residue represented in PYPKA, generates a cutout for each residue, appends the structure to cutouts_apdb"""
    all_fnames, all_cuts, all_centers = [],[],[]
    for i in range(19,20): #will equal len of set of pdbs in pypka, == 121294 
        cutouts_apdb, fnames, cutouts_1_datapoint, counter, pdbname, newfnames, centers_apdb = [],[], [],0, pdbs[i],[],[]
        Structure = pdb_parser.get_structure("",  PDBList().retrieve_pdb_file(str.lower(pdbname),obsolete=False, pdir='PDB',file_format = 'pdb'))
        structure= check_atoms_protein(Structure, Structure.get_atoms())
        if not structure: #skip entire pdb and all its entries in pypka db if there are undesirables in pdb
            continue
            
        ns = PDB.NeighborSearch(list(structure.get_atoms())) #set up ns , entire protein
        pdb_df = dask_df[dask_df.iloc[:, 1] == pdbname].drop(columns = ["PDB ID", "pKa"]) #make a subdf containing only residue entries which are in PYPKA (dask_df) 
        for j in range(len(pdb_df)):  #go through each residue in a pdb #each j is a datapoint!
            NTRresname, CTRresname=None,None
            chain, res_id =pdb_df.iloc[j]['Chain'], int(pdb_df.iloc[j]['Res ID'])
            try: 
                residue=structure[0][chain][res_id] #a datapoint #TODO: make ID?
                pypka_resname, PDBresname = pdb_df.iloc[j]['Res Name'], residue.get_resname() #pypka error
                if pypka_resname=='NTR':
                    resname, pypka_resname, NTRresname=0, 'N', PDBresname
                elif pypka_resname=='CTR':
                    resname,pypka_resname, CTRresname =1,"X", PDBresname #carboxyl
                elif pypka_resname==PDBresname:
                    resname=pypka_resname[0]
                #pypka error: if not in ntr or == pdbresname, it will error and pass.
                    #elif pypka_resname == PDBresname: #ACHTUNG! this navigates the pyka error. #TODO: mail him #TODO: THIS EXCLUDES NTR AND CTR!

                cutouts_1_datapoint=generate_cutout_around_protonatable_site(residue, distance_cutoff, ns, counter, resname) #can be multiple #returns empty if disordered
                if cutouts_1_datapoint: #cutouts_1_datapoint DNE if titratable site is disordered
                    cutouts_apdb.append(*cutouts_1_datapoint) #append each residue/data point cutouts here #it will error here if disordered
                    #for _ in cutouts_1_datapoint:
                    #TODO check: pypka resname always equals pdb resname
                    #fnames.append(f"{pdbname}{chain}{res_id}_{pypka_resname}{counter}") 
                    if resname!="H": #TODO: check if its quicker to do "for cuts in cutouts a pdb" or if resname==H
                        fnames.append(f"{pdbname}{chain}{res_id}_{pypka_resname}{counter}") 
                        
                    elif CTRresname:
                        fnames.append(f"{pdbname}{chain}{res_id}_{pypka_resname}{counter}-{CTRresname}")
                    elif NTRresname:
                        fnames.append(f"{pdbname}{chain}{res_id}_{pypka_resname}{counter}-{NTRresname}") 

                        
                    else: #normal
                        fnames.append(f"{pdbname}{chain}{res_id}_{pypka_resname}{counter + .1}") 
                        fnames.append(f"{pdbname}{chain}{res_id}_{pypka_resname}{counter + .2}") 
                    counter+=1
                else:
                    continue #pypka error

            except Exception as e:
                print(f"Exception caught: {e}")
                raise  
  
        #os.remove(f"{local_folder}/PDB/pdb{pdbname}.ent")  #TODO
        if cutouts_apdb:
            merged_and_solos, greaterN_pair_i =merge_or_not_cutouts(cutouts_apdb, distance_cutoff)#make a merged cutout or not based off radius criteria
            for cut, fname in zip(merged_and_solos, fnames):
                if not cut:
                    continue
                
                elif len(cut)==3: #means it is a merged cutout. second argument is the pairid #it is still 1-to-1 here but ima destroy it
                    Fname="".join([fname,'_',fnames[greaterN_pair_i[0]],"_",cut[1]]) #cut1 is pairid AT, HH,...
                    newfnames.append(Fname)
                    centers_apdb.append(cut[2])
                    del greaterN_pair_i[0]
                    atoms_to_structure(cut[0], Fname) #save as pdb) #cut 

                else:
                    newfnames.append(fname)
                    centers_apdb.append(cut[1])
                    atoms_to_structure(cut[0], fname) 
        all_fnames.append(newfnames)
        all_cuts.append(cutouts_apdb)
        all_centers.append(centers_apdb)

    return all_fnames, all_cuts, all_centers 



def amber(input_pdb):
    skript = f"""source leaprc.protein.ff14SB
    source leaprc.water.tip3p
    loadOff "/Users/jessihoernschemeyer/miniconda3/envs/cfcnn/dat/leap/lib/amino19.lib"
    mol = loadpdb "/Users/jessihoernschemeyer/pKaSchNet/cuts/{input_pdb}.pdb"
    savepdb mol "/Users/jessihoernschemeyer/pKaSchNet/prot/{input_pdb}.pdb"

    quit"""
    with open("ascript.py","w") as file: 
        file.writelines(skript)
    return


protonatable_sites2 = {"G":"GLH", 
                      "C":"CYS", 
                      "L":"LYS", 
                      "A":"ASH",
                      "H": "HIP",
                      "T": "TYR"
                      "X": "CTR"
                      "N": "NTR"}

protonatable_sites = {"G":"HE2", 
                      "C":"HG", 
                      "L":"HZ1", 
                      "A":"HD2",
                      "H": ("HD1","HE2"), #this needs to be fixed 
                      "T": "HH"}

def deprotonate_singles(cut, res): #turn one acidic into all its others? #NOT HIS!!!
    """Binary Situation
    this function takes a cutout that is not merged / a single cutout, and the file name of that cutout. It gets the residue string name, and from that a string of the atom name of the hydrogen ion
    that should be deleted. 
    ONE CCUT AT A TIME!
    Then, found atoms is a list of all the atoms with that name e.g. "OXT" and the desired parent res (in the file name).
    If the atom has the name and parent res name matching the file name information, then the atom is detached and then saved.
    
    This returns: ApBd or AdBp."""

#NORMAL CASE
    try: #NORMAL. the atoms for x and n are not in the dictionary. 
        #res=A,T..
        atom_to_delete = protonatable_sites[res] #ntr/ctr errors here
        res = protonatable_sites2[res] #careful here

        for residue in cut.get_residues():
                if residue.get_name() == res:
                    residue.detach_child(atom_to_delete)
                    return cut
                    
#TERMINUS
    except: #ter. "PHE" #CTR/NTR . input is "PHE #need to consider that there might be two e.g. phe or met!!
        if len(res) > 1: #his
            for residue in cut.get_residues():
                if residue.get_name() == "HIP":
                    residue.detach_child(res) #res IS atom to delete for HIS!
        
        else: #ctr/ntr
        deprotonate_terminus_single(cut, res)

def deprotonate_terminus_single(cut, res):
    """Input: XPHE, NARG..
    returns a deprotonated NTR or CTR"""
    ter, resi = res[0], res[1:]
    for residue in cut.get_residues():
        if residue.get_name() == resi: 
            if res=="X": #CTR
                atoms = residue.get_atoms()
                try: 
                    residue.detach_child("OXT")
                    return cut 
                except:
                    residue.detach_child("C")
                    return cut
            else: #NTR 
                residue.detach_child("N")
                return cut

def dp_GleichRes2Mal(cut, res): #doesnt take HHH or ters
    "returns two cuts, deprotonated of them both."
    
    found_atoms, dp_cuts =[],[]
    atom_to_delete = protonatable_sites[res]

    if type(atom_to_delete) == tuple: #histidine
        for residue in cut.get_residues():
            if residue.get_name() == "HIP":
                for atom in residue.get_atoms():   
                    atomname = atom.get_name()        
                    if atomname in atom_to_delete:
                        found_atoms.append((atomname, residue))
        
        found_atoms.sort() #will sort by letter and number 
        resA_del, resB_del, resA_eps, resB_eps = found_atoms[0], found_atoms[1], found_atoms[2], found_atoms[3]
        
        hipAhipB = cut.copy() 
        dp_cuts.append(hipAhipB, "HIPHIP")
        #modify resA
        resA_del[1].detach_child(resA_del[0]) #makes HIE res A, res B HIP
        hieAhipB = cut.copy()
        dp_cuts.append(hieAhipB, "HIEHIP") #HIE+HIp

        resA_eps[1].detach_child(resA_eps[0]) #makes HIS res A, res B HIP
        hisAhipB = cut.copy()
        dp_cuts.append(hisAhipB, "HISHIP") #HIS+HIp

        hidAhipB = hisAhipB.union(hipAhipB-hieAhipB)
        dp_cuts.append(hidAhipB, "HIDHIP")

        #modify residue b 
        resB_del[1].detach_child(resB_del[0]) #makes HIE res A, res B HIP
        hisAhieB = cut.copy()
        dp_cuts.append(hisAhieB, "HISHIE")

        resB_eps[1].detach_child(resB_eps[0])
        hisAhisB = cut.copy()
        dp_cuts.append(hisAhisB, "HISHIE")

        hisAhidB = hisAhisB.union(hisAhipB-hisAhieB)
        dp_cuts.append(hisAhisB, "HISHID")
        
        #now make the rest 
        hipA_atoms = hipAhipB - hisAhipB 
        HIPHIE = hisAhieB.union(hipA_atoms) #arg is the atoms which make hip seperate from his
        dp_cuts.append(HIPHIE, "HIPHIE")
        HIPHID = hisAhidB.union(hipA_atoms)
        dp_cuts.append(HIPHID, "HIPHID")
        HIPHIS = hisAhisB.union(hipA_atoms)
        dp_cuts.append(HIPHIS, "HIPHIS")
        #HIEHIE = HIPHIE | hieAhipB
        #HIEHID = HIPHID | hieAhipB
        #HIEHIS = HIPHIS | hieAhipB
        dp_cuts.append(HIPHIE | hieAhipB, "HIEHIE")
        dp_cuts.append(HIPHID | hieAhipB, "HIEHID")
        dp_cuts.apprnf(HIPHIS | hieAhipB,  "HIEHIS")

        dp_cuts.append(HIPHIE | hidAhipB, "HIDHIE") 
        dp_cuts.append(HIPHID | hidAhipB, "HIDHID")
        dp_cuts.append(HIPHIS | hidAhipB, "HIDHIS")

    elif len(atom_to_delete) == 2: #nor his
        Res = protonatable_sites2[res]
        for residue in cut.get_residues():
            if residue.get_name() == Res:
                for atom in residue.get_atoms():           
                    if atom.get_name() == atom_to_delete:
                        found_atoms.append((residue, atom_to_delete))

        #for two found atoms
        resA, resB = found_atoms[0], found_atoms[1]
        resA[0].detach_child(resA[1])
        Adp = cut.copy()
        #deprotonate cut fully by removing the other second residue
        resB[0].detach_child(resB[1])
        Bdp = cut.copy()

        return Adp, resB

def recut_and_deprotonate(fnames_apdb, centers_apdb, distance_cutoff): #the centers come in #fnames after protonation
    """
    fnames_apdb [list]: [fname_cut1, fname_cut2, fname_cut3] --> resA &/OR resB
                    Ex: ['199lA11_GLU3', '199lA10_ASP2_199lA161_TYR37_AT', '199lA70_ASP22_199lA31_HIS11.2_AH'] 

    This function takes in all the file names for a single pdb, as well as the centers of their titratable site (OXT, N, NE2, NH...). Using the fname, it
    makes a cutout again using that center which was found before. 
    
    If the cutout is merged, then there will be a multiple centers. Then, the cutout becomes the set of both cutoute (a sphere of distance_cutoff from the 
    protonatable site.)"""
    for fname, center in zip(fnames_apdb, centers_apdb):
        struct = pdb_parser.get_structure("",  f'/Users/jessihoernschemeyer/pKaSchNet/prot/{fname}.pdb')
        ns = PDB.NeighborSearch(list(struct.get_atoms())) #set up ns , entire protein

        #############MERGEDD#####
        if type(center)==tuple: 
            f=fname.split("_")
            singles_res, key=f[1],f[4] #22X-PHE, AT, XASP
            resA, resB = key[0],key[1] #A,T

            #NTR/CTR (cannot combine NTR and CTR.) MERGED
            if resA or resB in ('X', 'N'): #need to deal wih his in here???
                if resA in ('X', 'N'): 
                    res1 = f[1].split("-")
                    cut = set(ns.search(center[0], distance_cutoff, "A")) | set(ns.search(center[1], distance_cutoff, "A")) #the merged cut
                    Save(cut, fname) #fully protonated

                else: #elif resB in ('X', 'N'):
                    res2 = f[3].split("-")
                    cut = set(ns.search(center[0], distance_cutoff, "A")) | set(ns.search(center[1], distance_cutoff, "A")) #the merged cut
                    Save(cut, fname) #fully protonated

                resAdp = deprotonate_singles(resA+res1, cut.copy()) #xphe
                Save(resAdp, fname + f'~{resA}d') #resA deprot, B prot

                resBdp = deprotonate_singles(resB+res2, cut)
                Save(resBdp, fname + f"~{resB}d") #A prot, resB deprot

                deprotonated = set(resAdp) | set(resBdp)  #fully deprotonated #not making a new cut but combining what is done 
                Save(deprotonated, fname + "~d")

            elif resA == resB: #no ntr and ctr will get here!
                if resB == 'H' or resA=='H':  #double H
                    his_cuts = dp_HH("H", cut)
                    for cut in his_cuts:
                        Save(cut[0], fname + f'{cut[1]}')

                else: #normal AA, GG..
                  resAdp, resBdp = dp_GleichRes2Mal(cut, resA)   ##AA, GG... #res A is res B
                Save(resAdp, fname + f'~{resA}d') #resA deprot, B prot
                Save(resBdp, fname + f"~{resB}d") #A prot, resB deprot
                Save(set(resBdp) | set(resAdp), fname + '~d')
                
            elif resB == 'H' or resA=='H': #histidine 
                    if resA == 'H': #1 his #HA
                        HIE = deprotonate_singles("HD1", cut.copy())
                        Save(HIE, fname + '~HIE') #A, eps prot with B prot
                        HID = deprotonate_singles("HE2", cut.copy())
                        Save(HID, fname + '~HID') #A, delta protonated
                        HIS = deprotonate_singles("HD1", HID.copy())
                        Save(HIS, fname + '~HIS') #A fully deprotonated

                        resBdp = deprotonate_singles(resB, cut.copy())

                        #remaining combos with deprotonated other nonhis and nonter residue
                        resBd_HIE = set(resBdp) | set(HIE) 
                        Save(resAd_HIE, fname + f'~{resB}d+HIE')
                        resBd_HID = set(resBdp) | set(HID)
                        Save(resAd_HID, fname + f'~{resB}d+HID')
                        resBd_HIS = set(resBdp) | set(HIS)
                        Save(resAd_HIS, fname + f'~{resB}d+HIS')
                
                    else: #ResB = H. then this means that the second res is the his. AH
                        HIE = deprotonate_singles("HD1", cut.copy())
                        Save(HIE, fname + '~HIE') #HIE with A =HIP
                        HID = deprotonate_singles("HE2", cut.copy())
                        Save(HID, fname + '~HID')
                        HIS = deprotonate_singles("HD1", HID.copy())
                        Save(HID, fname + '~HIS') #fully deprotonated

                        resAdp = deprotonate_singles(resA, cut.copy())

                        #make combos with sets 
                        resAd_HIE = set(resAdp) | set(HIE) #after resA is already deprotonated
                        Save(resAd_HIE, fname + f'~{resA}d+HIE')
                        resAd_HID = set(resAdp) | set(HID)
                        Save(resAd_HID, fname + f'~{resA}d+HID')
                        resAd_HIS = set(resAdp) | set(HIS)
                        Save(resAd_HIS, fname + f'~{resA}d+HIS')
                        
            else:  #normal merged
                    cut = set(ns.search(center[0], distance_cutoff, "A")) | set(ns.search(center[1], distance_cutoff, "A")) #the merged cut
                    Save(cut, fname) #fully protonated, both

                    resAdp = deprotonate_singles(resA, cut.copy()) #resA = "A", "G". resAdp = a cut
                    Save(resAdp, fname + f'~{resA}d') #resA deprot, B prot

                    resBdp = deprotonate_singles(resB, cut.copy())
                    Save(resBdp, fname + f"~{resB}d") #A prot, resB deprot

                    deprotonated = set(resAdp) | set(resBdp)  #fully deprotonated #not making a new cut but combining what is done 
                    Save(deprotonated, fname + "~d")

        else: #single
            cut = ns.search(center, distance_cutoff, "A")
            Save(cut, fname) #fully protonated

            L = singles_res.split("-")
            if len(L)==2: #TER
                 deprotonated = deprotonate_singles(cut, L[1]) #L[1] = "PHE"
                 Save(deprotonated, fname + "~D")

            else: 
                if len(L[0].dplit(".")) == 2: #HIS
                    HIP=cut.copy()

                    HIE = deprotonate_singles(cut, "OD1")
                    Save(HIE, fname + "HIE")

                    HIS = deprotonate_singles(HIE.copy(), "OE2")
                    Save(HIS, fname + "HIS")
                    
                    Save(HIS.union(HIP - HIE), fname + "HID")
                    
                else: #regular
                    deprotonated = deprotonate_singles(cut, fname.split("_")[1][0]) #[A]SP22, [X]25PHE.. #not his
                    Save(deprotonated, fname + "~D")
            
dask_df = read_database(local_folder + pkPDB_CSV)
pdb_parser, pdbs = PDB.PDBParser(), list(OrderedSet(list(dask_df["PDB ID"])))

fs, all_cuts, all_centers = get_cutout(dask_df, 5)
    
for f in fs[0]:
    amber(f)
    !tleap -s -f /Users/jessihoernschemeyer/pKaSchNet/ascript.py

