### List of Imports

In [8]:
import pandas as pd
import numpy as np
import requests
#

# Class object that we utilize as framework

In [9]:
class MutationalClusterer:
    def __init__(self, work_dir, logging=True):
        '''
        General File structure should be something like:
        https://tree.nathanfriend.io/?s=(%27opGHs!(%27fancy8~fullPath!fJse~trailingSlash8~rootDot8)~K(%27K%27work_dir5C0QD64Swissprot.DB4OthA_files_for_Mlookup_in_d6.DB--IntAmediates4F_protein_mulGple_seq2.Ji4summary_staGsGcs29quJity2_df93-Results4*Mdf9*MplotsL*4*additHJ4*5Pymol_output-Example4screenshotL35PDBsQ4D6_parsA_s7%204NeighbFhood_clustAA.py43-3-%27)~vAsiH!%271%27)*%20%20-5*0HsAvaGH2_Jignment3...%2F4-*5%5Cn*6atabase7cripts8!true9.csv4AerFourGtiHonJalKsFce!L.png4Mc0_Q-S7-%01QMLKJHGFA987654320-*
        .
        └── work_dir/
            ├── Conservation/
            │   ├── Scripts
            │   ├── Database/
            │   │   ├── Swissprot.DB
            │   │   └── Other_files_for_conservation_lookup_in_database.DB
            │   ├── Intermediates/
            │   │   ├── our_protein_multiple_seq_alignment.ali
            │   │   ├── summary_statistics_alignment.csv
            │   │   ├── quality_alignment_df.csv
            │   │   └── .../
            │   └── Results/
            │       ├── conservation_df.csv
            │       ├── conservation_plots.png
            │       └── additonal
            ├── Pymol_output/
            │   └── Example/
            │       ├── screenshot.png
            │       └── .../
            └── PDBs/
                ├── Scripts/
                │   ├── Database_parser_scripts 
                │   ├── Neighbourhood_clusterer.py
                │   └── .../
                └── .../
        
        '''
        self.work_dir = work_dir #the directory we want to work in
        self.log_dir = logging # we want to store log results for whatever we do.
        #Your other features of the class that we need.
        
        

    def _visualize_clusters_pymol(self, pdb:str) -> None:
        """
        Utility function to visualize cluster in pymol.
        """
        
        #from pymol import cmd
        pass

    def _plot_clusters(self, pdb:str) -> object:
        """
        Helper function to plot some statistics or quick interactive plots to investigate clustering.
        Mostly thought about pyplot or plotly interactive plots i.e alignment where we can see the conservation etc.
        https://plotly.com/python/alignment-chart/
        """


    def _compute_neighbours(self, pdb:str) -> pd.DataFrame:
        pass



    def surr_atoms(self, inpath, protname, cutoff=8, outpath=os.getcwd()):
        with open(inpath, "r") as pdbfile1:

        # first we need to make extract all atoms from our pdb file.
        structure = parser.get_structure(protname, pdbfile1)

        # Selection.unfold_entities(<structure object>, <level of information that you want>)
        # other levels are "C" for chain, "R" for residue, "A" for atom and so on.
        atom_list = Selection.unfold_entities(structure, "A")

        # lets get the coordinates of all atoms now
        atom_coords = [(atom.get_coord(), atom.get_parent()) for atom in atom_list]
        # parse through atom_list and apply .get_coord() to each retrieved object.
        # we also store for each atom the parent residue

        # we provide as argument here the Selection.unfold.entities object which has all atoms.
        ns = NeighborSearch(atom_list)  # this class object has the .search() method defined in its __init__

        # we will store for each atom all residues that are within 8 A surroundings in this dictionary.
        f"""Keys: atom coordinates
        Values: all residues within {cutoff} A"""

        # we need a counter
        # we will go through all atoms of lets say residue 1: VAL.
        # this has 10 atoms so we need to search 10 times for each atom
        # within 8A cutoff radius, get all surrounding residues, and then
        # we will merge them together into a list and take only unique residues to get
        # rid of redundancy. Come to me Federico if you need explanation in detail given the next section is tricky.
        i = 1

        # we only want aa residue surroundings, excluding solvent and ligands
        aa_lst = ["VAL", "ALA", "GLY", "TRP", "ARG", "LYS", "LEU", "ILE", "ASP", "ASN", "GLN", "GLU", "PRO", "TYR",
                  "PHE",
                  "SER", "THR", "CYS", "MET", "HIS"]

        hits_per_atom_for_surr_residues = defaultdict(list)  # we store all hits in a dictionary
        for atoms in atom_coords:
            # if its a water atom, we are not interested in doing the neighbour search.
            # I will make this even more robust to exclude other ligands by only allowing the 20 aa to be selected.
            if atoms[1].get_resname() not in aa_lst:
                i += 1
                continue
            # if we are no longer in the same residue, we increment by 1 the counter and reset our tmp list.
            if atoms[1].get_id()[1] != i:
                i += 1

            f'''For each atom we will make a search for all surrounding atoms that are within {cutoff} A radius.'''

            proximal_atoms = ns.search(atoms[0], 8, "R")
            # I SET HERE search for atoms[0] because atoms is a tuple containing of coordinates
            # and parent residue name see line 75 + 76 #print(atom_coords[0])

            f"""Synthax: ns.search(<target object>, <Cutoff to be searched for>, 
            <type of information level that should be returned>
            R means we dont want the single atoms that are within {cutoff}A 
            found but instead their corresponding residues. For all atoms we would set <A> instead of <R>"""

            # this function searches through a target (in our case each atom as we loop through all available atoms)
            # and returns a list with all atoms within specified atoms .

            '''Take a look at the following print statement to see whats going on'''
            # print(f"The selected atom has the following coordinates:\nX:{atoms[0]}\nY:{atoms[1]}\nZ:{atoms[2]}\n \
            # These are all Residue ids that are within 8 A vicinity:\n")

            tmp = []  # we store all of them in a temporary list
            for residues in proximal_atoms:  # we go through all residues that were found within cutoff A
                id_x = residues.get_id()[1]
                # get_id gives us a tuple with shape ("", "residue number", "optinal flag").
                # Out of this tuple we want the residue id which is [1]
                # we only want residues that we dont have already in the list.
                # Makes no sense to add stuff that is already in there
                if id_x not in tmp and id_x != i:
                    # we also exclude the residue itself to be added to its neighbours.
                    tmp.append(id_x)
            # if we have all we append the whole list to the dictionary. we take the atoms parent residue name as a key.

            res_name = f"{atoms[1].get_resname()}{atoms[1].get_id()[1]}"
            # this string concatenation is super ugly to look at
            # and very confusing but it does the following:
            f"""{atoms[1].get_resname()} == residue name. In this case its a
                {atoms[1].get_id()[1]} """
            # so the resname would be in this case : VAL1 but you can modify the output as you wish. e.g VAL_1 or 1_VAL
            hits_per_atom_for_surr_residues[res_name].append(tmp)

        # print(hits_per_atom_for_surr_residues)
        # this shows we captures all aa within the protein
        # print(len(hits_per_atom_for_surr_residues))

        # lets make a dictionary containing all residues within 8 A on a residue base instead for all atoms.
        res_dict = defaultdict()
        # we parse through the old values and only add UNIQUE residues to the new dictionary so we dont have duplicates
        """UGLY SOLUTION BUT DOES THE TRICK"""
        for keys, values in hits_per_atom_for_surr_residues.items():
            # print(values)
            tmp_vals = []
            for vals in values:
                # each val is a list corresponding to 1 atom and its surrounded residues
                # which are the list entries. values is a list of lists covering the whole residue.
                for single_res in vals:
                    # now we go through all atoms of the residue.
                    if single_res not in tmp_vals:
                        # we add all atoms that are not already counted from previous atoms
                        # of the same target residue and add their surrounding residues to the tmp_vals list
                        # we don't add the atom itself
                        # print(single_res)
                        tmp_vals.append(single_res)

            # now we got 1 cycle done.
            # This corresponds to going through all atoms of e.g VALINE which has 7 atoms.
            # We group all residues that are neighbouring these 7 atoms and take only the unique ones.
            # This corresponds to the surrounding residues for the whole residue.
            res_dict[keys] = tmp_vals
            # quick look for the results.
            # sorry Federico for a bit mess above.
            # I will maybe refine it but this script works and you can directly implement it
            with open(outpath + protname, "w") as fh_out:
                for keys, values in res_dict.items():
                    full_hit = ""
                    fh_out.write(keys + ",")
                    for single_entries in values:
                        full_hit += str(single_entries) + " "
                    fh_out.write(full_hit + "\n")

                    # for keys, values in hits_per_atom_for_surr_residues.items():
                    #    #take a quick look at the result
                    #    print(f"here comes a list of all residues
                    #    that are in contact with all {len(values)} atoms that {keys} has:")
                    #    for i, vals in enumerate(values):
                    #        print(f"atom {i+1}: {vals}")
                    #    break

    

    def conservation(self, uniprot_id):
        '''Gets 3 different types of Conservation:
        - Shannon conservation: 
        Shannon entropy. 
        Higher values indicate lower conservation and greater variability at the site.
        
        - Relative conservation:
        Kullback-Leibler divergence.
        Higher values indicate greater conservation and lower variability at the site.
        
        - Lockless conservation
        Evolutionary conservation parameter defined by Lockless and Ranganathan (1999). 
        Higher values indicate greater conservation and lower variability at the site.
        '''

        if self.log_dir and not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        
        mmseq_fasta_result = self._mmseq_multi_fasta(uniprot_id=uniprot_id, outdir=self.work_dir)
        #get 3 different conservation scores in a pandas df.
        conserv_df = self._get_conservation(path_to_msa=mmseq_fasta_result)
        self.conservation_df = conserv_df

        conserv_df.to_csv(f"{self.log_dir}/conservation_df.csv")
        
    def _mmseq_multi_fasta(self, uniprot_id:str, outdir:str, 
                      sensitivity=7, filter_msa=0,
                     query_id = 0.6):
        """
        uniprot_id: The unique uniprot identifier used to fetch the corresponding fasta file that will be used as a template for mmseq2
        outdir: location where result files will be stored.
        sensitivity: mmseq2 specific parameter that goes from 1-7. The higher the more sensitive the search.
        filter_msa = 0 default. if 1 hits are stricter.
        query_id = 0.6 [0, 1]  the higher the more identity with query is retrieved. 1 means ONLY the query hits while 0 means take everything possible.
        """

        #we blast with this fasta as query.
        trgt_fasta_seq = self._get_gene_fasta(uniprot_id)
        #Make outdir for all required files.
        #we need to write it out to file.
        with open(f"{self.work_dir}/{uniprot_id}_fasta.fa", "w") as fasta_out:
            fasta_out.write(f">{uniprot_id}\n")
            fasta_out.write(trgt_fasta_seq)

        #fetch pre downloaded database from a parent folder.
        msa_file = None
        new_location = None
        try:
            DB_storage_location = f"{work_dir}"
            #shutil.copy(previous_path, savepath)
            bash_curl_cmd = f"mmseqs createdb {self.work_dir}/{uniprot_id}_fasta.fa {DB_storage_location}/query_fastaDB" 
            bash_curl_cmd_rdy = bash_curl_cmd.split()
            #run first cmd which setups query database based on our input fasta file
            result_setup_query_db = run(bash_curl_cmd_rdy, stdout=PIPE, stderr=PIPE, 
                                 universal_newlines=True)
            bash_curl_cmd_2 = f"mmseqs search {DB_storage_location}/query_fastaDB {DB_storage_location}/swiss_DB {DB_storage_location}/result_DB {DB_storage_location}/tmp -s {sensitivity}"    
            bash_curl_cmd_rdy_2 = bash_curl_cmd_2.split()
            #run 2nd cmd which blasts against swiss_DB and generates the resultDB (i.e our hits that were found)
            result_setup_blast_db = run(bash_curl_cmd_rdy_2, stdout=PIPE, stderr=PIPE, 
                                 universal_newlines=True)
            #mmseqs convert2fasta DB_clu_rep DB_clu_rep.fasta
            bash_curl_cmd_5 = f"mmseqs result2msa {DB_storage_location}/query_fastaDB {DB_storage_location}/swiss_DB {DB_storage_location}/result_DB {DB_storage_location}/{uniprot_id}_out.fasta --msa-format-mode 3 --filter-msa {filter_msa} --qid {query_id}" 
            bash_curl_cmd_5_rdy = bash_curl_cmd_5.split()
            result_setup_msa_convert = run(bash_curl_cmd_5_rdy, stdout=PIPE, stderr=PIPE, 
                                 universal_newlines=True)
            #delete last line.. required.
            sed_cmd = f'sed -e 1,4d -e $d {DB_storage_location}/{uniprot_id}_out.fasta'        
            bash_curl_cmd_6_rdy = sed_cmd.split()
            #f"{DB_storage_location}/{uniprot_id}_new_out.fasta"
            with open(f"{DB_storage_location}/{uniprot_id}_new_out.fasta", "w") as new_fasta:
                result_truncation = run(bash_curl_cmd_6_rdy, stdout=new_fasta, stderr=PIPE, 
                                 universal_newlines=True)
            # Specify the path to your MSA file
            msa_file = f"{DB_storage_location}/{uniprot_id}_new_out.fasta"
            #transfer the meta file to another location and delete useless files.
            # we need to delete : all uniprot* files. 
            # all query*. All result* 
            new_location = f"{self.work_dir}/{uniprot_id}.fasta"
            shutil.copy(msa_file, new_location)
            #remove_files_and_dirs_msa(DB_storage_location, uniprot_id=uniprot_id)
            
        except Exception as error:
            print(error)
        #we want the path to msa_file for downstream analysis.
        return new_location

    def _get_gene_fasta(self, uniprot_id:str):
        '''
        Helper function to grab the sequence 
        based on the Uniprot ID
        '''
        fields = "sequence"
        URL = f"https://rest.uniprot.org/uniprotkb/search?format=fasta&fields={fields}&query={uniprot_id}"
        resp = self._get_url(URL)
        resp = resp.iter_lines(decode_unicode=True)
        seq = ""
        i = 0
        for lines in resp:
            if i > 0:
                seq += lines
            i += 1
        return seq

    def _get_conservation(self, path_to_msa:str):    
        '''
        Helper function to compute 3 different types of conservation.
        
        - Shannon conservation: 
        Shannon entropy. 
        Higher values indicate lower conservation and greater variability at the site.
        
        - Relative conservation:
        Kullback-Leibler divergence.
        Higher values indicate greater conservation and lower variability at the site.
        
        - Lockless conservation
        Evolutionary conservation parameter defined by Lockless and Ranganathan (1999). 
        Higher values indicate greater conservation and lower variability at the site.
        '''
        canal = Canal(fastafile=path_to_msa, #Multiple sequence alignment (MSA) of homologous sequences
          ref=0, #Position of reference sequence in MSA, use first sequence always
          startcount=0, # ALways 0 because our seqs are always from 1 - end
          verbose=False) # no verbosity 
    
        result_cons = canal.analysis(method="all")
        return result_cons

    def _get_url(self, url):
        '''Helper function that uses requests for Downloads.'''
        try:
            response = requests.get(url)  
            if not response.ok:
                print(response.text)
        except:
            response.raise_for_status()
            #sys.exit() 
        return response