In [132]:
import pandas as pd
import numpy as np
import plotly.plotly as py
import plotly.graph_objs as go

In [181]:
#Make toy dataframe to try mapping

#test_df ~= result of subset of columns in df made from genbank seqs
#each row is a sequence
country = ['United States', 'Chile', 'Spain', 'Japan', 'United States', 'Spain', 'Spain', 'Spain', 'Spain', 'Spain']
# country_code = ['USA', 'CHL', 'ESP', 'JPN', 'USA', 'ESP', 'ESP', 'ESP', 'ESP']
subtype= ['A','A','A','B','B','A', 'A', 'A', 'B', 'A']
year = [2010,2011,2010,2010,2010,2010,2010,2011,2011, 2012]

test_df = pd.DataFrame({'country':country, 'subtype': subtype, 'year':year})

#use lat and long so datapoints can be jittered to show multiple subtypes
#lat and long data from https://worldmap.harvard.edu/data/geonode:country_centroids_az8

lat_lon = pd.read_csv('country_centroids_az8.csv', \
                      usecols=['name','brk_a3','Longitude','Latitude']\
                     ).rename(columns={'name':'country', 'brk_a3':'country_code'})

#count number of rows(seqs) from each country that are each subtype
df_count = pd.DataFrame({'count' : test_df.groupby(['country', 'subtype']).size()}).reset_index()
df_count_time = pd.DataFrame({'count' : test_df.groupby(['country', 'subtype', 'year']).size()}).reset_index()

#compile country-specific subtype count data with lat and long for plotting
df_countries = df_count.merge(lat_lon, how='left', left_on='country', right_on='country')
df_countries_time = df_count_time.merge(lat_lon, how='left', left_on='country', right_on='country')

In [182]:
#Jitter points for countries that have multiple subtypes, so markers on map don't overlap
jitter_dict= {'A':1.0, 'B':-1.0}

test_group = df_countries.groupby('country').size()
test_group_time = df_countries_time.groupby('country').size()

#Without time
df_countries['adj_lon'] = np.where(test_group[df_countries['country']]>1, 
                                   (df_countries['Longitude']+df_countries.subtype.map(lambda x: jitter_dict[x])
                                   ), df_countries['Longitude'])

df_countries['adj_lat'] = np.where(test_group[df_countries['country']]>1, 
                                   (df_countries['Latitude']+df_countries.subtype.map(lambda x: jitter_dict[x])
                                   ), df_countries['Latitude'])

#With data separated by year
df_countries_time['adj_lon'] = np.where(test_group_time[df_countries_time['country']]>1, 
                                   (df_countries_time['Longitude']+df_countries_time.subtype.map(lambda x: jitter_dict[x])
                                   ), df_countries_time['Longitude'])

df_countries_time['adj_lat'] = np.where(test_group_time[df_countries_time['country']]>1, 
                                   (df_countries_time['Latitude']+df_countries_time.subtype.map(lambda x: jitter_dict[x])
                                   ), df_countries_time['Latitude'])


In [146]:
scale_markers = 10
map_list = []
cmap= {'A':'royalblue','B':'salmon'}



for i in range(len(df_countries)):

    map_country = dict(
        type = 'scattergeo',
#         locationmode = 'country names',
#         locations = [df_countries.loc[i,'country']],
        lat = [df_countries.loc[i,'adj_lat']],
        lon = [df_countries.loc[i,'adj_lon']],
        marker = dict(
            size = df_countries.loc[i,'count']*scale_markers,
            color = cmap[df_countries.loc[i,'subtype']],
            line = dict(width=0.5, color='rgb(40,40,40)'),
            opacity=0.5,
            sizemode = 'diameter'),
        hovertext = 'Subtype '+df_countries.loc[i,'subtype']+' : '+str(df_countries.loc[i,'count'])+' sequences',
        name = df_countries.loc[i,'country']+' '+df_countries.loc[i,'subtype'],
        hoverinfo = 'text+name'
    )
    map_list.append(map_country)

    
    
layout = dict(
        title = 'Global distribution of RSV',
        showlegend = False,
        geo = dict(
            scope='world',
            showland = True,
            landcolor = 'rgb(217, 217, 217)',
            countrywidth=1,
        ),
    )

fig = dict(data=map_list, layout=layout)
py.iplot(fig)

In [185]:
#With time slider

scale_markers = 10
map_list = []
cmap= {'A':'royalblue','B':'salmon'}
year_range = [2010,2011,2012]


for i in range(len(df_countries_time)):

    map_country = dict(
        type = 'scattergeo',
#         locationmode = 'country names',
#         locations = [df_countries.loc[i,'country']],
        lat = [df_countries_time.loc[i,'adj_lat']],
        lon = [df_countries_time.loc[i,'adj_lon']],
        marker = dict(
            size = df_countries_time.loc[i,'count']*scale_markers,
            color = cmap[df_countries_time.loc[i,'subtype']],
            line = dict(width=0.5, color='rgb(40,40,40)'),
            opacity=0.5,
            sizemode = 'diameter'),
        hovertext = 'Subtype '+df_countries_time.loc[i,'subtype']+' : '+str(df_countries_time.loc[i,'count'])+' sequences',
        name = df_countries_time.loc[i,'country']+' '+df_countries_time.loc[i,'subtype'],
        hoverinfo = 'text+name'
    )
    map_list.append(map_country)

steps = []
for year in year_range:
    step = dict(
    method = 'restyle', 
    label = year,
    args = ['visible', [False] * len(df_countries_time)])
    for i in range(len(df_countries_time)):
        if df_countries_time.loc[i,'year']==year:

            step['args'][1][i] = True # Toggle i'th year to "visible"
    steps.append(step)   
            
layout = dict(
        title = 'Global distribution of RSV',
        showlegend = False,
        sliders = [dict(
            steps = steps
        )],
        geo = dict(
            scope='world',
            showland = True,
            landcolor = 'rgb(217, 217, 217)',
            countrywidth=1,
        ),
    )

fig = dict(data=map_list, layout=layout)
py.iplot(fig)