In [1]:
from torchtext import datasets, data
import numpy as np
import os, sys
from time import time

sys.path.append("../1. madex")

from neural_interaction_detection import *
from sampling_and_inference import *
from utils.dna_utils import *

%matplotlib inline

import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

device = torch.device("cuda:0")

## Load Model

In [2]:
model = load_dna_model("utils/pretrained/dna_cnn.pt").to(device)

## Get DNA Sequence

In [3]:
np.random.seed(42)
seq_instance = generate_random_dna_sequence_with_CACGTG()
print(seq_instance, "CACGTG" in seq_instance)

GTAGGTAAGCGCACGTGTTGCACTTCCCTTAATCCA True


## Run MADEX

In [4]:
data_inst = {"orig": seq_instance, "vectorizer": encode_dna_onehot}
Xs, Ys = generate_perturbation_dataset_dna(data_inst, model, device, seed=42)

100%|██████████| 60/60 [00:02<00:00, 29.46it/s]


In [5]:
t0 = time()
interactions, mlp_loss = detect_interactions(Xs, Ys, weight_samples=False, seed=42, verbose=False, add_linear=False)
print("{} test loss, {} seconds elapsed".format(round(mlp_loss, 4), round(time() - t0, 1)))

0.0046 test loss, 16.0 seconds elapsed


In [6]:
print("interaction ranking", "\n")
for rank, inter in enumerate(interactions[:10]):
    inter_indices, _ = inter
    inter_verbose = tuple((seq_instance[s], s) for s in inter_indices)

    inter_nucleotides, _ = zip(*inter_verbose)
    if "".join(inter_nucleotides) == "CACGTG" and all(np.diff(inter_indices) == 1):
        postfix = "found CACGTG >>"
    else:
        postfix = ""
    print(rank+1, postfix, tuple(a + "_" + str(b) for a,b in inter_verbose))



interaction ranking 

1 found CACGTG >> ('C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16')
2  ('A_21', 'C_25')
3  ('C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16', 'A_21')
4  ('C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16', 'T_18')
5  ('A_21', 'C_25', 'C_26')
6  ('A_21', 'T_23', 'C_25', 'C_26')
7  ('A_21', 'T_23', 'C_25', 'C_26', 'T_28')
8  ('A_2', 'C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16', 'T_18')
9  ('A_2', 'A_6', 'C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16', 'T_18')
10  ('A_2', 'A_6', 'C_11', 'A_12', 'C_13', 'G_14', 'T_15', 'G_16', 'T_18', 'C_20')
