In [None]:
import os

from Bio.SeqIO import QualityIO
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.cm as cm

import gzip

from utils import dna_rev_comp, translate_dna2aa

In [None]:
base_dir = "data/"

In [None]:
gt_a4 = "MNINDLIREIKNKDYTVKLSGTDSNSITQLIIRVNNDGNEYVISESENESIVEKFISAFKNGWNQEYEDEEEFYNDMQTITLKSELN"
gt_a5 = "MAYGKSRYNSYRKRSFNRSNKQRREYAQEMDRLEKAFENLDGWYLSSMKDSAYKDFGKYEIRLSNHSADNKYHDLENGRLIVNIKASKLNFVDIIENKLDKIIEKIDKLDLDKYRFINATNLEHDIKCYYKGFKTKKEVI"

In [None]:
def read_sequences(file_1, file_2, min_overlap=50, max_overlap=250):
    sequences = []
    a_sequences = []
    b_sequences = []

    with gzip.open(file_1, "rt") as a_file, gzip.open(file_2, "rt") as b_file:
        a_reader = QualityIO.FastqGeneralIterator(a_file)
        b_reader = QualityIO.FastqGeneralIterator(b_file)
        for total_read, (a, b) in enumerate(zip(a_reader, b_reader)):
            a_id, a_seq, a_qual = a
            b_id, b_seq, b_qual = b
            if len(a_seq) < 250:
                a_seq = a_seq + "G"
            if len(b_seq) < 250:
                b_seq = b_seq + "G"
            a_sequences.append(a_seq)
            b_sequences.append(b_seq)
            b_inv = dna_rev_comp(b_seq)
            for expected_overlap in range(min_overlap, max_overlap):
                if expected_overlap == 0 or a_seq[-expected_overlap:] == b_inv[:expected_overlap]:
                    res_seq = a_seq + b_inv[expected_overlap:]
                    sequences.append(res_seq)
                    break
        print(total_read)
    return sequences

def gather_variants(sequences, catch, gt):
    count = 0
    translations = {}
    multiples = []
    stops = []
    mislengths = []
    wildtype = 0

    length = 3 * len(gt)
    peptide_length = len(gt)
    catch_length = len(catch)
    
    dist = np.zeros((length, 4))

    for sequence in sequences:
        tr = None
        if catch in sequence:
            index = sequence.index(catch) + catch_length
            gene = sequence[index:index + length]
            tr = translate_dna2aa(gene)
            count += 1
        if catch in dna_rev_comp(sequence):
            sequence = dna_rev_comp(sequence)
            index = sequence.index(catch) + catch_length
            gene = sequence[index:index + length]
            tr = translate_dna2aa(gene)
            count += 1
        if tr is not None:
            if "*" in tr:
                stops.append(tr)
                continue
            if len(tr) != peptide_length:
                mislengths.append(tr)
                continue
            if tr == gt:
                wildtype += 1
            if (np.array([c for c in tr]) != np.array([c for c in gt])).sum() > 1:
                multiples.append(tr)
                continue
            for idx, val in enumerate(gene):
                if tr[idx // 3] != gt[idx // 3]:
                    dist[idx, "GATC".index(val)] += 1
            if tr not in translations:
                translations[tr] = 0
            translations[tr] += 1
    return dist, translations, multiples, stops, mislengths, wildtype

AA_CODE = "ACDEFGHIKLMNPQRSTVWY"
def check_mutants(translations, gt):
    length = len(gt)
    result = np.zeros((length, 20))
    for tr in translations:
        if tr == gt:
            for idx, c in enumerate(tr):
                result[idx, AA_CODE.index(c)] = 1#float("nan")
        for idx, c in enumerate(tr):
            if c != gt[idx]:
                result[idx, AA_CODE.index(c)] += translations[tr]
    return result

def process_directory(base, catch, gt, min_overlap=50, max_overlap=250):
    peptide_length = len(gt)
    length = len(gt) * 3

    fraction_paths_r1 = sorted([
        f"{base}/{path}"
        for path in os.listdir(base)
        if "R1" in path or "1_sequence" in path
    ])
    fraction_paths_r2 = sorted([
        f"{base}/{path}"
        for path in os.listdir(base)
        if "R2" in path or "2_sequence" in path
    ])

    sequences = []
    results = []
    for f1, f2 in zip(fraction_paths_r1, fraction_paths_r2):
        seq = read_sequences(f1, f2, min_overlap=min_overlap, max_overlap=max_overlap)
        res = gather_variants(seq, catch, gt)
        wt = res[-1]
        res = (check_mutants(res[1], gt), wt)
        sequences.append(seq)
        results.append(res)
        
    return sequences, results

def normalise_single_run(result):
    wt = result[-1] + 1
    variants = result[0] + 1
    total = wt + variants.sum()
    wt_norm = wt / total
    variants_norm = variants / total
    return variants_norm, wt_norm

def relative_results(normalised):
    variant_stack = np.stack(map(lambda x: x[0], normalised))
    wt_stack = np.array([item[-1] for item in normalised])
    variant_stack = variant_stack / variant_stack.sum(axis=0, keepdims=True)
    wt_stack = wt_stack / wt_stack.sum(axis=0, keepdims=True)
    return variant_stack, wt_stack

def normalise_results(results):
    normalised = []
    for item in results:
        normalised.append(normalise_single_run(item))
    relative = relative_results(normalised)
    return normalised, relative

def plot_contacts(upper, lower, positions, cutoff=None):
    upper_ratio = sum(upper) / (sum(lower) + sum(upper))
    gt_indices = [AA_CODE.index(gt[pos]) for pos in positions]
    gt_matrix = np.zeros((20, len(positions)))
    for idx, pos in enumerate(gt_indices):
        gt_matrix[pos, idx] = 1

    top = upper_ratio[positions].T / (1 - gt_matrix)
    if cutoff:
        top = 1.0 * (top > cutoff)
    cmable = plt.matshow(top)
    ax = plt.gca()
    ax.set_yticks(range(20))
    ax.set_yticklabels("ACDEFGHIKLMNPQRSTVWY")
    ax.set_xticks(range(len(positions)))
    ax.set_xticklabels([pos + 1 for pos in positions])
    plt.colorbar(cmable);

def selected_fraction(variants, wt, select=None):
    select = select if isinstance(select, (list, tuple)) else [select]
    wt_upper = sum(wt[idx] for idx in select)
    wt_lower = sum(wt[idx] for idx in range(len(variants)) if idx not in select)
    upper = sum(variants[idx] for idx in select)
    lower = sum(variants[idx] for idx in range(len(variants)) if idx not in select)
    return lower, upper, wt_lower, wt_upper

# Load Data

## A4

In [None]:
catch = "aggaattcgcca".upper()
base = f"{base_dir}/embla4ds3/"

if not os.path.isfile(f"{base_dir}/savefile_a4_ds3.npy"):
    sequences, results = process_directory(base, catch, gt_a4)
    normalised, relative = normalise_results(results)
    variants_flat = relative[0].reshape(len(results), -1)
    relative_a4_ds3 = relative
    results_a4_ds3 = results

    np.save(f"{base_dir}/savefile_a4_ds3.npy", (relative_a4_ds3, results_a4_ds3))
else:
    relative_a4_ds3, results_a4_ds3 = np.load(f"{base_dir}/savefile_a4_ds3.npy", allow_pickle=True)

## A5

In [None]:
catch = "aggaattcgcca".upper()
base = f"{base_dir}/embla5ds3/"

if not os.path.isfile(f"{base_dir}/savefile_a5_ds3.npy"):
    sequences, results = process_directory(base, catch, gt, min_overlap=10)
    normalised, relative = normalise_results(results)
    variants_flat = relative[0].reshape(len(results), -1)
    relative_a5_ds3 = relative
    results_a5_ds3 = results

    np.save(f"{base_dir}/savefile_a5_ds3.npy", (relative_a5_ds3, results_a5_ds3))
else:
    relative_a5_ds3, results_a5_ds3 = np.load(f"{base_dir}/savefile_a5_ds3.npy", allow_pickle=True)

### Confidence metrics

In [None]:
def raw_counts(data):
    result = np.concatenate([
        item[0][:, :, None]
        for item in data
    ], axis=-1)
    total = result.sum() + sum(item[1] for item in data)
    return np.log((result.sum(axis=2) + 1) / total)

def absolute_counts(data):
    result = np.concatenate([
        item[0][:, :, None]
        for item in data
    ], axis=-1)
    return result.sum(axis=2)

def read_entropy(data):
    return (data[0] * np.log(data[0] * 8)).sum(axis=0)

raw_counts_a4_ds3 = raw_counts(results_a4_ds3)
raw_counts_a5_ds3 = raw_counts(results_a5_ds3)

absolute_counts_a4_ds3 = absolute_counts(results_a4_ds3)
absolute_counts_a5_ds3 = absolute_counts(results_a5_ds3)

entropy_a4_ds3 = read_entropy(relative_a4_ds3)
entropy_a5_ds3 = read_entropy(relative_a5_ds3)


plt.matshow(raw_counts_a4_ds3.T)
plt.matshow(raw_counts_a5_ds3.T)

cx = plt.matshow((raw_counts_a4_ds3).T)
plt.colorbar(cx)
plt.savefig(f"{base_dir}/coverage_a4.svg")
cx = plt.matshow((raw_counts_a5_ds3).T)
plt.colorbar(cx)
plt.savefig(f"{base_dir}/coverage_a5.svg")

plt.matshow(entropy_a4_ds3.T)
plt.matshow(entropy_a5_ds3.T)
plt.show()

plt.scatter(raw_counts_a4_ds1.reshape(-1), -entropy_a4_ds1.reshape(-1))

#### fraction of variants with N reads

In [None]:
absolute_counts = dict(a4=absolute_counts_a4_ds3, a5=absolute_counts_a5_ds3)
for key in absolute_counts:
    for N in [1, 5, 10, 50, 100]:
        print(f"{key} > {N}", 100 * (absolute_counts[key] >= N).sum() / absolute_counts[key].size, "%")

#### Compare confidence for A4 and A5

In [None]:
from seaborn import violinplot

data = dict(
    value=list(entropy_a4_ds3.reshape(-1)) + list(entropy_a5_ds3.reshape(-1)),
    name=["A4"] * len(entropy_a4_ds3.reshape(-1)) + ["A5"] * len(entropy_a5_ds3.reshape(-1))
)

violinplot(x="name", y="value", data=data, inner="quartiles")
plt.savefig(f"{base_dir}/entropy_distribution_a4_vs_a5.svg")

#### Top 10 / Bottom 10 confidence by entropy on reads

In [None]:
index = list(range(87 * 20))
index = sorted(index, key=lambda x: entropy_a4_ds3.reshape(-1)[x])
fig, ax = plt.subplots(2, 10, figsize=(20, 5))
for i in range(10):
    x, y = np.unravel_index(index[-(i + 1)], (87, 20))
    ax[0, i].plot(relative_a4_ds3[0][:, x, y])
    ax[0, i].set_ylim(0, 1)
for i in range(10):
    x, y = np.unravel_index(index[i], (87, 20))
    ax[1, i].plot(relative_a4_ds3[0][:, x, y])
    ax[1, i].set_ylim(0, 1)

#### Top 10 / Bottom 10 confidence by read fraction

In [None]:
index = list(range(87 * 20))
index = sorted(index, key=lambda x: raw_counts_a4_ds3.reshape(-1)[x])
fig, ax = plt.subplots(2, 10, figsize=(20, 5))
for i in range(10):
    x, y = np.unravel_index(index[-(i + 1)], (87, 20))
    ax[0, i].plot(relative_a4_ds3[0][:, x, y])
    ax[0, i].set_ylim(0, 1)
for i in range(10):
    x, y = np.unravel_index(index[i], (87, 20))
    ax[1, i].plot(relative_a4_ds3[0][:, x, y])
    ax[1, i].set_ylim(0, 1)

# Fit FACS data

In [None]:
import FlowCal
import os
import numpy as np
from torch.distributions import MixtureSameFamily, Normal
import torch
import torch.nn as nn
from sklearn.neighbors import KernelDensity

In [None]:
class Constant(torch.nn.Module):
    def __init__(self, out_size):
        super().__init__()
        self.bias = torch.nn.Parameter(torch.zeros(1, out_size, requires_grad=True))

    def forward(self, inputs):
        return self.bias.expand(inputs.size(0), self.bias.size(1))

def fit_facs(relative, path, frac=8, start=4, end=12, factor=2, steps=100000, mode="linear", drop=None):
    if mode == "linear":
        linear = torch.nn.Linear(frac, (end - start) * factor + 1, bias=True)
        with torch.no_grad():
            linear.weight.zero_()
            linear.bias.zero_()
    else: # mode == "constant"
        linear = Constant((end - start) * factor + 1)
    loss = torch.nn.KLDivLoss(reduction="mean")
    optimizer = torch.optim.AdamW(linear.parameters(), lr=1e-4, weight_decay=1e-2)
    targets = []
    vals = []
    for p in os.listdir(path):
        kind = p.split(".")[0].split("_")[3]
        original = kind[0]
        changed = kind[-1]
        position = int(kind[1:-1])
        data = FlowCal.io.FCSData(f"{path}/{p}")
        target = np.log(data[:, 6][data[:, 6] >= 1])
        binned = (2 * (torch.tensor(target).clamp(start, end) - start)).floor().long()
        target = torch.zeros((end - start) * factor + 1)
        unique, counts = binned.unique(return_counts=True)
        target[unique] = 1.0 * counts
        target = target / target.sum()
        targets.append(target[None])
        changed = AA_CODE.index(changed)
        vals.append(torch.tensor(relative[0][:, position - 1, changed])[None])
    vals = torch.cat(vals, dim=0).float()
    targets = torch.cat(targets, dim=0).float()

    indices = [idx for idx in range(len(targets))]
    if drop is not None:
        del indices[drop]
    indices = torch.tensor(indices)
    targets = targets[indices]
    vals = vals[indices]
    for idx in range(steps):
        v = vals
        t = targets
        out = linear(v)
        out = out.softmax(dim=1)
        val = ((out - t) ** 2).sum()
        print(float(val), end="\r")
        val.backward()
        optimizer.step()
    return linear

def eval_facs(linears, relative, path, frac=8, start=4, end=12, factor=2, steps=100000, lin=True, drop=None):
    targets = []
    vals = []
    for p in os.listdir(path):
        kind = p.split(".")[0].split("_")[3]
        original = kind[0]
        changed = kind[-1]
        position = int(kind[1:-1])
        data = FlowCal.io.FCSData(f"{path}/{p}")
        target = np.log(data[:, 6][data[:, 6] >= 1])
        binned = (2 * (torch.tensor(target).clamp(start, end) - start)).floor().long()
        target = torch.zeros((end - start) * factor + 1)
        unique, counts = binned.unique(return_counts=True)
        target[unique] = 1.0 * counts
        target = target / target.sum()
        targets.append(target[None])
        changed = AA_CODE.index(changed)
        vals.append(torch.tensor(relative[0][:, position - 1, changed])[None])
    vals = torch.cat(vals, dim=0).float()
    targets = torch.cat(targets, dim=0).float()

    errors = []
    predictions = []
    for idx, reg in enumerate(linears):
        errors.append(((reg(vals)[idx].softmax(dim=0) - targets[idx]) ** 2).mean())
        predictions.append(reg(vals)[idx].softmax(dim=0).detach())
    return errors, predictions, targets
        
def predict_facs(linear, data, frac=8):
    inputs = torch.tensor(data[0]).float().permute(1, 2, 0)
    pred = linear(inputs.view(-1, frac))
    pred = pred.view(*inputs.shape[:2], pred.size(1))
    return pred

def predict_wt(linear, data, frac=8):
    inputs = torch.tensor(data[1]).float()[None]
    pred = linear(inputs.view(-1, frac))
    return pred[0]

In [None]:
linear_a4_ds3 = fit_facs(relative_a4_ds3, f"{base_dir}/FACS/FACS/A4/", mode="linear", start=0, end=12, drop=None)
linear_a5_ds3 = fit_facs(relative_a5_ds3, f"{base_dir}/FACS/FACS/A5_new/", mode="linear", frac=8, start=0, end=12, drop=None)

#### Fit two fractions

In [None]:
relative_a4_ds3_f2 = (
    np.concatenate((
        relative_a4_ds3[0][:4].sum(axis=0, keepdims=True),
        relative_a4_ds3[0][4:].sum(axis=0, keepdims=True)
    ), axis=0),
    np.concatenate((
        relative_a4_ds3[1][:4].sum(axis=0, keepdims=True),
        relative_a4_ds3[1][4:].sum(axis=0, keepdims=True)
    ), axis=0),

)

In [None]:
linear_a4_ds3_f2 = fit_facs(
    relative_a4_ds3_f2,
    f"{base_dir}/FACS/FACS/A4/",
    mode="linear", frac=2, start=0, end=12, drop=None
)

## Fit cross-validation

In [None]:
linears_a4_ds3 = []
for idx in range(16):
    linears_a4_ds3.append(fit_facs(relative_a4_ds3, f"{base_dir}/FACS/FACS/A4/", mode="linear", start=0, end=12, drop=idx))
linears_a5_ds3 = []
for idx in range(16):
    linears_a5_ds3.append(fit_facs(relative_a5_ds3, f"{base_dir}/FACS/FACS/A5_new/", mode="linear", start=0, end=12, drop=idx))

#### two fractions

In [None]:
linears_a4_ds3_f2 = []
for idx in range(16):
    linears_a4_ds3_f2.append(fit_facs(
        relative_a4_ds3_f2,
        f"{base_dir}/FACS/FACS/A4/",
        mode="linear", frac=2, start=0, end=12, drop=idx
    ))

#### save parameters

In [None]:
import torch
if not os.path.isfile(f"{base_dir}/backup.torch"):
    torch.save(dict(
        linear_a4_ds3=linear_a4_ds3,
        linear_a5_ds3=linear_a5_ds3,
        linears_a4_ds3=linears_a4_ds3,
        linears_a5_ds3=linears_a5_ds3
    ), f"{base_dir}/backup.torch")
else:
    backup = torch.load(f"{base_dir}/backup.torch")
    linear_a4_ds3 = backup["linear_a4_ds3"]
    linear_a5_ds3 = backup["linear_a5_ds3"]
    linears_a4_ds3 = backup["linears_a4_ds3"]
    linears_a5_ds3 = backup["linears_a5_ds3"]

### cross-validation error

In [None]:
errors_a4, cpredictions_a4, ctargets_a4 = eval_facs(linears_a4_ds3, relative_a4_ds3, f"{base_dir}/FACS/FACS/A4/", factor=2, lin=True, start=0, end=12)
errors_a4_f2, cpredictions_a4_f2, ctargets_a4_f2 = eval_facs(linears_a4_ds3_f2, relative_a4_ds3_f2, f"{base_dir}/FACS/FACS/A4/", factor=2, frac=2, lin=True, start=0, end=12)
errors_a5, cpredictions_a5, ctargets_a5 = eval_facs(linears_a5_ds3, relative_a5_ds3, f"{base_dir}/FACS/FACS/A5_new/", factor=2, lin=True, start=0, end=12)

### plot cross-validation error for variants

In [None]:
variant_names_a4 = []
for p in os.listdir(f"{base_dir}/FACS/FACS/A4/"):
        kind = p.split("_")[3]
        original = kind[0]
        changed = kind[-1]
        variant_names_a4.append(kind)
variant_names_a5 = []
for p in os.listdir(f"{base_dir}/FACS/FACS/A5_new/"):
        kind = p[:-4].split("_")[3]
        original = kind[0]
        changed = kind[-1]
        variant_names_a5.append(kind)

In [None]:
variant_names = dict(
    a4=variant_names_a4, a5=variant_names_a5, a4_f2=variant_names_a4
)
ctargets = dict(
    a4=ctargets_a4, a5=ctargets_a5, a4_f2=ctargets_a4_f2
)
cpredictions = dict(
    a4=cpredictions_a4, a5=cpredictions_a5, a4_f2=cpredictions_a4_f2
)
errors = dict(
    a4=errors_a4, a5=errors_a5, a4_f2=errors_a4_f2
)
for acrkind in ("a4", "a5", "a4_f2"):
    for idx, name in enumerate(variant_names[acrkind]):
        fig, ax = plt.subplots(figsize=(10, 5))
        ax.set_title(name)
        ax.fill_between(range(25), ctargets[acrkind][idx], alpha=0.5)
        ax.fill_between(range(25), cpredictions[acrkind][idx], alpha=0.5)
        ax.set_ylim(0.0, 0.5)
        plt.tight_layout()
        plt.savefig(f"{base_dir}/error_{acrkind}_{name}.svg")

### plot mean cross-validation error across all variants

In [None]:
for acrkind in ("a4", "a5", "a4_f2"):
    fig, ax = plt.subplots(figsize=(10, 5))
    ax.bar(list(range(len(errors[acrkind]))), [e.detach() for e in errors[acrkind]])
    ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
    ax.set_xticks(range(16))
    ax.set_xticklabels(variant_names[acrkind])
    ax.set_ylim(0.0, 0.1)
    plt.tight_layout()
    plt.savefig(f"{base_dir}/leave_one_out_{acrkind}.svg")
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(list(range(len(errors[acrkind]))), [(en - e).detach() for e, en in zip(errors_a5, errors_a5_no0)])
ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
ax.set_xticks(range(16))
ax.set_xticklabels(variant_names[acrkind])
# ax.set_ylim(0.0, 0.1)
plt.tight_layout()
plt.savefig(f"{base_dir}/leave_one_out_a5_comparison.svg")


## Predict FACS

### A4

In [None]:
pf = predict_facs(linear_a4_ds3.eval(), relative_a4_ds3)
pf_wt = predict_wt(linear_a4_ds3.eval(), relative_a4_ds3)
pf_rel = pf.softmax(dim=-1).detach().numpy()
pf_wt = pf_wt.softmax(dim=0).detach().numpy()
pf_esc = (abs(pf) / abs(pf).sum(dim=-1, keepdim=True)).detach().numpy()
pf_wt_a4 = pf_wt
pf_rel_a4 = pf_rel

#### Mutation tolerance A4

In [None]:
mean_log_fluorescence = (pf_rel * np.arange(25)[None, None, :] * 0.5).sum(axis=-1)
mean_log_fluorescence_wt = (pf_wt * np.arange(25) * 0.5).sum(axis=-1)
for threshold in [0.9, 0.95, 1.0]:
    print(threshold, (mean_log_fluorescence >= mean_log_fluorescence_wt * threshold).sum() / mean_log_fluorescence.size)

#### Correlation with Kd

In [None]:
# correlations
kd_variants = dict(WT=8.31, G38C=14.8, N25G=8.69, E70T=112, E70D=24.5, Y67K=332000, M77A=76)
mean_bin_variants = dict(WT=(pf_wt_a4 * np.arange(len(pf_wt_a4))).sum())
for name in kd_variants:
    if name not in mean_bin_variants:
        pos = int(name[1:-1]) - 1
        val = AA_CODE.index(name[-1])
        pf_variant = pf_rel_a4[pos, val]
        print(pf_variant.shape)
        mean_bin_variants[name] = (pf_variant * np.arange(len(pf_wt))).sum()

In [None]:
sorted_names = sorted(list(kd_variants.keys()))
kd_variants_rv = [1 / kd_variants[name] for name in sorted_names]
mean_bin_variants_v = [mean_bin_variants[name] for name in sorted_names]
fig, ax = plt.subplots()
ax.scatter(mean_bin_variants_v, kd_variants_rv)
for i, txt in enumerate(sorted_names):
    ax.annotate(txt, (mean_bin_variants_v[i], kd_variants_rv[i]))
ax.set_xlabel("mean predicted bin")
ax.set_ylabel("1 / Kd")
plt.savefig(f"{base_dir}/Kd_correlation_a4.svg")

#### Overview plot

In [None]:
import copy
def dot_heatmap(predictions, predictions_wt, confidence, gt, path, scale=0.25):
    pf_rel = predictions
    pf_wt = predictions_wt
    pf_corr = pf_rel - pf_wt[None, None]#pf_rel.mean(axis=(0, 1), keepdims=True)
    pf_corr = pf_corr * np.array(range(pf_corr.shape[-1]))[None, None, :]
    pf_mean = pf_corr.sum(axis=-1) / 2#pf_corr.shape[-1]
    wt = (pf_wt * np.array(range(pf_corr.shape[-1]))).sum() / 2
    print(wt)
    #pf_mean = pf_mean + 0.5
    pmax = -pf_mean.min()
    pmax = max(pmax, pf_mean.max())
    pmin = -pmax
    pmax = pmax + wt
    pmin = pmin + wt
    #pf_mean = pf_mean / pmax

    x = 0.25
    width = pf_rel.shape[0]
    fig, ax = plt.subplots(1, 1, figsize=(width * x, 20 * x), sharex=False, sharey=False)

    data_x = []
    data_y = []
    c = []
    s = []
    for idy in range(20):
        for idx in range(width):
            data_x.append(idx)
            data_y.append(-idy)
            vals = pf_rel[idx, idy]
            mean = pf_mean[idx, idy] + wt
            if AA_CODE.index(gt[idx]) == idy:
                c.append(100)
                s.append(1)
            else:
                c.append(mean)
                s.append(confidence[idx, idy] / 2)
                #s.append((vals * np.log(vals * 8)).sum())

    #print(min(s), max(s))
    s = [
        0.99 * (x - min(s)) / (max(s) - min(s)) + 0.01
        for x in s
    ]
    cmap = cm.bwr

    ca = plt.scatter(data_x, data_y, c=c, s=[sv * 100 for sv in s], cmap=cmap, vmin=pmin, vmax=pmax)
    ca.cmap.set_over("black")
    plt.axis("off")
    plt.colorbar(ca)
    fig.savefig(path)

dot_heatmap(pf_rel_a4, pf_wt_a4, entropy_a4_ds3, gt_a4, f"{base_dir}/large-prediction-dotheatmap-A4-DS1-final.svg", scale=4)

#### amino acid type overview

In [None]:
# position-wise violin plots
cas_positions = [13, 16, 17, 35, 37, 38, 39, 66, 68, 69]
non_cas_positions = [idx for idx in range(87) if idx not in cas_positions]
surface_positions = [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 51, 52, 53, 55, 56, 57, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 77, 78, 79, 80, 81] # TODO
core_positions = [5, 9, 16, 29, 31, 33, 50, 54, 58, 72, 76]
polar = "RNDCEQHKSTWY"
non_polar = "AGILMFPV"
basic = "RHK"
acidic = "DE"
neutral = "ANCQGILMFPSTWYV"
polar_positions = [idx for idx, c in enumerate(gt_a4) if c in polar] # TODO
nonpolar_positions = [idx for idx, c in enumerate(gt_a4) if c in non_polar]
acidic_positions = [idx for idx, c in enumerate(gt_a4) if c in acidic]
basic_positions = [idx for idx, c in enumerate(gt_a4) if c in basic]
neutral_positions = [idx for idx, c in enumerate(gt_a4) if c in neutral]
helix_positions = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 50, 51, 52, 53, 54, 55, 56, 57, 58, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80]
sheet_positions = [28, 29, 30, 31, 32, 39, 40, 41, 42, 43]
loop_positions = [0, 1, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 33, 34, 35, 36, 37, 38, 44, 45, 46, 47, 48, 49, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 81]

In [None]:
split_positions = {}
names = []
values = []
for name, pos in {
    "spyCas9 contact": cas_positions,
    "non-contact": non_cas_positions,
    "surface": surface_positions,
    "core": core_positions,
    "polar": polar_positions,
    "non-polar": nonpolar_positions,
    "acidic": acidic_positions,
    "basic": basic_positions,
    "neutral": neutral_positions,
    "structured": helix_positions + sheet_positions,
    "loop": loop_positions
}.items():
    items = []
    positions = pos
    for pos in positions:
        update = list((pf_rel_a4[pos] * 0.5 * (np.arange(25)[None, :] + 1)).sum(axis=1) - (pf_wt_a4 * (0.5 * np.arange(25) + 0.5)).sum(axis=0))
        update = [item for idx, item in enumerate(update) if entropy_a4_ds3[pos, idx] >= 0.0]
        items += update
    names.extend([name] * len(items))
    values.extend(items)

split_positions = dict(name=names, value=values)

In [None]:
from seaborn import violinplot
fig, ax = plt.subplots(figsize=(20, 5))
violinplot(x="name", y="value", width=1, data=split_positions, ax=ax, inner="quartiles")

count = 0
for l in ax.lines:
    if count % 3 == 1:
        l.set_linestyle('-')
        l.set_color('black')
    else:
        l.set_linestyle(":")
        l.set_color('black')
    count += 1

plt.savefig(f"{base_dir}/acriia4_activity_distribution.svg")

#### Big overview plot (sequencing)

In [None]:
fig, ax = plt.subplots(87, 20, figsize=(40, 180))
for idx in range(87):
    for idy in range(20):
        ax[idx, idy].plot(relative_a4_ds3[0][:, idx, idy])
        ax[idx, idy].set_ylim(0, 0.4)
        ax[idx, idy].set_xticks(range(8))
        ax[idx, idy].set_xticklabels([])
        ax[idx, idy].set_title(f"{idx + 1}: {gt[idx]}{idx + 1}{AA_CODE[idy]}")
plt.tight_layout()
fig.savefig(f"{base_dir}/large-overview-reads-A4-DS3.svg")

#### Big overview plot (predictions)

In [None]:
fig, ax = plt.subplots(87, 20, figsize=(40, 180))
for idx in range(87):
    for idy in range(20):
        ax[idx, idy].plot(pf_rel_a4[idx, idy])
        ax[idx, idy].set_ylim(0, 0.5)
        ax[idx, idy].set_xticklabels([])
        ax[idx, idy].set_title(f"{idx + 1}: {gt[idx]}{idx + 1}{AA_CODE[idy]}")
plt.tight_layout()
fig.savefig(f"{base_dir}/large-overview-prediction-A4.svg")

## Predict KL divergence (trRosetta / Alphafold)

In [None]:
with open(f"{base_dir}/trRosetta/acrin/a4/a4.a3m", "w") as f:
    f.write(f">a4\n")
    f.write(gt + "\n")
for idx in range(len(gt_a4)):
    for a1 in AA_CODE:
        seq = [c for c in gt_a4]
        seq[idx] = a1
        seq = "".join(seq)
        if seq != gt:
            name = f"a4{gt[idx]}{idx}{seq[idx]}"
            with open(f"{base_dir}/trRosetta/acrin/a4/{name}.a3m", "w") as f:
                f.write(f">{name}\n")
                f.write(seq + "\n")

In [None]:
AA_CODE = "ACDEFGHIKLMNPQRSTVWY"
kl_map = torch.zeros(20, 87) * float("nan")
rkl_map = torch.zeros(20, 87) * float("nan")
base = f"{base_dir}/alpha-dms/"
wt = torch.tensor(np.load(f"{base_dir}/alpha-dms/DMSjob_wt_model_1.npy"))
log_wt = wt.log_softmax(dim=2)
for fname in os.listdir(base):
    if fname.endswith(".npy"):
        if "wt" in fname:
            continue
        key = fname.split("_")[1]
        pos = int(key[1:-1])
        aa = AA_CODE.index(key[-1])
        data = torch.tensor(np.load(f"{base}{fname}"))
        log_p = data.log_softmax(dim=2)
        p = log_p.exp()
        kl_div = (p * (log_p - log_wt)).sum(dim=2).sum()
        rkl_div = (log_wt.exp() * (log_wt - log_p)).sum(dim=2).sum()
        kl_map[aa, pos] = kl_div
        rkl_map[aa, pos] = rkl_div
        
with open(f"{base_dir}/trRosetta/a4_trRosetta_kl_divergence.csv") as csv:
    rosetta_kl = np.zeros((20, 87))
    for idx, line in enumerate(csv):
        ln = np.array([float(item) for item in line.strip().split(",")])
        rosetta_kl[:, idx] = ln

In [None]:
import copy
def kl_dot_heatmap(kl_map, predictions, predictions_wt, gt, path, scale=0.25):
    pf_rel = predictions
    pf_wt = predictions_wt
    pf_corr = pf_rel - pf_wt[None, None]
    pf_corr = pf_corr * np.array(range(pf_corr.shape[-1]))[None, None, :]
    pf_mean = pf_corr.sum(axis=-1) / 2
    wt = (pf_wt * np.array(range(pf_corr.shape[-1]))).sum() / 2
    pmax = -pf_mean.min()
    pmax = max(pmax, pf_mean.max())
    pmin = -pmax
    pmax = pmax + wt
    pmin = pmin + wt

    x = 0.25
    width = pf_rel.shape[0]
    fig, ax = plt.subplots(1, 1, figsize=(width * x, 20 * x), sharex=False, sharey=False)

    data_x = []
    data_y = []
    c = []
    s = []
    for idy in range(20):
        for idx in range(width):
            data_x.append(idx)
            data_y.append(-idy)
            vals = pf_rel[idx, idy]
            mean = pf_mean[idx, idy] + wt
            alpha = kl_map[idy, idx]
            if AA_CODE.index(gt[idx]) == idy:
                c.append(alpha)
                s.append(1)
            else:
                c.append(alpha)
                s.append(1)

    s = [
        0.99 * (x - min(s)) / (max(s) - min(s) + 1) + 0.01
        for x in s
    ]
    cmap = cm.Greys

    ca = plt.scatter(data_x, data_y, c=c, s=[sv * 5000 for sv in s], cmap=cmap)
    ca.cmap.set_over("black")
    plt.axis("off")
    plt.colorbar(ca)
    fig.savefig(path)

kl_dot_heatmap(rosetta_kl, pf_rel_a4, pf_wt_a4, gt_a4, f"{base_dir}/large-prediction-dotheatmap-A4-DS1-rosetta.svg", scale=4)
kl_dot_heatmap(kl_map, pf_rel_a4, pf_wt_a4, gt_a4, f"{base_dir}/large-prediction-dotheatmap-A4-DS1-alpha.svg", scale=4)

#### export mean KL for pymol

In [None]:
def get_mean_colors(x, cmap):
    mean = np.nan_to_num(np.array(x), 0.0).mean(axis=0)
    cmap = plt.get_cmap(cmap)
    return list(map(lambda v: (cmap((v - mean.min()) / (mean.max() - mean.min())))[:3], mean))
print(get_mean_colors(rosetta_kl, "Greys"))
print(get_mean_colors(kl_map, "Greys"))

In [None]:
def get_mean_cbar(x, cmap, path):
    mean = np.nan_to_num(np.array(x), 0.0).mean(axis=0)
    cx = plt.matshow(mean[None], cmap=cmap)
    plt.colorbar(cx)
    plt.savefig(path)
    
get_mean_cbar(rosetta_kl, "Greys", f"{base_dir}/cbar_rosetta_kl.svg")
get_mean_cbar(kl_map, "Greys", f"{base_dir}/cbar_alpha_kl.svg")

#### scatter alpha vs rosetta

In [None]:
plt.scatter(kl_map.reshape(-1).log(), np.log(rosetta_kl.reshape(-1)))

#### scatter KL vs predicted activity

In [None]:
def kl_scatter(kl_map, predictions, predictions_wt, gt, path, scale=0.25):
    pf_rel = predictions
    pf_wt = predictions_wt
    pf_corr = pf_rel - pf_wt[None, None]#pf_rel.mean(axis=(0, 1), keepdims=True)
    pf_corr = pf_corr * np.array(range(pf_corr.shape[-1]))[None, None, :]
    pf_mean = pf_corr.sum(axis=-1) / 2#pf_corr.shape[-1]
    wt = (pf_wt * np.array(range(pf_corr.shape[-1]))).sum() / 2
    print(wt)
    #pf_mean = pf_mean + 0.5
    pmax = -pf_mean.min()
    pmax = max(pmax, pf_mean.max())
    pmin = -pmax
    pmax = pmax + wt
    pmin = pmin + wt
    #pf_mean = pf_mean / pmax

    x = 0.25
    width = pf_rel.shape[0]
    fig, ax = plt.subplots(1, 1, figsize=(width * x, 20 * x), sharex=False, sharey=False)

    data_x = []
    data_y = []
    c = []
    s = []
    for idy in range(20):
        for idx in range(width):
            data_x.append(idx)
            data_y.append(-idy)
            vals = pf_rel[idx, idy]
            mean = pf_mean[idx, idy] + wt
            alpha = kl_map[idy, idx]
            c.append(alpha)
            s.append(mean)

    plt.scatter(c, s)
    fig.savefig(path)

kl_scatter(np.log(kl_map), pf_rel_a4, pf_wt_a4, gt_a4, f"{base_dir}/kl-mean-scatter.svg")

#### mean KL vs mean activity

In [None]:
def kl_scatter(kl_map, predictions, predictions_wt, gt, path, scale=0.25):
    pf_rel = predictions
    pf_wt = predictions_wt
    kl_nonan = np.nan_to_num(kl_map, 0.0)
    pf_corr = pf_rel - pf_wt[None, None]#pf_rel.mean(axis=(0, 1), keepdims=True)
    pf_corr = pf_corr * np.array(range(pf_corr.shape[-1]))[None, None, :]
    pf_mean = pf_corr.sum(axis=-1) / 2#pf_corr.shape[-1]
    wt = (pf_wt * np.array(range(pf_corr.shape[-1]))).sum() / 2
    pmax = -pf_mean.min()
    pmax = max(pmax, pf_mean.max())
    pmin = -pmax
    pmax = pmax + wt
    pmin = pmin + wt

    x = 0.25
    width = pf_rel.shape[0]
    fig, ax = plt.subplots(1, 1, figsize=(width * x, 20 * x), sharex=False, sharey=False)

    data_x = []
    data_y = []
    c = []
    s = []
    for idx in range(width):
        data_x.append(idx)
        vals = pf_rel[idx, :]
        mean = pf_mean[idx, :]
        alpha = kl_nonan[:, idx]
        c.append(np.log(alpha.mean(axis=0)))
        s.append(mean.mean(axis=0))

    plt.scatter(c, s)
    fig.savefig(path)

kl_scatter(kl_map, pf_rel_a4, pf_wt_a4, gt_a4, f"{base_dir}/kl-mean-scatter.svg")

#### KL by amino-acid position

In [None]:
# position-wise violin plots
cas_positions = [13, 16, 17, 35, 37, 38, 39, 66, 68, 69]
non_cas_positions = [idx for idx in range(87) if idx not in cas_positions]
surface_positions = [0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 51, 52, 53, 55, 56, 57, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 73, 74, 75, 77, 78, 79, 80, 81] # TODO
core_positions = [5, 9, 16, 29, 31, 33, 50, 54, 58, 72, 76]
polar = "RNDCEQHKSTWY"
non_polar = "AGILMFPV"
basic = "RHK"
acidic = "DE"
neutral = "ANCQGILMFPSTWYV"
polar_positions = [idx for idx, c in enumerate(gt_a4) if c in polar] # TODO
nonpolar_positions = [idx for idx, c in enumerate(gt_a4) if c in non_polar]
acidic_positions = [idx for idx, c in enumerate(gt_a4) if c in acidic]
basic_positions = [idx for idx, c in enumerate(gt_a4) if c in basic]
neutral_positions = [idx for idx, c in enumerate(gt_a4) if c in neutral]
helix_positions = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 50, 51, 52, 53, 54, 55, 56, 57, 58, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80]
sheet_positions = [28, 29, 30, 31, 32, 39, 40, 41, 42, 43]
loop_positions = [0, 1, 13, 14, 15, 16, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 33, 34, 35, 36, 37, 38, 44, 45, 46, 47, 48, 49, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 81]

In [None]:
len(nonpolar_positions)

In [None]:
def plot_kl_violin(kl):
    split_positions = {}
    names = []
    values = []
    for name, pos in {
        "spyCas9 contact": cas_positions,
        "non-contact": non_cas_positions,
        "surface": surface_positions,
        "core": core_positions,
        "polar": polar_positions,
        "non-polar": nonpolar_positions,
        "acidic": acidic_positions,
        "basic": basic_positions,
        "neutral": neutral_positions,
        "structured": helix_positions + sheet_positions,
        "loop": loop_positions
    }.items():
        items = []
        positions = pos
        for pos in positions:
            items += list(map(float, -np.log(kl[:, pos][kl[:, pos] > 0.0])))
        names.extend([name] * len(items))
        values.extend(items)

    split_positions = dict(name=names, value=values)
    fig, ax = plt.subplots(figsize=(20, 5))
    violinplot(x="name", y="value", width=1, data=split_positions, ax=ax, inner="quartiles")

    count = 0
    for l in ax.lines:
        if count % 3 == 1:
            l.set_linestyle('-')
            l.set_color('black')
        else:
            l.set_linestyle(":")
            l.set_color('black')
        count += 1
    return ax

In [None]:
plot_kl_violin(np.nan_to_num(kl_map, 0.0))
plt.savefig(f"{base_dir}/acriia4_alpha_kl_distribution.svg")
plot_kl_violin(np.nan_to_num(rosetta_kl, 0.0))
plt.savefig(f"{base_dir}/acriia4_tr_kl_distribution.svg")

#### scatter KL vs activity for AA type

In [None]:
pos_dict = {
    "spyCas9 contact": cas_positions,
    "non-contact": non_cas_positions,
    "surface": surface_positions,
    "core": core_positions,
    "polar": polar_positions,
    "non-polar": nonpolar_positions,
    "acidic": acidic_positions,
    "basic": basic_positions,
    "neutral": neutral_positions,
    "structured": helix_positions + sheet_positions,
    "loop": loop_positions
}

def plot_kl_scatter(kl, res, select=None):
    split_positions = {}
    names = []
    kl_values = []
    fluorescence = []
    pos_dict = {
        "spyCas9 contact": cas_positions,
        "non-contact": non_cas_positions,
        "surface": surface_positions,
        "core": core_positions,
        "polar": polar_positions,
        "non-polar": nonpolar_positions,
        "acidic": acidic_positions,
        "basic": basic_positions,
        "neutral": neutral_positions,
        "structured": helix_positions + sheet_positions,
        "loop": loop_positions
    }
    if select is not None:
        pos_dict = {
            key: pos_dict[key]
            for key in select
        }
    for name, pos in pos_dict.items():
        kl_items = []
        fl_items = []
        positions = pos
        for pos in positions:
            kl_items += list(map(float, -np.log(kl[:, pos][kl[:, pos] > 0.0])))
            fl_items += list((res[pos, :, :] * np.arange(25)[None, :] * 0.5).sum(axis=-1)[kl[:, pos] > 0.0])
        names.extend([name] * len(kl_items))
        kl_values.extend(kl_items)
        fluorescence.extend(fl_items)
        
    split_positions = dict(name=names, value=kl_values)
    fig, ax = plt.subplots(figsize=(20, 5))
    ax.scatter(kl_values, fluorescence, c=[list(pos_dict.keys()).index(name) for name in names], cmap="tab10")
    return ax
for pos in pos_dict:
    plot_kl_scatter(kl_map, pf_rel_a4, select=[pos])
    filepos = pos.replace(" ", "_").replace("-", "_")
    plt.savefig(f"f"{base_dir}/alphafold_scatter_{filepos}.svg")

In [None]:
from seaborn import violinplot, boxplot

positions = []
values = []
for position in range(87):
    positions += 20 * [position + 1]
    values += list(map(float, -kl_map[:, position].log()))
sp = dict(name=positions, value=values)
fig, ax = plt.subplots(figsize=(20, 5))
boxplot(x="name", y="value", width=1, data=sp, ax=ax)

count = 0
for l in ax.lines:
    if count % 3 == 1:
        l.set_linestyle('-')
        l.set_color('black')
    else:
        l.set_linestyle(":")
        l.set_color('black')
    count += 1

plt.savefig(f"{base_dir}/acriia4_alpha_kl_distribution_per_position.svg")

#### correlation mean fluorescence + cell culture

In [None]:
cell_culture = dict(
    wt=(0.958145856666667, 0.55904086),
    K18M=(0.82442277, 0.40304539),
    G21Q=(0.80654383, 0.32298049),
    S24K=(0.773889983333333, 0.375188933333333),
    S24P=(0.706929726666667, 0.30560112),
    N25G=(0.820051383333333, 0.376600463333333),
    I31Q=(0.83000416, 0.386209923333333),
    E40I=(0.731740993333333, 0.33765131),
    E70T=(0.069560613333333, 0.035723183333333),
    M77A=(0.174492683333333, 0.084132156666667)
)

kind = 0
names = ["wt"]
cc = [cell_culture["wt"][kind]]
value = [(pf_wt_a4 * np.arange(25) * 0.5).sum()]
for name in cell_culture:
    if name != "wt":
        pos = int(name[1:-1]) - 1
        aa = AA_CODE.index(name[-1])
        val = (pf_rel_a4[pos, aa] * np.arange(25) * 0.5).sum()
        value.append(val)
        cc.append(cell_culture[name][kind])
        names.append(name)

fig, ax = plt.subplots()
ax.scatter(value, cc)

for i, txt in enumerate(names):
    ax.annotate(txt, (value[i], cc[i]))
    
plt.savefig(f"{base_dir}/acriia4_fluorescence_vs_cell_culture.svg")

#### correlation mean fluorescence to prior work

In [None]:
drive_activity_basgall = [0.027, 0.074, 0.015, 0.015, 0.084, 0.344, 0.845, 0.020, 0.000, 0.015, 0.113, 0.885, 0.983, 0.000, 0.595, 0.000, 0.432]
activity_basgall = [-np.log(v + 1e-3) for v in drive_activity_basgall]
names_basgall = ["D14A", "D23R", "N36A", "D37A", "G38A", "N39A", "N39R", "E40A", "N48A", "D69A", "D69R", "E70A", "E70R", "E72A", "F73A", "D76A", "M77A"]
activity_dong = [1.067, 0.216, 0.043, 0.062, 0.056, 0.048, 0.038, 0.130, 0.022]
names_dong = ["N12T", "D14R", "D23R", "N36Y", "G38A", "N39R", "E40R", "D69R", "E70R"]
activity_dong = [np.log(v) for v in activity_dong]

pred_log_basgall = [
    (pf_rel[int(name[1:-1]) - 1, AA_CODE.index(name[-1])] * np.arange(25) * 0.5).sum()
    for name in names_basgall
]
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
ax[0].scatter(pred_log_basgall, activity_basgall)

for i, txt in enumerate(names_basgall):
    ax[0].annotate(txt, (pred_log_basgall[i], activity_basgall[i]))


pred_log_dong = [
    (pf_rel[int(name[1:-1]) - 1, AA_CODE.index(name[-1])] * np.arange(25) * 0.5).sum()
    for name in names_dong
]
ax[1].scatter(pred_log_dong, activity_dong)

for i, txt in enumerate(names_dong):
    ax[1].annotate(txt, (pred_log_dong[i], activity_dong[i]))

plt.savefig(f"{base_dir}/correlation_prior_work.svg")

#### R^2 for log-log fit

In [None]:
print("R^2 Basgall et al.", np.corrcoef(np.array(pred_log_basgall), np.array(activity_basgall))[0, 1] ** 2)
print("R^2 Dong et al.", np.corrcoef(np.array(pred_log_dong), np.array(activity_dong))[0, 1] ** 2)
print("R Basgall et al.", np.corrcoef(np.array(pred_log_basgall), np.array(activity_basgall))[0, 1])
print("R Dong et al.", np.corrcoef(np.array(pred_log_dong), np.array(activity_dong))[0, 1])

### A5

In [None]:
# a5 spy
pf = predict_facs(linear_a5_ds3.eval(), relative_a5_ds3)
pf_wt = predict_wt(linear_a5_ds3.eval(), relative_a5_ds3)
pf_wt = pf_wt.softmax(dim=0).detach().numpy()
pf_rel = pf.softmax(dim=-1).detach().numpy()
pf_esc = (abs(pf) / abs(pf).sum(dim=-1, keepdim=True)).detach().numpy()
pf_wt_a5 = pf_wt
pf_rel_a5 = pf_rel

#### Mutation tolerance A5

In [None]:
mean_log_fluorescence = (pf_rel * np.arange(25)[None, None, :] * 0.5).sum(axis=-1)
mean_log_fluorescence_wt = (pf_wt * np.arange(25) * 0.5).sum(axis=-1)
for threshold in [0.9, 0.95, 1.0]:
    print(threshold, (mean_log_fluorescence >= mean_log_fluorescence_wt * threshold).sum() / mean_log_fluorescence.size)

#### Overview plot

In [None]:
dot_heatmap(pf_rel_a5, pf_wt_a5, entropy_a5_ds3, gt_a5, f"{base_dir}/large-prediction-dotheatmap-A5-DS3.svg", scale=4)

#### overview reads

In [None]:
fig, ax = plt.subplots(len(gt_a5), 20, figsize=(40, 180))
for idx in range(len(gt_a5)):
    for idy in range(20):
        ax[idx, idy].plot(relative_a5_ds3[0][:, idx, idy])
        ax[idx, idy].set_ylim(0, 0.4)
        ax[idx, idy].set_xticks(range(8))
        ax[idx, idy].set_xticklabels([])
        ax[idx, idy].set_title(f"{idx + 1}: {gt[idx]}{idx + 1}{AA_CODE[idy]}")
plt.tight_layout()
fig.savefig(f"{base_dir}/large-overview-reads-A5-DS3.svg")

#### overview prediction

In [None]:
fig, ax = plt.subplots(len(gt_a5), 20, figsize=(40, 180))
for idx in range(len(gt_a5)):
    for idy in range(20):
        ax[idx, idy].plot(pf_rel_a5[idx, idy])
        ax[idx, idy].set_ylim(0, 0.5)
        ax[idx, idy].set_xticklabels([])
        ax[idx, idy].set_title(f"{idx + 1}: {gt_a5[idx]}{idx + 1}{AA_CODE[idy]}")
plt.tight_layout()
fig.savefig(f"{base_dir}/large-overview-prediction-A5.svg")

#### amino-acid type overview

In [None]:
# position-wise violin plots
surface_positions = [0, 1, 2, 4, 5, 6, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 31, 33, 34, 37, 38, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 52, 54, 55, 56, 57, 58, 59, 61, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 81, 82, 85, 86, 87, 89, 92, 95, 96, 99, 100, 102, 103, 104, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 117, 118, 119, 120, 122, 123, 124, 126, 127, 129, 130, 132, 133, 134, 135, 136, 137, 138, 139]
core_positions = [3, 7, 29, 32, 35, 36, 39, 50, 51, 53, 60, 62, 63, 80, 83, 84, 88, 90, 91, 93, 94, 97, 98, 101, 105, 116, 121, 125, 128, 131]
polar = "RNDCEQHKSTWY"
non_polar = "AGILMFPV"
basic = "RHK"
acidic = "DE"
neutral = "ANCQGILMFPSTWYV"
polar_positions = [idx for idx, c in enumerate(gt_a5) if c in polar]
nonpolar_positions = [idx for idx, c in enumerate(gt_a5) if c in non_polar]
acidic_positions = [idx for idx, c in enumerate(gt_a5) if c in acidic]
basic_positions = [idx for idx, c in enumerate(gt_a5) if c in basic]
neutral_positions = [idx for idx, c in enumerate(gt_a5) if c in neutral]
helix_positions = [4, 5, 6, 7, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 110, 111, 112, 113, 114, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 135, 136, 137, 138]
sheet_positions = [42, 43, 44, 50, 51, 52, 53, 54, 59, 60, 61, 62, 63]
loop_positions = [0, 1, 2, 3, 13, 14, 15, 16, 17, 18, 39, 40, 41, 45, 46, 47, 48, 49, 55, 56, 57, 58, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 109, 115, 116, 117, 131, 132, 133, 134, 139]
idr_positions = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18]

In [None]:
split_positions = {}
names = []
values = []
for name, pos in {
    "surface": surface_positions,
    "core": core_positions,
    "polar": polar_positions,
    "non-polar": nonpolar_positions,
    "acidic": acidic_positions,
    "basic": basic_positions,
    "neutral": neutral_positions,
    "structured": helix_positions + sheet_positions,
    "loop": loop_positions,
    "idr": idr_positions
}.items():
    items = []
    positions = pos
    for pos in positions:
        update = list((pf_rel_a5[pos] * 0.5 * (np.arange(25)[None, :] + 1)).sum(axis=1) - (pf_wt_a5 * (0.5 * np.arange(25) + 0.5)).sum(axis=0))
        update = [item for idx, item in enumerate(update) if entropy_a5_ds3[pos, idx] > 0.6]
        items += update
    names.extend([name] * len(items))
    values.extend(items)

split_positions = dict(name=names, value=values)

In [None]:
from seaborn import violinplot
fig, ax = plt.subplots(figsize=(20, 5))
violinplot(x="name", y="value", width=1, data=split_positions, ax=ax, inner="quartiles")

count = 0
for l in ax.lines:
    if count % 3 == 1:
        l.set_linestyle('-')
        l.set_color('black')
    else:
        l.set_linestyle(":")
        l.set_color('black')
    count += 1

plt.savefig(f"{base_dir}/acriia5_activity_distribution.svg")