In [None]:
import json

import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
from matplotlib.patches import Rectangle

# Load configuration

In [None]:
# TODO: retrieve these bits of information dynamically
amplicon_mutations = {
    '72_UK': '21765-21770Δ,21991-21993Δ'.split(','),
    '78_UK': '23604A,23709T'.split(','),
    '92_UK': '27972T,28048T,28111G'.split(','),
    '93_UK': '28111G,28280-28280->CTA'.split(','),
    '76_SA': '23012A,23063T'.split(','),
    '77_EU': '23403G'.split(','),
}

amplicon_mutation_positions = {
    '72_UK': [21765, 21991],
    '78_UK': [23604, 23709],
    '92_UK': [27972, 28048, 28111],
    '93_UK': [28111, 28280],
    '76_SA': [23012, 23063],
    '77_EU': [23403],
}

amplicon_ranges = {
    '72_UK': (21760, 22000),
    '78_UK': (23600, 23720),
    '92_UK': (27970, 28120),
    '93_UK': (28110, 28290),
    '76_SA': (23010, 23070),
    '77_EU': (23400, 23410),
}

# Load data

In [None]:
# adapt 'data.json' to your needs
with open('data.json') as fd:
    data = json.load(fd)

In [None]:
# convert mutations in json format to dataframes
tmp_mutations = []
for sample, sample_data in data.items():
    for amplicon, amplicon_data in sample_data.items():
        print(sample, amplicon, amplicon_data)

        max_entry = max(amplicon_data['sites'])  # corresponds to all mutations

        if max_entry not in amplicon_data['muts']:
            print('Skipping incomplete field')
            continue

        tmp_mutations.append(
            {
                'sample': sample,
                'mutation': ','.join(amplicon_mutations[amplicon]),
                'position': ','.join(
                    str(e) for e in amplicon_mutation_positions[amplicon]
                )
                if len(amplicon_mutation_positions[amplicon]) > 1
                else amplicon_mutation_positions[amplicon][0],
                'frequency': amplicon_data['muts'][max_entry]
                / amplicon_data['sites'][max_entry],
                'coverage': amplicon_data['sites'][max_entry],
            }
        )

for amplicon in amplicon_ranges.keys():
    mutation_list = amplicon_mutations[amplicon]
    position_list = amplicon_mutation_positions[amplicon]
    assert len(mutation_list) == len(position_list)

    for mut, pos in zip(mutation_list, position_list):
        for sample in data.keys():
            tmp_mutations.append(
                {
                    'sample': sample,
                    'mutation': mut,
                    'position': pos,
                    'frequency': 0,
                    'coverage': 0,
                }
            )

df_mutations = pd.DataFrame(tmp_mutations)
df_mutations.head()

In [None]:
# load amplicon information
tmp_amplicons = []
for amplicon, (start, end) in amplicon_ranges.items():
    tmp_amplicons.append({'name': amplicon, 'start': start, 'end': end})

df_amplicons = pd.DataFrame(tmp_amplicons)
df_amplicons.head()

In [None]:
# dummy mutation data
# df_mutations = pd.DataFrame({
#     'sample': ['sample_A', 'sample_A', 'sample_A', 'sample_B'],
#     'mutation': ['T12G,C16A', 'T12G', 'C16A', 'T61G'],
#     'position': ['12,16', 12, 16, 61],
#     'frequency': [.1, .7, .25, .9],
#     'coverage': [200, 234, 250, 30]
# })
# df_mutations.head()

In [None]:
# dummy amplicon data
# df_amplicons = pd.DataFrame({
#     'name': ['first_amp', 'second_amp'],
#     'start': [10, 50],
#     'end': [20, 80]
# })
# df_amplicons.head()

# Visualization

## Helper functions

In [None]:
def plot_filled_rectangle(xy, width, height, color, fill_fraction, ax):
    p = Rectangle(
        xy, width, height, edgecolor=color, linewidth=3, fill=False, zorder=100
    )
    ax.add_patch(p)

    assert 0 <= fill_fraction <= 1
    p = Rectangle(
        xy, width * fill_fraction, height, facecolor=color, edgecolor='none', zorder=10
    )
    ax.add_patch(p)

In [None]:
def plot_mutation(mutation, position, frequency, coverage, cov_map, ax):
    ax.axvline(position, 0, 0.5, color='black')

    trans = (
        transforms.blended_transform_factory(ax.transData, ax.transAxes)
        + ax.transData.inverted()
    )

    width, height = trans.transform((2, 0.2))
    xy = trans.transform((position - width / 2, 0.5))
    plot_filled_rectangle(
        xy, width, height, color=cov_map(coverage), fill_fraction=frequency, ax=ax
    )

In [None]:
def plot_mutation_list(mutation, position, frequency, coverage, cov_map, ax):
    pos_list = [int(s) for s in position.split(',')]
    start, end = min(pos_list), max(pos_list)

    trans = (
        transforms.blended_transform_factory(ax.transData, ax.transAxes)
        + ax.transData.inverted()
    )

    width, height = trans.transform((end - start, 0.2))
    xy = trans.transform((start - 1, 0.8))
    plot_filled_rectangle(
        xy, width + 2, height, color=cov_map(coverage), fill_fraction=frequency, ax=ax
    )

## Main plot

In [None]:
sample_num = df_mutations['sample'].nunique()
amplicon_num = df_amplicons.shape[0]

fig, ax_grid = plt.subplots(
    nrows=sample_num, ncols=amplicon_num, figsize=(4 * amplicon_num, 3 * sample_num)
)

if len(ax_grid.shape) == 1:
    # only one sample given
    # TODO: handle only single amplicon
    ax_grid = np.array([ax_grid])

cmap = mpl.cm.viridis
norm = mpl.colors.Normalize(vmin=0, vmax=df_mutations['coverage'].max())
cov_map = lambda x: cmap(norm(x))

for i, ((sample, sample_group), ax_list) in enumerate(
    zip(df_mutations.groupby('sample'), ax_grid)
):
    for j, (amplicon, ax) in enumerate(zip(df_amplicons.itertuples(), ax_list)):
        for row in sample_group.itertuples():
            if isinstance(row.position, int):
                plot_mutation(
                    row.mutation, row.position, row.frequency, row.coverage, cov_map, ax
                )
            else:
                min_ = min([int(e) for e in row.position.split(',')])
                max_ = max([int(e) for e in row.position.split(',')])
                if min_ < amplicon.start or max_ > amplicon.end:
                    # mutations are not completely contained in current amplicon, skipping
                    continue

                plot_mutation_list(
                    row.mutation, row.position, row.frequency, row.coverage, cov_map, ax
                )

        ax.set_xlim(amplicon.start, amplicon.end)

        ax.tick_params(axis='y', which='both', left=False, labelleft=False)
        sns.despine(ax=ax, top=True, right=True, left=True, bottom=False)

        if i == 0:
            # in first amplicon
            ax.set_title(amplicon.name)
        elif i == df_mutations['sample'].nunique() - 1:
            ax.set_xlabel('bp')
        if j == 0:
            # in first sample
            ax.set_ylabel(sample)


ax_cb = fig.add_axes([1, 0.4, 0.03, 0.2])
cb = mpl.colorbar.ColorbarBase(ax_cb, cmap=cmap, norm=norm)
cb.set_label('Coverage')

fig.savefig('visualization.pdf')