# Analyze distance to outgroup comparators

Get variables from `snakemake`:

In [None]:
early_seq_subs_file = snakemake.input.early_seq_subs
early_seq_alignment_file = snakemake.input.early_seq_alignment
deleted_diffs_file = snakemake.input.deleted_diffs
deleted_alignment_file = snakemake.input.deleted_consensus
comparator_map_file = snakemake.input.comparator_map
region_of_interest = snakemake.params.region_of_interest
comparators = snakemake.params.comparators
min_frac_coverage = snakemake.params.min_frac_coverage
samples = snakemake.params.samples
aligners = snakemake.params.aligners
ref_genome_name = snakemake.params.ref_genome_name
ignore_muts_before = snakemake.params.ignore_muts_before
ignore_muts_after = snakemake.params.ignore_muts_after
last_date = snakemake.params.phylo_last_date
muts_to_ignore = snakemake.params.phylo_muts_to_ignore
collapse_rare_muts = snakemake.params.phylo_collapse_rare_muts
filter_rare_variants = snakemake.params.phylo_filter_rare_variants
min_frac_called = snakemake.params.phylo_min_frac_called

alignment_all_fasta = snakemake.output.alignment_all_fasta
alignment_region_fasta = snakemake.output.alignment_region_fasta
alignment_all_csv = snakemake.output.alignment_all_csv
alignment_region_csv = snakemake.output.alignment_region_csv
early_seq_count_charts = snakemake.output.early_seq_counts
early_seq_deltadist_charts = snakemake.output.early_seq_deltadist
early_seq_deltadist_region_charts = snakemake.output.early_seq_deltadist_region
deltadist_jitter_charts = snakemake.output.deltadist_jitter
deleted_diffs_latex = snakemake.output.deleted_diffs_latex

Import Python packages:

In [None]:
import collections
import itertools
import os
import re

import altair as alt

import altair_saver

import Bio.SeqIO

import numpy

import pandas as pd

_ = alt.data_transformers.disable_max_rows()

## Early GISAID sequences
Read early sequence substitutions and comparator map:

In [None]:
early_seq_subs = (
    pd.read_csv(early_seq_subs_file, na_filter=None)
    .assign(date=lambda x: pd.to_datetime(x['date']))
    )

comparator_map = pd.read_csv(comparator_map_file)

Annotate whether strains from Wuhan, elsewhere in China, or outside China:

In [None]:
assert early_seq_subs['country'].notnull().all()

def location_category(row):
    if 'Wuhan' == row['location']:
        return 'Wuhan'
    elif 'Wuhan' in row['strain'] and 'Beijing' not in row['strain']:
        # there are some strains with hCoV-19/Beijing/Wuhan_IME-BJ01/2020
        # which appear to actually be from Beijing according to most annotations
        return 'Wuhan'
    elif 'China' in row['country']:
        return 'other China'
    else:
        return 'outside China'

early_seq_subs['location_category'] = early_seq_subs.apply(location_category, axis=1)

early_seq_subs.groupby('location_category').aggregate(n_seqs=pd.NamedAgg('strain', 'count'))

Plot number of sequences from each location as function of date:

In [None]:
# get counts per week as here, with counts being the week
# before that date
location_date_df = (
    early_seq_subs
    [['strain', 'location_category', 'date']]
    # https://stackoverflow.com/a/45281439/4191652
    .groupby(['location_category',
              pd.Grouper(key='date', freq='W-FRI')],
              )
    .aggregate(nseqs=pd.NamedAgg('strain', 'count'))
    .reset_index()
    )

# make chart
location_date_chart = (
    alt.Chart(location_date_df)
    .encode(x=alt.X('date:T',
                    axis=alt.Axis(labelAngle=-90,
                                  values=location_date_df['date'].unique(),
                                  format='%b %d',
                                  tickCount=location_date_df['date'].nunique(),
                                  ),
                    title='week in 2019 or 2020',
                    ),
            y=alt.Y('nseqs',
                    title='sequences from prior week',
                    ),
            color=alt.Color('location_category:N',
                            legend=alt.Legend(title='location'),
                            scale=alt.Scale(range=['#E69F00', '#009E73', '#F0E442'])
                            ),
            tooltip=['date',
                     alt.Tooltip('nseqs',
                                 title='number of sequences',
                                 ),
                     alt.Tooltip('location_category',
                                 title='location',
                                 ),
                     ],
            )
    .mark_line(point=True)
    .configure_point(size=65)
    .configure_axis(grid=False)
    .properties(width=275,
                height=150,
                )
    )

for f in early_seq_count_charts:
    print(f"Saving to {f}")
    if os.path.splitext(f) == '.html':
        location_date_chart.save(f)
    else:
        altair_saver.save(location_date_chart, f)

location_date_chart

Define a function that computes the total change in Hamming distance from each comparator relative to the reference based on the substitutions:

In [None]:
site_to_ref = comparator_map.set_index('site')['reference'].to_dict()
comparator_site_to_nt = {comparator: comparator_map.set_index('site')[comparator].to_dict()
                         for comparator in comparators}

def delta_distance_comparator(subs_str, comparator):
    """Total change in Hamming distance from comparator relative to reference."""
    site_to_nt = comparator_site_to_nt[comparator]
    n = 0
    for s in [s for s in subs_str.split(',') if s]:
        m = re.fullmatch('(?P<wt>[ACGT])(?P<site>\d+)(?P<mut>[ACGT])', s)
        if not m:
            raise ValueError(f"cannot match {s}")
        wt = m.group('wt')
        site = int(m.group('site'))
        mut = m.group('mut')
        assert site_to_ref[site] == wt
        comp = site_to_nt[site]
        if comp in 'ACGT':
            if mut == comp:
                n -= 1
            elif mut != comp:
                n += 1
        elif comp not in ['-', 'N']:
            raise ValueError(f"invalid comparator identity {comp}")
    return n

Apply this function to each comparator:

In [None]:
for comparator in comparators:
    early_seq_subs[f"{comparator}_delta_dist"] = (early_seq_subs
                                                  ['substitutions']
                                                  .apply(delta_distance_comparator,
                                                         args=(comparator,)
                                                         )
                                                  )

Make a tidy data frame with these delta distances:

In [None]:
assert len(early_seq_subs) == early_seq_subs['strain'].nunique()

delta_dist = (
    early_seq_subs
    .melt(id_vars=['strain', 'gisaid_epi_isl', 'date', 'location_category',
                   'frac_called', 'substitutions', 'huanan_market'],
          value_vars=[f"{comparator}_substitutions" for comparator in comparators],
          var_name='outgroup',
          value_name='substitutions_to_outgroup')
    .assign(outgroup=lambda x: x['outgroup'].str.replace('_substitutions', ''),
            n_substitutions=lambda x: x['substitutions'].map(lambda s: len([s for s in s.split(',') if s])),
            )
    .merge(early_seq_subs.melt(id_vars='strain',
                               value_vars=[f"{comparator}_delta_dist" for
                                           comparator in comparators],
                               var_name='outgroup',
                               value_name='delta_distance_to_outgroup',
                               )
                         .assign(outgroup=lambda x: x['outgroup'].str
                                                    .replace('_delta_dist', '')
                                 ),
           on=['strain', 'outgroup'],
           validate='one_to_one',
           )
    )

assert len(delta_dist) == len(early_seq_subs) * len(comparators)

Now do the same thing just for the region of interest:

In [None]:
start = region_of_interest['start']
end = region_of_interest['end']
def subs_in_region(subs_str):
    return ','.join(s for s in subs_str.split(',')
                    if s and start <= int(s[1 : -1]) <= end)
    
early_seq_subs_region = (
    early_seq_subs
    .assign(substitutions=lambda x: x['substitutions'].map(subs_in_region),
            frac_called=lambda x: x['frac_called_in_region_of_interest'])
    )

for comparator in comparators:
    early_seq_subs_region[f"{comparator}_delta_dist"] = (
                                                early_seq_subs_region
                                                ['substitutions']
                                                .apply(delta_distance_comparator,
                                                       args=(comparator,)
                                                       )
                                                )
    early_seq_subs_region[f"{comparator}_substitutions"] = (
                        early_seq_subs_region
                        [f"{comparator}_substitutions"]
                        .map(subs_in_region)
                        )
    
delta_dist_region = (
    early_seq_subs_region
    .melt(id_vars=['strain', 'gisaid_epi_isl', 'date', 'location_category',
                   'frac_called', 'substitutions', 'huanan_market'],
          value_vars=[f"{comparator}_substitutions" for comparator in comparators],
          var_name='outgroup',
          value_name='substitutions_to_outgroup')
    .assign(outgroup=lambda x: x['outgroup'].str.replace('_substitutions', ''),
            n_substitutions=lambda x: x['substitutions'].map(lambda s: len([s for s in s.split(',') if s])),
            )
    .merge(early_seq_subs_region.melt(id_vars='strain',
                                      value_vars=[f"{comparator}_delta_dist" for
                                                  comparator in comparators],
                                      var_name='outgroup',
                                      value_name='delta_distance_to_outgroup',
                                      )
                                .assign(outgroup=lambda x: x['outgroup'].str
                                                           .replace('_delta_dist', '')
                                        ),
           on=['strain', 'outgroup'],
           validate='one_to_one',
           )
    )

assert len(delta_dist_region) == len(early_seq_subs) * len(comparators)

Get identity at site 28144 and add to data frame of delta distances for region of interest:

In [None]:
nt_28144 = {}
for s in Bio.SeqIO.parse(early_seq_alignment_file, 'fasta'):
    nt_28144[s.id] = str(s.seq[28143]).upper()

delta_dist_region = (
    delta_dist_region
    .assign(nt_28144=lambda x: x['strain'].map(nt_28144))
    )
assert delta_dist_region['nt_28144'].notnull().all()

display(delta_dist_region
        [['strain', 'location_category', 'nt_28144']]
        .drop_duplicates()
        .groupby(['location_category', 'nt_28144'])
        .aggregate(n_strains=pd.NamedAgg('strain', 'count'))
        )

Function to plot delta distances versus date of isolation:

In [None]:
delta_dist_chart_height = 170

outgroup_selection = alt.selection_single(
    name='sequence',
    fields=['outgroup'],
    bind=alt.binding_select(options=comparators),
    init={'outgroup': comparators[0]},
    )

jitter_slider = alt.selection_single(
        name='y_axis_jitter',
        fields=['amount'],
        init={'amount': 0.2},
        bind=alt.binding_range(min=0, max=1)
        )

def rand_jitter(n, seed):
    """Calculate random jitter as here:
    https://www.geeksforgeeks.org/how-to-make-stripplot-with-jitter-in-altair-python/
    """
    numpy.random.seed(seed)
    return (numpy.sqrt(-2 * numpy.log(numpy.random.rand(n))) *
            numpy.cos(2 * numpy.pi * numpy.random.rand(n)))

def get_delta_distance_plot(df):
    
    y_extent = df['delta_distance_to_outgroup'].max() - df['delta_distance_to_outgroup'].min()
    
    delta_distance_points = (
        alt.Chart(df
                  .assign(jitter_y=y_extent / 20 * rand_jitter(len(df), seed=1))
                  )
        .encode(x=alt.X('date:T',
                        axis=alt.Axis(labelAngle=-90),
                        title='date in 2019 or 2020',
                        ),
                y=alt.Y('y:Q',
                        title='relative mutations from outgroup',
                        scale=alt.Scale(nice=False),
                        axis=alt.Axis(tickMinStep=1),
                        ),
                color=alt.Color('location_category:N',
                                scale=alt.Scale(range=['#E69F00', '#009E73', '#F0E442']),
                                legend=None,
                                ),
                shape=alt.Shape('huanan_market:N',
                                legend=alt.Legend(orient='top',
                                                  symbolFillColor='#E69F00',
                                                  symbolStrokeColor='#E69F00',
                                                  offset=0,
                                                  symbolSize=50,
                                                  ),
                                scale=alt.Scale(range=['circle', 'square']),
                                title='from seafood market',
                                ),
                tooltip=['strain',
                         alt.Tooltip('gisaid_epi_isl',
                                     title='GISAID ID'),
                         'date',
                         alt.Tooltip('n_substitutions',
                                     title='number substitutions'),
                         'substitutions',
                         alt.Tooltip('substitutions_to_outgroup',
                                     title='substitutions to outgroup'),
                         alt.Tooltip('frac_called',
                                     title='fraction sites called',
                                     format='.3f'),
                         alt.Tooltip('huanan_market:N',
                                     title='from seafood market'),
                         ],
                )
        .mark_point(filled=True,
                    opacity=0.5,
                    size=30,
                    )
        .transform_filter(outgroup_selection)
        .transform_calculate(
            y='datum.delta_distance_to_outgroup + datum.jitter_y * y_axis_jitter.amount'
            )
        .properties(height=delta_dist_chart_height,
                    width=270,
                    )
        )

    delta_distance_lines = (
        delta_distance_points
        .transform_regression('date', 'delta_distance_to_outgroup',
                              groupby=['location_category'])
        .encode(color=alt.value('#999999'),
                y='delta_distance_to_outgroup')
        .mark_line(opacity=0.3,
                   size=5,
                   point=False,
                   )
        )

    return (
        (delta_distance_points + delta_distance_lines)
        .add_selection(jitter_slider,
                       outgroup_selection,
                       )
        .facet(facet=alt.Facet('location_category:N',
                               title=None,
                               header=alt.Header(labelFontStyle='bold',
                                                 labelPadding=1,
                                                 labelFontSize=12,
                                                 ),
                               ),
               columns=3,
               spacing=2,
               )
        .configure_axis(grid=False)
        )

Make chart for whole genome:

In [None]:
delta_distance_all_chart = get_delta_distance_plot(delta_dist)

for f in early_seq_deltadist_charts:
    print(f"Saving to {f}")
    if os.path.splitext(f) == '.html':
        delta_distance_all_chart.save(f)
    else:
        altair_saver.save(delta_distance_all_chart, f)

delta_distance_all_chart

Make chart for region of interest:

In [None]:
delta_distance_region_chart = get_delta_distance_plot(delta_dist_region)

for f in early_seq_deltadist_region_charts:
    print(f"Saving to {f}")
    if os.path.splitext(f) == '.html':
        delta_distance_region_chart.save(f)
    else:
        altair_saver.save(delta_distance_region_chart, f)

delta_distance_region_chart

## Read the alignment of all early seqs

In [None]:
early_seq_alignment = list(Bio.SeqIO.parse(early_seq_alignment_file, 'fasta'))

aligned_length = len(early_seq_alignment[0])
print(f"Alignment length is {aligned_length}")

assert all(len(s) == aligned_length for s in early_seq_alignment)

## Deleted sequence set
Get information on substitutions in deleted sequences

In [None]:
deleted_diffs = pd.read_csv(deleted_diffs_file)

deleted_alignment = pd.read_csv(deleted_alignment_file)

# make sure we have information for the expected samples / aligners
expect_samples_aligners = {(s, a) for s, a in itertools.product(samples, aligners)}
assert expect_samples_aligners == set(deleted_alignment[['sample', 'aligner']]
                                      .itertuples(index=False, name=None))
assert expect_samples_aligners.issuperset(deleted_diffs[['sample', 'aligner']]
                                          .itertuples(index=False, name=None))

# make sure alignment of correct length
assert all(deleted_alignment['sequence'].map(len) == aligned_length)

Get the region of interest sequence and fraction sites called in it:

In [None]:
patient_groups = {sample: d['patient_group'] for sample, d in samples.items()}

deleted_alignment = (
    deleted_alignment
    .assign(sequence_region=lambda x: x['sequence'].str[start - 1: end],
            frac_called_region=lambda x: x['sequence_region']
                                         .map(lambda s: sum(nt in 'ACGT' for nt in s) /
                                                        (end - start + 1)
                                              ),
            patient_group=lambda x: x['sample'].map(patient_groups)
            )
    )

Get the differences from the reference

In [None]:
print(f"Reading diffs from ref from {deleted_diffs_file}, subsetting to sites {start} to {end}")

deleted_diffs = (
    pd.read_csv(deleted_diffs_file)
    .query('(site >= @start) and (site <= @end)')
    .assign(mutation=lambda x: x['reference'] + x['site'].astype(str) + x['consensus'],
            mutation_str=lambda x: x['mutation'] + '(' +
                                   x.apply(lambda r: ','.join(f"{nt}={r[nt]}" for nt in
                                                               ['A', 'C', 'G', 'T'] if r[nt]),
                                           axis=1) + ')'
            )
    .melt(id_vars=['sample', 'aligner', 'site', 'reference', 'consensus',
                   'mutation', 'mutation_str'],
          value_vars=comparators,
          var_name='outgroup',
          value_name='comparator_nt')
    )

deleted_diffs_all = (
    deleted_diffs
    [['sample', 'aligner', 'mutation', 'mutation_str']]
    .drop_duplicates()
    .groupby(['sample', 'aligner'], as_index=False)
    .aggregate(substitutions=pd.NamedAgg('mutation', ','.join),
               substitutions_str=pd.NamedAgg('mutation_str', ','.join),
               )
    .merge(deleted_alignment,
           on=['sample', 'aligner'],
           how='outer',
           validate='many_to_one')
    .assign(substitutions=lambda x: x['substitutions'].fillna(''),
            substitutions_str=lambda x: x['substitutions_str'].fillna(''),
            n_substitutions=lambda x: x['substitutions'].map(lambda subs: len([s for s in subs.split(',') if s]))
            )
    )

deleted_diffs_to_outgroup = (
    deleted_diffs
    .query('comparator_nt == consensus')
    .groupby(['sample', 'aligner', 'outgroup'], as_index=False)
    .aggregate(substitutions_to_outgroup=pd.NamedAgg('mutation', ','.join),
               substitutions_to_outgroup_str=pd.NamedAgg('mutation_str', ','.join)
               )
    .merge(pd.DataFrame(itertools.product(samples, aligners, comparators),
                        columns=['sample', 'aligner', 'outgroup']),
           how='outer',
           on=['sample', 'aligner', 'outgroup'],
           validate='one_to_many',
           )
    .assign(substitutions_to_outgroup=lambda x: x['substitutions_to_outgroup'].fillna(''),
            substitutions_to_outgroup_str=lambda x: x['substitutions_to_outgroup_str'].fillna(''))
    )

deleted_diffs = deleted_diffs_all.merge(deleted_diffs_to_outgroup)

deleted_deltas = pd.DataFrame()
for comparator in comparators:
    deleted_deltas = deleted_deltas.append(
        deleted_diffs
        [['sample', 'aligner', 'outgroup', 'substitutions']]
        .query('outgroup == @comparator')
        .assign(delta_distance_to_outgroup=lambda x: x['substitutions']
                                                     .apply(delta_distance_comparator,
                                                            args=(comparator,))
                )
        )
    
deleted_diffs = deleted_diffs.merge(deleted_deltas)

deleted_diffs = deleted_diffs.assign(nt_28144=lambda x: x['sequence'].str[28143])

Look at delta distance from reference.
This is just a scratch chart for inspection:

In [None]:
aligner_selection = alt.selection_single(
    name='read',
    fields=['aligner'],
    bind=alt.binding_select(options=aligners),
    init={'aligner': aligners[0]},
    )

deleted_delta_chart = (
    alt.Chart(deleted_diffs
              .drop(columns=['sequence', 'sequence_region'])
              .assign(delta_distance_to_outgroup=lambda x: x['delta_distance_to_outgroup'] + 0.05 * rand_jitter(len(x), seed=1))
              )
    .encode(x='frac_called_region',
            y='delta_distance_to_outgroup',
            color='patient_group',
            tooltip=['sample',
                     'n_substitutions',
                     'substitutions_str',
                     'substitutions_to_outgroup_str',
                     'frac_called_region',
                     'nt_28144',
                     ]
            )
    .mark_point(filled=True,
                size=50,
                opacity=0.5)
    .add_selection(outgroup_selection,
                   aligner_selection)
    .transform_filter(outgroup_selection)
    .transform_filter(aligner_selection)
    )

deleted_delta_chart

For the rest of the analysis, we filter to just the samples with sufficient coverage:

In [None]:
print(f"Just retaining samples with >={min_frac_coverage} coverage")

filtered_deleted_diffs = (
    deleted_diffs
    .query('frac_called_region >= @min_frac_coverage')
    .assign(sample=lambda x: pd.Categorical(x['sample'], samples, ordered=True))
    .sort_values('sample')
    )

filtered_deleted_diffs_display = (
    filtered_deleted_diffs
    [['sample', 'frac_called_region', 'patient_group', 'substitutions_str', 'n_substitutions']]
    .drop_duplicates()
    .assign(substitutions_str=lambda x: x['substitutions_str'].str.replace(',', ', '))
    .rename(columns={'frac_called_region': f"fraction sites called ({start}-{end})", 
                     'patient_group': 'patient group',
                     'n_substitutions': 'number of substitutions',
                     'substitutions_str': f"substitutions relative to {ref_genome_name}"
                     })
    )

pd.set_option('display.max_colwidth', 100)
display(filtered_deleted_diffs_display)

print(f"Saving table to {deleted_diffs_latex}")
filtered_deleted_diffs_display.to_latex(deleted_diffs_latex,
                                        float_format='%.4f')

## Plot jitters of distance to outgroup
First make data frame to plot that combines deleted data set and Wuhan early sequences (from January).
Here we get the data for the deleted data:

In [None]:
assert len(aligners) == 1, 'code below only works for one aligner, otherwise add aligner selection'

deleted_jitter_df = (
    filtered_deleted_diffs
    .query('patient_group == "early outpatient"')
    [['sample', 'patient_group', 'n_substitutions', 'substitutions',
      'substitutions_to_outgroup', 'frac_called_region', 'outgroup',
      'delta_distance_to_outgroup']]
    .rename(columns={'sample': 'strain',
                     'frac_called_region': 'frac_called',
                     'patient_group': 'group'},
            )
    .assign(category='deleted PRJNA612766',
            huanan_market=False,
            date='early Wuhan epidemic')
    )

deleted_jitter_df.head()

For the early sequence set, just get sequences in January:

In [None]:
def assign_date_group(date):
    if date <= pd.to_datetime('2020-01-15'):
        return 'before Jan 15'
    elif pd.to_datetime('2020-01-15') < date <= last_date:
        assert last_date == pd.to_datetime('2020-01-31')
        return 'Jan 15-31'
    else:
        return ValueError(f"{date=} out of range")
    

early_seqs_jitter_df = (
    delta_dist_region
    .query('date <= @last_date')
    .sort_values(['location_category', 'date'])
    .assign(date_group=lambda x: x['date'].map(assign_date_group),
            group=lambda x: x['location_category'] + ' ' + x['date_group'],
            category=lambda x: x['location_category'],
            date=lambda x: x['date'].astype(str),
            )
    [['strain', 'group', 'n_substitutions', 'substitutions',
      'substitutions_to_outgroup', 'frac_called', 'outgroup',
      'delta_distance_to_outgroup', 'category', 'huanan_market',
      'date']]
    )

In [None]:
jitter_df = pd.concat([deleted_jitter_df, early_seqs_jitter_df])

groups = jitter_df['group'].unique()

y_extent = jitter_df['delta_distance_to_outgroup'].max() - jitter_df['delta_distance_to_outgroup'].min() + 1
jitter_df = (
    jitter_df
    .assign(jitter_x=lambda x: 0.21 * rand_jitter(len(x), seed=18),
            jitter_y=lambda x: y_extent / 20 * rand_jitter(len(x), seed=1)
            )
    )

dist_jitter_chart = (
    alt.Chart(jitter_df)
    .encode(column=alt.Column('group',
                              title=None,
                              header=alt.Header(labelAngle=-90,
                                                labelOrient='bottom',
                                                labelAlign='right',
                                                labelPadding=3,
                                                ),
                              sort=groups,
                              ),
            x=alt.X('jitter_x:Q',
                    title=None,
                    axis=alt.Axis(values=[0], ticks=True, grid=False, labels=False),
                    scale=alt.Scale(domain=[-1, 1]),
                    ),
            y=alt.Y('y:Q',
                    title='relative mutations from outgroup',
                    scale=alt.Scale(nice=False),
                    axis=alt.Axis(tickMinStep=1),
                    ),
            color=alt.Color('category:N',
                            sort=groups,
                            scale=alt.Scale(range=['#56B4E9', '#E69F00',
                                                   '#009E73', '#F0E442']),
                            legend=alt.Legend(title='sequence set')
                            ),
            tooltip=['strain',
                     'date',
                     alt.Tooltip('n_substitutions',
                                 title='number substitutions'),
                     'substitutions',
                     alt.Tooltip('substitutions_to_outgroup',
                                 title='substitutions to outgroup'),
                     alt.Tooltip('frac_called',
                                 title='fraction sites called',
                                 format='.3f'),
                     alt.Tooltip('huanan_market',
                                 title='from seafood market')
                     ],
            )
    .mark_point(filled=True,
                opacity=0.45,
                size=40,
                )
    .add_selection(outgroup_selection,
                   jitter_slider)
    .transform_filter(outgroup_selection)
    .transform_calculate(
            y='datum.delta_distance_to_outgroup + datum.jitter_y * y_axis_jitter.amount'
            )
    .configure_axis(grid=False)
    .configure_view(stroke=None)
    .configure_facet(spacing=0)
    .properties(height=175,
                width=40)
    )

for f in deltadist_jitter_charts:
    print(f"Saving to {f}")
    if os.path.splitext(f) == '.html':
        dist_jitter_chart.save(f)
    else:
        altair_saver.save(dist_jitter_chart, f)
        altair_saver.save(dist_jitter_chart, '_temp.png')

dist_jitter_chart

## Write out alignments
For each alignment, we only get sequences that have unique substitutions relative to the reference, and collapse within these sets.

First get data frames with the relevant information:

In [None]:
seqs_all_d = {s.id: str(s.seq).upper()[ignore_muts_before - 1: ignore_muts_after]
              for s in Bio.SeqIO.parse(early_seq_alignment_file, 'fasta')}

all_alignment_df = (
    delta_dist
    .query('date <= @last_date')
    .assign(strain_date=lambda x: x['strain'] + ' (' + x['date'].astype(str) + ')')
    [['strain', 'strain_date', 'substitutions', 'frac_called']]
    .assign(sequence=lambda x: x['strain'].map(seqs_all_d))
    .assign(site_offset=ignore_muts_before)
    .query('frac_called >= @min_frac_called')
    .drop_duplicates()
    )

seqs_region_d = {s.id: str(s.seq).upper()[start - 1: end]
                 for s in Bio.SeqIO.parse(early_seq_alignment_file, 'fasta')}

for tup in filtered_deleted_diffs.itertuples():
    seqs_region_d['early_Wuhan_epidemic/' + tup.sample] = tup.sequence_region
    
region_alignment_df = (
    delta_dist_region
    .query('date <= @last_date')
    .assign(strain_date=lambda x: x['strain'] + ' (' + x['date'].astype(str) + ')')
    [['strain', 'strain_date', 'substitutions', 'frac_called']]
    .append(deleted_jitter_df
            [['strain', 'substitutions', 'frac_called']]
            .assign(strain=lambda x: 'early_Wuhan_epidemic/' + x['strain'].astype(str),
                    strain_date=lambda x: x['strain'])
            )
    .assign(sequence=lambda x: x['strain'].map(seqs_region_d))
    .assign(site_offset=start)
    .query('frac_called >= @min_frac_called')
    .drop_duplicates()
    )

Now remove any mutations in the list to ignore or that are rare:

In [None]:
muts_to_ignore = set(muts_to_ignore)
print(f"There are {len(muts_to_ignore)} mutations to ignore:\n{muts_to_ignore}")

alignment_df = {'all': all_alignment_df.copy(),
                'region': region_alignment_df.copy()}

def filter_seq(seq, subs_to_remove, site_offset):
    seq = list(seq)
    for s in subs_to_remove:
        wt, site, mut = s[0], int(s[1: -1]) - site_offset, s[-1]
        assert seq[site] == mut, f"{seq[site]=}, {wt=}, {s=}"
        seq[site] = wt
    return ''.join(seq)

for desc in ['all', 'region']:
    singletons = {s for s, n in 
                  collections.Counter([s for s in
                                       ','.join(alignment_df[desc]['substitutions']).split(',') if s]
                                      ).items()
                  if n <= collapse_rare_muts
                  }
    print(f"For the {desc} set, there are {len(singletons)} mutations to collapse "
          f"because they are found <= {collapse_rare_muts} times")
    alignment_df[desc] = (
        alignment_df[desc]
        .assign(substitutions_to_remove=lambda x: x['substitutions'].map(
                        lambda subs: {s for s in subs.split(',')
                                      if s in muts_to_ignore or s in singletons}
                        ),
                n_substitutions_removed=lambda x: x['substitutions_to_remove'].map(len),
                substitutions=lambda x: x.apply(lambda r: ','.join([s for s in r['substitutions'].split(',')
                                                                    if s not in r['substitutions_to_remove']]),
                                                axis=1),
                sequence=lambda x: x.apply(lambda r: filter_seq(r['sequence'],
                                                                   r['substitutions_to_remove'],
                                                                   r['site_offset']),
                                           axis=1),
                )
        .drop(columns='substitutions_to_remove')
        )

Now just get one representative sequence from each set:

In [None]:
for desc in ['all', 'region']:
    print(f"For {desc}, starting with {len(alignment_df[desc])} sequences")
    assert all(alignment_df[desc]['sequence'].map(len) ==
               alignment_df[desc]['sequence'].map(len).values[0])
    alignment_df[desc] = (
        alignment_df[desc]
        .sort_values(['n_substitutions_removed', 'frac_called'],
                     ascending=[True, False])
        .groupby('substitutions', as_index=False)
        .aggregate(nstrains=pd.NamedAgg('strain', 'count'),
                   representative_strain=pd.NamedAgg('strain', 'first'),
                   sequence=pd.NamedAgg('sequence', 'first'),
                   all_strains=pd.NamedAgg('strain', ', '.join),
                   all_strains_dates=pd.NamedAgg('strain_date', ', '.join)
                   )
        )
    print(f"After collapsing, have {len(alignment_df[desc])} sequences")

Now filter out rare variants:

In [None]:
print(f"Filtering out variants observed <= {filter_rare_variants} times")
for desc in ['all', 'region']:
    print(f"For {desc}, starting with {len(alignment_df[desc])} sequences")
    alignment_df[desc] = (
        alignment_df[desc]
        .query('nstrains > @filter_rare_variants')
        )
    print(f"After filtering, have {len(alignment_df[desc])} sequences")

Write the alignments:

In [None]:
for desc, alignment_file, alignment_csv in [
        ('all', alignment_all_fasta, alignment_all_csv),
        ('region', alignment_region_fasta, alignment_region_csv)
        ]:
    df = alignment_df[desc]
    print(f"Writing {len(df)} sequences to {alignment_file} and {alignment_csv}")
    df.drop(columns='sequence').to_csv(alignment_csv, index=False)
    a = [Bio.SeqRecord.SeqRecord(seq=Bio.Seq.Seq(tup.sequence),
                                 id=tup.representative_strain,
                                 name='',
                                 description='')
         for tup in df.itertuples()]
    Bio.SeqIO.write(a, alignment_file, 'fasta')