# Annotate and filter substitutions in early sequences

Import Python modules:

In [None]:
import collections

import altair as alt

import pandas as pd

from ruamel.yaml import YAML

Get variables from `snakemake`:

In [None]:
comparator_map_file = snakemake.input.comparator_map
subs_csv = snakemake.input.subs_csv
max_date = pd.to_datetime(snakemake.params.max_date)
min_coverage = snakemake.params.min_coverage
max_subs = snakemake.params.max_subs
max_ambiguous = snakemake.params.max_ambiguous
filter_runs = snakemake.params.filter_runs
who_china_report_cases_yaml = snakemake.input.who_china_report_cases_yaml
who_china_report_last_date = pd.to_datetime(snakemake.params.who_china_report_last_date)
early_seqs_to_exclude_yaml = snakemake.input.early_seqs_to_exclude_yaml
comparators = snakemake.params.comparators
output_csv = snakemake.output.csv

## Read data
Read comparator map:

In [None]:
comparator_map = pd.read_csv(comparator_map_file)

assert set(comparators).issubset(comparator_map.columns)

comparator_map

Get set of mutations for each comparator getting **only** mismatch mutations to valid nucleotides:

In [None]:
comparator_muts = {}
valid_nts = ['A', 'C', 'G', 'T']
for comparator in comparators:
    comparator_muts[comparator] = set(
        comparator_map
        .assign(mutated=lambda x: x['reference'] != x[comparator],
                mutation=lambda x: x['reference'] + x['site'].astype(str) + x[comparator])
        .query('mutated')
        .query('reference in @valid_nts')
        .query(f"{comparator} in @valid_nts")
        ['mutation']
        )
    print(f"{comparator} has {len(comparator_muts[comparator])} mutations")

Read substitutions and annotate which ones are in each comparator:

In [None]:
subs = (
    pd.read_csv(subs_csv,  na_filter=False)
    .assign(subs_list=lambda x: x['substitutions'].str.split(','),
            date=lambda x: pd.to_datetime(x['date']),
            frac_coverage=lambda x: 1 - x['n_gapped_to_ref'] / (
                                        x['n_ident_to_ref'] +
                                        x['n_ambiguous_to_ref'] +
                                        x['n_subs_to_ref'] +
                                        x['n_gapped_to_ref']
                                        )
            )
    .sort_values('date')
    .reset_index(drop=True)
    )

for comparator, muts in comparator_muts.items():
    subs[f"{comparator}_substitutions"] = (subs
                                           ['subs_list']
                                           .map(lambda x: [xi for xi in x if xi in muts])
                                           )
    subs[f"{comparator}_n_substitutions"] = (subs
                                             [f"{comparator}_substitutions"]
                                             .map(len)
                                             )
    subs[f"{comparator}_substitutions"] = (subs
                                           [f"{comparator}_substitutions"]
                                           .map(lambda x: ','.join(x))
                                           )
subs = subs.drop(columns='subs_list')

subs.head()

## Filter sequences
Filter sequences with insufficient coverage:

In [None]:
print(f"Filtering sequences with <{min_coverage} alignment coverage")

subs['insufficient_coverage'] = subs['frac_coverage'] < min_coverage

display(subs
        .groupby('insufficient_coverage')
        .aggregate(n_seqs=pd.NamedAgg('strain', 'count'))
        )

chart_frac_coverage = (
    alt.Chart(subs)
    .encode(x='frac_coverage',
            y='frac_called_in_region_of_interest',
            color='insufficient_coverage',
            tooltip=['strain',
                     'gisaid_epi_isl',
                     'date',
                     'n_subs_to_ref',
                     'n_ident_to_ref',
                     'substitutions'],
            )
    .mark_point(filled=True,
                opacity=0.2,
                )
    )

subs = (subs
        .query('not insufficient_coverage')
        .drop(columns='insufficient_coverage')
        )

chart_frac_coverage

Filter sequences with excess substitutions:

In [None]:
print(f"Filtering sequences with >{max_subs} substitutions to reference")

subs['excess_subs'] = subs['n_subs_to_ref'] >= max_subs

display(subs
        .groupby('excess_subs')
        .aggregate(n_seqs=pd.NamedAgg('strain', 'count'))
        )

chart_subs = (
    alt.Chart(subs)
    .encode(x='date',
            y='n_subs_to_ref',
            color='excess_subs',
            tooltip=['strain',
                     'gisaid_epi_isl',
                     'date',
                     'n_subs_to_ref',
                     'n_ident_to_ref',
                     'substitutions'],
            )
    .mark_point(filled=True,
                opacity=0.2,
                )
    )

subs = (subs
        .query('not excess_subs')
        .drop(columns='excess_subs')
        )

chart_subs

Filter by number of ambiguous nucleotides:

In [None]:
print(f"Filtering sequences with >{max_ambiguous} ambiguous nucleotides in alignment to reference")

subs['excess_ambiguous'] = subs['n_ambiguous_to_ref'] >= max_ambiguous

display(subs
        .groupby('excess_ambiguous')
        .aggregate(n_seqs=pd.NamedAgg('strain', 'count'))
        )


chart_ambiguous = (
    alt.Chart(subs)
    .encode(x='date',
            y='n_ambiguous_to_ref',
            color='excess_ambiguous',
            tooltip=['strain',
                     'gisaid_epi_isl',
                     'date',
                     'n_subs_to_ref',
                     'n_ident_to_ref',
                     'n_ambiguous_to_ref',
                     'substitutions'],
            )
    .mark_point(filled=True,
                opacity=0.2,
                )
    )

subs = (subs
        .query('not excess_ambiguous')
        .drop(columns='excess_ambiguous')
        )

chart_ambiguous

Filter sequences with too early of date:

In [None]:
print(f"Filtering sequences collected after {max_date}")
subs['too_late_date'] = subs['date'] > max_date

display(subs
        .groupby('too_late_date')
        .aggregate(n_seqs=pd.NamedAgg('strain', 'count'))
        )

subs = (subs
        .query('not too_late_date')
        .drop(columns='too_late_date')
        )

Filter sequences with excessive mutations in a short run:

In [None]:
run_n_muts = filter_runs['n_muts']
run_span = filter_runs['span']

assert run_span > 1
assert run_n_muts > 1

print(f"Filtering sequences with >= {run_n_muts} mutations within {run_span} nucleotides")

def pass_run_filter(sub_str):
    sites = sorted([int(s[1: -1]) for s in sub_str.split(',') if s])
    while len(sites) >= run_n_muts:
        span = sites[run_n_muts - 1] - sites[0] 
        assert span > 0
        if span <= run_span:
            return False
        sites = sites[1: ]
    return True

subs = subs.assign(pass_run_filter=lambda x: x['substitutions'].map(pass_run_filter))

display(subs
        .groupby('pass_run_filter')
        .aggregate(n_seqs=pd.NamedAgg('strain', 'count'))
        )

subs = (subs
        .query('pass_run_filter')
        .drop(columns='pass_run_filter')
        )

## Check sequences against WHO-China joint report
Here we make sure that:
 1. We find a sequence for every sample listed in WHO-China report.
 2. When there are clearly multiple sequences from the same patient (based mostly on WHO-China report descriptions) collapse to just one to keep.

In [None]:
print(f"Examining early sequences isolated on or before {who_china_report_last_date}\n")

with open(who_china_report_cases_yaml) as f:
    who_china_report_cases = YAML(typ='safe').load(f)
    
subs_early = subs.query('date <= @who_china_report_last_date')
assert subs_early['strain'].nunique() == len(subs_early)

all_to_keep = []
all_to_collapse = []

for sample, sample_d in who_china_report_cases['pre_2020_seqs'].items():
    date = pd.to_datetime(sample_d['date'])
    
    # get entries in subs matching the strains
    sample_mask = subs['strain'].str.contains(sample_d['strain'][0], regex=False)
    for strain in sample_d['strain'][1:]:
        sample_mask = sample_mask | subs['strain'].str.contains(strain, regex=False)
    strains_to_keep = []
    strains_to_collapse = []
    if 'collapse_to' in sample_d:
        for strain in subs_early.loc[sample_mask].strain.tolist():
            if any(keep in strain for keep in sample_d['collapse_to']):
                strains_to_keep.append(strain)
            else:
                strains_to_collapse.append(strain)
    else:
        strains_to_keep = subs_early.loc[sample_mask].strain.tolist()
    if not strains_to_keep:
        raise ValueError(f"no strains matching {sample}:\n{sample_d}")
    print(f"For {sample}\n  Retaining:\n    " + '\n    '.join(strains_to_keep))
    if strains_to_collapse:
        print(f"  Collapsing:\n    " + '\n    '.join(strains_to_collapse))
    all_to_keep += strains_to_keep
    all_to_collapse += strains_to_collapse
    
assert len(all_to_keep) == len(set(all_to_keep)), collections.Counter(all_to_keep)
assert len(all_to_collapse) == len(set(all_to_collapse))

print(f"\nKeeping the following {len(all_to_keep)} early sequences in WHO-China report:")
display(subs_early.query('strain in @all_to_keep').reset_index(drop=True))

print(f"\nRemoving the following {len(all_to_collapse)} early sequences in WHO-China report "
      'as duplicates of other samples:')
display(subs_early.query('strain in @all_to_collapse').reset_index(drop=True))

not_in_report = subs_early.query('strain not in @all_to_collapse').query('strain not in @all_to_keep')
print(f"\nAlso keeping the following {len(not_in_report)} early sequences not in the report:")
display(not_in_report.reset_index(drop=True))

subs = subs.query('strain not in @all_to_collapse')

## Remove sequences manually specified for exclusion
Sequences identified as problematic in some way by manual inspection:

In [None]:
with open(early_seqs_to_exclude_yaml) as f:
    early_seqs_to_exclude = YAML(typ='safe').load(f)

gisaid_to_exclude = list(early_seqs_to_exclude['gisaid'])
print(f"Removing {len(gisaid_to_exclude)} GISAID IDs specified for manual removal")

assert len(gisaid_to_exclude) == len(set(gisaid_to_exclude))

assert set(gisaid_to_exclude).issubset(subs['gisaid_epi_isl'])

subs = subs.query('gisaid_epi_isl not in @gisaid_to_exclude')

## Check for duplicated strains

In [None]:
dup_strains = (
    subs
    .assign(n=lambda x: x.groupby('strain')['gisaid_epi_isl'].transform('count'))
    .query('n > 1')
    )

if len(dup_strains):
    raise ValueError('Duplicated strains:\n' + str(dup_strains))

## Write annotated and filtered substitutions

In [None]:
print(f"Writing filtered and annotated substitutions to {output_csv}")

subs.to_csv(output_csv, index=False)