Skip to content

Commit

Permalink
added regions colors to flu, merged multiflu
Browse files Browse the repository at this point in the history
  • Loading branch information
rneher committed Dec 31, 2016
2 parents 7c9ff8a + 9822cf4 commit 5470539
Show file tree
Hide file tree
Showing 6 changed files with 272 additions and 24 deletions.
42 changes: 42 additions & 0 deletions base/fetch_outgroups.py
@@ -0,0 +1,42 @@
from Bio import Entrez, SeqIO
from StringIO import StringIO

Entrez.email = "richard.neher@tuebingen.mpg.de" # Always tell NCBI who you are

outgroups = {#'H3N2':'U26830',
#'H3N2_77':'CY113261.1',
#'H1N1pdm':'AF455680',
#'Vic':'CY018813',
#'Yam':'CY019707',
#'zika':'KX369547.1'
#'vic_pb1':'CY018819',
#'vic_pb2':'CY018820',
#'vic_pa': 'CY018818',
#'vic_ha': 'CY018813',
#'vic_np': 'CY018816',
#'vic_na': 'CY018815',
#'vic_ma': 'CY018814',
#'vic_ns': 'CY018817',
'h1n1pdm_pb1':'AF455728',
'h1n1pdm_pb2':'AF455736',
'h1n1pdm_pa': 'AF455720',
'h1n1pdm_ha': 'AF455680',
'h1n1pdm_np': 'AF455704',
'h1n1pdm_na': 'AF455696',
'h1n1pdm_ma': 'AF455688',
'h1n1pdm_ns': 'AF455712',
'h3n2_pb1':'CY113683',
'h3n2_pb2':'CY113684',
'h3n2_pa': 'CY113682',
'h3n2_ha': 'CY113677',
'h3n2_np': 'CY113680',
'h3n2_na': 'CY113679',
'h3n2_ma': 'CY113678',
'h3n2_ns': 'CY113681',
}

for virus, genbank_id in outgroups.iteritems():
handle = Entrez.efetch(db="nucleotide", id=genbank_id, rettype="gb")
seq= SeqIO.read(StringIO(handle.read()), format = 'genbank')
SeqIO.write(seq, 'flu/metadata/'+virus+'_outgroup.gb', format='genbank')

4 changes: 2 additions & 2 deletions base/tree.py
Expand Up @@ -182,7 +182,7 @@ def geo_inference(self, attr):
infer a "mugration" model by pretending each region corresponds to a sequence
state and repurposing the GTR inference and ancestral reconstruction
'''
from treetime.gtr import GTR
from treetime import GTR
# Determine alphabet and store reconstructed ancestral sequences
places = set()
nuc_seqs = {}
Expand Down Expand Up @@ -228,7 +228,7 @@ def geo_inference(self, attr):
tmp_use_mutation_length = self.tt.use_mutation_length
self.tt.use_mutation_length=False
self.tt.infer_ancestral_sequences(method='ml', infer_gtr=True,
store_compressed=False, pc=5.0, marginal=True)
store_compressed=False, pc=5.0, marginal=True, normalized_rate=False)

# restore the nucleotide sequence and mutations to maintain expected behavior
self.tt.geogtr = self.tt.gtr
Expand Down
16 changes: 15 additions & 1 deletion ebola/ebola.py
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
import sys
sys.path.append('') # need to import from base
sys.path.append('/home/richard/Projects') # need to import from base
from base.io_util import make_dir, remove_dir, tree_to_json, write_json, myopen
from base.sequences import sequence_set, num_date
from base.tree import tree
Expand Down Expand Up @@ -140,8 +141,21 @@
ebola.load()

ebola.clock_filter(n_iqd=10, plot=True)
ebola.annotate_tree(Tc=0.0005, timetree=True, reroot='best', resolve_polytomies=True)
ebola.annotate_tree(Tc=0.002, timetree=True, reroot='best', resolve_polytomies=True)
for geo_attr in geo_attributes:
ebola.tree.geo_inference(geo_attr)
ebola.export(controls = attribute_nesting, geo_attributes = geo_attributes,
color_options=color_options, panels=panels)


# plot an approximate skyline
from matplotlib import pyplot as plt
T = ebola.tree.tt
plt.figure()
skyline = T.merger_model.skyline(n_points = 20, gen = 50/T.date2dist.slope,
to_numdate = T.date2dist.to_numdate)
plt.ticklabel_format(useOffset=False)
plt.plot(skyline.x, skyline.y)
plt.ylabel('effective population size')
plt.xlabel('year')
plt.savefig('ebola_skyline.png')
91 changes: 73 additions & 18 deletions flu/seasonal_flu.py
Expand Up @@ -10,23 +10,52 @@
regions = ['africa', 'south_asia', 'europe', 'china', 'north_america',
'china', 'south_america', 'japan_korea', 'oceania', 'southeast_asia']

region_cmap = [
["africa", "#5097BA"],
["south_america", "#60AA9E"],
["west_asia", "#75B681"],
["oceania", "#8EBC66"],
["europe", "#AABD52"],
["japan_korea", "#C4B945"],
["north_america", "#D9AD3D"],
["southeast_asia", "#E59637"],
["south_asia", "#E67030"],
["china", "#DF4327"]
];

region_groups = {'NA':'north_america',
'AS':['china', 'japan_korea', 'south_asia', 'southeast_asia'],
'OC':'oceania', 'EU':'europe'}

attribute_nesting = {'geographic location':['region', 'country', 'city'],}



color_options = {
"country":{"key":"country", "legendTitle":"Country", "menuItem":"country", "type":"discrete"},
"region":{"key":"region", "legendTitle":"Region", "menuItem":"region", "type":"discrete"},
"region":{"key":"region", "legendTitle":"Region", "menuItem":"region", "type":"discrete", "color_map":region_cmap},
"num_date":{"key":"num_date", "legendTitle":"Sampling date", "menuItem":"date", "type":"continuous"},
"ep":{"key":"ep", "legendTitle":"Epitope Mutations", "menuItem":"epitope mutations", "type":"continuous"},
"ne":{"key":"ne", "legendTitle":"Non-epitope Mutations", "menuItem":"nonepitope mutations", "type":"continuous"},
"rb":{"key":"rb", "legendTitle":"Receptor Binding Mutations", "menuItem":"RBS mutations", "type":"continuous"},
"gt":{"key":"genotype", "legendTitle":"Genotype", "menuItem":"genotype", "type":"discrete"}
"gt":{"key":"genotype", "legendTitle":"Genotype", "menuItem":"genotype", "type":"discrete"},
"cHI":{"key":"cHI", "legendTitle":"Antigenic advance", "menuItem":"Antigenic", "type":"continuous"}
}
panels = ['tree', 'entropy', 'frequencies']

outliers = {
'h3n2':["A/Sari/388/2006", "A/SaoPaulo/36178/2015", "A/Pennsylvania/40/2010", "A/Pennsylvania/14/2010",
"A/Pennsylvania/09/2011", "A/OSAKA/31/2005", "A/Ohio/34/2012", "A/Kenya/170/2011", "A/Kenya/168/2011",
"A/Indiana/21/2013", "A/Indiana/13/2012", "A/Indiana/11/2013", "A/Indiana/08/2012", "A/Indiana/06/2013",
"A/India/6352/2012", "A/HuNan/01/2014", "A/Helsinki/942/2013", "A/Guam/AF2771/2011", "A/Chile/8266/2003",
"A/Busan/15453/2009", "A/Nepal/142/2011", "A/Kenya/155/2011", "A/Guam/AF2771/2011", "A/Michigan/82/2016",
"A/Ohio/27/2016", "A/Ohio/28/2016", "A/Michigan/83/2016", "A/Michigan/84/2016", "A/Jiangsu-Tianning/1707/2013",
"A/HuNan/1/2014", "A/Iran/227/2014", "A/Iran/234/2014", "A/Iran/140/2014"],
'h1n1pdm': [],
'vic':[],
"yam":[]
}


clade_designations = {"h3n2":{
"3c3.a":[('HA1',128,'A'), ('HA1',142,'G'), ('HA1',159,'S')],
Expand Down Expand Up @@ -270,44 +299,55 @@ def plot_frequencies(flu, gene, mutation=None, plot_regions=None, all_muts=False
plt.ion()

parser = argparse.ArgumentParser(description='Process virus sequences, build tree, and prepare of web visualization')
parser.add_argument('-v', '--viruses_per_month', type = int, default = 10, help='number of viruses sampled per month')
parser.add_argument('-y', '--years_back', type = str, default = 3, help='number of years back to sample')
parser.add_argument('--resolution', type = str, help ="outfile suffix, can determine -v and -y")
parser.add_argument('-v', '--viruses_per_month_seq', type = int, default = 10, help='number of viruses sampled per month')
parser.add_argument('-w', '--viruses_per_month_tree', type = int, default = 10, help='number of viruses sampled per month')
parser.add_argument('-r', '--raxml_time_limit', type = float, default = 1.0, help='number of hours raxml is run')
parser.add_argument('-d', '--download', action='store_true', default = False, help='load from database')
parser.add_argument('-t', '--time_interval', nargs=2, help='specify time interval rather than use --years_back')
parser.add_argument('-l', '--lineage', type = str, default = 'h3n2', help='flu lineage to process')
parser.add_argument('--new_auspice', default = False, action="store_true", help='file name for new augur')
parser.add_argument('--load', action='store_true', help = 'recover from file')
parser.add_argument('--no_tree', default=False, action='store_true', help = "don't build a tree")
params = parser.parse_args()

# default values for --viruses_per_month and --years_back from resolution
if params.resolution == "2y":
params.viruses_per_month = 15
params.years_back = 2
params.viruses_per_month_tree = 15
params.viruses_per_month_seq = 20
params.years_back = 2
elif params.resolution == "3y":
params.viruses_per_month = 7
params.years_back = 3
params.viruses_per_month_tree = 7
params.viruses_per_month_seq = 20
params.years_back = 3
elif params.resolution == "6y":
params.viruses_per_month = 3
params.years_back = 6
params.viruses_per_month_tree = 3
params.viruses_per_month_seq = 10
params.years_back = 6
elif params.resolution == "12y":
params.viruses_per_month = 2
params.years_back = 12
params.viruses_per_month_tree = 2
params.viruses_per_month_seq = 10
params.years_back = 12

# construct time_interval from years_back
if not params.time_interval:
today_str = "{:%Y-%m-%d}".format(datetime.today())
date_str = "{:%Y-%m-%d}".format(datetime.today() - timedelta(days=365.25 * params.years_back))
params.time_interval = [date_str, today_str]

if params.new_auspice:
fname_prefix = "flu_"+params.lineage
else:
fname_prefix = params.lineage

input_data_path = '../fauna/data/'+params.lineage
if params.resolution:
store_data_path = 'store/'+params.lineage + '_' + params.resolution +'_'
build_data_path = 'build/'+params.lineage + '_' + params.resolution +'_'
store_data_path = 'store/'+ fname_prefix + '_' + params.resolution +'_'
build_data_path = 'build/'+ fname_prefix + '_' + params.resolution +'_'
else:
store_data_path = 'store/'+params.lineage + '_'
build_data_path = 'build/'+params.lineage + '_'
store_data_path = 'store/'+ fname_prefix + '_'
build_data_path = 'build/'+ fname_prefix + '_'

ppy = 12
flu = flu_process(input_data_path = input_data_path, store_data_path = store_data_path,
Expand All @@ -330,19 +370,22 @@ def plot_frequencies(flu, gene, mutation=None, plot_regions=None, all_muts=False
flu.seqs.filter(lambda s:
s.attributes['date']>=time_interval[0] and s.attributes['date']<time_interval[1])
flu.seqs.filter(lambda s: len(s.seq)>=900)
flu.seqs.filter(lambda s: s.name not in outliers[params.lineage])

flu.subsample(params.viruses_per_month)
flu.subsample(params.viruses_per_month_seq)
flu.align()
flu.dump()
# first estimate frequencies globally, then region specific
flu.estimate_mutation_frequencies(region="global", pivots=pivots)
for region in region_groups.iteritems():
flu.estimate_mutation_frequencies(region=region)
# for region in region_groups.iteritems():
# flu.estimate_mutation_frequencies(region=region)

if not params.no_tree:
flu.subsample(params.viruses_per_month_tree, repeated=True)
flu.align()
flu.build_tree()
flu.clock_filter(n_iqd=3, plot=True)
flu.tree.tt.debug=True
flu.annotate_tree(Tc=0.005, timetree=True, reroot='best')
flu.tree.geo_inference('region')

Expand All @@ -356,3 +399,15 @@ def plot_frequencies(flu, gene, mutation=None, plot_regions=None, all_muts=False
flu.export(extra_attr=['serum'], controls=attribute_nesting,
color_options=color_options, panels=panels)
flu.HI_export()

# plot an approximate skyline
from matplotlib import pyplot as plt
T = flu.tree.tt
plt.figure()
skyline = T.merger_model.skyline(n_points=20, gen = 50/T.date2dist.slope,
to_numdate = T.date2dist.to_numdate)
plt.ticklabel_format(useOffset=False)
plt.plot(skyline.x, skyline.y, lw=2)
plt.ylabel('effective population size')
plt.xlabel('year')
plt.savefig('%s_%s_skyline.png'%(params.lineage, params.resolution))
125 changes: 125 additions & 0 deletions flu/seasonal_flu_all_segments.py
@@ -0,0 +1,125 @@
from __future__ import division, print_function
from collections import defaultdict
import sys
sys.path.append('') # need to import from base
from base.process import process
import numpy as np
from datetime import datetime
from base.io_util import myopen

regions = ['africa', 'south_asia', 'europe', 'china', 'north_america',
'china', 'south_america', 'japan_korea', 'oceania', 'southeast_asia']

region_groups = {'NA':'north_america',
'AS':['china', 'japan_korea', 'south_asia', 'southeast_asia'],
'OC':'oceania', 'EU':'europe'}

attribute_nesting = {'geographic location':['region', 'country', 'city'],}


def sampling_category(x):
return (x.attributes['region'],
x.attributes['date'].year,
x.attributes['date'].month)


def sampling_priority(seq):
return len(seq.seq)*0.0001 - 0.01*np.sum([seq.seq.count(nuc) for nuc in 'NRWYMKSHBVD'])


if __name__=="__main__":
import argparse
import matplotlib.pyplot as plt
plt.ion()

parser = argparse.ArgumentParser(description='Process virus sequences, build tree, and prepare of web visualization')
parser.add_argument('-v', '--viruses_per_month', type = int, default = 10, help='number of viruses sampled per month')
parser.add_argument('-y', '--resolution', type = str, default = '3y', help='outfile suffix')
parser.add_argument('-r', '--raxml_time_limit', type = float, default = 1.0, help='number of hours raxml is run')
parser.add_argument('-d', '--download', action='store_true', default = False, help='load from database')
parser.add_argument('-t', '--time_interval', nargs=2, default=('2012-01-01', '2016-01-01'),
help='time interval to sample sequences from: provide dates as YYYY-MM-DD')
parser.add_argument('-l', '--lineage', type = str, default = 'h3n2', help='flu lineage to process')
parser.add_argument('--load', action='store_true', help = 'recover from file')
parser.add_argument('--no_tree', default=False, action='store_true', help = "don't build a tree")
params = parser.parse_args()

ppy = 12
time_interval = [datetime.strptime(x, '%Y-%m-%d').date() for x in params.time_interval]
pivots = np.arange(time_interval[0].year+(time_interval[0].month-1)/12.0,
time_interval[1].year+time_interval[1].month/12.0, 1.0/ppy)

# load data from all segments
segment_names = ['pb1', 'pb2', 'pa', 'ha', 'np', 'na', 'm', 'ns']
segments = {}
viruses = defaultdict(list)
for seg in segment_names:
input_data_path = '../fauna/data/'+params.lineage+'_'+seg
if seg=='m':
input_data_path+='p'
store_data_path = 'store/'+params.lineage + '_' + params.resolution + '_' + seg + '_'
build_data_path = 'build/'+params.lineage + '_' + params.resolution + '_' + seg + '_'
flu = process(input_data_path = input_data_path, store_data_path = store_data_path,
build_data_path = build_data_path, reference='flu/metadata/'+params.lineage + '_' + seg +'_outgroup.gb',
proteins=['SigPep', 'HA1', 'HA2'],
method='SLSQP', inertia=np.exp(-1.0/ppy), stiffness=2.*ppy)

flu.load_sequences(fields={0:'strain', 2:'isolate_id', 3:'date', 4:'region',
5:'country', 7:"city", 12:"subtype",13:'lineage'})

print("## loading data for segment %s, found %d number of sequences"%(seg, len(flu.seqs.all_seqs)))
for sequence in flu.seqs.all_seqs:
viruses[sequence].append(seg)

segments[seg] = flu

# determine strains that are complete
complete_strains = filter(lambda x:len(viruses[x])==len(segment_names), viruses.keys())
# determine filter every segment down to the sequences for which all other segments exist
segments['ha'].seqs.filter(lambda s: s.name in complete_strains)
segments['ha'].seqs.filter(lambda s:s.attributes['date']>=time_interval[0] and s.attributes['date']<time_interval[1])
segments['ha'].seqs.subsample(category = sampling_category, priority = sampling_priority, threshold = params.viruses_per_month)
strains_to_use = segments['ha'].seqs.seqs.keys()

# align and build tree
for seg, flu in segments.iteritems():
flu.seqs.filter(lambda s: s.name in strains_to_use)
flu.seqs.filter(lambda s:
s.attributes['date']>=time_interval[0] and s.attributes['date']<time_interval[1])
if seg!='ha':
flu.seqs.seqs = flu.seqs.all_seqs

flu.align()
flu.dump()
flu.build_tree()
flu.annotate_tree(Tc=0.005, timetree=True, reroot='best')
flu.tree.geo_inference('region')

flu.dump()
flu.export(extra_attr=[], controls=attribute_nesting)

# determine ladder rank strains in of every tree
ladder_ranks = defaultdict(list)
for seg in segment_names:
for leaf in segments[seg].tree.tree.get_terminals():
ladder_ranks[leaf.name].append(leaf.yvalue)


for seg in segment_names:
for leaf in segments[seg].tree.tree.get_terminals():
leaf.attr['ladder_ranks'] = np.array(ladder_ranks[leaf.name])
print(leaf.attr['ladder_ranks'])
leaf.nleafs=1

for seg in segment_names:
for node in segments[seg].tree.tree.get_nonterminals(order='postorder'):
node.nleafs = np.sum([x.nleafs for x in node])
node.attr['ladder_ranks'] = np.sum([x.nleafs*x.attr['ladder_ranks'] for x in node], axis=0)/node.nleafs
for seg in segment_names:
for leaf in segments[seg].tree.tree.find_clades():
leaf.attr['ladder_ranks'] = list(leaf.attr['ladder_ranks'])

# align and build tree
for seg, flu in segments.iteritems():
flu.dump()
flu.export(extra_attr=[], controls=attribute_nesting)

0 comments on commit 5470539

Please sign in to comment.