In [1]:
import torch
from torch.utils import data
import random
from collections import Counter
import numpy as np
import pandas as pd
import scipy
from tqdm import tqdm
import sys
import os
from joblib import dump, load
from sparse_vector.sparse_vector import SparseVector
import time
from sklearn.metrics import roc_auc_score, f1_score
from transformers import BertModel, BertConfig, PreTrainedTokenizer, BasicTokenizer, BertForTokenClassification
import collections
from transformers import utils
from torch.utils.data import DataLoader
import sklearn
from sklearn.metrics import accuracy_score
from collections import defaultdict
from dna_tokenizer import DNATokenizer, seq2kmer
import logging
logging.disable(logging.WARNING)

2023-04-22 10:18:50.411262: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-22 10:18:50.598367: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-04-22 10:18:51.905472: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-04-22 10:18:51.905623: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such 

In [2]:
class PredDataset(data.Dataset):
    def __init__(self, chroms, dna_source, intervals, tokenizer):

        self.chroms = chroms
        self.intervals = intervals
        self.tokenizer = tokenizer
        self.dna_source = dna_source
        
        
    def __len__(self):
        return len(self.intervals)
    
    def __getitem__(self, index):
        interval = self.intervals[index]
        chrom = interval[0]
        begin = interval[1]
        end = interval[2]

        k_mers = seq2kmer(self.dna_source[chrom][begin:end+5].upper(),6)
        encoded_k_mers = self.tokenizer.encode_plus(k_mers, add_special_tokens=False, max_length=512)["input_ids"]

        return torch.LongTensor(encoded_k_mers), (chrom, begin, end)

In [3]:
width = 128
pad = 192
k_mer_pad = 5

def final_prediction(chrom, DNA, models, device):
    
    intervals = []
    ends = []
    
    
    prediction = np.zeros(len(DNA[chrom]), dtype=np.float32)
    
    
    for st in range(0, len(DNA[chrom]) - 512, width):
        interval = [st, min(st + 512, len(DNA[chrom]))]
        intervals.append([chrom, interval[0], interval[1]])
        
    pred_dataset = PredDataset(chroms, DNA, intervals, 
                               DNATokenizer.from_pretrained('6-new-12w-0/', add_special_tokens=False))

    params = {'batch_size':32, 'num_workers':5, 'shuffle':False}
    load_predict = data.DataLoader(pred_dataset, **params)

    
    for model_i, model in enumerate(models):
    
        model.to(device)
        with torch.no_grad():
            for input_ids, intervals in tqdm(load_predict):
                input_ids = input_ids.to(device)
                outputs = torch.softmax(model(input_ids = input_ids)['logits'],axis = -1).cpu().numpy()[:,:,1]
                for ind, interval in enumerate(zip(intervals[0], intervals[1], intervals[2])): 
                    if interval[1] == 0:
                        prediction[interval[1]:interval[2]] = outputs[ind]
                    else:    
                        prediction[interval[1]+pad:interval[2]] = outputs[ind, pad:]
                    
        dump(prediction, f'/gim/lv01/dumerenkov/MM9_G4/MM9_kouzine_{model_i}_{chrom}', 3)

In [23]:
chroms = [f'chr{i}' for i in list(range(1, 20)) + ['X', 'Y']]
DNA = {chrom:load(f'../data/mm9_dna/sparse/{chrom}.pkl') for chrom in tqdm(chroms)}

G4_kouzine = {}
for chrom in DNA:
    G4_kouzine[chrom] = np.zeros(len(DNA[chrom]), dtype = bool)
    
    
with open("actB_ssDNA_enriched_Quadruplex.bed")as f:
    for idx, line in enumerate(f):
        if idx>0:
            chrom, start, end, _ , _ , _ = line.strip().split()
            if chrom in G4_kouzine:
                G4_kouzine[chrom][int(start):int(end)] = 1
                
G4 = G4_kouzine


  0%|                                                                                                                                                                           | 0/21 [00:00<?, ?it/s][A
  5%|███████▊                                                                                                                                                           | 1/21 [00:00<00:05,  3.79it/s][A
 10%|███████████████▌                                                                                                                                                   | 2/21 [00:00<00:04,  3.87it/s][A
 14%|███████████████████████▎                                                                                                                                           | 3/21 [00:00<00:04,  4.17it/s][A
 19%|███████████████████████████████                                                                                                                                    | 4/21 [00:00<00:04

In [5]:
models = []

for MODEL_NUMBER in range(5):
    dir_to_pretrained_model = f'dnabert_mm_fold_{MODEL_NUMBER}_kouzine_g4'
    model = BertForTokenClassification.from_pretrained(dir_to_pretrained_model)
    model.eval()
    models.append(model)

In [None]:
for chrom in chroms:
    print(f"BEGIN CHROM {chrom}")
    final_prediction(chrom, DNA, models, device = 1)

BEGIN CHROM chr1


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48144/48144 [4:00:16<00:00,  3.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48144/48144 [3:59:48<00:00,  3.35it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48144/48144 [3:59:43<00:00,  3.35it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48144/48144 [3:59:32<00:00,  3.35it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48144/48144 [3:59:31<00:00,  3.35it/s]


BEGIN CHROM chr2


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 44372/44372 [3:40:51<00:00,  3.35it/s]
 12%|█████████████████▉                                                                                                                                         | 5146/44372 [25:34<3:15:09,  3.35it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 48%|████████████████████████████████████████████████████████████████████████▌                                                                               | 21174/44372 [1:45:19<1:55:22,  3.35it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sendin

In [9]:
equalized, divisions = load('mm_divisions_kouzine_g4.pkl')

In [12]:
com_len = sum([len(DNA[chrom]) for chrom in chroms])
sums = []

for chrom in tqdm(chroms):
    loc_sum = []
    for model_num in range(5):
        vec = load(f'/gim/lv01/dumerenkov/MM9_G4/MM9_kouzine_{model_num}_{chrom}')
        loc_sum.append(vec.sum())
    sums.append(loc_sum)

multipliers = np.array(sums).sum(axis=0) / com_len

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [05:33<00:00, 15.89s/it]


In [13]:
for chrom in tqdm(chroms):
    vecs = np.array([load(f'/gim/lv01/dumerenkov/MM9_G4/MM9_kouzine_{model_num}_{chrom}') 
                     for model_num in range(5)])
    res_vec = (vecs / multipliers[:, None]) * multipliers.mean()
    mean_vec = res_vec.mean(axis=0)
    
    test_ints = []
    for MODEL_NUMBER in range(5):
        train_inds, test_inds = divisions[MODEL_NUMBER]
        train_intervals, test_intervals = [equalized[i] for i in train_inds], [equalized[i] for i in test_inds]
        test_ints.extend([(MODEL_NUMBER, inter) for inter in test_intervals if inter[0] == chrom])
    
    for model_num, inters in test_ints:
        mean_vec[inters[1]: inters[2]] = res_vec[model_num, inters[1]: inters[2]]
    dump(mean_vec, f'/gim/lv01/dumerenkov/MM9_G4/MM9_kouzine_res_{chrom}', 3)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [20:28<00:00, 58.50s/it]


In [15]:
all_pred = []
all_true = []
for chrom in tqdm(chroms):
    true_clean = G4[chrom][:].astype(int)
    all_pred.append(load(f'/gim/lv01/dumerenkov/MM9_G4/MM9_kouzine_res_{chrom}'))
    all_true.append(true_clean)
    
print(roc_auc_score(np.concatenate(all_true), np.concatenate(all_pred)))
print(sklearn.metrics.classification_report(np.concatenate(all_true), np.concatenate(all_pred)>0.5, digits=4))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [03:11<00:00,  9.14s/it]


0.9988107344940106
              precision    recall  f1-score   support

           0     1.0000    0.9971    0.9985 2654251788
           1     0.0682    0.8727    0.1265    643430

    accuracy                         0.9971 2654895218
   macro avg     0.5341    0.9349    0.5625 2654895218
weighted avg     0.9997    0.9971    0.9983 2654895218



In [16]:
from tabulate import tabulate
def to_fwf(df, fname):
    content = tabulate(df.values.tolist(), tablefmt="plain")
    open(fname, "w").write(content)

pd.DataFrame.to_fwf = to_fwf

In [24]:
pchroms, starts, ends = [], [], []
model_confidence_threshold = 0.25
min_length = 6

chroms = [f'chr{i}' for i in list(range(1, 20)) + ['X', 'Y']]

for chrom in tqdm(chroms):
    pred = load(f'/gim/lv01/dumerenkov/MM9_G4/MM9_kouzine_res_{chrom}')
    labeled, max_label = scipy.ndimage.label(pred>model_confidence_threshold)
    for idx in range(1,max_label+1):
        where = np.where(labeled == idx)[0]
        start = where[0]
        end = where[-1] + 1
        
        if end-start>min_length:
            pchroms.append(chrom)
            starts.append(start)
            ends.append(end)
pd.DataFrame(list(zip(pchroms, starts, ends))).to_fwf(f'beds/MM9_thr_{model_confidence_threshold}_minlen_{min_length}.bed')


  0%|                                                                                                                                                                           | 0/21 [00:00<?, ?it/s][A
  5%|███████▍                                                                                                                                                    | 1/21 [2:33:20<51:06:56, 9200.83s/it][A
 10%|██████████████▊                                                                                                                                             | 2/21 [5:25:41<52:05:51, 9871.14s/it][A
 14%|██████████████████████▎                                                                                                                                     | 3/21 [6:58:16<39:30:09, 7900.54s/it][A
 19%|█████████████████████████████▋                                                                                                                              | 4/21 [8:45:04<34:31:31, 