In [1]:
import json
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
import altair as alt
import pandas as pd
from Bio import SeqIO
import csv
import os
import subprocess
from altair_saver import save
import numpy as np

In [2]:
grantham_matrix_fname = "grantham_with_gaps.tsv"
with open(grantham_matrix_fname) as tsvfile:
    reader = csv.DictReader(tsvfile, dialect='excel-tab')
    grantham_rows = [l for l in reader]
    grantham_score_triplets = []
    for row in grantham_rows:
        first = row['FIRST']
        grantham_score_triplets += [(first, key, int(score)) for key, score in row.items() if key != 'FIRST' and int(score) != 0]

In [3]:
def is_radical(mut_from, mut_to, allow_x=False):
    if allow_x and mut_from.lower() == 'x' or mut_to.lower() == 'x':
        return True
    score_matches = [t for t in grantham_score_triplets if mut_from in t and mut_to in t]
    if len(score_matches) != 1:
        raise Exception("couldn't find score for {} and {} in grantham matrix or found multiple.".format(mut_from, mut_to))
    return score_matches[0][2] >= 70

def get_region(position, inclusive_region_bounds):
    if inclusive_region_bounds is not None:
        for region, bounds in inclusive_region_bounds.items():
            if position >= bounds["5p"] and position <= bounds["3p"]:
                return 'CDR' # return region to differentiate between them
    return 'Framework'

def find_naive(seqrecords, naive_name):
    if naive_name is None:
        return seqrecords[0]
    else:
        naives = [record for record in seqrecords if record.id == naive_name]
        if len(naives) != 1:
            raise Exception('there must be exactly 1 sequence with name {}'.format(naive_name))
        return naives.pop()
    
def create_alignment(lineage_fname, inclusive_region_bounds=None, is_aa=False, json_fname=None, translate=False, naive_name=None):
    if translate:    
        seqrecords = []
        for record in SeqIO.parse(lineage_fname, "fasta"):
            record.seq = record.seq.translate()
            seqrecords.append(record)
    else:
        seqrecords = [record for record in SeqIO.parse(lineage_fname, "fasta")]
    naive = find_naive(seqrecords, naive_name)
    naive_seq = str(naive.seq)
    all_mutations = []
    # compute mutations for each node in the tree
    sortorder = []
    for iseq, seqrecord in enumerate(seqrecords):
        seq = str(seqrecord.seq)
        seq_id = seqrecord.id
        sortorder.append(seq_id)
        is_naive = (seq_id == naive_name) if naive_name is not None else (iseq == 0)
        for ipos, residue in enumerate(seq):
            region = get_region(ipos, inclusive_region_bounds)
            mut_record = {'type': 'lineage_member',
                          'index': iseq,
                          'seq_id': ' ' * iseq + seq_id, # hack which allows all seqs to be considered unique even with the same name
                          'position': ipos,
                          'mut_from': naive_seq[ipos],
                          'mut_to': residue,
                          'region': region,
                          'framework': region == 'Framework', 
                          'radical': False,
                          'real_mut': False} # adding mutations for every position so that we can show sequences with no mutations as just a line.
            # add this to ignores dashes: and residue.lower() != '-'
            if residue != naive_seq[ipos] and residue.lower() != '-':
                # add a mutation for a sequence deviating from the naive
                mut_record['real_mut'] = True
                if is_aa:
                    mut_record['radical'] = is_radical(naive[ipos], residue)
            if is_naive:
                # add a mutation for the naive so it shows up in the viz
                mut_record['type'] = 'naive'
            all_mutations.append(mut_record)
    json_fname = json_fname if json_fname else lineage_fname.split('.fasta')[0] + '.json'
    with open(json_fname, 'w') as outfile:
        json.dump(all_mutations, outfile)
    return sortorder, json_fname

In [8]:
def write_plot(chart, chart_path, scale_factor=3, png=True):
    '''
    save the altair chart to an svg and optionally as a png
    TODO replace this with https://github.com/altair-viz/altair_saver#nodejs
    '''
    save(chart, chart_path + '_vl.json')
    save(chart, chart_path + '.png')
    save(chart, chart_path + '.svg')

        
def make_one_tick_plot(plot_d, plot_save_dir, save_plot=True, display_plot=True, combined=False):
    if not os.path.exists(plot_save_dir):
        os.mkdir(plot_save_dir)
    sortorder, json_fname = create_alignment(plot_d["lineage_alignment_fasta"],
                                             plot_d["inclusive_region_bounds"],
                                             plot_d["is_aa"],
                                             plot_d["translate"])
    lineage_df = pd.read_json(json_fname)
    ticks = alt.Chart(lineage_df).mark_tick(thickness=2, opacity=1).encode(
        x=alt.X(
            'position:O',
            title="Amino acid position",
            axis=alt.Axis(grid=False, labelOverlap="parity", labelSeparation=10, labelFontSize=12)
        ),
        y=alt.Y(
            'seq_id:N',
            title="sequence",
            axis=alt.Axis(grid=True,
                          labels= not combined or plot_d["chain"] == "heavy",
                          title= "mAb" if not combined or plot_d["chain"] == "heavy" else None,
                          orient= "right" if combined and plot_d["chain"] == "light" else "left" ),
            sort=sortorder #for now this is the order they come in, which is the same as doing sort=None, but in case we want more advanced sorting later I'm using a variable  
        ),
        # COLOR:
        # Domain and range of color scale are set here. Set currently to "radical" field
        # which has domain: True, False and range: Red, Black (True=red, False=black)
        color=alt.Color(scale=alt.Scale(domain=[True, False], range=["black", "transparent"]), #changed to all black for now
                        field='real_mut',
                        type='nominal',
                        legend=None)
    ).properties(title=plot_d["title"])

    #hack to make tick marks for every position in the CDRs since making a large rectangle apparently is quite hard with an ordinal scale
    #region_marks = pd.DataFrame([{"x": v} for bounds in plot_d["inclusive_region_bounds"].values() for v in range(bounds["5p"], bounds["3p"]+1)])
    all_region_marks = lineage_df.copy().drop_duplicates("position")
    frs_region_marks = all_region_marks.loc[all_region_marks['framework']]
    frs = alt.Chart(frs_region_marks).mark_rect(opacity=0.25, color="#169cf5").encode(
        x='position:O'
    )
    cdrs_region_marks = all_region_marks.loc[all_region_marks['framework']==False]
    cdrs = alt.Chart(cdrs_region_marks).mark_rect(opacity=0.25, color="#f5e216").encode(
        x='position:O'
    )
    full_chart = frs + cdrs + ticks
    if save_plot:
        full_chart.properties(background = "white", width={"step":3})
        write_plot(full_chart, os.path.join(plot_save_dir, plot_d["title"].replace(' ', '_')))
    if display_plot:
        full_chart.display()
    return full_chart, max([len(seqid) for seqid in lineage_df["seq_id"]])  
    
def make_many_tick_plots(plot_summary_file, plot_save_dir):
    charts_dict = {}
    with open(plot_summary_file) as fh:
        plot_descriptions = json.load(fh)
    for i, plot_dict in enumerate(plot_descriptions):
        make_one_tick_plot(plot_dict, plot_save_dir)

        
def make_many_combined_tick_plots(plot_summary_file, plot_save_dir):
    charts_dict = {}
    with open(plot_summary_file) as fh:
        plot_descriptions = json.load(fh)
    for i, plot_dict in enumerate(plot_descriptions):
        if charts_dict.get(plot_dict["lineage"]) is None:
            charts_dict[plot_dict["lineage"]] = {}
        plot, offset = make_one_tick_plot(plot_dict, plot_save_dir, save_plot=False, display_plot=False, combined=True)
        charts_dict[plot_dict["lineage"]][plot_dict["chain"]] = {"plot": plot, "offset": offset}
    for lineage_name, lineage in charts_dict.items():
        h = alt.hconcat(lineage["heavy"]["plot"],
                        lineage["light"]["plot"],
                        title=alt.TitleParams(text=lineage_name + " Lineage",
                                              align="center",
                                              anchor="middle",
                                              fontSize=16,
                                              dx=lineage["heavy"]["offset"]*2.5 + 20),
                        spacing=-5,
                        bounds="full",
                        background = "white")
        h.display()
        write_plot(h, os.path.join(plot_save_dir, lineage_name.replace(' ', '_') + ".combined_plot"))

In [12]:
'''
EXAMPLE from plots_summary.json

[
    {
        "title":  "lineage 1",
        "lineage_alignment_fasta" : "072v2-Vh-minimal translation.fasta",
        "is_aa" : true,
        "inclusive_region_bounds" : {
                                        "CDR3": {"5p": 98, "3p": 121},
                                        "CDR2": {"5p": 40, "3p": 60}
                                    },
        "json_fname": null,
        "translate": false
    },
    {
        "title":  "lineage 2",
        "lineage_alignment_fasta" : "072v2-Vh-minimal translation.fasta",
        "is_aa" : true,
        "inclusive_region_bounds" : {
                                        "CDR3": {"5p": 108, "3p": 121},
                                        "CDR2": {"5p": 40, "3p": 60}
                                    },
        "json_fname": null,
        "translate": false
    }
]


1. Paste a file path whose contents look like the above example in into the function below with quotes like so:
make_many_tick_plots("plots_summary.json")
2. Click the fastforward looking button at the top to create the plots
3. Click on the three dots next to each plot to save an image with the desired format and name
'''

make_many_tick_plots("example_input/plots_summary.json", "output")
make_many_combined_tick_plots("example_input/plots_summary.json", "output")