In [None]:
import os

from Bio.SeqIO import QualityIO
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.cm as cm

import gzip

from utils import dna_rev_comp, translate_dna2aa, convert_phred_byte

In [None]:
base_dir = "acrdms_data_v2/"

In [None]:
gt_a4 = "MNINDLIREIKNKDYTVKLSGTDSNSITQLIIRVNNDGNEYVISESENESIVEKFISAFKNGWNQEYEDEEEFYNDMQTITLKSELN"
gt_a5 = "MAYGKSRYNSYRKRSFNRSNKQRREYAQEMDRLEKAFENLDGWYLSSMKDSAYKDFGKYEIRLSNHSADNKYHDLENGRLIVNIKASKLNFVDIIENKLDKIIEKIDKLDLDKYRFINATNLEHDIKCYYKGFKTKKEVI"

In [None]:
def read_sequences(file_1, file_2, min_overlap=50, max_overlap=250, quality_threshold=0):
    sequences = []
    qualities = []
    a_sequences = []
    b_sequences = []

    with gzip.open(file_1, "rt") as a_file, gzip.open(file_2, "rt") as b_file:
        a_reader = QualityIO.FastqGeneralIterator(a_file)
        b_reader = QualityIO.FastqGeneralIterator(b_file)
        for total_read, (a, b) in enumerate(zip(a_reader, b_reader)):
            a_id, a_seq, a_qual = a
            a_qual = convert_phred_byte(bytes(a_qual, "utf-8"))
            b_id, b_seq, b_qual = b
            b_inv_qual = convert_phred_byte(bytes(b_qual[::-1], "utf-8"))
            if len(a_seq) < 250:
                a_seq = a_seq + "G"
            if len(b_seq) < 250:
                b_seq = b_seq + "G"
            a_sequences.append(a_seq)
            b_sequences.append(b_seq)
            b_inv = dna_rev_comp(b_seq)
            for expected_overlap in range(min_overlap, max_overlap):
                if expected_overlap == 0 or a_seq[-expected_overlap:] == b_inv[:expected_overlap]:
                    res_seq = a_seq + b_inv[expected_overlap:]
                    res_qual = np.concatenate((
                        a_qual[:-expected_overlap],
                        np.minimum(a_qual[-expected_overlap:], b_inv_qual[:expected_overlap]),
                        b_inv_qual[expected_overlap:]
                    ), axis=0)
                    if (res_qual < quality_threshold).any():
                        break
                    sequences.append(res_seq)
                    break
        print("total reads", total_read)
    return sequences

def gather_variants(sequences, catch, gt):
    count = 0
    translations = {}
    multiples = []
    stops = []
    mislengths = []
    wildtype = 0

    length = 3 * len(gt)
    peptide_length = len(gt)
    catch_length = len(catch)
    synonymous = {}
    
    dist = np.zeros((length, 4))

    for sequence in sequences:
        tr = None
        if catch in sequence:
            index = sequence.index(catch) + catch_length
            gene = sequence[index:index + length]
            tr = translate_dna2aa(gene)
            count += 1
        if catch in dna_rev_comp(sequence):
            sequence = dna_rev_comp(sequence)
            index = sequence.index(catch) + catch_length
            gene = sequence[index:index + length]
            tr = translate_dna2aa(gene)
            count += 1
        if tr is not None:
            if len(tr) != peptide_length:
                mislengths.append(tr)
                continue
            if tr == gt:
                wildtype += 1
                if gene not in synonymous:
                    synonymous[gene] = 0
                synonymous[gene] += 1
            if (np.array([c for c in tr]) != np.array([c for c in gt])).sum() > 1:
                multiples.append(tr)
                continue
            if "*" in tr:
                stops.append(tr)
                # continue
            for idx, val in enumerate(gene):
                if tr[idx // 3] != gt[idx // 3]:
                    dist[idx, "GATC".index(val)] += 1
            if tr not in translations:
                translations[tr] = 0
            translations[tr] += 1
    return dist, translations, multiples, stops, mislengths, wildtype#, synonymous

AA_CODE = "ACDEFGHIKLMNPQRSTVWY*"
def check_mutants(translations, gt):
    length = len(gt)
    result = np.zeros((length, 21))
    for tr in translations:
        if tr == gt:
            for idx, c in enumerate(tr):
                result[idx, AA_CODE.index(c)] = 1
        for idx, c in enumerate(tr):
            if c != gt[idx]:
                result[idx, AA_CODE.index(c)] += translations[tr]
                break
    return result

def process_directory(base, catch, gt, min_overlap=50, max_overlap=250):
    peptide_length = len(gt)
    length = len(gt) * 3

    fraction_paths_r1 = sorted([
        f"{base}/{path}"
        for path in os.listdir(base)
        if "R1" in path or "1_sequence" in path
    ])
    fraction_paths_r2 = sorted([
        f"{base}/{path}"
        for path in os.listdir(base)
        if "R2" in path or "2_sequence" in path
    ])

    sequences = []
    results = []
    for f1, f2 in zip(fraction_paths_r1, fraction_paths_r2):
        seq = read_sequences(f1, f2, min_overlap=min_overlap, max_overlap=max_overlap)
        res = gather_variants(seq, catch, gt)
        wt = res[-1]
        res = (check_mutants(res[1], gt), wt)
        sequences.append(seq)
        results.append(res)
        
    return sequences, results

def normalise_single_run(result):
    wt = result[-1] + 1
    variants = result[0] + 1
    total = wt + variants.sum()
    wt_norm = wt / total
    variants_norm = variants / total
    return variants_norm, wt_norm

def relative_results(normalised):
    variant_stack = np.stack(map(lambda x: x[0], normalised))
    wt_stack = np.array([item[-1] for item in normalised])
    variant_stack = variant_stack / variant_stack.sum(axis=0, keepdims=True)
    wt_stack = wt_stack / wt_stack.sum(axis=0, keepdims=True)
    return variant_stack, wt_stack

def normalise_results(results):
    normalised = []
    for item in results:
        normalised.append(normalise_single_run(item))
    relative = relative_results(normalised)
    return normalised, relative

def plot_contacts(upper, lower, positions, cutoff=None):
    upper_ratio = sum(upper) / (sum(lower) + sum(upper))
    gt_indices = [AA_CODE.index(gt[pos]) for pos in positions]
    gt_matrix = np.zeros((20, len(positions)))
    for idx, pos in enumerate(gt_indices):
        gt_matrix[pos, idx] = 1

    top = upper_ratio[positions].T / (1 - gt_matrix)
    if cutoff:
        top = 1.0 * (top > cutoff)
    cmable = plt.matshow(top)
    ax = plt.gca()
    ax.set_yticks(range(20))
    ax.set_yticklabels("ACDEFGHIKLMNPQRSTVWY")
    ax.set_xticks(range(len(positions)))
    ax.set_xticklabels([pos + 1 for pos in positions])
    plt.colorbar(cmable);

def selected_fraction(variants, wt, select=None):
    select = select if isinstance(select, (list, tuple)) else [select]
    wt_upper = sum(wt[idx] for idx in select)
    wt_lower = sum(wt[idx] for idx in range(len(variants)) if idx not in select)
    upper = sum(variants[idx] for idx in select)
    lower = sum(variants[idx] for idx in range(len(variants)) if idx not in select)
    return lower, upper, wt_lower, wt_upper

# Load Data

### write library count matrices

In [None]:
catch = "aggaattcgcca".upper()
base = f"{base_dir}/library_a4/"
_, replicates_a4 = process_directory(base, catch, gt_a4)

base = f"{base_dir}/library_a5/"
_, replicates_a5 = process_directory(base, catch, gt_a5, min_overlap=10)

def write_count_files(path, results, gt):
    with open(path, "wt") as f:
        f.write("variant,num_reads_replicate_1,num_reads_replicate_2\n")
        variant = "wt"
        counts = [str(int(results[0][-1])) for i in range(2)]
        f.write(f"{variant},{','.join(counts)}\n")
        for pos in range(len(gt)):
            for aa in range(21):
                if AA_CODE[aa] == gt[pos]:
                    continue
                variant = f"{gt[pos]}{pos + 1}{AA_CODE[aa]}"
                counts = [str(int(results[i][0][pos, aa])) for i in range(2)]
                f.write(f"{variant},{','.join(counts)}\n")
write_count_files(f"{base_dir}/count_matrix_library_replicates_a4.csv", replicates_a4, gt_a4)
write_count_files(f"{base_dir}/count_matrix_library_replicates_a5.csv", replicates_a5, gt_a5)

## read A4 data

In [None]:
catch = "aggaattcgcca".upper()
base = f"{base_dir}/fractions_a4/"

relative_a4_ds3 = []
results_a4_ds3 = []
if not os.path.isfile(f"{base_dir}/savefile_a4_ds3.npy"):
    for replicate in os.listdir(base):
        full_path = f"{base}/{replicate}"
        sequences, results = process_directory(full_path, catch, gt_a4)
        normalised, relative = normalise_results(results)
        variants_flat = relative[0].reshape(len(results), -1)
        relative_a4_ds3.append(relative)
        results_a4_ds3.append(results)

    np.save(f"{base_dir}/savefile_a4_ds3.npy", (relative_a4_ds3, results_a4_ds3))
else:
    relative_a4_ds3, results_a4_ds3 = np.load(f"{base_dir}/savefile_a4_ds3.npy", allow_pickle=True)

## read A5 data

In [None]:
catch = "aggaattcgcca".upper()
base = f"{base_dir}/fractions_a5/"

relative_a5_ds3 = []
results_a5_ds3 = []
if not os.path.isfile(f"{base_dir}/savefile_a5_ds3.npy"):
    for replicate in os.listdir(base):
        full_path = f"{base}/{replicate}/"
        sequences, results = process_directory(full_path, catch, gt_a5, min_overlap=10)
        normalised, relative = normalise_results(results)
        variants_flat = relative[0].reshape(len(results), -1)
        relative_a5_ds3.append(relative)
        results_a5_ds3.append(results)

    np.save(f"{base_dir}/savefile_a5_ds3.npy", (relative_a5_ds3, results_a5_ds3))
else:
    relative_a5_ds3, results_a5_ds3 = np.load(f"{base_dir}/savefile_a5_ds3.npy", allow_pickle=True)

## write fraction count matrices

In [None]:
def write_count_files(path, results, gt):
    for idx, replicate in enumerate(results):
        with open(path + f"_rep_{idx + 1}.csv", "wt") as f:
            f.write("variant,reads_in_fraction_1,reads_in_fraction_2,reads_in_fraction_3,reads_in_fraction_4\n")
            variant = "wt"
            counts = [str(int(replicate[0][-1])) for i in range(4)]
            f.write(f"{variant},{','.join(counts)}\n")
            for pos in range(len(gt)):
                for aa in range(21):
                    if AA_CODE[aa] == gt[pos]:
                        continue
                    variant = f"{gt[pos]}{pos + 1}{AA_CODE[aa]}"
                    counts = [str(int(replicate[i][0][pos, aa])) for i in range(4)]
                    f.write(f"{variant},{','.join(counts)}\n")
        
write_count_files(f"{base_dir}/count_matrix_a4", results_a4_ds3, gt_a4)
write_count_files(f"{base_dir}/count_matrix_a5", results_a5_ds3, gt_a5)

# Plot coverage and mean fraction

## A4 mode fraction

In [None]:
argmax_frac = np.argmax(sum(rep[0] for rep in relative_a4_ds3) / 3, axis=0)
argmax_wt = np.argmax(sum(rep[1] for rep in relative_a4_ds3), axis=0)
index_a4 = np.array(list(map(AA_CODE.index, gt_a4)))
argmax_frac[np.arange(len(index_a4)), index_a4] = argmax_wt
fig, ax = plt.subplots(1, 1, figsize=(20, 4))
cx = ax.matshow(argmax_frac.T, cmap="Greys_r")
ax.set_xticks([i * 10 - 1 if i != 0 else 0 for i in range(9)])
ax.set_xticklabels([i * 10 if i != 0 else 1 for i in range(9)])
ax.set_yticks(range(21))
ax.set_yticklabels(AA_CODE)
cax = plt.colorbar(cx)
cax.set_ticks((0, 1, 2, 3))
cax.set_ticklabels((0, 1, 2, 3))
a4_dmsplot_dir = f"{base_dir}/outputs/A4_dms_plots/"
os.makedirs(a4_dmsplot_dir, exist_ok=True)
plt.savefig(f"{a4_dmsplot_dir}/mode_fraction.svg")

## A4 mean fraction

In [None]:
# compute mean fraction for non-wt variants
mean_frac = sum((rep[0] * np.arange(4)[:, None, None]).sum(axis=0) for rep in relative_a4_ds3) / 3
# compute mean fraction for wt
mean_wt = sum((rep[1] * np.arange(4)).sum(axis=0) for rep in relative_a4_ds3) / 3
index_a4 = np.array(list(map(AA_CODE.index, gt_a4)))
# include wt mean fraction at all wt positions
mean_frac[np.arange(len(index_a4)), index_a4] = mean_wt
# compute symmetric range centered around wt mean fraction
mean_range = abs(mean_frac - mean_wt).max()
# same for std
std_frac = np.sqrt(sum(((rep[0] * np.arange(4)[:, None, None]).sum(axis=0) - mean_frac) ** 2 for rep in relative_a4_ds3) / 2)
std_wt = np.sqrt(sum(((rep[1] * np.arange(4)).sum(axis=0) - mean_wt) ** 2 for rep in relative_a4_ds3) / 2)
index_a4 = np.array(list(map(AA_CODE.index, gt_a4)))
std_frac[np.arange(len(index_a4)), index_a4] = std_wt
# plot mean fraction
fig, ax = plt.subplots(1, 1, figsize=(20, 4))
cx = ax.matshow(mean_frac.T, cmap="RdBu_r", vmin=mean_wt - mean_range, vmax=mean_wt + mean_range)
ax.set_xticks([i * 10 - 1 if i != 0 else 0 for i in range(9)])
ax.set_xticklabels([i * 10 if i != 0 else 1 for i in range(9)])
ax.set_yticks(range(21))
ax.set_yticklabels(AA_CODE)
cax = plt.colorbar(cx)
ax.set_title("mean fraction")
plt.savefig(f"{a4_dmsplot_dir}/mean_fraction.svg")
# plot standard devation of mean fraction across 3 replicates
fig, ax = plt.subplots(1, 1, figsize=(20, 4))
cx = ax.matshow(std_frac.T, cmap="Greys_r")
ax.set_xticks([i * 10 - 1 if i != 0 else 0 for i in range(9)])
ax.set_xticklabels([i * 10 if i != 0 else 1 for i in range(9)])
ax.set_yticks(range(21))
ax.set_yticklabels(AA_CODE)
cax = plt.colorbar(cx)
ax.set_title("std dev fraction")
plt.savefig(f"{a4_dmsplot_dir}/std_dev_fraction.svg")

## Log number of reads per variant A4

In [None]:
target = np.log10(sum(res[0] for rep in results_a4_ds3 for res in rep).T)
wt_target = np.log10(sum(res[1] for rep in results_a4_ds3 for res in rep))
target = np.where(np.array(list(gt_a4))[None, :] == np.array(list(AA_CODE))[:, None], wt_target, target)
fig, ax = plt.subplots(figsize=(20, 4))
cx = ax.matshow(target, cmap="magma")
ax.set_xticks([i * 10 - 1 if i != 0 else 0 for i in range(9)])
ax.set_xticklabels([i * 10 if i != 0 else 1 for i in range(9)])
ax.set_yticks(range(21))
ax.set_yticklabels(AA_CODE)
plt.colorbar(cx)
plt.savefig(f"{a4_dmsplot_dir}/log_reads_per_variant.svg")

### log number of reads per fraction per variant A4 heatmap-of-heatmaps

In [None]:
target = sum(np.stack([res[0] for res in rep], axis=-1) for rep in results_a4_ds3)
target = np.swapaxes(target, 0, 1)
wt_target = sum(np.stack([res[1] for res in rep], axis=-1) for rep in results_a4_ds3)
target = np.where((np.array(list(gt_a4))[None, :] == np.array(list(AA_CODE))[:, None])[..., None], wt_target, target)
target = target.reshape(*target.shape[:2], 2, 2)
target /= target.sum(axis=(-1, -2), keepdims=True)
target = np.moveaxis(target, -2, 1)
target = np.moveaxis(target, -1, 3)
target = target.reshape(21 * 2, 87 * 2)
print(target.shape)
fig, ax = plt.subplots(figsize=(20, 4), dpi=600)
cx = ax.matshow(target, cmap="magma")
ax.set_xticks([(i * 10 - 1) * 2 + 0.5 if i != 0 else 0.5 for i in range(9)])
ax.set_xticklabels([i * 10 if i != 0 else 1 for i in range(9)])
ax.set_yticks([2 * i + 0.5 for i in range(21)])
ax.set_yticklabels(AA_CODE)
ax.hlines(2 * np.arange(21) - 0.5, -0.5, 87 * 2 - 0.5, color="w")
ax.vlines(2 * np.arange(87) - 0.5, -0.5, 21 * 2 - 0.5, color="w")
#ax.set_xlim((0 - 0.5, 87 * 2 -0.5))
#ax.grid(c='k', ls='-', lw='2')
ax.set_title("AcrIIA4 reads-per-fraction heatmap", fontsize=14)
ax.set_xlabel("sequence position", fontsize=12)
ax.set_ylabel("variant", fontsize=12)

plt.colorbar(cx)
plt.savefig(f"{a4_dmsplot_dir}/log_reads_per_variant_fraction.svg")

## A4 overview fraction distribution

In [None]:
from matplotlib.backends.backend_pdf import PdfPages

mean_dist_a4 = sum(rep[0] for rep in relative_a4_ds3) / 3
mean_dist_a4_wt = sum(rep[1] for rep in relative_a4_ds3) / 3
mean_dist_a4[:, np.arange(len(index_a4)), index_a4] = mean_dist_a4_wt[:, None]
std_dist_a4 = np.sqrt(sum((rep[0] - mean_dist_a4) ** 2 for rep in relative_a4_ds3) / 2)
std_dist_a4_wt = np.sqrt(sum((rep[1] - mean_dist_a4_wt) ** 2 for rep in relative_a4_ds3) / 2)
std_dist_a4[:, np.arange(len(index_a4)), index_a4] = std_dist_a4_wt[:, None]
fig = None
ax = None
with PdfPages(f"outputs/overview_figures/overview_fractions_a4_pages.pdf") as pdf:
    for pos in range(87):
        if pos % 10 == 0:
            if fig is not None:
                pdf.savefig(fig)
            remaining = len(gt_a4) - pos
            size = min(10, remaining)
            fig, ax = plt.subplots(size, 21, figsize=(21 * 4, size * 4))
        ax[pos % 10, 0].set_ylabel("fraction of reads")
        for aa in range(21):
            title = f"{gt_a4[pos]}{pos + 1}{AA_CODE[aa]}"
            if gt_a4[pos] == AA_CODE[aa]:
                title = "wt"
            ax[pos % 10, aa].set_title(title)
            ax[pos % 10, aa].bar(np.arange(4), mean_dist_a4[:, pos, aa], width=0.5, yerr=std_dist_a4[:, pos, aa])
            ax[pos % 10, aa].set_ylim(0.0, 1.0)
            ax[pos % 10, aa].set_xticks((0, 1, 2, 3))
            ax[pos % 10, aa].set_xticklabels((1, 2, 3, 4))
            ax[-1, aa].set_xlabel("sorted fraction ID")
    pdf.savefig(fig)

## A5 mode fraction

In [None]:
argmax_a5 = np.argmax(sum(rep[0] for rep in relative_a5_ds3) / 3, axis=0)
argmax_a5_wt = np.argmax(sum(rep[1] for rep in relative_a5_ds3) / 3, axis=0)
index_a5 = np.array(list(map(AA_CODE.index, gt_a5)))
argmax_a5[np.arange(len(index_a5)), index_a5] = argmax_a5_wt
fig, ax = plt.subplots(1, 1, figsize=(len(gt_a5) / 4, 20 / 4))
cx = ax.matshow(argmax_a5.T, cmap="Greys_r")
ax.set_yticks(np.arange(21))
ax.set_yticklabels(AA_CODE)
ax.set_xticks([10 * i - 1 if i != 0 else 0 for i in range(15)])
ax.set_xticklabels([10 * i if i != 0 else 1 for i in range(15)])
cax = plt.colorbar(cx)
cax.set_ticks((0, 1, 2, 3))
cax.set_ticklabels((0, 1, 2, 3))
a5_dmsplot_dir = f"{base_dir}/outputs/A5_dms_plots/"
os.makedirs(a5_dmsplot_dir, exist_ok=True)
plt.savefig(f"{a5_dmsplot_dir}/mode_fraction.svg")

## A5 mean fraction

In [None]:
# compute mean fraction for non-wt variants
mean_frac_a5 = sum((rep[0] * np.arange(4)[:, None, None]).sum(axis=0) for rep in relative_a5_ds3) / 3
# compute mean fraction for wt
mean_wt_a5 = sum((rep[1] * np.arange(4)).sum(axis=0) for rep in relative_a5_ds3) / 3
# include wt mean fraction at all wt positions
mean_frac_a5[np.arange(len(index_a5)), index_a5] = mean_wt_a5
# compute symmetric range centered around wt mean fraction
mean_range_a5 = abs(mean_frac_a5 - mean_wt_a5).max()
# same for std
std_frac_a5 = np.sqrt(sum(((rep[0] * np.arange(4)[:, None, None]).sum(axis=0) - mean_frac_a5) ** 2 for rep in relative_a5_ds3) / 2)
std_wt_a5 = np.sqrt(sum(((rep[1] * np.arange(4)).sum(axis=0) - mean_wt_a5) ** 2 for rep in relative_a5_ds3) / 2)
std_frac_a5[np.arange(len(index_a5)), index_a5] = std_wt_a5
# plot mean fraction
fig, ax = plt.subplots(1, 1, figsize=(30, 4))
cx = ax.matshow(mean_frac_a5.T, cmap="RdBu_r", vmin=mean_wt_a5 - mean_range_a5, vmax=mean_wt_a5 + mean_range_a5)
ax.set_xticks([10 * i - 1 if i != 0 else 0 for i in range(15)])
ax.set_xticklabels([10 * i if i != 0 else 1 for i in range(15)])
ax.set_yticks(range(21))
ax.set_yticklabels(AA_CODE)
cax = plt.colorbar(cx)
ax.set_title("mean fraction")
plt.savefig(f"{a5_dmsplot_dir}/mean_fraction.svg")
# plot standard devation of mean fraction across 3 replicates
fig, ax = plt.subplots(1, 1, figsize=(30, 4))
cx = ax.matshow(std_frac_a5.T, cmap="Greys_r")
ax.set_xticks([10 * i - 1 if i != 0 else 0 for i in range(15)])
ax.set_xticklabels([10 * i if i != 0 else 1 for i in range(15)])
ax.set_yticks(range(21))
ax.set_yticklabels(AA_CODE)
cax = plt.colorbar(cx)
ax.set_title("std dev fraction")
plt.savefig(f"{a5_dmsplot_dir}/std_dev_fraction.svg")

## A5 log readcount

In [None]:
target = np.log10(sum(res[0] for rep in results_a5_ds3 for res in rep).T)
wt_target = np.log10(sum(res[1] for rep in results_a5_ds3 for res in rep))
target = np.where(np.array(list(gt_a5))[None, :] == np.array(list(AA_CODE))[:, None], wt_target, target)
fig, ax = plt.subplots(figsize=(20, 4))
cx = ax.matshow(target, cmap="magma")
ax.set_xticks([10 * i - 1 if i != 0 else 0 for i in range(15)])
ax.set_xticklabels([10 * i if i != 0 else 1 for i in range(15)])
ax.set_yticks(range(21))
ax.set_yticklabels(AA_CODE)
plt.colorbar(cx)
plt.savefig(f"{a5_dmsplot_dir}/log_reads_per_variant.svg", dpi=600)

### log reads per fraction A5 heatmap-of-heatmaps

In [None]:
target = sum(np.stack([res[0] for res in rep], axis=-1) for rep in results_a5_ds3)
target = np.swapaxes(target, 0, 1)
wt_target = sum(np.stack([res[1] for res in rep], axis=-1) for rep in results_a5_ds3)
target = np.where((np.array(list(gt_a5))[None, :] == np.array(list(AA_CODE))[:, None])[..., None], wt_target, target)
target = target.reshape(*target.shape[:2], 2, 2)
target /= target.sum(axis=(-1, -2), keepdims=True)
target = np.moveaxis(target, -2, 1)
target = np.moveaxis(target, -1, 3)
target = target.reshape(21 * 2, (len(gt_a5)) * 2)
print(target.shape)
fig, ax = plt.subplots(figsize=(32, 4), dpi=600)
cx = ax.matshow(target, cmap="magma")
ax.set_xticks([(i * 10 - 1) * 2 + 0.5 if i != 0 else 0.5 for i in range(15)])
ax.set_xticklabels([i * 10 if i != 0 else 1 for i in range(15)])
ax.set_yticks([2 * i + 0.5 for i in range(21)])
ax.set_yticklabels(AA_CODE)
ax.hlines(2 * np.arange(21) - 0.5, -0.5, len(gt_a5) * 2 - 0.5, color="w")
ax.vlines(2 * np.arange(len(gt_a5)) - 0.5, -0.5, 21 * 2 - 0.5, color="w")
#ax.set_xlim((0 - 0.5, 87 * 2 -0.5))
#ax.grid(c='k', ls='-', lw='2')
ax.set_title("AcrIIA5 reads-per-fraction heatmap", fontsize=14)
ax.set_xlabel("sequence position", fontsize=12)
ax.set_ylabel("variant", fontsize=12)
plt.colorbar(cx)
plt.savefig(f"{a4_dmsplot_dir}/log_reads_per_variant_fraction_a5.svg")

## A5 overview fraction distribution

In [None]:
from matplotlib.backends.backend_pdf import PdfPages

mean_dist_a5 = sum(rep[0] for rep in relative_a5_ds3) / 3
mean_dist_a5_wt = sum(rep[1] for rep in relative_a5_ds3) / 3
mean_dist_a5[:, np.arange(len(index_a5)), index_a5] = mean_dist_a5_wt[:, None]
std_dist_a5 = np.sqrt(sum((rep[0] - mean_dist_a5) ** 2 for rep in relative_a5_ds3) / 2)
std_dist_a5_wt = np.sqrt(sum((rep[1] - mean_dist_a5_wt) ** 2 for rep in relative_a5_ds3) / 2)
std_dist_a5[:, np.arange(len(index_a5)), index_a5] = std_dist_a5_wt[:, None]
fig = None
ax = None
with PdfPages(f"outputs/overview_figures/overview_fractions_a5_pages.pdf") as pdf:
    for pos in range(len(gt_a5)):
        if pos % 10 == 0:
            if fig is not None:
                pdf.savefig(fig)
            remaining = len(gt_a5) - pos
            size = min(10, remaining)
            fig, ax = plt.subplots(size, 21, figsize=(21 * 4, size * 4))
        ax[pos % 10, 0].set_ylabel("fraction of reads")
        for aa in range(21):
            title = f"{gt_a5[pos]}{pos + 1}{AA_CODE[aa]}"
            if gt_a5[pos] == AA_CODE[aa]:
                title = "wt"
            ax[pos % 10, aa].set_title(title)
            ax[pos % 10, aa].bar(np.arange(4), mean_dist_a5[:, pos, aa], width=0.5, yerr=std_dist_a5[:, pos, aa])
            ax[pos % 10, aa].set_ylim(0.0, 1.0)
            ax[pos % 10, aa].set_xticks((0, 1, 2, 3))
            ax[pos % 10, aa].set_xticklabels((1, 2, 3, 4))
            ax[-1, aa].set_xlabel("sorted fraction ID")
    pdf.savefig(fig)

# Fit FACS data

In [None]:
import FlowCal
import os
import numpy as np
from torch.distributions import MixtureSameFamily, Normal
import torch
import torch.nn as nn
from sklearn.neighbors import KernelDensity

In [None]:
def make_dataset(path, relative, start, end, count, replicates=(0,)):
    targets = []
    vals = []
    names = []
    for p in sorted(os.listdir(path)):
        kind = p.split(".")[0].split("_")[1]
        original = kind[0]
        changed = kind[-1]
        position = int(kind[1:-1])
        data = FlowCal.io.FCSData(f"{path}/{p}")
        target = np.log(data[:, -1][data[:, -1] >= 1])
        bins = start + np.arange(0, count) * (end - start) / count + (end - start) / count / 2
        unique, counts = np.unique(np.argmin(abs(target[:, None] - bins[None, :]), axis=-1), return_counts=True)
        target = torch.zeros(count)
        target[unique] = torch.tensor(counts).float()
        target = target / target.sum()
        targets += len(replicates) * [target[None]]
        names += len(replicates) * [kind]
        changed = AA_CODE.index(changed)
        vals += [torch.tensor(relative[rep][0][:, position - 1, changed])[None] for rep in replicates]
    vals = torch.cat(vals, dim=0).float()#.view(-1, 6, *vals[0].shape[1:])
    targets = torch.cat(targets, dim=0).float()#.view(-1, 6, *targets[0].shape[1:])
    return vals, targets, names
    
def fit_facs(relative, path, frac=4, start=4, end=12, count=20, steps=100000, mode="linear", drop=None):
    linear = torch.nn.Linear(frac, count, bias=False)
    with torch.no_grad():
        linear.weight.zero_()
    loss = torch.nn.KLDivLoss(reduction="mean")
    optimizer = torch.optim.AdamW(linear.parameters(), lr=1e-4)
    vals, targets, _ = make_dataset(path, relative, start, end, count, replicates=(0, 1, 2))
    print(vals.shape, targets.shape)
    indices = torch.arange(0, targets.size(0), dtype=torch.long)
    if drop is not None:
        indices = torch.tensor([i for i in indices if i != drop])
    targets = targets[indices]
    vals = vals[indices]
    for idx in range(steps):
        optimizer.zero_grad()
        batch = torch.arange(targets.shape[0]) // 6 % 4 != 0
        v = vals.reshape(-1, vals.shape[-1])
        t = targets.reshape(-1, targets.shape[-1])
        out = linear(v[batch])
        out = out.log_softmax(dim=1)
        cross_entropy = -((out * t[batch])).sum(axis=-1).mean()
        mean_error = (abs(((out.exp() - t[batch]) * torch.arange(count)).sum(axis=-1))).mean()
        val = cross_entropy
        if idx % 500 == 0:
            print(float(cross_entropy), float(mean_error), end="\r", flush=True)
        val.backward()
        optimizer.step()
    return linear

def eval_facs(linears, relative, path, frac=4, start=4, end=12, count=20, steps=100000):
    vals, targets, _ = make_dataset(path, relative, start, end, count, replicates=(0,))

    errors = []
    predictions = []
    for idx, reg in enumerate(linears):
        errors.append((abs(reg(vals)[idx].softmax(dim=0) - targets[idx])).mean())
        predictions.append(reg(vals)[idx].softmax(dim=0).detach())
    return errors, predictions, targets
        
def predict_facs(linear, data, frac=8):
    results = []
    for rep in data:
        inputs = torch.tensor(rep[0]).float().permute(1, 2, 0)
        v = inputs.view(-1, frac)
        features = v#torch.cat((v, (v[:, :, None] * v[:, None, :]).reshape(v.shape[0], -1)), dim=-1)
        pred = linear(features)
        pred = pred.view(*inputs.shape[:2], pred.size(1))
        results.append(pred.softmax(dim=-1)[None])
    return torch.cat(results, axis=0)

def predict_wt(linear, data, frac=8):
    results = []
    for rep in data:
        inputs = torch.tensor(rep[1]).float()[None]
        pred = linear(inputs.view(-1, frac)).softmax(dim=-1)
        results.append(pred)
    return torch.cat(results, axis=0)

In [None]:
linear_a4_ds3 = fit_facs(relative_a4_ds3, f"{base_dir}/facs/A4/", mode="linear", start=0, end=14, count=20, drop=None)
linear_a5_ds3 = fit_facs(relative_a5_ds3, f"{base_dir}/facs/A5/", mode="linear", start=0, end=14, count=20, drop=None)

In [None]:
cx = plt.matshow(linear_a4_ds3.weight.detach().softmax(axis=0).numpy().T)
plt.colorbar(cx)
linear_dir = f"{base_dir}/outputs/linear_model/"
os.makedirs(linear_dir, exist_ok=True)
plt.savefig(f"{linear_dir}/a4_weight.svg")

## Fit cross-validation

In [None]:
linears_a4_ds3 = []
for idx in range(16):
    linears_a4_ds3.append(fit_facs(relative_a4_ds3, f"{base_dir}/facs/A4/", mode="linear", start=0, end=14, count=20, drop=idx))
linears_a5_ds3 = []
for idx in range(16):
    linears_a5_ds3.append(fit_facs(relative_a5_ds3, f"{base_dir}/facs/A5/", mode="linear", start=0, end=14, count=20, drop=idx))

#### save parameters

In [None]:
import torch
if not os.path.isfile(f"{base_dir}/backup.torch"):
    torch.save(dict(
        linear_a4_ds3=linear_a4_ds3,
        linear_a5_ds3=linear_a5_ds3,
        linears_a4_ds3=linears_a4_ds3,
        linears_a5_ds3=linears_a5_ds3
    ), f"{base_dir}/backup.torch")
else:
    backup = torch.load(f"{base_dir}/backup.torch")
    linear_a4_ds3 = backup["linear_a4_ds3"]
    linear_a5_ds3 = backup["linear_a5_ds3"]
    linears_a4_ds3 = backup["linears_a4_ds3"]
    linears_a5_ds3 = backup["linears_a5_ds3"]

### cross-validation error

In [None]:
errors_a4, cpredictions_a4, ctargets_a4 = eval_facs(linears_a4_ds3, relative_a4_ds3, f"{base_dir}/facs/A4/", count=20, frac=4, start=0, end=14)
errors_a5, cpredictions_a5, ctargets_a5 = eval_facs(linears_a5_ds3, relative_a5_ds3, f"{base_dir}/facs/A5/", count=20, frac=4, start=0, end=14)

### plot cross-validation error for variants

In [None]:
variant_names_a4 = []
for p in os.listdir(f"{base_dir}/facs/A4/"):
        kind = p.split(".")[0].split("_")[1]
        rep = p.split(".")[0].split("_")[2]
        if rep != "1":
            continue
        original = kind[0]
        changed = kind[-1]
        variant_names_a4.append(kind)
variant_names_a5 = []
for p in os.listdir(f"{base_dir}/facs/A5/"):
        kind = p.split(".")[0].split("_")[1]
        rep = p.split(".")[0].split("_")[2]
        if rep != "1":
            continue
        original = kind[0]
        changed = kind[-1]
        variant_names_a5.append(kind)

In [None]:
variant_names = dict(
    a4=variant_names_a4, a5=variant_names_a5#, a4_f2=variant_names_a4
)
print(variant_names)
ctargets = dict(
    a4=ctargets_a4, a5=ctargets_a5#, a4_f2=ctargets_a4_f2
)
cpredictions = dict(
    a4=cpredictions_a4, a5=cpredictions_a5#, a4_f2=cpredictions_a4_f2
)
errors = dict(
    a4=errors_a4, a5=errors_a5#, a4_f2=errors_a4_f2
)
bins = 0.0 + np.arange(0, 20) * 14.0 / 20 + 14.0 / 20 / 2
for acrkind in ("a4", "a5"):#, "a4_f2"):
    for idx, name in enumerate(variant_names[acrkind][:16]):
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.set_title(name)
        ax.fill_between(bins, ctargets[acrkind][idx], alpha=0.5)
        ax.fill_between(bins, cpredictions[acrkind][idx], alpha=0.5)
        ax.set_ylim(0.0, 0.5)
        plt.tight_layout()
        plt.savefig(f"{base_dir}/error_{acrkind}_{name}.svg")

### plot mean cross-validation error across all variants

In [None]:
for acrkind in ("a4", "a5"):#, "a4_f2"):
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.bar(list(range(len(errors[acrkind]))), [e.detach() for e in errors[acrkind]])
    ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    ax.set_xticks(range(16))
    print(variant_names[acrkind])
    ax.set_xticklabels(variant_names[acrkind][:16])
    ax.set_ylim(0.0, 0.1)
    plt.tight_layout()
    plt.savefig(f"{base_dir}/leave_one_out_{acrkind}.svg")

## Predict FACS

### predict A4 & A5 FACS profiles

In [None]:
pf_rel_a4 = predict_facs(linear_a4_ds3.eval(), relative_a4_ds3, frac=4).detach().numpy()
pf_wt_a4 = predict_wt(linear_a4_ds3.eval(), relative_a4_ds3, frac=4).detach().numpy()
pf_rel_a5 = predict_facs(linear_a5_ds3.eval(), relative_a5_ds3, frac=4).detach().numpy()
pf_wt_a5 = predict_wt(linear_a5_ds3.eval(), relative_a5_ds3, frac=4).detach().numpy()

conversion_factor = (np.arange(pf_rel_a4.shape[-1]) / pf_rel_a4.shape[-1] + 1 / pf_rel_a4.shape[-1] / 2) * 14.0

In [None]:
def plot_overview(path, rel, wt, index, gt):
    mean = rel.mean(axis=0)
    mean_wt = wt.mean(axis=0)
    mean[np.arange(len(index)), index, :] = mean_wt[None, :]
    std = rel.std(axis=0)
    std_wt = wt.std(axis=0)
    std[np.arange(len(index)), index, :] = std_wt[None, :]
    #fig, ax = plt.subplots(rel.shape[1], 21, figsize=(21 * 4, len(gt) * 4))
    fig = None
    ax = None
    with PdfPages(path) as pdf:
        for pos in range(rel.shape[1]):
            if pos % 10 == 0:
                if fig is not None:
                    pdf.savefig(fig)
                remaining = rel.shape[1] - pos
                size = min(10, remaining)
                fig, ax = plt.subplots(size, 21, figsize=(21 * 4, size * 4))
            ax[pos % 10, 0].set_ylabel("predicted probability")
            for aa in range(21):
                title = f"{gt[pos]}{pos + 1}{AA_CODE[aa]}"
                if gt[pos % 10] == AA_CODE[aa]:
                    title = "wt"
                ax[pos % 10, aa].set_title(title)
                ax[pos % 10, aa].bar(np.arange(20), mean[pos, aa, :], width=0.5, yerr=std[pos, aa, :])
                ax[pos % 10, aa].set_ylim(0.0, 1.0)
                ax[pos % 10, aa].set_xticks(np.arange(0, 20, 5))
                ax[pos % 10, aa].set_xticklabels([f"{i:.2f}" for i in conversion_factor[::5]])
                ax[-1, aa].set_xlabel("log fluorescence intensity")
        pdf.savefig(fig)
plot_overview("outputs/overview_figures/overview_predictions_a4_pages.pdf", pf_rel_a4, pf_wt_a4, index_a4, gt_a4)
plot_overview("outputs/overview_figures/overview_predictions_a5_pages.pdf", pf_rel_a5, pf_wt_a5, index_a5, gt_a5)

In [None]:
with open("a4_mean_prediction.csv", "wt") as f:
    f.write(f"variant_name,value,std\nwt,{mean_pred_mean_a4[0, AA_CODE.index(gt_a4[0])]},{mean_pred_std_a4[0, AA_CODE.index(gt_a4[0])]}\n")
    for pos in range(mean_pred_mean_a4.shape[0]):
        for aa in range(21):
            aa_name = AA_CODE[aa]
            if aa_name == gt_a4[pos]:
                continue
            val = mean_pred_mean_a4[pos, aa]
            std = mean_pred_std_a4[pos, aa]
            f.write(f"{gt_a4[pos]}{pos + 1}{aa_name},{val},{std}\n")

In [None]:
with open("a5_mean_prediction.csv", "wt") as f:
    f.write(f"variant_name,value,std\nwt,{mean_pred_mean_a5[0, AA_CODE.index(gt_a5[0])]},{mean_pred_std_a5[0, AA_CODE.index(gt_a5[0])]}\n")
    for pos in range(mean_pred_mean_a5.shape[0]):
        for aa in range(21):
            aa_name = AA_CODE[aa]
            if aa_name == gt_a5[pos]:
                continue
            val = mean_pred_mean_a5[pos, aa]
            std = mean_pred_std_a5[pos, aa]
            f.write(f"{gt_a5[pos]}{pos + 1}{aa_name},{val},{std}\n")

#### Mutation tolerance A4

In [None]:
pf_rel_a4_mean = (pf_rel_a4.mean(axis=0) * conversion_factor).sum(axis=-1)
pf_wt_a4_mean = (pf_wt_a4.mean(axis=0) * conversion_factor).sum(axis=-1)
wt_mask = np.zeros_like(pf_rel_a4_mean)
for idx, aa in enumerate(gt_a4):
    aa = AA_CODE.index(aa)
    wt_mask[idx, aa] = 1
wt_mask = wt_mask > 0
pf_rel_a4_mean[wt_mask] = np.nan
for threshold in [0.9, 0.95, 1.0]:
    print(threshold, (pf_rel_a4_mean >= pf_wt_a4_mean * threshold).sum() / (1 - wt_mask).sum())

In [None]:
pf_rel_a5_mean = (pf_rel_a5.mean(axis=0) * conversion_factor).sum(axis=-1)
pf_wt_a5_mean = (pf_wt_a5.mean(axis=0) * conversion_factor).sum(axis=-1)
wt_mask = np.zeros_like(pf_rel_a5_mean)
for idx, aa in enumerate(gt_a5):
    aa = AA_CODE.index(aa)
    wt_mask[idx, aa] = 1
wt_mask = wt_mask > 0
pf_rel_a5_mean[wt_mask] = np.nan
for threshold in [0.9, 0.95, 1.0]:
    print(threshold, (pf_rel_a5_mean >= pf_wt_a5_mean * threshold).sum() / (1 - wt_mask).sum())

#### Correlation with Kd

In [None]:
# correlations
kd_variants = dict(WT=8.31, G38C=14.8, N25G=8.69, E70T=112, E70D=24.5, Y67K=332000, M77A=76)
mean_bin_variants = dict(WT=pf_wt_a4_mean)
for name in kd_variants:
    if name not in mean_bin_variants:
        pos = int(name[1:-1]) - 1
        val = AA_CODE.index(name[-1])
        mean_bin_variants[name] = pf_rel_a4_mean[pos, val]

In [None]:
sorted_names = sorted(list(kd_variants.keys()))
kd_variants_rv = [1 / kd_variants[name] for name in sorted_names]
mean_bin_variants_v = [mean_bin_variants[name] for name in sorted_names]
fig, ax = plt.subplots()
ax.scatter(mean_bin_variants_v, kd_variants_rv)
for i, txt in enumerate(sorted_names):
    ax.annotate(txt, (mean_bin_variants_v[i], kd_variants_rv[i]))
ax.set_xlabel("mean log fluorescence intensity")
ax.set_ylabel("1 / Kd")
plt.savefig(f"{base_dir}/outputs/Kd_correlation_a4.svg")

In [None]:
# waterfall plot
COLORS=["teal", "orange", "purple", "red", "green", "grey", "brown"]
def waterfall(path, rel, wt, sequence, names, ymin=0.0, ymax=9.0):
    fig, ax = plt.subplots(figsize=(8, 2))
    rel = (rel.mean(axis=0) * conversion_factor).sum(axis=-1)
    wt = (wt.mean(axis=0) * conversion_factor).sum(axis=-1)
    wt_mask = np.zeros_like(rel)
    for idx, aa in enumerate(sequence):
        aa = AA_CODE.index(aa)
        wt_mask[idx, aa] = 1
    wt_mask = wt_mask > 0
    rel[wt_mask] = np.nan
    rel[wt_mask.nonzero()[0][0]] = wt
    index = np.argsort(rel.reshape(-1))
    revindex = np.arange(index.shape[0])
    revindex[index] = np.arange(index.shape[0])
    plotrel = rel.reshape(-1)[index]
    for idy, name in enumerate(names):
        if name == "WT":
            pos = 0
            aa = AA_CODE.index(sequence[pos])
            ii = np.ravel_multi_index([pos, aa], rel.shape)
            val = rel.reshape(-1)[ii]
            val2 = rel[pos, aa]
            ax.vlines(revindex[ii], 0.0, val, label=name, color="black")
        else:
            pos = int(name[1:-1]) - 1
            aa = AA_CODE.index(name[-1])
            ii = np.ravel_multi_index([pos, aa], rel.shape)
            val = rel.reshape(-1)[ii]
            val2 = rel[pos, aa]
            ax.vlines(revindex[ii], 0.0, val, label=name, color=COLORS[idy])
    ax.fill_between(np.arange(plotrel.shape[0]), plotrel, np.zeros_like(plotrel), color="#DD8888")
    ax.legend()
    ax.set_ylim(ymin, ymax)
waterfall(None, pf_rel_a4, pf_wt_a4, gt_a4, kd_variants)
plt.savefig(f"{base_dir}/outputs/waterfall_a4.svg", dpi=600)
waterfall(None, pf_rel_a5, pf_wt_a5, gt_a5, ["WT", "G3I", "R13F", "K58M", "D69A", "E76L", "G3W"], ymin=4, ymax=7)
plt.savefig(f"{base_dir}/outputs/waterfall_a5.svg", dpi=600)

#### Overview plot

In [None]:
import copy
def dot_heatmap(predictions, predictions_wt, gt, path, scale=0.25):
    pf_rel = predictions
    pf_wt = predictions_wt
    pf_corr = pf_rel - pf_wt[:, None, None]#pf_rel.mean(axis=(0, 1), keepdims=True)
    pf_corr = pf_corr * conversion_factor
    pf_mean = pf_corr.sum(axis=-1)#pf_corr.shape[-1]
    wt = pf_wt * conversion_factor
    wt = wt.sum(axis=-1)
    pf_mean, pf_std = pf_mean.mean(axis=0), pf_mean.std(axis=0)
    std_max = pf_std.max()
    std_min = 0#pf_std.min()
    std_norm = (pf_std - std_min) / (std_max - std_min)
    print(std_max, std_min)
    conf_norm = 1 - std_norm
    wt = wt.mean(axis=0)
    print(wt)
    #pf_mean = pf_mean + 0.5
    pmax = -pf_mean.min()
    pmax = max(pmax, pf_mean.max())
    pmin = -pmax
    pmax = pmax + wt
    pmin = pmin + wt
    #pf_mean = pf_mean / pmax

    x = 0.25
    width = pf_rel.shape[1]
    fig, ax = plt.subplots(1, 1, figsize=(width * x, 20 * x), sharex=False, sharey=False)

    data_x = []
    data_y = []
    c = []
    s = []
    for idy in range(20):
        for idx in range(width):
            data_x.append(idx)
            data_y.append(-idy)
            mean = pf_mean[idx, idy] + wt
            if AA_CODE.index(gt[idx]) == idy:
                c.append(100)
                s.append(1)
            else:
                c.append(mean)
                s.append(conf_norm[idx, idy] ** 2)

    #print(min(s), max(s))
    s = [
        0.9 * x ** 2
        for x in s
    ]
    cmap = cm.bwr

    ca = plt.scatter(data_x, data_y, c=c, s=[sv * 100 for sv in s],
                     cmap=cmap, vmin=pmin, vmax=pmax, edgecolors="black", linewidth=0.6)
    ca.cmap.set_over("black")
    plt.axis("off")
    plt.colorbar(ca)
    fig.savefig(path)

def square_heatmap(rel, wt, gt, path, scale=4):
    pf_rel = rel
    pf_wt = wt
    pf_corr = pf_rel - pf_wt[:, None, None]
    pf_corr = pf_corr * conversion_factor
    pf_mean = pf_corr.sum(axis=-1)
    wt = pf_wt * conversion_factor
    wt = wt.sum(axis=-1)
    pf_mean, pf_std = pf_mean.mean(axis=0), pf_mean.std(axis=0)
    std_max = pf_std.max()
    std_min = 0
    std_norm = (pf_std - std_min) / (std_max - std_min)
    conf_norm = 1 - std_norm
    wt = wt.mean(axis=0)
    pmax = -pf_mean.min()
    pmax = max(pmax, pf_mean.max())
    pmin = -pmax
    pmax = pmax + wt
    pmin = pmin + wt
    gt_mask = np.zeros_like(pf_mean)
    gt_mask[np.arange(pf_mean.shape[0], dtype=np.int32), np.array([AA_CODE.index(aa) for aa in gt])] = 1
    gt_mask = gt_mask > 0
    pf_mean = np.where(gt_mask, 100, pf_mean + wt)

    x = 0.25
    width = pf_rel.shape[1]
    fig, ax = plt.subplots(1, 1, figsize=(width * x, 20 * x), sharex=False, sharey=False)
    cmap = cm.bwr
    cax = ax.matshow(pf_mean.T, cmap=cmap, vmin=pmin, vmax=pmax)
    tick_count = len(gt) // 10 + 1
    ax.set_xticks([i * 10 - 1 if i != 0 else 0 for i in range(tick_count)])
    ax.set_xticklabels([i * 10 if i != 0 else 1 for i in range(tick_count)])
    ax.set_yticks(np.arange(21))
    ax.set_yticklabels(AA_CODE)
    cax.cmap.set_over("black")
    plt.colorbar(cax)
    fig.savefig(path)

dot_heatmap(pf_rel_a4, pf_wt_a4, gt_a4, f"{base_dir}/outputs/large-prediction-dotheatmap-A4-final.svg", scale=4)
square_heatmap(pf_rel_a4, pf_wt_a4, gt_a4, f"{base_dir}/outputs/large-prediction-squareheatmap-A4-final.svg", scale=4)

#### amino acid type overview

In [None]:
# position-wise violin plots
cas_positions = [13, 16, 17, 35, 37, 38, 39, 66, 68, 69]
non_cas_positions = [idx for idx in range(87) if idx not in cas_positions]
surface_positions = [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 51, 52, 53, 55, 56, 57, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 77, 78, 79, 80, 81] # TODO
core_positions = [5, 9, 16, 29, 31, 33, 50, 54, 58, 72, 76]
polar = "RNDCEQHKSTWY"
non_polar = "AGILMFPV"
basic = "RHK"
acidic = "DE"
neutral = "ANCQGILMFPSTWYV"
polar_positions = [idx for idx, c in enumerate(gt_a4) if c in polar] # TODO
nonpolar_positions = [idx for idx, c in enumerate(gt_a4) if c in non_polar]
acidic_positions = [idx for idx, c in enumerate(gt_a4) if c in acidic]
basic_positions = [idx for idx, c in enumerate(gt_a4) if c in basic]
neutral_positions = [idx for idx, c in enumerate(gt_a4) if c in neutral]
helix_positions = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 50, 51, 52, 53, 54, 55, 56, 57, 58, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80]
sheet_positions = [28, 29, 30, 31, 32, 39, 40, 41, 42, 43]
loop_positions = [0, 1, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 33, 34, 35, 36, 37, 38, 44, 45, 46, 47, 48, 49, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 81]

In [None]:
split_positions = {}
names = []
values = []
print(pf_rel_a4.shape)
for name, pos in {
    "spyCas9 contact": cas_positions,
    "non-contact": non_cas_positions,
    "surface": surface_positions,
    "core": core_positions,
    "structured": helix_positions + sheet_positions,
    "loop": loop_positions,
    "acidic": acidic_positions,
    "basic": basic_positions,
    "neutral": neutral_positions,
    "polar": polar_positions,
    "non-polar": nonpolar_positions,
}.items():
    items = []
    positions = pos
    for pos in positions:
        update = list((pf_rel_a4.mean(axis=0)[pos, :20] * conversion_factor).sum(axis=1) - (pf_wt_a4.mean(axis=0)[:20] * conversion_factor).sum(axis=0))
        update = [item for idx, item in enumerate(update)]
        items += update
    names.extend([name] * len(items))
    values.extend(items)

split_positions = dict(name=names, value=values)

In [None]:
from seaborn import violinplot
fig, ax = plt.subplots(figsize=(20, 5))
violinplot(x="name", y="value", width=1, data=split_positions, ax=ax, inner="quartiles")

count = 0
for l in ax.lines:
    if count % 3 == 1:
        l.set_linestyle('-')
        l.set_color('black')
    else:
        l.set_linestyle(":")
        l.set_color('black')
    count += 1

plt.savefig(f"{base_dir}/outputs/acriia4_activity_distribution.svg")

#### correlation mean fluorescence + cell culture

In [None]:
cell_culture = dict(
    wt=(0.958145856666667, 0.55904086),
    K18M=(0.82442277, 0.40304539),
    G21Q=(0.80654383, 0.32298049),
    S24K=(0.773889983333333, 0.375188933333333),
    S24P=(0.706929726666667, 0.30560112),
    N25G=(0.820051383333333, 0.376600463333333),
    I31Q=(0.83000416, 0.386209923333333),
    E40I=(0.731740993333333, 0.33765131),
    E70T=(0.069560613333333, 0.035723183333333),
    M77A=(0.174492683333333, 0.084132156666667)
)

kind = 0
names = ["wt"]
cc = [cell_culture["wt"][kind]]
pf_rel_a4_mean = (pf_rel_a4.mean(axis=0) * conversion_factor).sum(axis=-1)
pf_wt_a4_mean = (pf_wt_a4.mean(axis=0) * conversion_factor).sum(axis=-1)
value = [pf_wt_a4_mean]
for name in cell_culture:
    if name != "wt":
        pos = int(name[1:-1]) - 1
        aa = AA_CODE.index(name[-1])
        val = pf_rel_a4_mean[pos, aa]
        value.append(val)
        cc.append(cell_culture[name][kind])
        names.append(name)

fig, ax = plt.subplots()
ax.scatter(value, cc)

for i, txt in enumerate(names):
    ax.annotate(txt, (value[i], cc[i]))
    
plt.savefig(f"{base_dir}/outputs/acriia4_fluorescence_vs_cell_culture.svg")

#### correlation mean fluorescence to prior work

In [None]:
drive_activity_basgall = [0.027, 0.074, 0.015, 0.015, 0.084, 0.344, 0.845, 0.020, 0.000, 0.015, 0.113, 0.885, 0.983, 0.000, 0.595, 0.000, 0.432]
activity_basgall = [-np.log(v + 1e-3) for v in drive_activity_basgall]
names_basgall = ["D14A", "D23R", "N36A", "D37A", "G38A", "N39A", "N39R", "E40A", "N48A", "D69A", "D69R", "E70A", "E70R", "E72A", "F73A", "D76A", "M77A"]
activity_dong = [1.067, 0.216, 0.043, 0.062, 0.056, 0.048, 0.038, 0.130, 0.022]
names_dong = ["N12T", "D14R", "D23R", "N36Y", "G38A", "N39R", "E40R", "D69R", "E70R"]
activity_dong = [np.log(v) for v in activity_dong]

pred_log_basgall = [
    pf_rel_a4_mean[int(name[1:-1]) - 1, AA_CODE.index(name[-1])]
    for name in names_basgall
]
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
ax[0].scatter(pred_log_basgall, activity_basgall)

for i, txt in enumerate(names_basgall):
    ax[0].annotate(txt, (pred_log_basgall[i], activity_basgall[i]))


pred_log_dong = [
    pf_rel_a4_mean[int(name[1:-1]) - 1, AA_CODE.index(name[-1])]
    for name in names_dong
]
ax[1].scatter(pred_log_dong, activity_dong)

for i, txt in enumerate(names_dong):
    ax[1].annotate(txt, (pred_log_dong[i], activity_dong[i]))

plt.savefig(f"{base_dir}/outputs/correlation_prior_work.svg")

#### R^2 for log-log fit

In [None]:
print("R^2 Basgall et al.", np.corrcoef(np.array(pred_log_basgall), np.array(activity_basgall))[0, 1] ** 2)
print("R^2 Dong et al.", np.corrcoef(np.array(pred_log_dong), np.array(activity_dong))[0, 1] ** 2)
print("R Basgall et al.", np.corrcoef(np.array(pred_log_basgall), np.array(activity_basgall))[0, 1])
print("R Dong et al.", np.corrcoef(np.array(pred_log_dong), np.array(activity_dong))[0, 1])

### A5

#### Mutation tolerance A5

In [None]:
pf_rel_a5_mean = (pf_rel_a5.mean(axis=0) * conversion_factor).sum(axis=-1)
pf_wt_a5_mean = (pf_wt_a5.mean(axis=0) * conversion_factor).sum(axis=-1)
for threshold in [0.9, 0.95, 1.0]:
    print(threshold, (pf_rel_a5_mean >= pf_wt_a5_mean * threshold).sum() / pf_rel_a5_mean.size)

#### Overview plot

In [None]:
dot_heatmap(pf_rel_a5, pf_wt_a5, gt_a5, f"{base_dir}/outputs/large-prediction-dotheatmap-A5-final.svg", scale=4)
square_heatmap(pf_rel_a5, pf_wt_a5, gt_a5, f"{base_dir}/outputs/large-prediction-squareheatmap-A5-final.svg", scale=4)

#### amino-acid type overview

In [None]:
# position-wise violin plots
surface_positions = [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 31, 33, 34, 37, 38, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 52, 54, 55, 56, 57, 58, 59, 61, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 85, 86, 87, 89, 92, 95, 96, 99, 100, 102, 103, 104, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 117, 118, 119, 120, 122, 123, 124, 126, 127, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139]
core_positions = [3, 7, 29, 32, 35, 36, 39, 50, 51, 53, 60, 62, 63, 80, 83, 84, 88, 90, 91, 93, 94, 97, 98, 101, 105, 116, 121, 125, 128, 131]
polar = "RNDCEQHKSTWY"
non_polar = "AGILMFPV"
basic = "RHK"
acidic = "DE"
neutral = "ANCQGILMFPSTWYV"
polar_positions = [idx for idx, c in enumerate(gt_a5) if c in polar]
nonpolar_positions = [idx for idx, c in enumerate(gt_a5) if c in non_polar]
acidic_positions = [idx for idx, c in enumerate(gt_a5) if c in acidic]
basic_positions = [idx for idx, c in enumerate(gt_a5) if c in basic]
neutral_positions = [idx for idx, c in enumerate(gt_a5) if c in neutral]
helix_positions = [4, 5, 6, 7, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 113, 114, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 135, 136, 137, 138]
sheet_positions = [42, 43, 44, 50, 51, 52, 53, 54, 59, 60, 61, 62, 63]
loop_positions = [0, 1, 2, 3, 13, 14, 15, 16, 17, 18, 39, 40, 41, 45, 46, 47, 48, 49, 55, 56, 57, 58, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 109, 115, 116, 117, 131, 132, 133, 134, 139]
idr_positions = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]
idr_2_positions = list(range(64, 79))

In [None]:
split_positions = {}
names = []
values = []
for name, pos in {
    "surface": surface_positions,
    "core": core_positions,
    "structured": helix_positions + sheet_positions,
    "loop": loop_positions,
    "IDR-NT": idr_positions,
    "IDR-center": idr_2_positions,
    "acidic": acidic_positions,
    "basic": basic_positions,
    "neutral": neutral_positions,
    "polar": polar_positions,
    "non-polar": nonpolar_positions,
}.items():
    items = []
    positions = pos
    for pos in positions:
        update = list(pf_rel_a5_mean[pos, :20] - pf_wt_a5_mean)
        update = [item for idx, item in enumerate(update)]
        items += update
    names.extend([name] * len(items))
    values.extend(items)

split_positions = dict(name=names, value=values)

In [None]:
from seaborn import violinplot
fig, ax = plt.subplots(figsize=(20, 5))
violinplot(x="name", y="value", width=1, data=split_positions, ax=ax, inner="quartiles")

count = 0
for l in ax.lines:
    if count % 3 == 1:
        l.set_linestyle('-')
        l.set_color('black')
    else:
        l.set_linestyle(":")
        l.set_color('black')
    count += 1

plt.savefig(f"{base_dir}/outputs/acriia5_activity_distribution.svg")

#### amino acid mutant correlation

In [None]:
fig, ax = plt.subplots()
cax = ax.matshow(np.corrcoef(pf_rel_a4_mean.T), cmap="Reds")
ax.set_xticks(np.arange(21))
ax.set_xticklabels(AA_CODE)
ax.set_yticks(np.arange(21))
ax.set_yticklabels(AA_CODE)
plt.colorbar(cax)
plt.savefig(f"{base_dir}/outputs/aa_correlation_a4.svg")

fig, ax = plt.subplots()
cax = ax.matshow(np.corrcoef(pf_rel_a5_mean.T), cmap="Reds")
ax.set_xticks(np.arange(21))
ax.set_xticklabels(AA_CODE)
ax.set_yticks(np.arange(21))
ax.set_yticklabels(AA_CODE)
plt.colorbar(cax)
plt.savefig(f"{base_dir}/outputs/aa_correlation_a5.svg")