# Generate a beast 2 xml

I would like to develop a template for an xml, where I can then slot in extra stuff. I will start with figuring out how to slot in estimated tip dates, given an xml that works. I will be working off of this tutorial/information page https://www.beast2.org/2015/06/09/sampling-tip-dates.html

In [1]:
import sys, subprocess, glob, os, shutil, re, importlib, Bio, csv
from subprocess import call
from Bio import SeqIO
import pandas as pd
import numpy as np
from time import gmtime, strftime

import datetime
from datetime import datetime
from dateutil.parser import parse

In [169]:
input_xml = "beast-runs/2021-06-10-mascot-3deme-skyline-with-mig-history/2021-06-17-mascot-3deme-skyline.xml"
output_xml = "beast-runs/2021-06-10-mascot-3deme-skyline-with-mig-history/2021-06-22-mascot-3deme-skyline-tipdates.xml"

In [170]:
"""the prior on the tip dates should be set in years, not in years from present"""

def determine_date_bounds(most_recent_tip_date, tip_year):
    
    # for a tip with a year-only tip, the maximum furthest back in time it could be is January 1st of that year
    #upper_bound = str(round((most_recent_tip_date - tip_year), 4))
    upper_bound = str(tip_year + 0.9959)
    
    # the most recent it could be is December 31st of that year; in decimal dates, that is year.9959
    #lower_bound = str(round((most_recent_tip_date - (tip_year + 0.9959)), 4))
    lower_bound = str(tip_year)
    
    return(upper_bound, lower_bound)

In [171]:
def generate_mrca_distribution_block(taxon_id, tree_id, upper_prior_bound_year, lower_prior_bound_year, count):
    
    strain_name = taxon_id.split("|")[0]
    strain_prior_id = strain_name + ".prior"
    
    line1 = "            <distribution id=\"insert_prior_id\" spec=\"beast.math.distributions.MRCAPrior\" tipsonly=\"true\" tree=\"@Tree.t:insert_tree_id\">"
    line2 = "                <taxonset id=\"insert_id\" spec=\"TaxonSet\">"
    line3 = "                    <taxon id=\"insert_taxon_id\" spec=\"Taxon\"/>"
    line4 = "                </taxonset>"
    line5 = "                <Uniform id=\"Uniform.insert_count_here\" lower=\"lower_prior_bound_year\" name=\"distr\" upper=\"upper_prior_bound_year\"/>"
    line6 = "            </distribution>"
    
    line1 = line1.replace("insert_prior_id",strain_prior_id)
    line1 = line1.replace("insert_tree_id",tree_id)
    line2 = line2.replace("insert_id",strain_name)
    line3 = line3.replace("insert_taxon_id",taxon_id)
    line5 = line5.replace("insert_count_here", str(count))
    line5 = line5.replace("lower_prior_bound_year",lower_prior_bound_year)
    line5 = line5.replace("upper_prior_bound_year",upper_prior_bound_year)

    block = "\n".join([line1,line2,line3,line4,line5,line6])
    
    return(block)

In [172]:
def generate_operator(taxon_id, tree_id):
    strain_name = taxon_id.split("|")[0]
    
    operator_line = "    <operator id=\"tipDatesSampler.insert_strain_id\" spec=\"TipDatesRandomWalker\" taxonset=\"@insert_strain_id\" tree=\"@Tree.t:insert_tree_id\" weight=\"1.0\" windowSize=\"1.0\"/>"
    operator_line = operator_line.replace("insert_strain_id",strain_name)
    operator_line = operator_line.replace("insert_tree_id",tree_id)
    
    return(operator_line)

In [173]:
def generate_logger(taxon_id):
    strain_name = taxon_id.split("|")[0]
    logger_line = "        <log idref=\"insert_strain_name.prior\"/>".replace("insert_strain_name",strain_name)
    return(logger_line)

In [174]:
def generate_tip_calibration_blocks(input_xml, tree_id, most_recent_tip):
    
    date_trait_blocks = []
    operators = []
    loggers = []
    counter = 5
    
    with open(input_xml, "r") as infile: 
        for line in infile: 
            
            # find the date trait line
            
            # there will sometimes be other parameters that use uniform priors, so print this to make sure that your
            # counter is set higher than any existing uniform block (otherwise there will be duplicate ids)
            if "Uniform." in line: 
                print(line)
            
            if "traitname=\"date\"" in line: 
                date_string = line.split("value=")[1]
                start_of_date_line = line.split("value=")[0]
                taxa = date_string.split(",")
                
                # loop through taxa and pull out year-only taxa
                reformatted_taxa = []
                all_dates = []
                for t in taxa:
                    taxon_id = t.split("=")[0]
                    taxon_id = taxon_id.replace("\"","")
                    long_date = t.split("|")[2]
                    all_dates.append(float(t.split("|")[1]))
                    
                    if "-XX-XX" in long_date:
                        counter += 1
                        year = float(long_date.split("-")[0])
                        
                        upper_bound_date, lower_bound_date = determine_date_bounds(most_recent_tip, year)
                        block = generate_mrca_distribution_block(taxon_id, tree_id, upper_bound_date, lower_bound_date, counter)
                        date_trait_blocks.append(block)
                        
                        operator = generate_operator(taxon_id, tree_id)
                        operators.append(operator)
                        
                        logger = generate_logger(taxon_id)
                        loggers.append(logger)
                        
                        old_decimal_date = "=" + long_date.split("-")[0] + ".0"
                        new_decimal_date = "=" + long_date.split("-")[0] + ".5"
                        reformatted_taxa.append(t.replace(old_decimal_date, new_decimal_date))
                        
                    else:
                        reformatted_taxa.append(t)
                        
    
    end_of_date_line = ",".join(reformatted_taxa)
    new_date_line = start_of_date_line + "value=" + end_of_date_line
    
    print("max_date", max(all_dates))
    
    return(new_date_line, date_trait_blocks, operators, loggers)

In [175]:
def write_new_xml(input_xml, output_xml, date_blocks, operators, loggers, new_datet_line):
    with open(output_xml, "w") as outfile: 
        outfile.write("")
    
    with open(input_xml, "r") as infile: 
        for line in infile:
            
            if "traitname=\"date\"" in line:
                with open(output_xml, "a") as outfile: 
                    outfile.write(new_date_line)

            elif "<!--  add tip calibration here -->" in line:
                date_block = "\n".join(date_blocks)
                with open(output_xml, "a") as outfile: 
                    outfile.write(line + "\n" + date_block + "\n" + "<!--  end of tip calibration -->")
            
            elif "<!-- insert tip calibration operators -->" in line:
                operators_block = "\n".join(operators)
                with open(output_xml, "a") as outfile: 
                    outfile.write(line + operators_block + "\n" + "<!--  end of tip calibration operators -->")
                
            elif "<!-- insert tip calibration loggers -->" in line:
                loggers_block = "\n".join(loggers)
                with open(output_xml, "a") as outfile: 
                    outfile.write(line + loggers_block + "\n" + "<!--  end of tip calibration loggers-->")
                
            else: 
                with open(output_xml, "a") as outfile: 
                    outfile.write(line)

In [176]:
tree_id = "aligned_h5n1_ha-3deme-1per-country-month-host-downsampled-bad-dates-2021-06-09-with-annotations-2021-06-08"
most_recent_tip_date = 2019.227

new_date_line, date_blocks, operators, loggers = generate_tip_calibration_blocks(input_xml, tree_id, most_recent_tip_date)

                <Uniform id="Uniform.3" name="distr"/>

max_date 2019.227


In [177]:
write_new_xml(input_xml, output_xml, date_blocks, operators, loggers, new_date_line)