[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mitiau/DNABERT-Z/blob/main/ZDNA-prediction.ipynb)

# Install dependecies and define helper functions

In [1]:
'''!pip install transformers
!pip install biopython
!pip install transformers_interpret
!pip install numpy==1.25.0
!pip install shap
!python --version
!pip list'''

'!pip install transformers\n!pip install biopython\n!pip install transformers_interpret\n!pip install numpy==1.25.0\n!pip install shap\n!python --version\n!pip list'

In [2]:
import torch
from torch import nn
import transformers
from transformers import BertTokenizer, BertForTokenClassification
import numpy as np
from Bio import SeqIO
from io import StringIO, BytesIO
#from google.colab import drive, files
from tqdm import tqdm
import pickle
import scipy
from scipy import ndimage

In [3]:
def seq2kmer(seq, k):
    if len(seq) < k:
        return seq
        
    kmer = [seq[x:x+k] for x in range(len(seq)+1-k)]
    return kmer

def split_seq(seq, length = 512, pad = 16):
    res = []
    for st in range(0, len(seq), length - pad):
        end = min(st+512, len(seq))
        res.append(seq[st:end])
    return res

def stitch_np_seq(np_seqs, pad = 16):
    res = np.array([])
    for seq in np_seqs:
        res = res[:-pad]
        res = np.concatenate([res,seq])
    return res

In [4]:
'''!git clone https://github.com/Nazar1997/Sparse_vector.git
!git clone https://github.com/vladislareon/z_dna'''

'!git clone https://github.com/Nazar1997/Sparse_vector.git\n!git clone https://github.com/vladislareon/z_dna'

# Select model and parameters

In [5]:
model = 'HG kouzine' #@param ["HG chipseq", "HG kouzine", "MM chipseq", "MM kouzine"]
model_confidence_threshold = 0.5 #@param {type:"number"}
minimum_sequence_length = 10 #@param {type:"integer"}

In [6]:

if model == 'HG chipseq':
    model_id = '1VAsp8I904y_J0PUhAQqpSlCn1IqfG0FB'
elif model == 'HG kouzine':
    model_id = '1dAeAt5Gu2cadwDhbc7OnenUgDLHlUvkx'
elif model == 'MM curax':
    model_id = '1W6GEgHNoitlB-xXJbLJ_jDW4BF35W1Sd'
elif model == 'MM kouzine':
    model_id = '1dXpQFmheClKXIEoqcZ7kgCwx6hzVCv3H'


In [7]:
'''!gdown $model_id
!gdown 10sF8Ywktd96HqAL0CwvlZZUUGj05CGk5
!gdown 16bT7HDv71aRwyh3gBUbKwign1mtyLD2d
!gdown 1EE9goZ2JRSD8UTx501q71lGCk-CK3kqG
!gdown 1gZZdtAoDnDiLQqjQfGyuwt268Pe5sXW0'''


'''!mkdir 6-new-12w-0
!mv pytorch_model.bin 6-new-12w-0/
!mv config.json 6-new-12w-0/
!mv special_tokens_map.json 6-new-12w-0/
!mv tokenizer_config.json 6-new-12w-0/
!mv vocab.txt 6-new-12w-0/'''

'!mkdir 6-new-12w-0\n!mv pytorch_model.bin 6-new-12w-0/\n!mv config.json 6-new-12w-0/\n!mv special_tokens_map.json 6-new-12w-0/\n!mv tokenizer_config.json 6-new-12w-0/\n!mv vocab.txt 6-new-12w-0/'

In [8]:
from transformers import AutoConfig
from typing import Optional
class ParallelBert(BertForTokenClassification):
    def __init__(self):
        self.config = AutoConfig.from_pretrained("/6-new-12w-0/")
        super().__init__(self.config)
        
        self.model = nn.DataParallel(BertForTokenClassification.from_pretrained('/6-new-12w-0/'))
    def forward(self, inp, attn_mask=Optional[torch.Tensor]):
        return self.model(inp, attn_mask)

In [9]:
tokenizer = BertTokenizer.from_pretrained('/home/arulybin/6-new-12w-0/')
model = BertForTokenClassification.from_pretrained('/home/arulybin/6-new-12w-0/')
model = model.cuda()

# Predict and save raw outputs

In [10]:
'''
from transformers_interpret import MultiLabelClassificationExplainer
out = []
ans = []
mce = MultiLabelClassificationExplainer(model, tokenizer)
for key in uploaded.keys():
    print(key)
    out.append(key)
    result_dict = {}
    for seq_record in SeqIO.parse(StringIO(BytesIO(uploaded[key]).read().decode('UTF-8')), 'fasta'):
        kmer_seq = seq2kmer(str(seq_record.seq).upper(), 6)
        seq_pieces = split_seq(kmer_seq)
        #print(seq_record.name)
        out.append(seq_record.name)
        with torch.no_grad():
            preds = []
            for seq_piece in seq_pieces:
                input_ids = torch.LongTensor(tokenizer.encode(' '.join(seq_piece), add_special_tokens=False))
                print(input_ids.cuda().unsqueeze(0))
                break
                outputs = torch.softmax(model(input_ids.cuda().unsqueeze(0))[-1],axis = -1)[0,:,1]
                preds.append(outputs.cpu().numpy())
                print(mce(' '.join(seq_piece)))
        result_dict[seq_record.name] = stitch_np_seq(preds)

        labeled, max_label = scipy.ndimage.label(result_dict[seq_record.name]>model_confidence_threshold)
        ans.append(np.any(labeled))
        #print('  start     end')
        out.append('  start     end')
        for label in range(1, max_label+1):
            candidate = np.where(labeled == label)[0]
            candidate_length = candidate.shape[0]
            if candidate_length>minimum_sequence_length:
                #print('{:8}'.format(candidate[0]), '{:8}'.format(candidate[-1]))
                out.append('{:8}'.format(candidate[0]) + '{:8}'.format(candidate[-1]))

    with open(key + '.preds.pkl',"wb") as fh:
      pickle.dump(result_dict, fh)
    print()

with open('text_predictions.txt',"w") as fh:
    for item in out:
        fh.write("%s\n" % item)
'''

'\nfrom transformers_interpret import MultiLabelClassificationExplainer\nout = []\nans = []\nmce = MultiLabelClassificationExplainer(model, tokenizer)\nfor key in uploaded.keys():\n    print(key)\n    out.append(key)\n    result_dict = {}\n    for seq_record in SeqIO.parse(StringIO(BytesIO(uploaded[key]).read().decode(\'UTF-8\')), \'fasta\'):\n        kmer_seq = seq2kmer(str(seq_record.seq).upper(), 6)\n        seq_pieces = split_seq(kmer_seq)\n        #print(seq_record.name)\n        out.append(seq_record.name)\n        with torch.no_grad():\n            preds = []\n            for seq_piece in seq_pieces:\n                input_ids = torch.LongTensor(tokenizer.encode(\' \'.join(seq_piece), add_special_tokens=False))\n                print(input_ids.cuda().unsqueeze(0))\n                break\n                outputs = torch.softmax(model(input_ids.cuda().unsqueeze(0))[-1],axis = -1)[0,:,1]\n                preds.append(outputs.cpu().numpy())\n                print(mce(\' \'.join(seq_

In [11]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score, f1_score
from IPython.display import clear_output

from joblib import load
from tqdm import trange
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import StratifiedKFold
from torch.utils import data

ZDNA = load("/home/arulybin/ZDNA_cousine.pkl")

In [12]:
class Dataset(data.Dataset):
    def __init__(self, chroms,
                 dna_source,
                 labels_source, intervals, lrp_feat=[]):
        self.chroms = chroms
        #self.features = features
        self.dna_source = dna_source
        #self.features_source = features_source
        self.labels_source = labels_source
        self.intervals = intervals
        self.le = LabelBinarizer().fit(np.array([["A"], ["C"], ["T"], ["G"]]))
        self.lrp_feat = lrp_feat

    def __len__(self):
        return len(self.intervals)

    def __getitem__(self, index):
        interval = self.intervals[index]
        chrom = interval[0]
        begin = int(interval[1])
        end = int(interval[2])
        dna_OHE = self.le.transform(list(self.dna_source[chrom][begin:end].upper()))
        
        dna_letters = list(self.dna_source[chrom][begin:end + 5].upper())

        #X = dna_OHE.astype(np.float32)
        X = "".join(dna_letters)
        y = self.labels_source[interval[0]][interval[1]: interval[2]]
        if len(self.lrp_feat) > 0:
            X = X[:,np.sort(self.lrp_feat)]

        return (X, y)

In [13]:
width = 100

np.random.seed(10)

ints_in = []
ints_out = []
chrom_names = [f'chr{i}' for i in list(range(1, 23)) + ['X', 'Y','M']]

for chrm in chrom_names:
    for st in trange(0, ZDNA[chrm].shape - width, width):
        interval = [st, min(st + width, ZDNA[chrm].shape)]
        if ZDNA[chrm][interval[0]: interval[1]].any():
            ints_in.append([chrm, interval[0], interval[1]])
        else:
            ints_out.append([chrm, interval[0], interval[1]])

ints_in = np.array(ints_in)
ints_out = np.array(ints_out)[np.random.choice(range(len(ints_out)), size=len(ints_in) * 3, replace=False)]

100%|██████████| 2489564/2489564 [00:34<00:00, 72681.40it/s]
100%|██████████| 2421935/2421935 [00:32<00:00, 74253.51it/s]
100%|██████████| 1982955/1982955 [00:26<00:00, 75236.28it/s]
100%|██████████| 1902145/1902145 [00:25<00:00, 75974.82it/s]
100%|██████████| 1815382/1815382 [00:23<00:00, 75641.65it/s]
100%|██████████| 1708059/1708059 [00:23<00:00, 73702.39it/s]
100%|██████████| 1593459/1593459 [00:20<00:00, 78133.87it/s]
100%|██████████| 1451386/1451386 [00:20<00:00, 70240.04it/s]
100%|██████████| 1383947/1383947 [00:17<00:00, 79610.24it/s]
100%|██████████| 1337974/1337974 [00:17<00:00, 78193.62it/s]
100%|██████████| 1350866/1350866 [00:19<00:00, 70167.16it/s]
100%|██████████| 1332753/1332753 [00:16<00:00, 79237.60it/s]
100%|██████████| 1143643/1143643 [00:14<00:00, 78210.36it/s]
100%|██████████| 1070437/1070437 [00:13<00:00, 80370.16it/s]
100%|██████████| 1019911/1019911 [00:16<00:00, 62514.64it/s]
100%|██████████| 903383/903383 [00:11<00:00, 76534.17it/s]
100%|██████████| 832574/83

In [14]:
np.random.seed(42)
equalized = np.vstack((ints_in, ints_out))
equalized = [[inter[0], int(inter[1]), int(inter[2])] for inter in equalized]

train_inds, test_inds = next(StratifiedKFold().split(equalized, [f"{int(i < 400)}_{elem[0]}"
                                                                 for i, elem
                                                                 in enumerate(equalized)]))

train_intervals, test_intervals = [equalized[i] for i in train_inds], [equalized[i] for i in test_inds]

In [15]:
def chrom_reader(chrom):
    files = sorted([i for i in os.listdir(f'/home/arulybin/z_dna/hg38_dna/') if f"{chrom}_" in i])
    return ''.join([load(f"z_dna/hg38_dna/{file}") for file in files])
DNA = {chrom:chrom_reader(chrom) for chrom in tqdm(chrom_names)}

  0%|          | 0/25 [00:00<?, ?it/s]

In [16]:
np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}


train_dataset = Dataset(chrom_names, 
                       DNA, 
                       ZDNA, train_intervals, lrp_feat = [])

test_dataset = Dataset(chrom_names, 
                       DNA, 
                       ZDNA, test_intervals, lrp_feat = [])

len(test_dataset)

36161

In [None]:
from transformers_interpret import TokenClassificationExplainer
from IPython.display import clear_output
from time import time

mce = TokenClassificationExplainer(model, tokenizer)

TP, FP, TN, FN = 0, 0, 0, 0
num_seqs = 100000

res_dict = {}
counter = {}
ts = time()
for i in range(len(test_dataset)):
    sequence, labels = test_dataset[i]
    seq_piece = seq2kmer(sequence, 6) 
    
    with torch.no_grad():
        input_ids = torch.LongTensor(tokenizer.encode(seq_piece, add_special_tokens=False))
        outputs = (torch.softmax(model(input_ids.cuda().unsqueeze(0))[-1],axis = -1)[0,:,1] > model_confidence_threshold).to(torch.int32).cpu()
        
        tp_res = ((outputs == 1) & (labels == 1))
        tp_vals = {}
        for res_i, res in enumerate(tp_res):
            if res:
                tp_vals[seq_piece[res_i]] = 1
                
        interpret_res = mce(" ".join(seq_piece), ignored_labels=["LABEL_0"])
        for k in interpret_res.keys():
            if k in tp_vals:
                for interpr_pair in interpret_res[k]["attribution_scores"]:
                    res_dict[interpr_pair[0]] = res_dict.get(interpr_pair[0], 0) + interpr_pair[1]
                    counter[interpr_pair[0]] = counter.get(interpr_pair[0], 0) + 1
        
        TP += ((outputs == 1) & (labels == 1)).sum().item()
        FP += ((outputs == 1) & (labels == 0)).sum().item()
        TN += ((outputs == 0) & (labels == 0)).sum().item()
        FN += ((outputs == 0) & (labels == 1)).sum().item()
    
        if (i + 1) % 1 == 0:
            clear_output(wait=True)
            print(f"Обработано {i + 1} последовательностей за {time() - ts} секунд")
        num_seqs -= 1
        if num_seqs <= 0:
            break

precision = TP/(TP + FP)
recall = TP/(TP + FN)
F1_score = 2*precision*recall/(precision + recall)

for key in res_dict:
    res_dict[key] /= counter[key]

print(f"precision: {precision}, recall: {recall}, F_score: {F1_score}")


Обработано 4340 последовательностей за 14265.589334964752 секунд


In [None]:
'''import shap
from time import time
from IPython.display import clear_output

def f(x):
    out_tensors = []
    with torch.no_grad():
        for seq in x:
            input_ids = torch.LongTensor(tokenizer.encode(seq, add_special_tokens=False)).cuda()
            outputs = torch.softmax(model(input_ids.cuda().unsqueeze(0))[-1],axis = -1)[:, :, -1].cuda()
            out_tensors.append(outputs)
    out = torch.cat(out_tensors, dim=0)
    return out

TP, FP, TN, FN = 0, 0, 0, 0
num_seqs = 15000

spec_tokens = ["[CLS]", "[SEP]"]
res_dict = {}
counter = {}
ts = time()
explainer = shap.Explainer(f, tokenizer)

for i in range(len(test_dataset)):
    sequence, labels = test_dataset[i]
    seq_piece = seq2kmer(sequence, 6) 
    
    with torch.no_grad():
        input_ids = torch.LongTensor(tokenizer.encode(seq_piece, add_special_tokens=False))
        outputs = (torch.softmax(model(input_ids.cuda().unsqueeze(0))[-1],axis = -1)[0,:,1] > model_confidence_threshold).to(torch.int32).cpu()
        shap_values = explainer([" ".join(seq_piece)])
        tp_res = ((outputs == 1) & (labels == 1))
        
        tp_shap_values_sum = shap_values.values.squeeze(0)[:, tp_res].sum(axis = 1)
        data_tokens = shap_values.data[0]
        spec_counter = 0
        for j, el in enumerate(data_tokens):
            if len(el) > 0 and el[0] == " ":
                data_tokens[j] = el[1:]
            elif len(el) == 0:
                data_tokens[j] = spec_tokens[spec_counter]
                spec_counter += 1
        
        for j, seq in enumerate(data_tokens):
            res_dict[seq] = res_dict.get(seq, 0) + tp_shap_values_sum[j]
            counter[seq] = counter.get(seq, 0) + tp_res.sum().item()
        
        TP += ((outputs == 1) & (labels == 1)).sum().item()
        FP += ((outputs == 1) & (labels == 0)).sum().item()
        TN += ((outputs == 0) & (labels == 0)).sum().item()
        FN += ((outputs == 0) & (labels == 1)).sum().item()
    
        if (i + 1) % 1 == 0:
            clear_output(wait=True)
            print(f"Обработано {i + 1} последовательностей за {time() - ts} секунд")
            
        num_seqs -= 1
        if num_seqs <= 0:
            break

precision = TP/(TP + FP)
recall = TP/(TP + FN)
F1_score = 2*precision*recall/(precision + recall)

for key in res_dict:
    res_dict[key] /= counter[key]

print(f"precision: {precision}, recall: {recall}, F_score: {F1_score}")'''

In [None]:
df_interpret_res = pd.DataFrame.from_dict(res_dict, orient="index").rename(columns={0: "impact"})
df_interpret_res_sorted = df_interpret_res.sort_values(by=["impact"], ascending=False)
df_interpret_res_sorted.to_csv("integrated_gradients_interpret_res.csv")