In [None]:
class DownloadPipe:
    '''Class object containing the download function that will download all pdbs 
    which we need for downstream analysis of a particular uniprot ID'''

    def __init__(self, input_df, work_dir, script_dir, seq_id_cutoff=None, download_type="pdb", logging=True):
        self.work_dir = work_dir # Storage of seq identity useful later for temp selection.
        self.script_dir = script_dir #here we store all scripts
        self.seq_id_cutoff = seq_id_cutoff
        self.download_type = download_type # Download PDB or also mmCIF (currently only PDB)
        self.input_df = input_df
        self.pdbs_to_download = input_df.loc[:, "Target_id"]
        self.seq_id = input_df.loc[:, "Seq_identity"]
        # The bash script location which will download the pdbs. 
        self.download_script = os.path.join(script_dir, "batch_download_modified.sh") #modify for script location
        self.download_tmp = os.path.join(work_dir, "pdb_list.csv") # The location for the temporary file that is required for the download_script as input.
        self.log_dir = os.path.join(work_dir, "log_files")
        # The list of chains that will be used later to fetch correct structures.
        self.chain_seqid_dict = self._setup_download_list()
        self.temp_seqid_dict = {template: seq_id for template, seq_id in zip(self.pdbs_to_download, self.seq_id)}
        # We store also meta info as a json dict
        self.meta_dict = None
        #we store high resolution structures as a list if the user wants to separate based on resolution.
        self.high_resolution = None
        # set a flag that stops redownloading.
        self.already_downloaded = None
        # collect conservation
        self.conservation_df = None
        #filtered structures based on meta resolution
        self.filtered_structures = None
        #store shifts.
        self.shift_dict = None
        self.logging = logging # for report purpose.
            
    def paralellized_download(self):
        '''
        This function is going to call _download_files n times to parallelize download. 
        It is going to pass the function call itself **_download_file**,
        self.download_tmp (the location of the tmp file which is pdb_id comma separated), 
        p (an additional parameter specifying that 
        we want to download pdbs, and self.work_dir(the current work dir)
        '''

        self.already_downloaded = self._check_for_pdbs_present()
        # ThreadPoolExecutor
        print(f"{self.already_downloaded=}")
        if self.already_downloaded == False:
            print("we start downloading now:")

            # now for debugging.

            return 0
            
            with ThreadPoolExecutor() as executor:
                # Submit your tasks to the executor.
                futures_pdb = [executor.submit(self._download_files, self.download_tmp, 'p', self.work_dir)]
                # Optionally, you can use as_completed to wait for and retrieve completed results.
                for future in as_completed(futures_pdb):
                    result = future.result()
            self.already_downloaded = True
        else:
            print("we already have pdbs from the templates downloaded")
    
    def _setup_download_list(self):
        '''Helper function to setup the list of comma-separated pdb
        ids for the download_files function'''
        
        
        if not self.input_df.empty:
            pdbs = self.input_df.loc[:, "Target_id"]
            seq_ids = self.input_df.loc[:, "Seq_identity"]
        else:
            # we cant proceed
            return

        #initialize dict
        chain_seqid_dict = defaultdict(list)

        self.pdbs_to_download = [] # overwrite to set it blank for seq_id filtering.

        original_pdbs = len(set([x[0:4] for x in pdbs])) #for logging purpose. tells us how many pdbs originally were there before cutoff
        for pdb, seq_id in zip(pdbs, seq_ids):
            if float(seq_id) > float(self.seq_id_cutoff):
                pdb_4_digit_id = pdb[:4] # e.g 4CFR
                chain = pdb[-1] # e.g A
                chain_seqid_dict[pdb_4_digit_id].append((chain, seq_id)) #map chains and seq id to pdb id
                # We only want to download pdb files once. 
                # No reason to download a PDB-file 4 times just because we need chain [A, B, C, D]
                self.pdbs_to_download.append(pdb_4_digit_id) # we store it for a later check 
                
        unique_pdbs = chain_seqid_dict.keys() # Keys : PDB-IDs, Vals: Chains, seq_id
        # Create download_files input list
        
        if unique_pdbs:
            with open(self.download_tmp, "w") as pdb_tar:
                pdb_tar.write(",".join(unique_pdbs))
            
            #return dict {key: pdb_id = [(chain, seq_id)]}
            print(f"Before applying cutoff: {original_pdbs} Structures\nAfter applying cutoff: {len(unique_pdbs)} Structures")
            return chain_seqid_dict
        else:
            print(f"No structures available for cutoff {self.seq_id_cutoff}. Try lowering cutoff.")

    
    def _download_files(self, download_tmp, download_type, path)->list:
        """This helper function runs inside paralellized_download 
        and will be used to get the PDB files that we require for downstream analysis."""
        results = []
        # Input for subprocess
        bash_curl_cmd = f"{self.download_script} -f {download_tmp} -o {path} -{download_type}"
        # split into list 
        bash_curl_cmd_rdy = bash_curl_cmd.split()
        
        try:
            # Run subprocess
            result = run(bash_curl_cmd_rdy, stdout=PIPE, stderr=PIPE, universal_newlines=True)
            # Append result output to results
            results.append(result.stdout.split("\n")[:-1])  # Skip the last empty element
        except Exception as e:
            results.append(f"Error downloading: {e}")

        return results    

    def _check_for_pdbs_present(self):
        '''
        Could be good to improve so that if we miss SOME structures we fetch them as well and download ONLY those.
        For those structures we also need seqid per chain and then also update the seqid_chain dict for the whole directory after
        successful download.
        Currently we only check if pdbs are present and if yes we dont download anything further.
        '''
        pdbs_to_retrieve = {f[:4] for f in os.listdir(self.work_dir) if f.endswith(".pdb")}  # Use a set for efficient lookups
        template_codes = {f[:4] for f in self.pdbs_to_download}  # Convert list to set for efficient intersection operation

        print(f"{pdbs_to_retrieve=}, {template_codes=}")
        # Check for any overlap between the two sets
        overlap = pdbs_to_retrieve.intersection(template_codes)
        
        print(f"This is overlap in the directory: {overlap}")
        # Return 1 if there is an overlap, else 0
        return True if overlap else False

    
    def retrieve_meta(self, dict_location=None, human_readable=True)->dict:
        '''
        We also want to store meta information about resolution etc.
        This function takes each pdb file and retrieves the following information:
        - Title
        - Keywords
        - PDBcode
        - Authors
        - Deposition date
        - Technique
        - Resolution
        - R_value : If crystallography else None
        - R_free : If crystallographe else None
        - Classification
        - Organism
        - Expression System
        - Number of amino acids in the asymmetric unit
        - Mass of amino acids in the asymmetric unit (Da)
        - Number of amino acids in the biological unit
        - Mass of amino acids in the biological unit (Da)
        '''
        
        json_file_path = os.path.join(self.log_dir, 'meta_dictionary.json')

        for path in [json_file_path, dict_location]: #check first the supposed location alternatively the user supplied location.
            if path and os.path.exists(path):
                try:
                    with open(path, 'r') as json_fh:
                        self.meta_dict = json.load(json_fh)
                except Exception as e:
                    print(f"Error reading {path}: {e}")
                
        #little helper function to deal with date data
        def _date_encoder(obj):
            if isinstance(obj, date):
                return obj.isoformat()  # Convert date to ISO format

        #grab all PDB files which contain the meta information.
        pdbs_to_retrieve = [f for f in os.listdir(self.work_dir) if f.endswith(".pdb")]
        #here we store info about ALL pdbs.
        meta_dictionary = dict()
        
        for pdbs in pdbs_to_retrieve:
            if len(pdbs) == 8: #lets exclude preprocessed pdbs that are longer or shorter.
                sub_dict = dict()
                pdb_code = pdbs[:4]
                try:
                    fullp = os.path.join(self.work_dir, pdbs)
                    pdb = atomium.open(fullp)
                    sub_dict["title"] = pdb.title
                    sub_dict["key_words"] = pdb.keywords
                    sub_dict["code"] = pdb.code
                    sub_dict["authors"] = pdb.authors
                    #sub_dict["deposition_date"] = pdb.deposition_date.isoformat()  #isoformat because it is a time object
                    sub_dict["technique"] = pdb.technique
                    sub_dict["resolution"] = pdb.resolution
                    sub_dict["r_val"] = pdb.rvalue
                    sub_dict["r_free"] = pdb.rfree
                    sub_dict["classification"] = pdb.classification
                    sub_dict["organism"] = pdb.source_organism
                    sub_dict["expression_system"] = pdb.expression_system
                    sub_dict['number_of_residues_asymmetric_unit'] = len(pdb.model.residues())
                    sub_dict['mass_dalton_asymetric_unit'] = f"{pdb.model.mass:.2f}" 
                    try:
                        assembly = pdb.generate_assembly(1) #build the biological assembly 
                        sub_dict['number_of_residues_biological_unit'] = len(assembly.residues())
                        sub_dict['mass_dalton_biological_unit'] = f"{assembly.mass:.2f}"
                    except Exception as e:
                        print(f"We could not build the assembly for: {pdb_code}")
    
                except Exception as e:
                    print(f"We had an error with file: {pdb_code}")
                # store meta info and return
                meta_dictionary[pdb_code] = sub_dict


        #lets store meta info as json dict
        self.meta_dict = meta_dictionary
        
        # Code block to store meta info as a txt file.
        self._save_meta_dict(self.meta_dict, human_readable=human_readable)


    def _save_meta_dict(self, meta_dictionary, human_readable=True):
        '''Helper function to store meta info as a txt file.'''
        #check if log file dir exists, else make it.
        
        if self.log_dir and not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        #lets store the dict in json to read it in for later useage.
        json_file_path = os.path.join(self.log_dir, 'meta_dictionary.json')
        #convert defaultdict to normal dict.
        
        with open(json_file_path, 'w') as json_fh:
            json.dump(meta_dictionary, json_fh, indent=4, default=str)  # Use default=str to handle non-serializable objects

    
    def conservation(self, uniprot_id):
        '''Gets 3 different types of Conservation:
        - Shannon conservation: 
        Shannon entropy. 
        Higher values indicate lower conservation and greater variability at the site.
        
        - Relative conservation:
        Kullback-Leibler divergence.
        Higher values indicate greater conservation and lower variability at the site.
        
        - Lockless conservation
        Evolutionary conservation parameter defined by Lockless and Ranganathan (1999). 
        Higher values indicate greater conservation and lower variability at the site.
        '''

        if self.log_dir and not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        
        mmseq_fasta_result = self._mmseq_multi_fasta(uniprot_id=uniprot_id, outdir=self.work_dir)
        #get 3 different conservation scores in a pandas df.
        conserv_df = self._get_conservation(path_to_msa=mmseq_fasta_result)
        self.conservation_df = conserv_df

        conserv_df.to_csv(f"{self.log_dir}/conservation_df.csv")
        
    def _mmseq_multi_fasta(self, uniprot_id:str, outdir:str, 
                      sensitivity=7, filter_msa=0,
                     query_id = 0.6):
        """
        uniprot_id: The unique uniprot identifier used to fetch the corresponding fasta file that will be used as a template for mmseq2
        outdir: location where result files will be stored.
        sensitivity: mmseq2 specific parameter that goes from 1-7. The higher the more sensitive the search.
        filter_msa = 0 default. if 1 hits are stricter.
        query_id = 0.6 [0, 1]  the higher the more identity with query is retrieved. 1 means ONLY the query hits while 0 means take everything possible.
        """

        #we blast with this fasta as query.
        trgt_fasta_seq = self._get_gene_fasta(uniprot_id)
        #Make outdir for all required files.
        #we need to write it out to file.
        with open(f"{self.work_dir}/{uniprot_id}_fasta.fa", "w") as fasta_out:
            fasta_out.write(f">{uniprot_id}\n")
            fasta_out.write(trgt_fasta_seq)

        #fetch pre downloaded database from a parent folder.
        msa_file = None
        new_location = None
        try:
            DB_storage_location = f"{work_dir}"
            #shutil.copy(previous_path, savepath)
            bash_curl_cmd = f"mmseqs createdb {self.work_dir}/{uniprot_id}_fasta.fa {DB_storage_location}/query_fastaDB" 
            bash_curl_cmd_rdy = bash_curl_cmd.split()
            #run first cmd which setups query database based on our input fasta file
            result_setup_query_db = run(bash_curl_cmd_rdy, stdout=PIPE, stderr=PIPE, 
                                 universal_newlines=True)
            bash_curl_cmd_2 = f"mmseqs search {DB_storage_location}/query_fastaDB {DB_storage_location}/swiss_DB {DB_storage_location}/result_DB {DB_storage_location}/tmp -s {sensitivity}"    
            bash_curl_cmd_rdy_2 = bash_curl_cmd_2.split()
            #run 2nd cmd which blasts against swiss_DB and generates the resultDB (i.e our hits that were found)
            result_setup_blast_db = run(bash_curl_cmd_rdy_2, stdout=PIPE, stderr=PIPE, 
                                 universal_newlines=True)
            #mmseqs convert2fasta DB_clu_rep DB_clu_rep.fasta
            bash_curl_cmd_5 = f"mmseqs result2msa {DB_storage_location}/query_fastaDB {DB_storage_location}/swiss_DB {DB_storage_location}/result_DB {DB_storage_location}/{uniprot_id}_out.fasta --msa-format-mode 3 --filter-msa {filter_msa} --qid {query_id}" 
            bash_curl_cmd_5_rdy = bash_curl_cmd_5.split()
            result_setup_msa_convert = run(bash_curl_cmd_5_rdy, stdout=PIPE, stderr=PIPE, 
                                 universal_newlines=True)
            #delete last line.. required.
            sed_cmd = f'sed -e 1,4d -e $d {DB_storage_location}/{uniprot_id}_out.fasta'        
            bash_curl_cmd_6_rdy = sed_cmd.split()
            #f"{DB_storage_location}/{uniprot_id}_new_out.fasta"
            with open(f"{DB_storage_location}/{uniprot_id}_new_out.fasta", "w") as new_fasta:
                result_truncation = run(bash_curl_cmd_6_rdy, stdout=new_fasta, stderr=PIPE, 
                                 universal_newlines=True)
            # Specify the path to your MSA file
            msa_file = f"{DB_storage_location}/{uniprot_id}_new_out.fasta"
            #transfer the meta file to another location and delete useless files.
            # we need to delete : all uniprot* files. 
            # all query*. All result* 
            new_location = f"{self.work_dir}/{uniprot_id}.fasta"
            shutil.copy(msa_file, new_location)
            #remove_files_and_dirs_msa(DB_storage_location, uniprot_id=uniprot_id)
            
        except Exception as error:
            print(error)
        #we want the path to msa_file for downstream analysis.
        return new_location

    def _get_gene_fasta(self, uniprot_id:str):
        '''
        Helper function to grab the sequence 
        based on the Uniprot ID
        '''
        fields = "sequence"
        URL = f"https://rest.uniprot.org/uniprotkb/search?format=fasta&fields={fields}&query={uniprot_id}"
        resp = self._get_url(URL)
        resp = resp.iter_lines(decode_unicode=True)
        seq = ""
        i = 0
        for lines in resp:
            if i > 0:
                seq += lines
            i += 1
        return seq

    def _get_conservation(self, path_to_msa:str):    
        '''
        Helper function to compute 3 different types of conservation.
        
        - Shannon conservation: 
        Shannon entropy. 
        Higher values indicate lower conservation and greater variability at the site.
        
        - Relative conservation:
        Kullback-Leibler divergence.
        Higher values indicate greater conservation and lower variability at the site.
        
        - Lockless conservation
        Evolutionary conservation parameter defined by Lockless and Ranganathan (1999). 
        Higher values indicate greater conservation and lower variability at the site.
        '''
        canal = Canal(fastafile=path_to_msa, #Multiple sequence alignment (MSA) of homologous sequences
          ref=0, #Position of reference sequence in MSA, use first sequence always
          startcount=0, # ALways 0 because our seqs are always from 1 - end
          verbose=False) # no verbosity 
    
        result_cons = canal.analysis(method="all")
        return result_cons

    def _get_url(self, url):
        '''Helper function that uses requests for Downloads.'''
        try:
            response = requests.get(url)  
            if not response.ok:
                print(response.text)
        except:
            response.raise_for_status()
            #sys.exit() 
        return response
    
    def setup_cutoff(self, cutoff=10, apply_filter=False):
        '''If we want to setup a resolution cutoff filter for further downstream analysis, 
        this function helps with it.'''
        # If there is no meta dict we cant proceed and filter based on resolution.
        if self.meta_dict:
            #here we store the pdb codes that we keep
            pdbs_to_keep = []
            #Now lets parse through the whole meta dict and fetch the cutoffs for structures.
            for _, single_pdbs in self.meta_dict.items():
                try:
                    if single_pdbs['resolution'] <= cutoff:
                        pdbs_to_keep.append(single_pdbs['code'].lower()) #normalize to lower in order to have uniform list members.   
                except:
                    # 'technique': 'SOLUTION NMR' check for that.
                    print(f"we allow for now {single_pdbs=} because no resolution! check if NMR")
                    pass

            
            self.filtered_structures = pdbs_to_keep
            #now if we directly want to apply the filter to remove files that dont match our criteria.
            if apply_filter:
                #check for union between files and kept structures.
                pdbs_to_retrieve = [f[:4] for f in os.listdir(self.work_dir) if f.endswith(".pdb")]
                #lets fetch the intersect between the 2 sets which corresponds to the pdbs we want to keep.
                common_pdb = set(pdbs_to_retrieve) & set(pdbs_to_keep) #intersection
                intersect_lst = list(common_pdb)
                self.filtered_structures = intersect_lst
                if self.chain_seqid_dict:
                    #now we need to update the chain_dict as well:
                    filtered_dict = {pdb: v for pdb, v in self.chain_seqid_dict.items() if pdb[:4] in self.filtered_structures}
                    self.filtered_structures = filtered_dict
                    
        else:
            print("We have no meta dict to implement a cutoff")
            #In this case we take all.
            print(f"{self.chain_seqid_dict=}")
            # this also needs to take into account the seq id to be useful.
            pdbs_to_retrieve = [f[:4] for f in os.listdir(self.work_dir) if f.endswith(".pdb") and len(f) == 8] #exclude non original files. Only store pdb + _ + chains.
            
            self.filtered_structures = pdbs_to_retrieve

    def parallel_shift_calculation(self):
        '''Here we compute the shift according to uniprot or authors
        in order to be in line with UNIPROT numbering which is crucial for later renumbering.'''
        
        pdbs_to_retrieve = [f[0:4] for f in os.listdir(self.work_dir) if f.endswith(".pdb")]  
        pdbs_to_retrieve = set(pdbs_to_retrieve) & set(x[:4] for x in self.oligodict.keys()) #here we check the first 4 which is pdb code
        link_path = "https://www.ebi.ac.uk/pdbe/api/mappings/uniprot"
        shift_dict = defaultdict()
        
        with ThreadPoolExecutor() as executor:
            calculate_shift_bound = partial(self._calculate_shift)
            tasks = ((link_path, pdb) for pdb in pdbs_to_retrieve)
            # Map the bound function to the arguments in parallel
            results = executor.map(calculate_shift_bound, tasks)
            for result in results:
                for keys, vals in result.items():
                    shift_dict[keys] = vals
                    
        self.shifts = shift_dict

    def _calculate_shift(self, args):
        '''
        Helper function to compute the shift.
        Args: link_path to UNIPROT page and the pdb path.
        '''
        link_path, pdb = args
        shift_dict = defaultdict()
        searchp = f"{link_path}/{pdb[0:4]}"
        resp = self._get_url(searchp)
        resp = resp.json()
        for pdb_id, pdb_info in resp.items():
            for uniprot_id, uniprot_info in pdb_info['UniProt'].items():
                for mapping in uniprot_info['mappings']:
                    chain_id = mapping['chain_id']
                    unp_start = mapping['unp_start']
                    unp_end = mapping['unp_end']
                    author_start = mapping['start']['author_residue_number']
                    author_end = mapping['end']['author_residue_number']
    
                    if author_start is None:
                        author_start = unp_start
                    if author_end is None:
                        author_end = unp_end
                    shift_start = unp_start - author_start
                    shift_end = unp_end - author_end
                    shift_dict[f"{pdb_id}_{chain_id}"] = shift_start 
                    
        self.shift_dict = shift_dict
        return shift_dict

    
    def parallel_renumbering(self):
        '''
        Helper function to do parallelized renumbering.
        If already renumbered, don't do it again.
        '''
        if self.renumbered:
            print("You already renumbered your structures based on shift.")
            return  # Exit the function early

        if not self.shifts:
            print("You first need to obtain shifts which will be used as reference in order to start renumbering.\nCall .parallel_shift_calculation() first.")
            return  # Exit the function if no shifts are available

        # At this point, we know renumbering has not been done and shifts are available
        relevant_files = self.chain_seq_dict.keys()
        with ProcessPoolExecutor() as executor:
            # Using partial to create a function with fixed parameters (shift_dict, path)
            renumber_structure_partial = partial(self._renumber_structure, shift_dict=self.shifts, path=self.work_dir)
            # Map the renumbering function to each relevant file in parallel
            executor.map(renumber_structure_partial, relevant_files)

        self.renumbered = True

    
    def _renumber_structure(self, files, shift_dict, path):
        '''Function that is going to apply pdb_shiftres_by_chain.py to each pdb file that is shifted.
        Will apply renumbering to ALL structures if you did not set a cutoff previously and applied filter. 
        If filter applied for resolution will only renumber those structures that are left after filtering.'''
        for keys, vals in shift_dict.items():
            #dont renumber if there is not shift
            if files == keys[0:4] and vals != str(0):
                chain = keys[-1]
                shift = int(vals)
                filepath = f"{self.work_dir}/{files}.pdb"
                # Should we really shift by shift + 1??? or just shift?
                bash_cmd = f"python {self.script_dir}/pdb_shiftres_by_chain.py {filepath} {shift} {chain}"
                bash_cmd_rdy = bash_cmd.split()
            
                with open(f"{filepath}_tmp", "w") as fh_tmp:
                    result = run(bash_cmd_rdy, stdout=fh_tmp, stderr=PIPE, universal_newlines=True)
                    # Now replace the original one with the temp file.
                    os.replace(f"{filepath}_tmp", filepath)