In [1]:
import torch
from torch.utils import data
import random
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split, StratifiedKFold
from collections import Counter
import pandas as pd
import numpy as np
import scipy
from tqdm import trange
from tqdm import tqdm
from datetime import datetime
import sys
import os
import seaborn as sns
from matplotlib import pyplot as plt
from joblib import Parallel, delayed, dump, load
from matplotlib import pyplot as plt
from sparse_vector.sparse_vector import SparseVector
from scipy.signal import convolve2d, convolve
import time
from torch import nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score
from IPython.display import clear_output

import torch
from transformers import BertModel, BertConfig, PreTrainedTokenizer, BasicTokenizer, BertForTokenClassification
import collections

from transformers import utils
from bertviz import model_view, head_view

from torch.utils.data import DataLoader
import sklearn
from sklearn.metrics import accuracy_score
from torch.nn import CrossEntropyLoss

import gc
from collections import defaultdict

import warnings
warnings.filterwarnings("ignore")

  import pandas.util.testing as tm


In [2]:
from dna_tokenizer import DNATokenizer, seq2kmer

In [3]:
class Dataset(data.Dataset):
    def __init__(self, chroms, features, 
                 dna_source, features_source, 
                 labels_source, intervals, tokenizer):

        self.chroms = chroms
        self.features = features
        self.dna_source = dna_source
        self.features_source = features_source
        self.labels_source = labels_source
        self.intervals = intervals
        self.le = LabelBinarizer().fit(np.array([["A"], ["C"], ["T"], ["G"]]))
        self.configs = {
                        'ZHUNT_AS': {
                                'CG': 0, 'GC': 1, 'CA': 0, 'AC': 1, 
                                'TG': 0, 'GT': 1, 'TA': 1, 'AT': 1, 
                                'CC': 0, 'GG': 0, 'CT': 1, 'TC': 1, 
                                'GA': 1, 'AG': 1, 'AA': 1, 'TT': 1},
                       }
        seqs = (["A", "C", "T", "G"] + 
                ['AC', 'AT', 'AG', 'CT', 'CG', 'GT'] +
                ['AAC', 'ACC', 'AAT', 'ATT', 'AAG', 'AGG', 
                 'CCA', 'CAA', 'CCT', 'CTT', 'CCG', 'CGG', 
                 'TTA', 'TAA', 'TTC', 'TCC', 'TTG', 'TGG', 
                 'GGA', 'GAA', 'GGC', 'GCC', 'GGT', 'GTT'] +
                ['AAAC', 'AAAT', 'AAAG', 'CCCA', 'CCCT', 'CCCG',
                 'TTTA', 'TTTC', 'TTTG', 'GGGA', 'GGGC', 'GGGT'])
        self.tars = np.array([self.le.transform(list(i * 11)[:11]) for i in seqs])[:, ::-1, ::-1]
        # purine-pyrimidine
        self.tars = np.concatenate((self.tars, np.array([self.tars[4] + self.tars[9]])))
        self.tokenizer = tokenizer
        
        
    def __len__(self):
        return len(self.intervals)
    
    def __getitem__(self, index):
        interval = self.intervals[index]
        chrom = interval[0]
        begin = int(interval[1])
        end = int(interval[2])
        ll = list(self.dna_source[chrom][begin:end].upper())
        y = self.labels_source[interval[0]][interval[1]: interval[2]]        
        
        
#         DNA PART
        
        dna_OHE = self.le.transform(ll)[None]
        
        res = pd.DataFrame(convolve(dna_OHE, self.tars)[:, 5:-5, 3].T / 11)
        res = (res.rolling(5, min_periods=1).max().values == 1).astype(int)
        
        
#         ZHUNT PART
        zhunts = []
        for key in self.configs:
            vec = np.array(ll)
            vec = np.vectorize(lambda x:self.configs[key].get(x, 0))(
                                    np.char.add(vec[1:], vec[:-1]))
            zhunts.append(np.concatenate([vec, [0]]))
        
        
        # FEATURES PART
        feature_matr = []
        for feature in self.features:
            source = self.features_source[feature]
            feature_matr.append(source[chrom][begin:end])
        
        # UNION
        if len(feature_matr) > 0:
            X = np.hstack((
                           res,
                           np.array(zhunts).T, 
                           np.array(feature_matr).T/1000)).astype(np.float32)
#             X = (np.array(feature_matr).T/1000).astype(np.float32)
        else:
            X = dna_OHE.astype(np.float32)
        
        #K-mer part
        
        k_mers = seq2kmer(self.dna_source[chrom][begin:end+5].upper(),6)
        encoded_k_mers = self.tokenizer.encode_plus(k_mers, add_special_tokens=False, max_length=512)["input_ids"]

        return torch.Tensor(X), torch.Tensor(y).long(), ll, torch.LongTensor(encoded_k_mers), (chrom, begin, end)

In [12]:
from collections import defaultdict
kmer2pred = defaultdict(int)
kmer2att = defaultdict(float)

device = 1
for FOLD in range(5):
    gc.collect()
    attentions, preds, targets, seqs, bps = [],[],[], [], []
    train_dataset, test_dataset = load(f'/gim/lv01/dumerenkov/zdna_data/datasets/ds_w_seq_hg_fold{FOLD}_kouzine.pkl')
    model = BertForTokenClassification.from_pretrained(f'dnabert_hg_fold_{FOLD}_kouzine', output_attentions=True)
    model.to(device)
    for example in tqdm(test_dataset):
        features, target, seq, input_ids, interval = example
        if target.numpy().sum()>0:
            with torch.no_grad():
                input_ids = input_ids.to(device)
                outputs = model(input_ids.unsqueeze(0))
                
                raw_preds = outputs[-2].detach().to('cpu')
                attention = outputs[-1][-1].detach().to('cpu')
                
                pred = torch.softmax(raw_preds, axis = -1)[0,:,1]

                
            attentions.append(attention)
            preds.append(pred)
            targets.append(target)
            seqs.append(seq)
            
    for attention, pred, target, seq in tqdm(zip(attentions, preds, targets, seqs)):
        kmer = seq2kmer(''.join(seq), 6).split(' ')
        #print(kmer)
        att = attention[0,:,:,:]
    
        for idx in range(512-5):
            if target[idx]>0:
                kmer2pred[kmer[idx]]+=1
            
                for head in range(12):
                    c_att = att[head,idx,:].numpy()
                    for att_idx in range(512-5):
                        kmer2att[kmer[att_idx]]+=c_att[att_idx]
    
    

  0%|                                                                                                                        | 0/17847 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17847/17847 [10:40<00:00, 27.87it/s]
5949it [06:19, 15.68it/s]
  0%|                                                                                                                        | 0/17847 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples

In [13]:
sorted_pred = [t[0] for t in sorted(kmer2pred.items(), key=lambda item: -item[1])]
kmer2att2 = {k:int(kmer2att[k]) for k in kmer2att}
sorted_att  = [t[0] for t in sorted(kmer2att2.items(), key=lambda item: -item[1])]

In [14]:
for idx, kmer in enumerate(sorted_att):
    print(idx+1, kmer, sorted_pred.index(kmer)+1)
    if idx>100:
        break

1 GCGCGC 1
2 GTGTGT 5
3 CGCGCG 2
4 ACACAC 6
5 TGTGTG 3
6 GCGCGG 7
7 CACACA 4
8 CCGCGC 10
9 GGGCGC 11
10 GCGCCC 12
11 GTGCGC 17
12 GGCGCG 9
13 GTGTGC 14
14 GCGCAC 19
15 GCACAC 15
16 GCCCGC 20
17 GCGGGC 16
18 CGCGCC 8
19 GCGTGC 25
20 GCACGC 26
21 CCCGCG 18
22 CGCGGG 13
23 GTGCAC 31
24 CGGGCG 24
25 ACGCGC 35
26 GCGCGT 34
27 CGCCCG 23
28 TGCGCG 22
29 CGCGCA 21
30 CGTGTG 27
31 GCGTGT 39
32 GTGCGT 47
33 CACGCG 28
34 ACACGC 44
35 CGCGTG 29
36 CGTGCG 32
37 ACGCAC 53
38 CACACG 30
39 CGCACG 33
40 GCGCAG 48
41 GCGGGG 46
42 TGTGCG 36
43 CTGCGC 52
44 TGCGTG 37
45 TGCACA 40
46 CGCACA 38
47 GCGCCG 51
48 GCCGCG 50
49 TGTGCA 42
50 CCCCGC 59
51 CGGCGC 55
52 CACGCA 41
53 CACGTG 43
54 GGGCGG 70
55 CGCGGC 45
56 GGGGCG 61
57 AGGCGC 81
58 ACGTGC 79
59 GCGGCG 80
60 CGCCCC 49
61 CCGCCC 71
62 GCGCCT 69
63 GCACGT 73
64 AGCGCG 58
65 ATGCAC 101
66 TGCACG 57
67 GCGTGG 76
68 GTGCAT 112
69 GGCGGG 64
70 GCAGGC 85
71 GCCTGC 84
72 GTGCGG 89
73 CCGCAC 92
74 CGCCGC 65
75 CGTGCA 56
76 GGGCAC 106
77 CCACGC 90
78 GCGCTC 87
7

In [4]:
chroms = [f'chr{i}' for i in list(range(1, 23)) + ['X', 'Y']]
ZDNA = load('ZDNA_hg19_kouzine.pkl')

In [6]:
all_pred = []
all_true = []
for chrom in tqdm(chroms):
    all_pred.append(load(f"/gim/lv01/dumerenkov/zdna_data/new_mod_hg_res_{chrom}_kouzine"))
    all_true.append(ZDNA[chrom][:].astype(int))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [02:30<00:00,  6.29s/it]


In [2]:
chroms = [f'chr{i}' for i in list(range(1, 23)) + ['X', 'Y']]
ZDNA = load('ZDNA_hg19_kouzine.pkl')
black_list = load(f'../data/hg19_zdna/sparse/blacklist_hg19.pkl')

all_pred = []
all_true = []
for chrom in tqdm(chroms):
    true_clean = ZDNA[chrom][:].astype(int)
    iids = np.where(black_list[chrom].data == 1)[0]
    for i, j in zip(black_list[chrom].indices[iids], black_list[chrom].indices[iids + 1]):
        true_clean[i:j] = 0
    all_pred.append(load(f"/gim/lv01/dumerenkov/zdna_data/new_mod_hg_res_{chrom}_kouzine"))
    all_true.append(true_clean)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [03:19<00:00,  8.32s/it]


In [3]:
roc_auc_score(np.concatenate(all_true), np.concatenate(all_pred))

0.9991730742363316

In [4]:
print(sklearn.metrics.classification_report(np.concatenate(all_true), np.concatenate(all_pred)>0.5, digits=4))

              precision    recall  f1-score   support

           0     0.9999    0.9986    0.9993 3094878438
           1     0.1185    0.7303    0.2040    798974

    accuracy                         0.9985 3095677412
   macro avg     0.5592    0.8644    0.6016 3095677412
weighted avg     0.9997    0.9985    0.9991 3095677412

