In [1]:
from Bio import SeqIO
from pybedtools import BedTool

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats
from itertools import chain

from tqdm import tqdm
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, \
    recall_score

import pickle
import os

%matplotlib inline

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data.sampler import Sampler, BatchSampler
from torch.nn.modules.loss import MSELoss

torch.manual_seed(61406)
torch.__version__

'1.3.1'

In [3]:
label_dict = {
    'A': [1, 0, 0, 0],
    'C': [0, 1, 0, 0],
    'G': [0, 0, 1, 0],
    'T': [0, 0, 0, 1],
}

In [4]:
from torch.utils.data.dataset import Dataset

class G4HIDataset(Dataset):  # G4 & histones intersection
    def __init__(self, pos_filename, neg_filename):
        self.data = [record.seq.upper() for record in SeqIO.parse(pos_filename, 'fasta')
                     if 'N' not in record.seq.upper()]
        pos_size = len(self.data)
        self.data.extend([record.seq.upper() for record in SeqIO.parse(neg_filename, 'fasta')
                     if 'N' not in record.seq.upper()])
        neg_size = len(self.data) - pos_size
        self.labels = torch.unsqueeze(
            torch.FloatTensor(
                torch.cat((torch.ones(pos_size), torch.zeros(neg_size)), dim=0)
            ), 1)
        torch.manual_seed(7642)
        self.indexes = torch.randperm(len(self.data))
        
    def __getitem__(self, index):
        return (torch.FloatTensor([label_dict[bp] for bp in self.data[self.indexes[index]]]),
                self.labels[self.indexes[index]])

    def __len__(self):
        return len(self.data)

In [6]:
hist_names = [fn[2:-4] for fn in os.listdir('histone_modifications/') if fn[:2] == 'i_']

In [7]:
from sklearn.model_selection import train_test_split

def negative_class(train_num, test_num):
    f_tr_out = open(f'train_negative', 'w')
    f_ts_out = open(f'test_negative', 'w')
    for i, seq_record in enumerate(SeqIO.parse('negative_rand.fa', 'fasta')):
        if 'N' not in seq_record.seq.upper():
            if i < train_num:
                SeqIO.write(seq_record, f_tr_out, 'fasta')
            elif i < train_num + test_num:
                SeqIO.write(seq_record, f_ts_out, 'fasta')
            else:
                break
    f_tr_out.close()
    f_ts_out.close()

def make_dataset(hist_name):
    print(f'{hist_name}  (making dataset)')
    
    test_prop = 0.7  # proportion of test data
    # positive class
    data_bed = BedTool(f'histone_modifications/pos_i_{hist_name}.bed')
    train_size = int(len(data_bed) * test_prop)
    train_bed = BedTool(list(data_bed[: train_size]))
    test_bed = BedTool(list(data_bed[train_size:]))
    
    test_bed_iv = test_bed.intersect(b=train_bed, v=True)
    test_size = len(test_bed_iv)
    
    # теперь получаем fasta файлы
    train_seq = train_bed.sequence(fi='hg19.fa')
    test_seq = test_bed_iv.sequence(fi='hg19.fa')
    
    # negative class
    quad_bed = BedTool(f'quad_centered_500.bed')
    quad_bed = quad_bed.intersect(b=train_bed, v=True)
    quad_bed_neg = quad_bed.intersect(b=test_bed_iv, v=True)
    
    quad_train_size = int(len(quad_bed_neg) * test_prop)
    if quad_train_size >= train_size:
        train_neg = list(quad_bed_neg[: train_size])
        test_neg = list(quad_bed_neg[train_size: train_size + test_size])
    else:
        train_neg = list(quad_bed_neg[: quad_train_size])
        test_neg = list(quad_bed_neg[quad_train_size: ])
        with open('negative_rand.bed', 'r') as f:
            neg_rand = ['\t'.join(line.split('\t')[:3]) + '\n' for line in f.readlines()]

        rand_tr_size = train_size - quad_train_size
        train_neg.extend(neg_rand[: rand_tr_size])

        if test_size - len(test_neg) > 0:
            test_neg.extend(neg_rand[rand_tr_size: rand_tr_size + test_size - len(test_neg)])
        else:
            # when (size of test set) << (size of train set) * (test_prop) because of intersection
            test_neg = test_neg[: test_size]
        
    train_bed_neg = BedTool(train_neg)
    test_bed_neg = BedTool(test_neg)
    test_bed_neg_iv = test_bed_neg.intersect(b=train_bed_neg, v=True)

    train_seq_neg = train_bed_neg.sequence(fi='hg19.fa')
    train_seq_neg.save_seqs(f'histone_modifications/{hist_name}_train_neg.fa')
    
    test_seq_neg = test_bed_neg_iv.sequence(fi='hg19.fa')
    test_seq_neg.save_seqs(f'histone_modifications/{hist_name}_test_neg.fa')
    
    train_ratio = round(len(train_seq)/len(train_seq_neg), 4)
    print(f'train sizes: {len(train_seq)}, {len(train_seq_neg)}, {train_ratio}% ratio')
    test_ratio = round(len(test_seq)/len(test_seq_neg), 4)
    print(f'test sizes: {len(test_seq)}, {len(test_seq_neg)}, {test_ratio}% ratio')
    return train_seq.seqfn, train_seq_neg.seqfn, test_seq.seqfn, test_seq_neg.seqfn

## Процедура обучения:

In [5]:
from tqdm import tqdm_notebook

In [60]:
from IPython.display import clear_output

def train_epoch(model, optimizer):
    tr_loss_log = []
    model.train()
    pbar = tqdm_notebook(enumerate(train_loader))
    for i, (data, target) in pbar:
        optimizer.zero_grad()
        output = model(data.transpose(1, 2))
        loss = F.binary_cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        loss = loss.item()
        pbar.set_description(f"train loss: {round(loss, 7)}")
        tr_loss_log.append(loss)
        if i > int(len(train_loader) * 0.8):
            break
    val_loss_log = []
    model.eval()
    for i, (data, target) in pbar:
        output = model(data.transpose(1, 2))
        loss = F.binary_cross_entropy(output, target)
        loss = loss.item()
        val_loss_log.append(loss)
        pbar.set_description(f"val loss: {round(loss, 7)}")
    return tr_loss_log, val_loss_log

def train(model, opt, model_name='model'):
    train_log = []
    val_log = [(0, 1)]

    epoch = 0
    while len(val_log) == 1 or val_log[-2][1] - val_log[-1][1] > 0.0001:
        print(f"Epoch {epoch}")
        train_loss, val_loss = train_epoch(model, opt)
        torch.save(model.state_dict(), f"models/{model_name}_epoch_{epoch}.weights")
        train_log.extend(train_loss)
        val_log.append((int(len(train_loader) * 0.8) * (epoch + 1), np.mean(val_loss)))
        plot_history(train_log, val_log, model_name)
        epoch += 1

def plot_history(train_history, val_history, model_name, title='loss'):
    plt.figure()
    plt.title(f'{title}')
    plt.plot(train_history, label='train', zorder=1)    
    points = np.array(val_history)
    plt.scatter(points[:, 0], points[:, 1], marker='+', s=180, c='orange', label='val', zorder=2)
    plt.xlabel('steps')
    plt.legend(loc='best')
    plt.grid()
    plt.show()
    plt.savefig(f'training plots/{model_name}.png')

## модель

In [61]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size()[0], -1)

class CNN_one_l(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(4, 16, 8),
            nn.MaxPool1d(493, stride=493),
            nn.ReLU(),
            Flatten(),
            nn.Linear(16, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        output = self.cnn(x)
        return output

In [62]:
batch_size = 128

In [63]:
net = CNN_one_l()

In [None]:
for hist_name in hist_names:
    train_pos_fn, train_neg_fn, test_pos_fn, test_neg_fn = make_dataset(hist_name)
    ghi_train = G4HIDataset(train_pos_fn, train_neg_fn)
    train_loader = torch.utils.data.DataLoader(dataset=ghi_train, batch_size=batch_size)
    
    net = CNN_one_l()
    opt = torch.optim.Adam(net.parameters(), lr=0.001)
    train(net, opt, f"{hist_name}_CNN_one_layer")

# Baseline-model

In [79]:
'H3K4me3_(@_Wharton_Jelly)' in hist_names

True

In [80]:
train_pos_fn, train_neg_fn, test_pos_fn, test_neg_fn = make_dataset('H3K4me3_(@_Wharton_Jelly)')

H3K4me3_(@_Wharton_Jelly)  (making dataset)
train sizes: 103265, 103265, 1.0% ratio
test sizes: 28773, 28772, 1.0% ratio


In [20]:
ghi_train = G4HIDataset(train_pos_fn, train_neg_fn)
train_loader = torch.utils.data.DataLoader(dataset=ghi_train, batch_size=1)

In [21]:
X_train = []
y_train = []
for i, (x, y) in enumerate(train_loader):
    X_train.append(x.view(-1).int().tolist())
    y_train.extend(y.view(-1).int().tolist())

In [22]:
rf = RandomForestClassifier(random_state=50448)

In [23]:
%time rf.fit(X_train, y_train)

CPU times: user 4min 45s, sys: 1.52 s, total: 4min 46s
Wall time: 4min 46s


RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,
                       criterion='gini', max_depth=None, max_features='auto',
                       max_leaf_nodes=None, max_samples=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, n_estimators=100,
                       n_jobs=None, oob_score=False, random_state=50448,
                       verbose=0, warm_start=False)

In [81]:
ghi_test = G4HIDataset(test_pos_fn, test_neg_fn)
test_loader = torch.utils.data.DataLoader(dataset=ghi_test, batch_size=1)

In [82]:
X_test = []
y_test = []
for i, (x, y) in enumerate(test_loader):
    X_test.append(x.view(-1).int().tolist())
    y_test.extend(y.view(-1).int().tolist())
y_test = np.array(y_test)
y_pred = rf.predict(X_test)

In [104]:
fpr, tpr, thresholds = roc_curve(y_test, y_pred, pos_label=1)
roc_auc = round(metrics.auc(fpr, tpr), 4)
print('AUC ROC:', roc_auc)

precision, recall, thresholds = precision_recall_curve(y_test, y_pred, pos_label=1)
auc_pr = round(metrics.auc(recall, precision), 4)
print(f'auc pr: {auc_pr}')

acc = round(np.sum(y_test == y_pred) / len(y_test), 4)
print(f'accuracy: {acc}')

AUC ROC: 0.8538
auc pr: 0.9023
accuracy: 0.8539


# Визуализация фильтров и проверка качества на тестовой выборке

In [15]:
label_dict_rev = {
    0: 'A',
    1: 'C',
    2: 'G',
    3: 'T',
}

In [None]:
for hist_name in hist_names:
    model_nms = [[int(el[:-8].split('_')[-1]), el] for el in os.listdir('models')
                 if 'CNN_one_layer_epoch' in el and hist_name in el]
    if len(model_nms) >= 2:
        model_nm = sorted(model_nms)[-2][1]
    else:
        continue
    
    train_pos_fn, train_neg_fn, test_pos_fn, test_neg_fn = make_dataset(hist_name)
    ghi_test = G4HIDataset(test_pos_fn, test_neg_fn)
    test_loader = torch.utils.data.DataLoader(dataset=ghi_test, batch_size=batch_size)
    
    net.load_state_dict(torch.load(f"models/{model_nm}"))
    net.eval()
    y_pred = []
    y_test = []
    for batch, target in test_loader:
        with torch.no_grad():
            y_pred.extend(net(batch.transpose(1, 2)))
        y_test.extend(target)
    y_pred_bool = torch.FloatTensor(y_pred) > 0.5
    y_test = torch.FloatTensor(y_test)
    print(f'epoch {sorted(model_nms)[-2][0]}:')
    
    fpr, tpr, thresholds = roc_curve(y_test, y_pred, pos_label=1)
    roc_auc = round(metrics.auc(fpr, tpr), 4)
    print('AUC ROC:', roc_auc)

    precision, recall, thresholds = precision_recall_curve(y_test, y_pred, pos_label=1)
    auc_pr = round(metrics.auc(recall, precision), 4)
    print(f'auc pr: {auc_pr}')

    acc = round(int((y_test == y_pred_bool).sum()) / y_test.shape[0], 4)
    print(f'accuracy: {acc}')

    plt.plot(fpr, tpr, color='darkorange', label=f'ROC AUC={roc_auc}')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.fill_between(fpr, tpr, color=(1.0, 0.5490196078431373, 0.0, 0.1))
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.legend(loc='lower right')
    plt.title(f"{hist_name.replace('@_', '').replace('_', ' ')}")
    plt.grid()
    plt.savefig(f'ROC_curves/{hist_name}.png')
    plt.show()
    
    plt.plot(recall, precision, label=f'PR AUC={auc_pr}')
    plt.fill_between(recall, precision, color=(0,0,1,0.1))
    plt.ylabel('Precision')
    plt.xlabel('Recall')
    plt.legend(loc='best')
    plt.title(f"{hist_name.replace('@_', '').replace('_', ' ')}")
    plt.grid()
    plt.savefig(f'PR_curves/{hist_name}.png')
    plt.show()

    for i, filt in enumerate(net.state_dict()['cnn.0.weight']):
        value = np.transpose(filt).max(axis=1).values.sum() + net.state_dict()['cnn.0.bias'][i]
        s = ''.join([label_dict_rev[int(ind)] for ind in np.transpose(filt).max(axis=1).indices])

        print(i, s, round(float(value * net.state_dict()['cnn.4.weight'][0][i]), 4))
    print()