In [1]:
import torch
from torch.utils import data
import random
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split, StratifiedKFold
from collections import Counter
import pandas as pd
import numpy as np
import scipy
from tqdm import trange
from tqdm import tqdm
from datetime import datetime
import sys
import os
import seaborn as sns
from matplotlib import pyplot as plt
from joblib import Parallel, delayed, dump, load
from matplotlib import pyplot as plt
from sparse_vector.sparse_vector import SparseVector
from scipy.signal import convolve2d, convolve
import time
from torch import nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score
from IPython.display import clear_output

import torch
from transformers import BertModel, BertConfig, PreTrainedTokenizer, BasicTokenizer, BertForTokenClassification
import collections

from transformers import utils
from bertviz import model_view, head_view

from torch.utils.data import DataLoader
import sklearn
from sklearn.metrics import accuracy_score
from torch.nn import CrossEntropyLoss

import warnings
warnings.filterwarnings("ignore")

  import pandas.util.testing as tm


In [2]:
from dna_tokenizer import DNATokenizer, seq2kmer

In [3]:
class Dataset(data.Dataset):
    def __init__(self, chroms, features, 
                 dna_source, features_source, 
                 labels_source, intervals, tokenizer):

        self.chroms = chroms
        self.features = features
        self.dna_source = dna_source
        self.features_source = features_source
        self.labels_source = labels_source
        self.intervals = intervals
        self.le = LabelBinarizer().fit(np.array([["A"], ["C"], ["T"], ["G"]]))
        self.configs = {
                        'ZHUNT_AS': {
                                'CG': 0, 'GC': 1, 'CA': 0, 'AC': 1, 
                                'TG': 0, 'GT': 1, 'TA': 1, 'AT': 1, 
                                'CC': 0, 'GG': 0, 'CT': 1, 'TC': 1, 
                                'GA': 1, 'AG': 1, 'AA': 1, 'TT': 1},
                       }
        seqs = (["A", "C", "T", "G"] + 
                ['AC', 'AT', 'AG', 'CT', 'CG', 'GT'] +
                ['AAC', 'ACC', 'AAT', 'ATT', 'AAG', 'AGG', 
                 'CCA', 'CAA', 'CCT', 'CTT', 'CCG', 'CGG', 
                 'TTA', 'TAA', 'TTC', 'TCC', 'TTG', 'TGG', 
                 'GGA', 'GAA', 'GGC', 'GCC', 'GGT', 'GTT'] +
                ['AAAC', 'AAAT', 'AAAG', 'CCCA', 'CCCT', 'CCCG',
                 'TTTA', 'TTTC', 'TTTG', 'GGGA', 'GGGC', 'GGGT'])
        self.tars = np.array([self.le.transform(list(i * 11)[:11]) for i in seqs])[:, ::-1, ::-1]
        # purine-pyrimidine
        self.tars = np.concatenate((self.tars, np.array([self.tars[4] + self.tars[9]])))
        self.tokenizer = tokenizer
        
        
    def __len__(self):
        return len(self.intervals)
    
    def __getitem__(self, index):
        interval = self.intervals[index]
        chrom = interval[0]
        begin = int(interval[1])
        end = int(interval[2])
        ll = list(self.dna_source[chrom][begin:end].upper())
        y = self.labels_source[interval[0]][interval[1]: interval[2]]        
        
        
#         DNA PART
        
        dna_OHE = self.le.transform(ll)[None]
        
        res = pd.DataFrame(convolve(dna_OHE, self.tars)[:, 5:-5, 3].T / 11)
        res = (res.rolling(5, min_periods=1).max().values == 1).astype(int)
        
        
#         ZHUNT PART
        zhunts = []
        for key in self.configs:
            vec = np.array(ll)
            vec = np.vectorize(lambda x:self.configs[key].get(x, 0))(
                                    np.char.add(vec[1:], vec[:-1]))
            zhunts.append(np.concatenate([vec, [0]]))
        
        
        # FEATURES PART
        feature_matr = []
        for feature in self.features:
            source = self.features_source[feature]
            feature_matr.append(source[chrom][begin:end])
        
        # UNION
        if len(feature_matr) > 0:
            X = np.hstack((
                           res,
                           np.array(zhunts).T, 
                           np.array(feature_matr).T/1000)).astype(np.float32)
#             X = (np.array(feature_matr).T/1000).astype(np.float32)
        else:
            X = dna_OHE.astype(np.float32)
        
        #K-mer part
        
        k_mers = seq2kmer(self.dna_source[chrom][begin:end+5].upper(),6)
        encoded_k_mers = self.tokenizer.encode_plus(k_mers, add_special_tokens=False, max_length=512)["input_ids"]

        return torch.Tensor(X), torch.Tensor(y).long(), ll, torch.LongTensor(encoded_k_mers), (chrom, begin, end)

In [45]:
attentions, preds, targets, seqs, bps = [],[],[], [], []

for FOLD in range(5):
    train_dataset, test_dataset = load(f'ds_w_seq_hg_fold{FOLD}.pkl')
    model = BertForTokenClassification.from_pretrained(f'dnabert_hg_fold_{FOLD}', output_attentions=True)
    
    for example in tqdm(test_dataset):
        features, target, seq, input_ids, interval = example
        if target.numpy().sum()>0:
            with torch.no_grad():
                outputs = model(input_ids.unsqueeze(0))
                
                pred = torch.softmax(outputs[-2], axis = -1)[0,:,1]
                attention = outputs[-1]

                
            attentions.append(attention)
            preds.append(pred)
            targets.append(target)
            seqs.append(seq)
    
    

  0%|                                                                                                                         | 0/2642 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2642/2642 [04:22<00:00, 10.07it/s]
  0%|                                                                                                                         | 0/2642 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting

In [46]:
from collections import defaultdict
kmer2pred = defaultdict(int)
kmer2att = defaultdict(float)


for attention, pred, target, seq in tqdm(zip(attentions, preds, targets, seqs)):
    kmer = seq2kmer(''.join(seq), 6).split(' ')
    #print(kmer)
    att = attention[-1][0,:,:,:]
    
    for idx in range(512-5):
        if target[idx]>0:
            kmer2pred[kmer[idx]]+=1
            
            for head in range(12):
                c_att = att[head,idx,:].numpy()
                for att_idx in range(512-5):
                    kmer2att[kmer[att_idx]]+=c_att[att_idx]


629it [05:27,  1.92it/s]


In [47]:
sorted_pred = [t[0] for t in sorted(kmer2pred.items(), key=lambda item: -item[1])]
kmer2att2 = {k:int(kmer2att[k]) for k in kmer2att}
sorted_att  = [t[0] for t in sorted(kmer2att2.items(), key=lambda item: -item[1])]

In [48]:
for idx, kmer in enumerate(sorted_att):
    print(idx+1, kmer, sorted_pred.index(kmer)+1)
    if idx>100:
        break

1 GCGCGC 5
2 CGCGCG 6
3 TGTGTG 1
4 GTGTGT 2
5 ACACAC 4
6 CACACA 3
7 GCACAC 7
8 GTGCGC 19
9 TGCGCG 12
10 GCGCAC 17
11 GTGTGC 10
12 CGCGCA 14
13 CGCACA 8
14 GCGCGT 26
15 ACGCGC 27
16 CGCGTG 32
17 TGTGCG 18
18 ACACGC 15
19 GCGTGC 41
20 CACACG 11
21 CACGCG 31
22 TGCGTG 35
23 CGTGTG 25
24 GCACGC 39
25 GCGTGT 28
26 GGCGGC 23
27 ACGCAC 24
28 TTTTTT 22
29 CACGCA 21
30 AAAAAA 9
31 GCGGCG 20
32 GCCGCC 29
33 GAGAGA 13
34 GTGCGT 36
35 CGTGCG 44
36 AGAGAG 16
37 CTCTCT 46
38 CGCCGC 33
39 CGGCGG 34
40 CCTCCC 42
41 TCTCTC 37
42 CGCACG 51
43 CCGCCC 30
44 GGCGGG 43
45 GGGCGG 45
46 GGGAGG 40
47 GCCCGC 78
48 CTCCTC 70
49 CTCCCC 60
50 TCCTCC 76
51 GCGGGG 50
52 GCCTCC 53
53 CCCGCC 38
54 CGCCCC 52
55 CCCTCC 59
56 GGCCGC 56
57 CGGCGC 72
58 GCGCCC 63
59 GCGGCC 55
60 CCCCGC 58
61 GGCTGC 61
62 CCTCCT 57
63 GGGGAG 85
64 CCGCCG 48
65 GCCCCC 107
66 GCCCGG 49
67 CCCGGC 47
68 GCCCCG 66
69 CGGGGC 79
70 GGGGCG 84
71 GCGCGG 75
72 GCCGCG 90
73 CGCGGC 65
74 GGCGCG 87
75 GAGGAG 108
76 GGGCGC 95
77 GGAGGG 86
78 GGCCGG 64
79

In [22]:
chroms = [f'chr{i}' for i in list(range(1, 23)) + ['X', 'Y']]
ZDNA = load(f'../data/hg19_zdna/sparse/ZDNA_2016.pkl')
black_list = load(f'../data/hg19_zdna/sparse/blacklist_hg19.pkl')

In [23]:
all_pred = []
all_true = []
for chrom in tqdm(chroms):
    all_pred.append(load(f"/gim/lv01/dumerenkov/zdna_data/new_mod_hg_res_{chrom}"))
    all_true.append(ZDNA[chrom][:].astype(int))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:50<00:00,  7.12s/it]


In [26]:
roc_auc_score(np.concatenate(all_true), np.concatenate(all_pred))

0.9501730035376861

In [27]:
print(sklearn.metrics.classification_report(np.concatenate(all_true), np.concatenate(all_pred)>0.5, digits=4))

              precision    recall  f1-score   support

           0     1.0000    0.9980    0.9990 3095541102
           1     0.0105    0.4816    0.0206    136310

    accuracy                         0.9980 3095677412
   macro avg     0.5052    0.7398    0.5098 3095677412
weighted avg     0.9999    0.9980    0.9989 3095677412

