In [1]:
import pandas as pd
import os
import re
import sys
import subprocess
from datetime import datetime
import json
import baltic as bt
import caffeine

module_dir = "/Users/monclalab1/Documents/scripts/"
sys.path.append(module_dir)

from fasta_editing import fasta_to_df, fasta_writer

In [2]:
#treetime needs a dates.csv file to root the tree
#the date column needs to be in decimal format

def create_date_csv(strain_date_map, output_dir):
    
    file_name = f"{output_dir}/strain_dates.csv"
    
    with open(file_name, 'w') as date_file:
        date_file.write("name, date\n")
        
        for strain, date_str in strain_date_map.items():

            date = datetime.strptime(date_str, '%Y-%m-%d')
            
            # convert to decimal year because thats what treetime uses
            dec_date = date.year + ((date.month - 1) * 30 + date.day) / 365.0
        
            date_file.write(f"{strain}, {dec_date:.2f}\n")
    
    return file_name


In [3]:
def treesort_prep(list_of_genes, home_folder, align_path, metadata_path,
                  newick_path, output_dir, ref_gene, reroot=False):
    
    descriptor_entries = [] 

    for gene in list_of_genes:

        align_file = f"{home_folder}/{align_path}/pruned_alignments_h3nx_{gene}.fasta"
        metadata_file = f"{home_folder}/{metadata_path}/metadata_h3nx_{gene}.tsv"
        nwk_file = f"{home_folder}/{newick_path}/h3nx_{gene}.nwk"
        output_align = f"{output_dir}/alignments/h3nx_{gene}.fasta"
        output_metadata = f"{output_dir}/metadata/h3nx_{gene}.csv"
        output_nwk = f"{output_dir}/div_tree/h3nx_{gene}.nwk"
        
        for path in [f"{output_dir}/alignments", f"{output_dir}/metadata", f"{output_dir}/div_tree"]:
            if not os.path.exists(path):
                os.makedirs(path)
            else:
                pass

        # metadata file

        metadata = pd.read_csv(metadata_file, sep='\t')
        metadata["date"] = metadata["date"].str.replace('XX', '01')
        
        # alignment file 
        align = fasta_to_df(align_file)
        align.header = align.header.str.replace(">", "")

        merged = pd.merge(align, metadata[['strain', 'date']], left_on='header', right_on='strain', how='left')
        merged["header"] = merged[['strain', 'date']].apply('|'.join, axis=1)
        merged["header"] = ">" + merged["header"]
        fasta_writer(f"{output_dir}/alignments/", f"h3nx_{gene}.fasta", merged)
       
        metadata["strain"] = metadata[['strain', 'date']].apply('|'.join, axis=1)
        metadata.to_csv(output_metadata, index=False)

        # stain date map for treetime
        strain_date_map = dict(zip(merged['strain'].str.replace(">", ""), merged['date']))
        # strain|date date map for the tree file
        aln_date_map = dict(zip(merged['header'].str.replace(">", ""), merged['date']))

        # tree file
        
        def replace_strain(match):
            strain = match.group(0)[:-1]  # remove trailing ':'
            if strain in strain_date_map:
                return f"{strain}|{strain_date_map[strain]}:"
            return match.group(0)
        
        with open(nwk_file, 'r') as file:
            nwk_data = file.read()
            
        strain_pattern = re.compile(r'\bA[^:]+:')
        updated_nwk_data = strain_pattern.sub(replace_strain, nwk_data)
        
        with open(output_nwk, 'w') as file:
            file.write(updated_nwk_data)


        gene_label = f"*{gene.upper()}" if gene == ref_gene else gene.upper()

        # rerooting using treetime if needed 
        if gene == ref_gene and reroot:
            create_date_csv(aln_date_map, output_dir)
            subprocess.run([
                "treetime", "clock",
                "--tree", output_nwk,
                "--dates", f"{output_dir}/strain_dates.csv",
                "--aln", output_align,
                "--outdir", f"{output_dir}/div_tree"
            ], check=True)
            descriptor_entries.append([gene_label, output_align, f"{output_dir}/div_tree/rerooted.newick"])
        else:
            descriptor_entries.append([gene_label, output_align, output_nwk])

    # descriptor csv
    with open("descriptor.csv", 'w') as descriptor_file:
        for row in descriptor_entries:
            descriptor_file.write(','.join(row) + '\n')

In [4]:
def wrap_prep(list_of_genes, home_folder, align_path, metadata_path,
              newick_path, output_dir, ref_gene, 
              reroot=False, treesort=True, mincut=False):
    
    
    treesort_prep(list_of_genes, home_folder, align_path, metadata_path,
                  newick_path, output_dir, ref_gene, reroot = False)
    
    if treesort:
        
        for i in range(5):
        
            if not os.path.exists(f"trees_{i}"):
                os.makedirs(f"trees_{i}")
            else:
                pass

            caffeine.on(display=False)

            if mincut:

                treesort_cmd = [
                            "treesort",
                            "-i", "descriptor.csv",
                            "-o", f"trees_{i}/annotated.tree",
                            "-m", "mincut"
                        ]
            else:

                treesort_cmd = [
                            "treesort",
                            "-i", "descriptor.csv",
                            "-o", f"trees_{i}/annotated.tree"
                        ]
            subprocess.run(treesort_cmd, check=True, capture_output=True, text=True)

In [None]:
list_of_genes = ["ha", "pb2","pb1","na","np","pa","ns","mp"]
wrap_prep(list_of_genes, 
          "preprepped", "aln", "metadata", "trees", 
          "prepped", "na", mincut =True)