# Mod enzyme discovery

In [1]:
import os
import sys
import re
import gzip
import json
import pandas as pd
import requests
from bs4 import BeautifulSoup

class COG_OBJECT():
    def __init__(self, cog_source_path):
        ''' initialize the files from NCBI cog ftp site'''
        self.cog_source_path = cog_source_path
        for i in os.listdir(cog_source_path):
            if i.startswith('cog-20'):
               # self.taxa = i.split('.')[0]+'.tax.csv' # this file doesn't have useful info
                self.cog = i.split('.')[0]+'.cog.csv'
                self.cog_prod_g = i.split('.')[0]+'.def.tab'
                self.cog_org = i.split('.')[0]+'.org.csv'
                self.cog_seq = i.split('.')[0]+'.fa.gz'

        self.known_files = [self.cog,
                            self.cog_prod_g, self.cog_org, self.cog_seq]

    def flatten_json(self, y):
        ''' flatten json to make it easier to work with '''
        out = {}

        def flatten(x, name=''):
            if type(x) is dict:
                for a in x:
                    flatten(x[a], name + a + '_')
            elif type(x) is list:
                i = 0
                for a in x:
                    flatten(a, name + str(i) + '_')
                    i += 1
            else:
                out[name[:-1]] = x
        flatten(y)
        self.flat_cog_json = out

    def build_cog_json(self, flatten=False, genomes=False):
        ''' organize each file and its contents into a json, if genomes is true, then reutrn a dict with genomes by clade, if flatten is true, then flatten the json '''

        self.cog_json = {}
        self.clade_assemblies = {}
        for i in self.known_files:
            self.cog_json[i] = {'features': []}

            # read in file type by its extension
            if i.endswith('.gz'):
                with gzip.open(self.cog_source_path+i, 'rt') as f:

                    # read in like a fasta file
                    for line in f:
                        if line.startswith('>'):
                            gene_id = line.split()[0][1:]

                            # if not a NoneType then add string in between [] to json
                            if re.search('\[.*\]', line):
                                taxid = re.search('\[(.*?)\]', line).group(1)

                        # while not a new gene_id, add sequence to gene_id
                        while not line.startswith('>'):
                            self.cog_json[i]['features'].append(
                                {'genbank_prot_ID': gene_id, 'taxid': taxid, 'sequence': line.strip()+line.strip()})
                            break

            if i.endswith('.csv'):
                with open(self.cog_source_path+i, 'r') as f:
                    for line in f:
                        line = line.strip().split(',')

                        # add cog, gene, genbank ID, length, and assembly to json
                        if i == 'cog-20.cog.csv':
                            self.cog_json[i]['features'].append({'cog':  line[6], 'gene_id': line[0], 'genbank_prot_ID': line[2].replace(
                                '.', '_'), 'gene_length': line[3], 'footprint_coords': line[12], 'assembly': line[1]})

                        # add assembly, taxid, and phylum to json
                        if i == 'cog-20.org.csv':
                            self.cog_json[i]['features'].append(
                                {'assembly': line[0], 'taxid': line[1], 'phylum': line[3]})

                            # add assembly to clade_assemblies dict
                            if line[3] not in self.clade_assemblies:
                                self.clade_assemblies[line[3]] = []
                            self.clade_assemblies[line[3]].append(line[0])

            if i.endswith('.tab'):
                with open(self.cog_source_path+i, 'r', encoding='cp1252') as f:
                    for line in f:
                        line = line.strip().split('\t')
                        
                        # add cog, cog type, and product to json
                        self.cog_json[i]['features'].append(
                            {'cog': line[0], 'group': line[1], 'product': ' '.join(line[2:])})
        if flatten:
            self.flatten_json(self.cog_json)

        if genomes:
            return self.clade_assemblies

    def write_json_to_file(self, flatten=True):
        ''' write the flat json to a file '''
        if flatten:
            self.cog_json = self.flat_cog_json
            f_name = 'cog-20-flat.json'
        else:
            f_name = 'cog-20.json'
        with open(self.cog_source_path+f_name, 'w') as f:
            json.dump(self.cog_json, f)  # indent=4)

    def json_to_hdf(self):
        ''' read the flat json file from directory and convert to hdf format'''
        with pd.HDFStore(self.cog_source_path+'cog.h5') as store:
            with open(self.cog_source_path+'cog-20-flat.json', 'r') as json_file:
                for i, line in enumerate(json_file):
                    try:
                        flat_data = self.flatten_json(json.loads(line))
                        df = pd.DataFrame.from_dict([flat_data])
                        store.append('observations', df)
                    except:
                        print('error on line {}'.format(i))
                        continue

class GENOME_BATCH():
    def __init__(self, assembly_list, outpath):
        self.assembly_list = assembly_list
        self.outpath = outpath

    def get_NCBI_files(self, assembly, assembly_dir):
        ''' get NCBI files for assembly '''

        # split assembly up into thirds
        assembly_num = assembly.split('.')[0].split('_')[1]
        url = 'https://ftp.ncbi.nlm.nih.gov/genomes/all/GCA/{}/{}/{}/'.format(assembly_num[0:3], assembly_num[3:6], assembly_num[6:9])
        
        genbank_version = 'GCA_'+assembly.split('_')[1]
        # get the directory listing in url and find the child directory
        r = requests.get(url)
        soup = BeautifulSoup(r.text, 'html.parser')
        for i in soup.find_all('a'):
            if genbank_version in i.get('href'):
                child_dir = i.get('href')
                break
        try:
            # download the content from the child directory
            r = requests.get(url+child_dir)
            soup = BeautifulSoup(r.text, 'html.parser')
            for i in soup.find_all('a'):
                if i.get('href').endswith('.gz'):
                    # check if file already exists
                    if not os.path.exists(assembly_dir+i.get('href')):
                        with open(assembly_dir+i.get('href'), 'wb') as f:
                            f.write(requests.get(url+child_dir+i.get('href')).content)
        except:
            print('error on assembly {}'.format(assembly), 'it\'s probably becuase a child directory doesn\'t exist')


    def make_genome_dirctory(self):
        ''' for every assembly in assembly list make a directory in outpath and get NCBI files '''
        for clade in self.assembly_list:
            clade_dir = self.outpath+clade+'/'
            if not os.path.exists(clade_dir):
                os.makedirs(clade_dir)
            for assembly in self.assembly_list[clade]:
                assembly_dir = clade_dir+assembly+'/'
                if not os.path.exists(assembly_dir):
                    os.makedirs(assembly_dir)
                self.get_NCBI_files(assembly, assembly_dir)
                



if __name__ == '__main__':
    cog_source_path = '/projects/lowelab/users/jsleavit/data/cogs/COG_ftp_files/'
    clade_assemblies = COG_OBJECT(cog_source_path).build_cog_json(genomes=True)
    
    '''
    # GENOME_BATCH failed after 40 min but probably doesn't need to be executed again since we have the genomes of interest
    #GENOME_BATCH(clade_assemblies, '/projects/lowelab/users/jsleavit/data/cogs/cog_genomes/').make_genome_dirctory()
    '''
    # cog = COG_OBJECT(cog_source_path)
    # cog.build_cog_json(flatten=True) # if want to convert to hdf, set flatten=True
    # cog.write_json_to_file(flatten=True)
    # cog.json_to_hdf() # doesn't appear to work...


In [8]:
from Bio.Align.Applications import MafftCommandline
#from StringIO import StringIO
from Bio import AlignIO
import os
import sys
import re
import gzip
import json
import pandas as pd

class COG_STATS():
    def __init__(self, cog_json_path, genbank_genome_path):
        self.cog_json_path = cog_json_path
        self.genbank_genome_path = genbank_genome_path


    def read_json(self):
        ''' brings in the cog json from the COG_OBJECT class '''
        with open(self.cog_json_path, 'r') as f:
            self.cog_json = json.load(f)
    
    def count_json(self, outpath):
        ''' count the total number of genes for a given cog in each genome (warning: this took 1 hour to complete) '''

        total_cog_counts = {}
        for i in self.cog_json['cog-20.cog.csv']['features']:
            phylum = self.get_phylum(i['assembly'])
            self.get_cog_taxid_clan(i['cog'], phylum)
            species = self.assembly_taxid[i['assembly'].split('.')[0].split('_')[1]]
            total_cog_counts[(phylum,species,i['cog'],self.cog_clan[i['cog']])] = total_cog_counts.get((phylum,species,i['cog'],self.cog_clan[i['cog']]), 0) + 1


        cog_count_json = {'features':[]}
        for tax, cog_count in total_cog_counts.items():
            p,s,co,cl = zip(tax)
            token = {'species':s[0],'phylum':p[0],'cog':co[0],'clan':cl[0], 'gene_count':cog_count}
            cog_count_json['features'].append(token)
        
        with open(outpath+'cog-20_count.json', 'w') as f:
            json.dump(cog_count_json, f)
            
    
    def get_clade_species(self, clade):
        ''' helper function get all the cogs for a given clade with taxid identifier '''
        self.clade = clade
        self.clade_cogs = []
        for i in self.cog_json['cog-20.org.csv']['features']:
            if i['phylum'] == clade:
                self.clade_cogs.append(i['taxid'])
    
    def get_cog_gene_ids(self, cog):
        ''' helper function get all the gene ids for a given cog '''
        self.cog = cog
        self.cog_gene_ids = []
        for i in self.cog_json['cog-20.cog.csv']['features']:
            if i['cog'] == cog:
                self.cog_gene_ids.append(i['genbank_prot_ID'])

    def get_cog_taxid_clan(self, cog, clade):
        ''' helper function get all the taxids by assembly for a cog and its clan '''

        self.cog_clan = {}
        self.assembly_taxid = {}

        # cog-20.org.csv has the assembly info with taxid
        for i in self.cog_json['cog-20.org.csv']['features']:
            if i['phylum'] == clade:
                short_assembly = i['assembly'].split('.')[0].split('_')[1]
                self.assembly_taxid[short_assembly] = i['taxid']

        # cog-20.def.tab has the clan info
        for i in self.cog_json['cog-20.def.tab']['features']:
            if i['cog'] == cog:
                self.cog_clan[cog] = i['group']
                
    def get_phylum(self, assembly):
        ''' helper function given the assembly get the respective phylum'''
        for i in self.cog_json['cog-20.org.csv']['features']:
            if i['assembly'].split('.')[0].split('_')[1] == assembly.split('.')[0].split('_')[1]:
                return i['phylum']

    def get_cog_seqs(self, clade=None, cog=None):
        ''' get the sequences for a given cog '''

        self.get_cog_gene_ids(cog)
        self.get_clade_species(clade)

        
        cog_seqs = {}
        for i in self.cog_json['cog-20.fa.gz']['features']:
            if i['genbank_prot_ID'] in self.cog_gene_ids:
                # check if the tax id is in any element of the clade list
                if any('_'.join(i['taxid'].split(' ')[0:2]) in s for s in self.clade_cogs): 
                    cog_seqs[i['genbank_prot_ID']] = i['sequence']

        # print brief summary of sequences for the cog
        print('cog {} has {} sequences in clade {}'.format(cog, len(cog_seqs), self.clade))
        self.cog_seqs = cog_seqs

    def align_cog_seqs(self, outpath):
        ''' take in the dictionary of geneID and seq aka cog_seqs and align them'''

        # initialize alignment path for building hmm profile function later
        self.alignment_path = outpath

        # if outpath doesn't exist, create it
        if not os.path.exists(outpath):
            os.makedirs(outpath)
        
        # align sequences and write as a stockholm file
        for i in self.cog_seqs:
            #write temp fasta file
            seq = self.cog_seqs[i]
            with open('temp.fa', 'a') as f:
                f.write('>{}\n{}\n'.format(i, seq))


        # align with mafft
        mafft_cline = MafftCommandline(input='temp.fa')
        stdout, stderr = mafft_cline()
        # write alignment to temp file
        with open('temp.fa', 'w') as f:
            f.write(stdout)
        # read stockholm file into biopython
        alignment = AlignIO.parse(open('temp.fa'), 'fasta')
        #stk_alignment = alignment.format('stockholm')
        # remove all temp files
        os.remove('temp.fa')
        AlignIO.write(alignment, outpath+'alignment_{}_{}.sto'.format(self.cog,self.clade), 'stockholm')
        print('alignment for {} in {} written to {}'.format(self.cog, self.clade, outpath))
    
    def build_hmm_profile(self,outpath):
        ''' build hmm profile from stockholm alignment '''

        # initialize hmm profile path for hmmsearch function later
        self.hmm_path = outpath

        # if outpath doesn't exist, create it
        if not os.path.exists(outpath):
            os.makedirs(outpath)

        for aln in os.listdir(self.alignment_path):
            if aln.endswith('.sto'):
                
                # build hmm profile
                os.system('hmmbuild {} {}'.format(outpath+aln.replace('.sto','.hmm'), self.alignment_path+aln) + ' > {}hmm_build_{}_{}.log'.format(outpath,self.cog,self.clade))

    def remove_illegal_characters(self, genome_path):
        ''' remove '-' from the sequences because hmmsearch can't handlge it '''
        with gzip.open(genome_path, 'rt') as f:
            # write back to file without the '-' character
            with open(genome_path.replace('.gz',''), 'w') as g:
                for line in f:
                    if line.startswith('>'):
                        g.write(line)
                    else:
                        g.write(line.replace('-',''))

    def search_genome_cds(self, outpath,retrieve_only=False):
        ''' search genome cds with hmm profile '''

        if retrieve_only == True:
            self.hmmsearch_path = outpath
            return

        # initialize hmmsearch path for parsing stats later
        self.hmmsearch_path = outpath

        # if outpath doesn't exist, create it
        if not os.path.exists(outpath):
            os.makedirs(outpath)
        
        # get all the hmm profiles
        hmm_files = {}
        for hmm in os.listdir(self.hmm_path):
            if hmm.endswith('.hmm'):
                if hmm.split('_')[1] not in hmm_files:
                    hmm_files[hmm.split('_')[2].replace('.hmm', '')] = {}
                    hmm_files[hmm.split('_')[2].replace('.hmm', '')][hmm.split('_')[1]] = self.hmm_path+hmm

        # get all the genome cds 
        genome_cds = {}
        for genome in os.listdir(self.genbank_genome_path+self.clade):
            for file in os.listdir(self.genbank_genome_path+self.clade+'/'+genome):
                if '_cds.faa.gz' in file:
                    self.remove_illegal_characters(self.genbank_genome_path+self.clade+'/'+genome+'/'+file)
                    genome_cds[genome] = self.genbank_genome_path+self.clade+'/'+genome+'/'+file.replace('.gz','')
        n_genomes = []       
        n = 0
        for clade in os.listdir(self.genbank_genome_path):
            if clade == self.clade:
                for genome in os.listdir(self.genbank_genome_path+clade):

                    # search hmm profile
                    if genome in genome_cds:

                        # only search if results file doesn't exist
                        if not os.path.exists(outpath+'{}results_{}_{}_{}.hmmsearch_.out'.format(outpath, self.cog, self.clade, '_'.join(genome.split('.')))):
                            os.system('hmmsearch -o {}results_{}_{}_{}.hmmsearch.out {} {} '.format(outpath, self.cog, self.clade, '_'.join(genome.split('.')), hmm_files[clade][self.cog], genome_cds[genome]))
                            print('hmmsearch of {}  {}  in {} written to {}'.format(self.clade, self.cog, genome, outpath))

                    else:
                            if not os.path.exists(outpath+'{}results_{}_{}_{}.hmmsearch.out'.format(outpath, self.cog, self.clade, '_'.join(genome.split('.')))):
                                try:
                                    os.system('hmmsearch -o {}results_{}_{}_{}.hmmsearch.out {} {} '.format(outpath, self.cog, self.clade, '_'.join(genome.split('.')), hmm_files[clade][self.cog], genome_cds['GCA_'+genome.split('_')[1]] ))
                                except:
                                    print('no genome cds for {}'.format(genome))
                                    n_genomes.append(genome)
                                    n += 1

        print('{} genomes have no cds file'.format(n), '{}'.format(n_genomes) if n_genomes != [] else '')
        print('{} hmmsearch failed'.format(f))

    def count_my_hmmsearch(self):
        ''' parses the hmm search file to get counts of cogs  '''

        # count number of hits per genome
        cog, clade = None, None
        hmmsearch_results = {'features':[]}
        for hmmsearch in os.listdir(self.hmmsearch_path+'tblout_results/'):
            if hmmsearch.endswith('.hmmsearch.out'):
                assembly = hmmsearch.split('.')[0].split('_')[-2] # might want this later to access the species name
                cog, clade  = hmmsearch.split('.')[0].split('_')[1], hmmsearch.split('.')[0].split('_')[2]
                self.get_cog_taxid_clan(cog,clade)
                if hmmsearch.endswith('.hmmsearch.out'):
                    count = 0
                    with open(self.hmmsearch_path+'tblout_results/'+hmmsearch, 'r') as f:
                        for line in f:
                            
                            # each entry is on a line without a # in front
                            if not line.startswith('#'):
                                count += 1

                        hmmsearch_results['features'].append({'species': self.assembly_taxid[assembly] , 'phylum': clade, 'cog': cog, 'clan': self.cog_clan[cog], 'gene_count': count})
        self.hmmsearch_results = hmmsearch_results

    def compare_cog_counts(self, cog_count_path):
        ''' compare my hmmsearch results to the COG database '''
        # load the cog count json


        with open(cog_count_path, 'r') as f:
            cog_count = json.load(f)


        my_clades = set()
        my_cogs = set()
        my_species = set()
        species_cog_count = {}
        for clade in self.hmmsearch_results['features']:
            my_clades.add(clade['phylum'])
            my_cogs.add(clade['cog'])
            my_species.add(clade['species'])
            if clade['species'] not in species_cog_count:
                species_cog_count[clade['species']] = clade['gene_count']
        
        cog, hom = 'COG0590', 'HOM0590'
        sub_cog_count = {}
        for i in cog_count['features']:
            if i['cog'] == cog:
                sub_cog_count[i['species']] = i['gene_count']

        # add the cog counts to comparision dict
        comparison = {'features':[]}
        for i in self.hmmsearch_results['features']:
            if i['species'] in sub_cog_count:
                if i['gene_count'] >= sub_cog_count[i['species']]:
                    comparison['features'].append({'species': i['species'], 'phylum': i['phylum'], 'cog': hom, 'clan': i['clan'], 'gene_count': i['gene_count']})
                    comparison['features'].append({'species': i['species'], 'phylum': i['phylum'], 'cog': cog, 'clan': i['clan'], 'gene_count': sub_cog_count[i['species']]})
        with open(self.hmmsearch_path+'comparison_hmmsearch_NCBI_COG_COG0590_no_zero.json', 'w') as f:
            json.dump(comparison, f)
        


if __name__ == '__main__':
    cog_object_path = '/projects/lowelab/users/jsleavit/data/cogs/COG_ftp_files/cog-20.json'
    genbank_genome_path = '/projects/lowelab/users/jsleavit/data/cogs/cog_genomes/'
    cog_counts_path = '/projects/lowelab/users/jsleavit/data/cogs/COG_ftp_files/cog-20_count.json'
    CS = COG_STATS(cog_object_path, genbank_genome_path)
    CS.read_json()
    #CS.count_json('/projects/lowelab/users/jsleavit/data/cogs/COG_ftp_files/') probably don't ever need to run this unless adding something to the json

    ''' run this chunk for doing initial hmmsearch, if hasn't been done before '''
    # CS.get_cog_seqs(clade='EURYARCHAEOTA', cog = 'COG0590')
    # CS.align_cog_seqs('/projects/lowelab/users/jsleavit/data/cogs/cog_alignments/')
    # CS.build_hmm_profile('/projects/lowelab/users/jsleavit/data/cogs/cog_hmms/')
    # CS.search_genome_cds('/projects/lowelab/users/jsleavit/data/cogs/cog_hmmsearch/')

    ''' run this chunk if the hmmsearch already has been done '''
    CS.search_genome_cds('/projects/lowelab/users/jsleavit/data/cogs/cog_hmmsearch/', retrieve_only=True)
    CS.count_my_hmmsearch()
    CS.compare_cog_counts(cog_counts_path)
  


    ''' needs work for multiple clades '''
    # clades = ['EURYARCHAEOTA', 'THAUMARCHAEOTA', 'CRENARCHAEOTA', 'OTHER\ ARCHAEA']
    # #clades = [ 'CRENARCHAEOTA', 'OTHER\ ARCHAEA']
    # statue = True
    # for clade in clades:
    #     CS.get_cog_sequences(clade=clade, cog = 'COG0590')
    #     #CS.get_cog_sequences(clade='EURYARCHAEOTA', cog='COG0590')
    #     status=CS.align_cog_seqs(outpath='/projects/lowelab/users/jsleavit/data/cogs/cog_alignments/')
    #     if status:
    #         CS.build_hmm_profile(outpath='/projects/lowelab/users/jsleavit/data/cogs/cog_hmms/')
    #         CS.search_genome_cds(outpath='/projects/lowelab/users/jsleavit/data/cogs/cog_hmmsearch/')
    #     else:
    #         print(clade+ 'doesn\'t have seed for alignment and hmmsearch')

In [None]:
from ete3 import Tree


''' Collect the top hits from each branch  and use hmmalign to align them '''
class FastAreader():
    def __init__(self, fname=''):
        ''' contructor: saves attribute fname '''
        self.fname = fname

    def doOpen(self):
        if self.fname == '':
            return sys.stdin
        else:
            return open(self.fname)

    def readFasta(self):
        header = ''
        sequence = ''

        with self.doOpen() as fileH:
            header = ''
            sequence = ''

            # skip to first fasta header
            line = fileH.readline()
            while not line.startswith('>'):
                line = fileH.readline()
            header = line[1:].rstrip()

            for line in fileH:
                if line.startswith('>'):
                    yield header, sequence
                    header = line[1:].rstrip()
                    sequence = ''
                else:
                    sequence += ''.join(line.rstrip().split()).upper()

        yield header, sequence

class HMM_STATS():
    def __init__(self, hmmsearch_path, hmm_path, hmmalign_path, cog_genome_path,clade,tree_path):
        self.hmmsearch_path = hmmsearch_path
        self.hmm_path = hmm_path
        self.hmmalign_path = hmmalign_path
        self.cog_genome_path = cog_genome_path+clade+'/'

        self.clade = clade

        # read in assembly to genome dictionary
        self.assembly_to_genome = json.load(open(cog_genome_path+'assembly_to_species.json', 'r'))

        # read in taxonomic assignment dictionary
        self.taxa_assignment = json.load(open(cog_genome_path+'species_taxonomic_assignment.json', 'r'))

        # read in tree
        self.tree = Tree(tree_path)
    

    def collect_top_hits(self):
        ''' collect the gene name of the top hit from each branch '''
        top_hit_genes = {}
        for hmm in os.listdir(self.hmmsearch_path):
            assembly = hmm.split('.')[0].split('_')[-2]
            if self.clade in hmm and hmm.endswith('.hmmsearch.out'):
                with open(self.hmmsearch_path+hmm, 'r') as f:
                    first = True
                    for line in f:
                        # if line doesn't start with #, then it is a hit
                        if not line.startswith('#'):
                            if first:
                                gene = line.split()[0]
                                if assembly not in top_hit_genes:
                                    top_hit_genes[assembly] = []
                                top_hit_genes[assembly].append(gene)
                                first = False
        self.top_hit_genes = top_hit_genes

    def collect_sequences_from_genome(self):
        ''' use the top hit gene dict to extract the genes '''

        # if hmm_align_path doesn't exist, make it
        if not os.path.exists(self.hmmalign_path):
            os.mkdir(self.hmmalign_path)

        for assembly in os.listdir(self.cog_genome_path):
            short_name_ass =assembly.split('.')[0].split('_')[1]
            species = self.assembly_to_genome[short_name_ass]
            if short_name_ass in self.top_hit_genes:
                for file in os.listdir(self.cog_genome_path+assembly):
        
                    if file.endswith('_cds.faa'):
                        FA = FastAreader(self.cog_genome_path+assembly+'/'+file)
                        for head, seq in FA.readFasta():
                    
                            if head.split()[0] in self.top_hit_genes[short_name_ass]:
                                # get the locus tag from head
                                locus_tag = [x.split('=')[1].strip('[').strip(']') for x in head.split() if 'locus_tag' in x]
                                protein_id = [x.split('=')[1].strip('[').strip(']') for x in head.split() if 'protein_id' in x]
                                if len(locus_tag) == 1:
                                    locus_tag = locus_tag[0]
                                else:
                                    locus_tag = protein_id[0]
                                head = species+'_'+locus_tag
                                with open(self.hmmalign_path+head+'.faa', 'w') as f:
                                    f.write('>'+head+'\n'+seq+'\n')
                
         
                  
    def align_sequences(self, phylum):
        ''' use the hmmalign to align the sequences '''


        # if hmm_align_path doesn't exist, make it
        if not os.path.exists(self.hmmalign_path):
            os.mkdir(self.hmmalign_path)

        # concatenate all the sequences into one file
        with open(self.hmmalign_path+phylum+'_top_hits.faa', 'w') as f:
            for file in os.listdir(self.hmmalign_path):
                if file.endswith('.faa'):
                    with open(self.hmmalign_path+file, 'r') as f2:
                        for line in f2:
                            f.write(line)

        for hmm in os.listdir(self.hmm_path):
            if hmm.endswith('.hmm'):
               
                # if the hmm file is for the phylum we are looking at
                if hmm.split('.')[0].split('_')[-1] == phylum:

                    os.system('hmmalign --trim --amino --informat FASTA -o '+self.hmmalign_path+phylum+'_top_hits.sto '+self.hmm_path+hmm+' '+ self.hmmalign_path+phylum+'_top_hits.faa ')#+self.hmmalign_path+'ALL_EURYARCHAEOTA_TOP_HITS.sto')
                    print('hmm profile: '+hmm+' aligned with '+phylum+' sequences')
    def collect_tree_genera(self):
        ''' collect the genera for the r_tree '''

        species2genus = {}
        for i in self.tree.iter_leaves():
            leaf_name = ' '.join(i.name.split('_')[0:2])
            if leaf_name in self.taxa_assignment:
                genus = self.taxa_assignment[leaf_name][-3]

                if genus is None:
                    # use previous iteration genus if genus is None ( this is safe because we are iterating through the tree in topographical order of relationships)
                    genus = prev_genus
                    species2genus[leaf_name] = genus

                species2genus[leaf_name] = genus
                prev_genus = genus

        self.species2genus = species2genus
        


    def subset_top_hits_by_genera(self):
        ''' use the species2genus dict to pull species by genera from the fasta file '''

        # read in the fasta file
        genera_found = set()
        FA = FastAreader(self.hmmalign_path+self.clade+'_top_hits.faa')
        for head, seq in FA.readFasta():
            species = ' '.join(head.split('_')[0:2])
            if species in self.species2genus:
                genus = self.species2genus[species]
                genera_found.add(genus)
                # open a directory for the genus if it doesn't exist and write a new fasta file for the genus
                if not os.path.exists(self.hmmalign_path+genus):
                    os.mkdir(self.hmmalign_path+genus)
                with open(self.hmmalign_path+genus+'/'+head+'.faa', 'w') as f:
                    f.write('>'+head+'\n'+seq+'\n')
                
        self.genera_found = genera_found

    def hmm_align_genera(self):
        ''' if there is more than one sequence use the hmm align with the respective phylum '''
    
        for genus in self.genera_found:
            # if there is more than one sequence in the genus
            if len(os.listdir(self.hmmalign_path+genus)) > 1:
                # concatenate all the sequences into one file
                with open(self.hmmalign_path+genus+'/'+genus+'_top_hits.faa', 'w') as f:
                    for file in os.listdir(self.hmmalign_path+genus):
                        if file.endswith('.faa'):
                            with open(self.hmmalign_path+genus+'/'+file, 'r') as f2:
                                for line in f2:
                                    f.write(line)

                for hmm in os.listdir(self.hmm_path):
                    if hmm.endswith('.hmm'):
                        # if the hmm file is for the phylum we are looking at (this is still using the phylum level profile)
                        if hmm.split('.')[0].split('_')[-1] == self.clade: 

                            os.system('hmmalign --trim --amino --informat FASTA -o '+self.hmmalign_path+genus+'/'+genus+'_top_hits.sto '+self.hmm_path+hmm+' '+ self.hmmalign_path+genus+'/'+genus+'_top_hits.faa ')    
    

if __name__ == '__main__':
    hmmsearch_path = '/projects/lowelab/users/jsleavit/data/cogs/cog_hmmsearch/tblout_results/' # has the hmmsearch results
    hmm_path = '/projects/lowelab/users/jsleavit/data/cogs/cog_hmms/' # has the hmm profiles for each cog
    hmmalign_path = '/projects/lowelab/users/jsleavit/data/cogs/cog_hmmalign/' # will have the hmmalign results
    cog_genome_path = '/projects/lowelab/users/jsleavit/data/cogs/cog_genomes/' # has the genomes for each species
    tree_path = '/projects/lowelab/users/jsleavit/data/rna_db_out/archaea/a_16S_23S_combined_names_with_cogs.sina.aligned.fa.final_tree.nw'

    HS = HMM_STATS(hmmsearch_path, hmm_path, hmmalign_path, cog_genome_path, 'EURYARCHAEOTA', tree_path)


    ''' this chunk collects the top hits from hmmsearch results, finds the hit in the organisms CDS genome, and aligns them all together '''
    HS.collect_top_hits()
    HS.collect_sequences_from_genome()
    HS.align_sequences('EURYARCHAEOTA')


    ''' subset the top hits and  group specific genera based on their phylogeny and then align them individually '''

    # need to read in the tree
    # then categorize the species in the tree from the species_taxonomic_assignment.json in cog_genoms folder
    HS.collect_tree_genera()
    HS.subset_top_hits_by_genera()
    HS.hmm_align_genera()




    # HS.get_top_hits()
    # HS.align_top_hits()