In [1]:
import baltic as bt
import pandas as pd
import json
import os
import matplotlib as mpl
from matplotlib import pyplot as plt
import requests
from io import StringIO as sio
from matplotlib.patches import Patch
import matplotlib.ticker as ticker
import itertools
import re
import sys
import subprocess
from Bio import Phylo

module_dir = "/Users/monclalab1/Documents/scripts/"
sys.path.append(module_dir)

from fasta_editing import fasta_to_df, fasta_writer

list_of_genes = ["ha", "pb2","pb1","na","np","pa","ns","mp"] 

In [2]:
''' 
1. gets rid of the index stuff and turns tree 1 into Tree tree1 
2. replaces commas with "-" (where the reassorted segments are inferred)
3. remove the single quotation marks around NODE_####
4. replaces ? with _ (where there is an undetermined reassortment event)
5. removes parantheses so that phylo.bio can add in node names 
    - must be removed again using add_back_para() so baltic can read it in
    
'''

def treesort_QC(qc_input, qc_output):
    
    with open(qc_input, 'r') as file:
        nexus = file.read()

    # removing commas between segments
    modified_nexus = re.sub(r'&rea="([^"]+)"', lambda match: f'&rea="{match.group(1).replace(",", "-")}"', nexus)
    # removing quotation marks around node names
    modified_nexus = re.sub(r"'(NODE_([0-9]{7}))'", r'\1', modified_nexus)
    # replacing ? with _ so baltic can read it in
    modified_nexus = modified_nexus.replace('?', '_')
    # removing parenthese around reassorting segments augur traits and phylo bio works on it
    modified_nexus = re.sub(r'\((\d+)\)', r'^\1', modified_nexus)
    # removing excess info before tree starts (try just turning into nwk file?)
    modified_nexus = re.sub(r"(?<=BEGIN TREES;)(\s+TREE 1 = \[.*?\])", r"\n    Tree tree1 = ", modified_nexus, flags=re.DOTALL)

    with open(qc_output, 'w') as output_file:
        output_file.write(modified_nexus)

In [3]:
# using bio phylo package instead of baltic because it can generate tree files
# dont really need to do this now that treesort updated code to name the nodes
def name_nodes(naming_input, naming_output):

    tree = Phylo.read(naming_input, "nexus")

    # making sure we don't have duplicate node names
    node_names = {node.name for node in tree.find_clades() if node.name is not None}

    counter = 0

    for node in tree.find_clades():

        # setting the node names as keys becuase right now they are stored 
        # as the value. this ignores leaves since they have names
        if node.confidence is not None and node.name is None:
            node.name = node.confidence

        # naming the unnamed bifurcated nodes and setting them as keys
        if node.name is None:
            counter += 1
            potential_node_name = f"NODE_{str(counter).rjust(7, '0')}"
            while potential_node_name in node_names:
                counter += 1
                potential_node_name = f"NODE_{str(counter).rjust(7, '0')}"

            node.name = potential_node_name
            # print(potential_node_name, node.comment, node.branch_length)

        # treetime doesn't like quotation marks around leaf names
        else:
            node.name = node.name.replace("'","")

    # set branch_length_only to true otherwise it throws errors because 
    # "confidence" is not a float (its the node name)
    Phylo.write(tree, naming_output, "newick", branch_length_only=True)

    # phylo read adds in back splashes when it reads/writes the tree
    # removing that here
    with open(naming_output, "r") as f:
        new_tree = f.read()

        cleaned = (
            new_tree
            .replace("\\'", "")
            .replace("\\[", "")
            .replace("\\]", "")
        )

    with open(naming_output, "w") as f:
        f.write(cleaned)
        
    return(naming_output)

In [4]:
def add_back_para(para_input, para_output):

    with open(para_input, 'r') as file:
        nexus = file.read()

    added = re.sub(r'\^(\d+)', r'(\1)', nexus)

    with open(para_output, 'w') as output_file:
        output_file.write(added)
        
    return(para_output)

In [5]:
def reassortment_counter(counter_input, counter_output):
    
    mytree = bt.loadNewick(counter_input, absoluteTime= False, verbose=False)

    rea_dict = {}
    segments = ["PB2", "PB1", "PA", "HA", "NP", "NA", "MP", "NS"]

    for k in mytree.Objects:
        if k.traits["is_reassorted"]:
            # Extract and clean reassorted segments
            raw_segments = k.traits["rea"]
            segment_names = [seg.split("(")[0] for seg in raw_segments.split("-")]
            reassorted_segments = f"{len(segment_names)} ({', '.join(segment_names)})"

            if k.is_node():
                rea_dict[k.traits.get("label")] = {
                    "Reassorted": "True",
                    "Reassorted Segments": reassorted_segments
                }
            elif k.is_leaf():
                rea_dict[k.name] = {
                    "Reassorted": "True",
                    "Reassorted Segments": reassorted_segments
                }
        else:
            key = (k.traits["label"] if k.is_node() else k.name)
            rea_dict[key] = {"Reassorted": "False"}

    branch_dict = {}
    for k in mytree.Objects:
        if k.traits["is_reassorted"]:
            # Use the cleaned reassorted segments from `rea_dict`
            if k.is_node():
                branch_dict[k.traits.get("label")] = {
                    "labels": {'Reassorted Segments': rea_dict[k.traits.get("label")]['Reassorted Segments']}
                }
            elif k.is_leaf():
                branch_dict[k.name] = {
                    "labels": {'Reassorted Segments': rea_dict[k.name]['Reassorted Segments']}
                }


    out_dict = {'nodes': rea_dict, 'branches': branch_dict}

    with open(counter_output, 'w') as f:
        json.dump(out_dict, f)


In [6]:
def wrapper_func(qc_input, qc_output, naming_output, para_output, counter_output):
    
    treesort_QC(qc_input, qc_output)
    name_nodes(qc_output, naming_output)
    add_back_para(naming_output, para_output)
    reassortment_counter(para_output, counter_output)

In [9]:
ref = "mp"
for i in range(1,4):
    
    wrapper_func(f"trees_{i}/annotated.tree", 
                 f"trees_{i}/modified_no-para_treesort.tre",
                 f"trees_{i}/named_h3nx_{ref}.nwk", 
                 f"trees_{i}/for_plotting.nwk", 
                 f"trees_{i}/h3nx_{ref}_rea.json"
                )
    
nextstrain_cmd = ["nextstrain", "build", "."]
subprocess.run(nextstrain_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)

CalledProcessError: Command '['nextstrain', 'build', '.']' returned non-zero exit status 1.

In [10]:
nextstrain_cmd = ["nextstrain", "build", "."]
subprocess.run(nextstrain_cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)

CompletedProcess(args=['nextstrain', 'build', '.'], returncode=0)