# Phylomap - tool for projecting a phylogenetic tree onto a map 

This tool should:

1. Load metadata and get countries coordinates
3. Load phylogentic tree.
4. Draw world map.
5. Draw phylogenetic tree.
6. Associate countries will phylogenetic tree bracnhes
7. Plotly implementation.

In [40]:
import pandas as pd
import requests
from Bio import Phylo
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots


In [43]:
def get_country_coordinates(country):
    print(f"Getting coordinates for {country}")
    url = "{0}{1}{2}".format(
        "http://nominatim.openstreetmap.org/search?country=",
        country,
        "&format=json&polygon=0",
    )
    response = requests.get(url).json()[0]
    coordinates = [response.get(key) for key in ["lat", "lon"]]
    output = [float(i) for i in coordinates]
    return output

def merge_coordinates_with_metadata(metadata):
    coordinates = pd.DataFrame(columns=["Country", "Latitude", "Longitude"])
    countries = set(metadata["Country"])
    for country in countries:
        country_name = country.lower().replace(" ", "+")
        coordinates_list = get_country_coordinates(country_name)
        country_data = pd.DataFrame(
            [[country, coordinates_list[0], coordinates_list[1]]],
            columns=["Country", "Latitude", "Longitude"]
        )
        coordinates = pd.concat([coordinates, country_data], ignore_index=True)
    merged_metadata = metadata.merge(coordinates, on="Country")
    return merged_metadata

def get_x_coordinates(tree):
    xcoords = tree.depths(unit_branch_lengths=True)
    return xcoords

def get_x_coordinates(tree):
    xcoords = tree.depths()
    if not max(xcoords.values()):
        xcoords = tree.depths(unit_branch_lengths=True)
    return xcoords

def get_y_coordinates(tree, dist=1.3):
    maxheight = tree.count_terminals() 
    ycoords = dict((leaf, maxheight - i * dist) for i, leaf in enumerate(reversed(tree.get_terminals())))
    def calc_row(clade):
        for subclade in clade:
            if subclade not in ycoords:
                calc_row(subclade)
        ycoords[clade] = (ycoords[clade.clades[0]] +
                          ycoords[clade.clades[-1]]) / 2

    if tree.root.clades:
        calc_row(tree.root)
    return ycoords

def draw_clade(clade, x_start, line_shapes, line_color='rgb(15,15,15)', line_width=1, x_coords=0, y_coords=0):
    x_curr = x_coords[clade]
    y_curr = y_coords[clade]
    branch_line = get_clade_lines(orientation='horizontal', y_curr=y_curr, x_start=x_start, x_curr=x_curr,
                                  line_color=line_color, line_width=line_width)
    line_shapes.append(branch_line)

    if clade.clades:
        y_top = y_coords[clade.clades[0]]
        y_bot = y_coords[clade.clades[-1]]

        line_shapes.append(get_clade_lines(orientation='vertical', x_curr=x_curr, y_bot=y_bot, y_top=y_top,
                                           line_color=line_color, line_width=line_width))

        # Draw descendants
        for child in clade:
            draw_clade(child, x_curr, line_shapes, x_coords=x_coords, y_coords=y_coords)
def get_clade_lines(orientation='horizontal', y_curr=0, x_start=0, x_curr=0, y_bot=0, y_top=0,
                    line_color='rgb(25,25,25)', line_width=0.5):
    branch_line = dict(type='line',
                       layer='below',
                       line=dict(color=line_color,
                                 width=line_width)
                       )
    if orientation == 'horizontal':
        branch_line.update(x0=x_start,
                           y0=y_curr,
                           x1=x_curr,
                           y1=y_curr)
    elif orientation == 'vertical':
        branch_line.update(x0=x_curr,
                           y0=y_bot,
                           x1=x_curr,
                           y1=y_top)
    else:
        raise ValueError("Line type can be 'horizontal' or 'vertical'")
    return branch_line


def draw_clade(clade, x_start, line_shapes, line_color='rgb(15,15,15)', line_width=1, x_coords=0, y_coords=0):

    x_curr = x_coords[clade]
    y_curr = y_coords[clade]

    branch_line = get_clade_lines(orientation='horizontal', y_curr=y_curr, x_start=x_start, x_curr=x_curr,
                                  line_color=line_color, line_width=line_width)

    line_shapes.append(branch_line)

    if clade.clades:
        y_top = y_coords[clade.clades[0]]
        y_bot = y_coords[clade.clades[-1]]

        line_shapes.append(get_clade_lines(orientation='vertical', x_curr=x_curr, y_bot=y_bot, y_top=y_top,
                                           line_color=line_color, line_width=line_width))

        for child in clade:
            draw_clade(child, x_curr, line_shapes, x_coords=x_coords, y_coords=y_coords)
            
def create_plot(tree, x_coords, y_coords, metadata, title):
    line_shapes = []
    draw_clade(tree.root, 0, line_shapes, line_color='rgb(25,25,25)', line_width=1, x_coords=x_coords, y_coords=y_coords)
    
    X, Y, text, color = [], [], [], []
    color_map = generate_country_color_map(metadata)
    label_legend = set(metadata["Country"].unique())
    color_scale = {country: color_map.get(country, 'rgb(100,100,100)') for country in label_legend}

    for cl in x_coords.keys():
        if cl.is_terminal():
            X.append(x_coords[cl])
            Y.append(y_coords[cl])
            node_id = cl.name
            country = metadata.loc[metadata["ID"] == node_id, "Country"].values[0]
            text.append(f'Country: {country}<br>ID: {node_id}')
            color.append(color_scale[country])

    trace = go.Scatter(
        x=X,
        y=Y,
        mode='markers',
        marker=dict(color=color, size=5),
        text=text,
        hoverinfo='text',
        name='Countries'
    )

    layout = go.Layout(
        title= title,
        paper_bgcolor='rgba(0,0,0,0)',
        xaxis=dict(title='Branch Length'),
        yaxis=dict(
            showline=False,
            zeroline=False,
            showgrid=False,
            showticklabels=False,
            title=''
        ),
        hovermode='closest',
        shapes=line_shapes,
        plot_bgcolor='rgb(250,250,250)',
        legend={'x': 0, 'y': 1}
    )


    fig = go.Figure(data=[trace], layout=layout)
    return fig
    
def generate_country_color_map(metadata):
    unique_countries = metadata["Country"].unique()
    color_map = {}
    available_colors = [
    "rgb(31, 119, 180)", "rgb(255, 127, 14)", "rgb(44, 160, 44)",
    "rgb(214, 39, 40)", "rgb(148, 103, 189)", "rgb(140, 86, 75)",
    "rgb(227, 119, 194)", "rgb(127, 127, 127)", "rgb(188, 189, 34)",
    "rgb(23, 190, 207)", "rgb(240, 228, 66)", "rgb(65, 244, 47)",
    "rgb(502, 102, 152)", "rgb(204, 204, 204)", "rgb(200, 36, 17)",
    "rgb(114, 147, 203)", "rgb(83, 81, 84)", "rgb(147, 160, 61)",
    "rgb(169, 170, 68)", "rgb(193, 190, 70)", "rgb(93, 162, 233)"
    ]

    for i, country in enumerate(unique_countries):
        color_map[country] = available_colors[i % len(available_colors)]

    return color_map

def create_tree(tree, metadata, title = ""):
    x_coords = get_x_coordinates(tree)
    y_coords = get_y_coordinates(tree)
    fig = create_plot(tree, x_coords, y_coords, metadata, title)
    return fig

1. Load metadata and get countries coordinates

{Clade(branch_length=0.01388, name='1700_04112013_AXDK00-0'): 109.0, Clade(branch_length=0.01695, name='3004_24042007_AXDK00-0'): 107.7, Clade(branch_length=0.00017, name='0803_26042014_AXDK00-0'): 106.4, Clade(branch_length=0.00023, name='0803_27042007_AXDK00-0'): 105.1, Clade(branch_length=6e-05, name='0803_01012000_AXDK00-0'): 103.8, Clade(branch_length=0.0006, name='6402_27032017_AXDK00-0'): 102.5, Clade(branch_length=0.00015, name='0703_07082013_AXDK00-0'): 101.2, Clade(branch_length=6e-05, name='8603_03072008_AXDK00-0'): 99.9, Clade(branch_length=0.00015, name='8603_27082014_AXDK00-0'): 98.6, Clade(branch_length=0.00409, name='2103_11042016_AXDK00-0'): 97.3, Clade(branch_length=8e-05, name='1704_07022011_AXDK00-0'): 96.0, Clade(branch_length=0.00358, name='1704_01011999_AXDK00-0'): 94.7, Clade(branch_length=8e-05, name='6203_01012000_AXDK00-0'): 93.4, Clade(branch_length=0.00021, name='6804_08032006_AXDK00-0'): 92.1, Clade(branch_length=0.00081, name='6604_25012007_AXDK00-0'): 90

In [42]:
metadata = pd.read_csv("data/MetaData.csv")
metadata_with_coordinates = merge_coordinates_with_metadata(metadata)

Getting coordinates for spain



The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.



Getting coordinates for france
Getting coordinates for angola
Getting coordinates for italy
Getting coordinates for estonia
Getting coordinates for turkey
Getting coordinates for algeria
Getting coordinates for lithuania
Getting coordinates for poland
Getting coordinates for mexico
Getting coordinates for canada
Getting coordinates for egypt
Getting coordinates for germany
Getting coordinates for denmark
Getting coordinates for united+states
Getting coordinates for libya


2. Load phylogenetic tree
https://github.com/plotly/dash-phylogeny


In [25]:
tree = Phylo.read("data/tree.nwk", "newick")

In [38]:
def create_world_map(metadata, title = ""):
    color_map = generate_country_color_map(metadata)

    fig = go.Figure()

    for country, color in color_map.items():
        country_data = metadata[metadata['Country'] == country]
        ids = '<br>'.join(country_data['ID'].astype(str).tolist())
        hover_text = f"Country: {country}<br>ID(s): {ids}"

        fig.add_trace(go.Choropleth(
            locations=[country],
            locationmode="country names",
            z=[1], 
            colorscale=[[0, color], [1, color]],
            hoverinfo='text',
            text=hover_text,
            colorbar=dict(
                title='',
                tickvals=[],
                ticktext=[]
            ),
            showscale=False  
        ))

    fig.update_geos(
        resolution=110,
        showcoastlines=True,
        coastlinecolor='rgb(255, 255, 255)',
        showland=True,
        landcolor='rgb(217, 217, 217)',
        showocean=True,
        oceancolor='rgb(199, 215, 255)',
        showcountries=True,
        countrycolor='rgb(0,0,0)',
        countrywidth=0.5,
        subunitcolor='rgb(255,255,255)',
        lonaxis=dict(range=[-180, 180]),
        lataxis=dict(range=[-90, 90]),
    )

    fig.update_layout(
        title_text= title
    )

    return fig


In [44]:
def combine_tree_and_world_map(tree_plot, world_map_fig, title = ""):
    # Extract traces and layout from the tree plot
    tree_traces = tree_plot['data']
    tree_layout = tree_plot['layout']

    # Extract traces and layout from the world map plot
    world_map_traces = world_map_fig['data']
    world_map_layout = world_map_fig['layout']

    # Create a subplot with 1 row and 2 columns
    combined_fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=[tree_layout['title']['text'], world_map_layout['title']['text']],
        shared_yaxes=True,
        specs=[[{"type": "scatter"}, {"type": "choropleth"}]]
    )

    # Add world map traces and layout to the first subplot
    for trace in world_map_traces:
        combined_fig.add_trace(trace, row=1, col=2)
    combined_fig.update_layout(world_map_layout)

    # Add tree traces and layout to the second subplot
    for trace in tree_traces:
        combined_fig.add_trace(trace, row=1, col=1)

    combined_fig.update_layout(tree_layout)

    # Update the layout of the combined figure
    combined_fig.update_layout(
        title_text=title,
        showlegend=False,
        paper_bgcolor='white',  # Set background color to white
    )

    # Show the combined figure
    combined_fig.show()

# Example usage
tree_plot = create_tree(tree, metadata, title = "My tree plot")
world_map_fig = create_world_map(metadata_with_coordinates, title="My world map")

combine_tree_and_world_map(tree_plot, world_map_fig)

