# IBD strain comparisons

In [1]:
%load_ext autoreload
%autoreload 2

In [15]:
import numpy
import pandas
import matplotlib.pyplot as plt

In [10]:
straingst_data = pandas.read_csv('ibd_timeseries/straingst/straingst.txt', sep='\t').set_index('sample')
straingst_data['strain'] = straingst_data['strain'].str.replace("Esch_coli_", "")

collapsed = {
    "NCTC9087": "NCTC122",
    "2014C-3338": "MSHS_133",
    "JJ1886": "MVAST0167",
    "ST540_GCF_000597845.1": "AR_0061"
}

timepoint = {
    s: f"TP{i+1}"
    for i, s in enumerate(straingst_data.index.unique())
}

timepoint_to_sample = {
    v: k for k, v in timepoint.items()
}

straingst_data['timepoint'] = straingst_data.index.map(timepoint)

straingst_data['collapsed_strain'] = straingst_data['strain'].map(collapsed)
ix = straingst_data['collapsed_strain'].isna()
straingst_data.loc[ix, 'collapsed_strain'] = straingst_data[ix]['strain'].copy()
straingst_data['collapsed'] = straingst_data['collapsed_strain'] != straingst_data['strain']
straingst_data = straingst_data.reset_index().set_index(['sample', 'collapsed_strain'])

# Six timepoints where MIDAS didn't detect anything
no_strain_midas = ["LS_1_26_2013", "LS_2_8_2013", "LS_2_25_2013", "LS_3_24_2013", "LS_4_7_2013", "LS_4_29_2013"]
straingst_data['midas_undetected'] = False
straingst_data.loc[no_strain_midas, 'midas_undetected'] = True

# Another sample with a strain close to 118UI that went undetected
straingst_data.loc[("LS_12_7_2014", "118UI"), "midas_undetected"] = True

# Samples with secondary strain close to JJ1886
jj1886_samples = list(map(timepoint_to_sample.get, ["TP11", "TP13", "TP16"]))
straingst_data.loc[(jj1886_samples, collapsed["JJ1886"]), "midas_undetected"] = True

# Secondary strains close to AR_0061
ar_0061_samples = list(map(timepoint_to_sample.get, ["TP23", "TP24"]))
straingst_data.loc[(ar_0061_samples, "AR_0061"), "midas_undetected"] = True

# StrainGR summaries
straingr_data = pandas.read_csv('ibd_timeseries/straingr/summary.txt', sep='\t', index_col=0)
samples_cov_gt05 = straingr_data[(straingr_data['length'] > 3e6) & (straingr_data['coverage'] > 0.5)].copy()
samples_cov_gt05['ref'] = samples_cov_gt05['ref'].str.replace("Esch_coli_", "")
samples_cov_gt05 = samples_cov_gt05.reset_index().set_index(['sample', 'ref'])

compare_data = pandas.read_csv('ibd_timeseries/straingr/compare.summary.chrom.txt', sep='\t')
compare_data['ref'] = compare_data['ref'].str.replace("Esch_coli_", "")
compare_data = compare_data.set_index(['sample1', 'sample2', 'ref'])

def is_straingst_present(ix):
    sample1, sample2, ref = ix
    
    return (sample1, ref) in straingst_data.index and (sample2, ref) in straingst_data.index

compare_data['straingst_present'] = compare_data.index.map(is_straingst_present)

def straingst_differs(ix):
    sample1, sample2, ref = ix
    
    if not is_straingst_present(ix):
        return False
    
    return bool(straingst_data.loc[(ix[0], ref), 'collapsed']) ^ bool(straingst_data.loc[(ix[1], ref), 'collapsed'])

compare_data['straingst_differs'] = compare_data.index.map(straingst_differs)

def midas_undetected(ix):
    sample1, sample2, ref = ix
    
    if not is_straingst_present(ix):
        return False
    
    return bool(straingst_data.loc[(ix[0], ref), 'midas_undetected']) or bool(straingst_data.loc[(ix[1], ref), 'midas_undetected'])

compare_data['midas_undetected'] = compare_data.index.map(midas_undetected)

def enough_cov(ix):
    sample1, sample2, ref = ix
    
    if (sample1, ref) in samples_cov_gt05.index and (sample2, ref) in samples_cov_gt05.index:
        return True
    
    return False

compare_data['enough_cov'] = compare_data.index.map(enough_cov)

def orig_ref(ix):
    sample1, sample2, ref = ix
    
    if not is_straingst_present(ix):
        return numpy.nan
    
    if straingst_differs(ix):
        return ref
    
    return straingst_data.loc[(sample1, ref), 'strain']
    
compare_data['orig_ref'] = compare_data.index.map(orig_ref)

compare_data['tp1'] = compare_data.index.map(lambda e: timepoint[e[0]])
compare_data['tp2'] = compare_data.index.map(lambda e: timepoint[e[1]])

only_straingst_present = compare_data[compare_data['straingst_present']]
compare_min_05_callable = only_straingst_present[(only_straingst_present['commonPct'] > 0.5) & only_straingst_present['enough_cov']]

In [12]:
import altair as alt

strain_colors = {
    "2014C-3338": "#519b84",
    "MSHS_133": "#66C2A5",
    "PA45B": "#99821c",
    "MVAST0167": "#ccad25",
    "JJ1886": "#FFD92F",
    "118UI": "#6A8A36",
    "EC-1639": "#A6D854",
    "Santai": "#B75DE2",
    "2011C-3911": "#CA714E",
    "D5": "#FC8D62",
    "NCTC9087": "#485168",
    "NCTC122": "#5A6682",
    "ST540_GCF_000597845.1": "#7a849b",
    "AR_0061": "#9ca3b4",
    "ST1": "#99821c",
    "ST2": "#B75DE2",
    "ST3": "#6A8A36",
    "ST4": "#FC8D62",
    "ST5": "#ccad25",
    "ST6": "#7180A2",
    "ST7": "#CA714E",
    "n.d.": "#bbbbbb"
}

color_domain = list(sorted(compare_min_05_callable['orig_ref'].unique()))
color_range = list(map(strain_colors.get, color_domain))

c1_open = alt.Chart(compare_min_05_callable.reset_index()).mark_point().encode(
    alt.X('gapJaccardSim', scale=alt.Scale(zero=False), title="Gap Similarity"),
    alt.Y('singleAgreePct', scale=alt.Scale(zero=False), title="Pairwise ACNI [%]"),
    color=alt.Color('orig_ref', type='nominal', title="Reference", 
                    scale=alt.Scale(domain=color_domain, range=color_range)),
    size=alt.Size('commonPct', type="quantitative", title="Common Callable [%]",
                  scale=alt.Scale(range=(25, 500))),
    shape=alt.Shape('straingst_differs', type="nominal",
                    scale=alt.Scale(domain=[False, True], range=['circle', 'diamond']),
                    legend=alt.Legend(title="StrainGST calls", labelExpr="datum.value ? 'Different (collapsed)' : 'Identical'")),
    tooltip=['tp1', 'tp2', 'ref', 'orig_ref', 'commonPct', 'singleAgreePct', 'gapJaccardSim', 'sharedAllelesPct']
)

compare_midas_undetected = compare_min_05_callable[compare_min_05_callable['midas_undetected']].reset_index()
c1_filled = alt.Chart(compare_midas_undetected).mark_point(filled=True).encode(
    alt.X('gapJaccardSim', scale=alt.Scale(zero=False), title="Gap Similarity"),
    alt.Y('singleAgreePct', scale=alt.Scale(zero=False), title="Pairwise ACNI [%]"),
    color=alt.Color('orig_ref', type='nominal', title="Reference", 
                    scale=alt.Scale(domain=color_domain, range=color_range)),
    size=alt.Size('commonPct', type="quantitative", title="Common Callable [%]",
                  scale=alt.Scale(range=(25, 500))),
    shape=alt.Shape('straingst_differs', type="nominal",
                    scale=alt.Scale(domain=[False, True], range=['circle', 'diamond']),
                    legend=alt.Legend(title="StrainGST calls", labelExpr="datum.value ? 'Different (collapsed)' : 'Identical'")),
    tooltip=['tp1', 'tp2', 'ref', 'orig_ref', 'commonPct', 'singleAgreePct', 'gapJaccardSim', 'sharedAllelesPct']
)

c1 = (c1_open + c1_filled).properties(width=300, height=300)

ix = (compare_min_05_callable['gapJaccardSim'] > 0.97) & (compare_min_05_callable['singleAgreePct'] > 99.9)
zoomed_in = compare_min_05_callable[ix].reset_index()

c2_open = alt.Chart(zoomed_in).mark_point().encode(
    alt.X('gapJaccardSim', scale=alt.Scale(zero=False, domain=[0.97, 1.0]), title="Gap Similarity"),
    alt.Y('singleAgreePct', scale=alt.Scale(zero=False, domain=[99.9, 100]), title="Pairwise ACNI [%]"),
    color=alt.Color('orig_ref', type='nominal', title="Reference", 
                    scale=alt.Scale(domain=color_domain, range=color_range), legend=None),
    size=alt.Size('commonPct', type="quantitative", title="Common Callable [%]",
                  scale=alt.Scale(range=(25, 500))),
    shape=alt.Shape('straingst_differs', type="nominal",
                    scale=alt.Scale(domain=[False, True], range=['circle', 'diamond'])),
    tooltip=['tp1', 'tp2', 'ref', 'orig_ref', 'commonPct', 'singleAgreePct', 'gapJaccardSim', 'sharedAllelesPct']
).properties(
    width=300,
    height=125
)

zoomed_filled = zoomed_in[zoomed_in['midas_undetected']]
c2_filled = alt.Chart(zoomed_filled).mark_point(filled=True).encode(
    alt.X('gapJaccardSim', scale=alt.Scale(zero=False, domain=[0.97, 1.0]), title="Gap Similarity"),
    alt.Y('singleAgreePct', scale=alt.Scale(zero=False, domain=[99.9, 100]), title="Pairwise ACNI [%]"),
    color=alt.Color('orig_ref', type='nominal', title="Reference", 
                    scale=alt.Scale(domain=color_domain, range=color_range), legend=None),
    size=alt.Size('commonPct', type="quantitative", title="Common Callable [%]",
                  scale=alt.Scale(range=(25, 500))),
    shape=alt.Shape('straingst_differs', type="nominal",
                    scale=alt.Scale(domain=[False, True], range=['circle', 'diamond'])),
    tooltip=['tp1', 'tp2', 'ref', 'orig_ref', 'commonPct', 'singleAgreePct', 'gapJaccardSim', 'sharedAllelesPct']
).properties(
    width=300,
    height=125
)

c2 = c2_open + c2_filled


c1 | c2

In [5]:
import json
import skbio
from collections import defaultdict
from IPython.display import display, HTML
from altair import expr, datum

with open('ibd_timeseries/ref.meta.json') as f:
    ref_meta = json.load(f)

chromosomes = {}
with open("ibd_timeseries/ref.fa") as f:
    for r in skbio.io.read(f, "fasta"):
        if len(r) < 3e6:
            continue
            
        contig = r.metadata['id']
        ref = ref_meta['contig_to_strain'][contig].replace("Esch_coli_", "")
        chromosomes[ref] = contig

for collapsed_ref in straingst_data.index.unique(level=1):
    samples = list(sorted(straingst_data.loc[(slice(None), collapsed_ref), :].index.unique(level=0),
                          key=lambda s: int(timepoint[s][2:])))
    tps = list(map(timepoint.get, samples))
    
    incl_samples = []
    sample_gap_dfs = []
    straingr_summaries = []
    for sample in samples:
        if (sample, collapsed_ref) not in samples_cov_gt05.index:
            print(sample, timepoint[sample], collapsed_ref, "not enough cov")
            continue
            
        incl_samples.append(sample)
        gaps = pandas.read_csv(f'ibd_timeseries/straingr/{sample}.gaps.bed', sep='\t', names=['chromosome', 'start', 'end'], index_col=[0, 1]) 
        gaps['collapsed'] = bool(straingst_data.loc[(sample, collapsed_ref), 'collapsed'])
        sample_gap_dfs.append(gaps)
        
        summary = pandas.read_csv(f'ibd_timeseries/straingr/{sample}.summary.tsv', sep='\t')
        summary['ref'] = summary['ref'].str.replace("Esch_coli_", "")
        summary['collapsed'] = bool(straingst_data.loc[(sample, collapsed_ref), 'collapsed'])
        straingr_summaries.append(summary.set_index(['ref', 'name']))
        
    if not sample_gap_dfs:
        continue
        
    display(HTML(f"<h3>Strain: {collapsed_ref}</h3>"))
    display(HTML("<p>{}</p>".format(", ".join(tps))))
        
    gap_df = pandas.concat(sample_gap_dfs, keys=incl_samples)
    gap_df = gap_df.reset_index(level=[1, 2])
    gap_df = gap_df[gap_df['chromosome'] == chromosomes[collapsed_ref]]
    gap_df['size'] = gap_df['end'] - gap_df['start']
    gap_df['tp'] = gap_df.index.map(timepoint).str[2:].astype(int)
    gap_df['label'] = gap_df.index.map(timepoint)
    gap_df.loc[gap_df['collapsed'], 'label'] += "*"
    
    straingr_df = pandas.concat(straingr_summaries, keys=incl_samples).drop(index='TOTAL', level=1)
    straingr_df = straingr_df.loc[(slice(None), collapsed_ref, chromosomes[collapsed_ref]), :].copy().reset_index(level=[1, 2])
    straingr_df['tp'] = straingr_df.index.map(timepoint).str[2:].astype(int)
    straingr_df['label'] = straingr_df.index.map(timepoint)
    straingr_df.loc[straingr_df['collapsed'], 'label'] += "*"
    
    gap_chart = alt.Chart(gap_df.reset_index()).mark_rect().encode(
        x=alt.X('start:Q', axis=alt.Axis(title="Chromosome position", format="~s", grid=False)),
        x2='end:Q',
        y=alt.Y('tp:O', axis=alt.Axis(title="Timepoint", bandPosition=0, grid=True)), 
        tooltip=['index', 'tp', 'size']
    ).properties(width=600)
    
    cov_bar = alt.Chart(straingr_df.reset_index()).mark_bar().encode(
        x=alt.X("coverage:Q", axis=alt.Axis(title="Coverage", labelExpr="datum.value + 'x'")),
        y=alt.Y("tp:O", axis=alt.Axis(title=None, bandPosition=0, grid=True, labels=False, labelExpr="datum.label")),
    )
    
    cov_text = cov_bar.mark_text(
        align='left',
        baseline='middle',
        dx=3,
    ).encode(
        text='coverage'
    )
    
    cov_chart = (cov_bar + cov_text).properties(width=100)
    
    concat_chart = alt.hconcat(gap_chart, cov_chart).resolve_scale(
        y='shared'
    ).configure_mark(
        color=strain_colors[collapsed_ref]
    )
    
    display(concat_chart)
    

LS_3_24_2013 TP8 EC-1639 not enough cov
LS_4_29_2013 TP10 EC-1639 not enough cov


LS_1_26_2013 TP5 118UI not enough cov
LS_3_24_2013 TP8 118UI not enough cov
LS_4_29_2013 TP10 118UI not enough cov


LS_1_26_2013 TP5 NCTC122 not enough cov


## Strain 118UI comparison

In [60]:
samples_with_118ui = straingst_data.loc[(slice(None), '118UI'), :].index.unique(level=0)
strain_118ui = compare_min_05_callable.loc[(samples_with_118ui, samples_with_118ui, '118UI'), :].copy().reset_index().set_index(['tp1', 'tp2'])
samples_pass = set(strain_118ui[strain_118ui['commonPct'] > 0.5]['sample1'].unique())
samples_pass.update(strain_118ui[strain_118ui['commonPct'] > 0.5]['sample2'].unique())
tps = list(sorted(map(timepoint.get, samples_pass), key=lambda e: int(e[2:])))

scale = alt.Scale(domain=tps)

base = alt.Chart(strain_118ui.reset_index()).encode(
    x=alt.X('tp1:O', scale=scale, axis=alt.Axis(title="Timepoint 1", labelExpr='replace(datum.label, "TP", "")'),
            sort=alt.EncodingSortField(field='tp1:O', order='ascending')),
    y=alt.Y('tp2:O', scale=scale, axis=alt.Axis(title="Timepoint 2", labelExpr='replace(datum.label, "TP", "")'),
            sort=alt.EncodingSortField(field='tp2:O', order='ascending')),
    color=alt.Color('singleAgreePct', scale=alt.Scale(scheme="Reds"),
                    legend=alt.Legend(title="Pairwise ACNI [%]")),
    tooltip=['singleAgreePct']
)

heatmap_acni = base.mark_rect()
heatmap_acni_sized = base.mark_square(opacity=1.0).encode(size=alt.Size('commonPct:Q', title="Common Callable [%]"))

display(heatmap_acni | heatmap_acni_sized)

strain_118ui_upper = strain_118ui.copy().reset_index()
tmp = strain_118ui_upper['tp1'].copy()

strain_118ui_upper['tp1'] = strain_118ui_upper['tp2']
strain_118ui_upper['tp2'] = tmp

strain_118ui_upper = strain_118ui_upper.sort_values('tp1')

heatmap_gaps = alt.Chart(strain_118ui_upper).mark_square(opacity=1.0).encode(
    x=alt.X('tp1:O', scale=scale, axis=alt.Axis(title="Timepoint 1", labelExpr='replace(datum.label, "TP", "")'),
            sort=alt.EncodingSortField(field='tp1:O', order='ascending')),
    y=alt.Y('tp2:O', scale=scale, axis=alt.Axis(title="Timepoint 2", labelExpr='replace(datum.label, "TP", "")'),
            sort=alt.EncodingSortField(field='tp2:O', order='ascending')),
    color=alt.Color('gapJaccardSim', scale=alt.Scale(domain=(0.91, 1.0), scheme="Blues"),
                    legend=alt.Legend(title="Gap Similarity")),
    size='commonPct:Q',
)

# chart = heatmap_acni + heatmap_gaps
# chart.resolve_scale(color='independent')