In [1]:
import os

os.chdir("../..")
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tqdm.auto as tqdm
import torch

from modular_splicing.dataset.datafile_object import SpliceAIDatafile

from modular_splicing.data_pipeline.create_dataset import create_datapoints
from modular_splicing.models_for_testing.list import AM

from modular_splicing.motif_names import get_motif_names

In [3]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [4]:
CL_max = 10_000

In [5]:
def load_gene(name):
    dfile = SpliceAIDatafile.load("datafile_train_all.h5")
    [[idx]] = np.where(dfile.names == name)
    assert dfile.datafile["STRAND"][idx] == b"-"
    x, [y] = create_datapoints(
        dfile.datafile["SEQ"][idx],
        dfile.datafile["STRAND"][idx],
        tx_start=dfile.datafile["TX_START"][idx],
        tx_end=dfile.datafile["TX_END"][idx],
        jn_start=dfile.datafile["JN_START"][idx],
        jn_end=dfile.datafile["JN_END"][idx],
        SL=1_000_000,
        CL_max=CL_max,
    )
    x, y = np.concatenate(x), np.concatenate(y)
    return x, y, dfile.starts[idx], dfile.ends[idx]


def kth_gene(y, k):
    pos, _ = np.where(y[:, 1:])
    # first intron is (0, 1)
    # second exon is (1, 2)
    # kth exon starts at 2 * k - 3
    return pos[2 * k - 3 : 2 * k - 1]


def context_around_allele(x, y, *, allele, which_exon, cl, sl, s, e):
    extra = sl // 2
    center = e - allele
    exon = kth_gene(y, which_exon) - (center - extra)
    x, y = (
        x[CL_max // 2 + center - extra - cl // 2 : CL_max // 2 + cl // 2 + center + extra],
        y[center - extra : center + extra],
    )
    return x, y, exon

def case_study_data(*, gene_name, which_exon, allele, cl, sl):
    x, y, s, e = load_gene(gene_name)
    return context_around_allele(
        x=x,
        y=y,
        allele=allele,
        which_exon=which_exon,
        sl=sl,
        cl=cl,
        s=s, e=e
    )

In [6]:
x, y, exon = case_study_data(
    gene_name="OLR1", which_exon=5, allele=10313134, cl=400, sl=5000
)
x_updated = x.copy()
# G > A ==> C > T on the - strand
x_updated[x.shape[0] // 2] = [0, 0, 0, 1]

In [7]:
from modular_splicing.utils.sequence_utils import draw_bases
draw_bases(x.argmax(-1))[x.shape[0]//2-10:x.shape[0]//2+10]

'TTTGTGGATCCAACACTAAC'

In [8]:
am = AM.non_binarized_model(1).model.cpu()

In [9]:
def run_on_seq(am, x, exon):
    with torch.no_grad():
        res = am(torch.tensor([x]).float(), collect_intermediates=True)
        out = res["output"].softmax(-1).numpy()[0]
        motifs = res["post_sparse_motifs_only"].numpy()[0]
    return out[exon, [1, 2]], motifs

In [10]:
out, motifs = run_on_seq(am, x, exon)

In [11]:
out_mut, motifs_mut = run_on_seq(am, x_updated, exon)

In [12]:
print("Without mutation:", f"A={out[0]:.2%}", f"D={out[1]:.2%}")
print("With mutation:   ", f"A={out_mut[0]:.2%}", f"D={out_mut[1]:.2%}")

Without mutation: A=87.41% D=95.72%
With mutation:    A=82.50% D=96.80%


In [13]:
names = np.array(get_motif_names("rbns"))

In [14]:
diff_pos, diff_mot_id = np.where(motifs != motifs_mut)
diff_mot_names = names[diff_mot_id]

In [15]:
diff_mot_names

array(['SRSF5', 'RBM6', 'ZNF326'], dtype='<U9')