In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
import matplotlib

In [3]:
import logging

# create logger with 'Model_application'
logger = logging.getLogger('Model')
logger.setLevel(logging.DEBUG)
# create file handler which logs even debug messages
fh = logging.FileHandler('ABEdeepoff.log')
fh.setLevel(logging.DEBUG)
# create formatter and add it to the handlers
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)

In [4]:
import pkbar
import torch
import math
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
import torch.optim as optim
import os
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
from sklearn.model_selection import GroupShuffleSplit, GroupKFold, train_test_split

from torch.utils.data import Dataset, DataLoader
from utils import EarlyStopping
import random



In [5]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch.set_printoptions(precision=6,sci_mode=False)
pd.set_option('display.float_format',lambda x : '%.6f' % x)

SEED = 1356
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [6]:
def safe_dir(folder_name=None):
    folder_path = os.path.join(os.getcwd(), folder_name)
    if not os.path.exists(folder_path):
        os.mkdir(folder_path)
        print(f"Folder {folder_path} created successfully")
    else:
        print(f"Folder {folder_path} already exists")
    return folder_path


def do_encoding(source, target):
    '''Sequence encoding.'''
    aln = pairwise2.align.globalms(source, target, 1, -1, -3, -2)
    src, _aln, tgt = format_alignment(*aln[0]).split('\n')[:-2]
    encode_dict = {'<pad>':0, 'A': 1, 'C': 2, 'G':3, 'T': 4, '-': 5}
    seq1 = [encode_dict[nuc] for nuc in src]
    seq2 = [encode_dict[nuc] for nuc in tgt]
    return source, target, seq1, seq2


class gRNADataset(Dataset):
    def __init__(self, df):
        df[['source', 'target', 'seq1', 'seq2']] = df.apply(
            lambda x: do_encoding(x['source'], x['target']), axis=1, result_type='expand')
        df.reset_index(drop=True, inplace=True)
        self.source = df['source']
        self.target = df['target']
        self.efficiency = df['efficiency'].values
        self.seq1 = df['seq1'].values
        self.seq2 = df['seq2'].values
        self.otype = df['type']
        # print(f'Finished loading the {data} ({df.shape[0]} samples found)')

    def __len__(self):
        return len(self.source)
    
    def __getitem__(self, index):
        source = self.source[index] 
        target = self.target[index]
        y = torch.FloatTensor(np.array(self.efficiency[index]))
        seq1 = torch.LongTensor(self.seq1[index])
        seq2 = torch.LongTensor(self.seq2[index])
        seq_len = seq1.shape[0]
        otype = self.otype[index]
        return source, target, y, seq1, seq2, seq_len, otype


def generate_batch(batch):
    ys = []
    seqlen_lst = []
    source_lst = []
    target_lst = []
    seq1_lst = []
    seq2_lst = []
    otype_lst = []
    #x[-2]即seq1，他的shape[0]就是seq1的长度（同时也是seq2）的长度
    #通过对seq1长度进行排序，可以令每一个batch中都是第一个序列长度最长；
    batch = [ (a, b, c, d, e, f, g) for a, b, c, d, e, f, g in sorted( batch, key=lambda x:x[-2], reverse=True) ]

    for i, (source, target, y, seq1, seq2, seq_len, otype) in enumerate(batch):
        source_lst.append(source)
        target_lst.append(target)
        ys.append(y)
        
        seq1_lst.append(seq1)
        seq2_lst.append(seq2)
        seqlen_lst.append(seq_len)
        otype_lst.append(otype)
    
    # 将序列填充到相同的长度，并设置填充的值为0
    padded_seqs = rnn_utils.pad_sequence(seq1_lst, batch_first=False, padding_value=0)
    # 对于每个序列，通过 mask 将填充的部分设置为-1
    mask = padded_seqs.ne(0)
    seq1_batch = padded_seqs.masked_fill(~mask, 0)
    
    padded_seqs = rnn_utils.pad_sequence(seq2_lst, batch_first=False, padding_value=0)
    # 对于每个序列，通过 mask 将填充的部分设置为-1
    mask = padded_seqs.ne(0)
    seq2_batch = padded_seqs.masked_fill(~mask, 0)
    
    return (source_lst, target_lst, 
            torch.FloatTensor(ys),
            seq1_batch, seq2_batch,
            torch.LongTensor(seqlen_lst), otype_lst)

In [7]:
### 模型
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
        super(Encoder, self).__init__()
        
        self.embedding_1 = nn.Embedding(input_dim, emb_dim)
        self.embedding_2 = nn.Embedding(input_dim, emb_dim, _weight=self.embedding_1.weight)
        self.rnn = nn.LSTM(input_size=emb_dim, hidden_size=hid_dim, num_layers=n_layers,
                           bidirectional=True)
        self.dropout = nn.Dropout(dropout)     
        self.fc_feat_1 = nn.Linear(6 * hid_dim, 3 * hid_dim)
        self.fc_out = nn.Linear(3 * hid_dim, 1) 
        self.att_score = None
    
    def attention_net(self, x, query, mask=None): 
        
        d_k = query.size(-1)
        scores = torch.matmul(query, x.transpose(1, 2)) / math.sqrt(d_k)  
        alpha_n = F.softmax(scores, dim=-1) 
        context = torch.matmul(alpha_n, x).sum(1)
        return context, alpha_n
    
    def forward(self,seq_1, seq_2):
        emb_1 = self.embedding_1(seq_1)
        emb_2 = self.embedding_2(seq_2)
        emb_comb = self.dropout(emb_1 + emb_2)
        #self.embedding = emb_comb
        out, (hid_, _) = self.rnn(emb_comb)
        hidden = torch.cat( (hid_[-2,:,:], hid_[-1,:,:]), dim = 1 )
        
        out = out.permute(1,0,2)
        avg_pool = torch.mean( out, 1)
        max_pool, _ = torch.max( out, 1)
                             
        query = self.dropout(out)
        # 加入attention机制
        attn_output, alpha_n = self.attention_net(out, query)
        self.att_score = alpha_n
        
        #hid_size*2*3
        x = torch.cat([ attn_output, hidden, max_pool], dim=1)
        x = self.dropout(F.relu(self.fc_feat_1( x )))
        fc_out = self.fc_out(x)
        return fc_out

In [8]:
df_external = pd.read_table('data/ABE_Off_endo.txt', skip_blank_lines=True)
df_external.dropna(inplace=True)
df_external['source'] = df_external.seq1.str.upper().str.strip()
df_external['target'] = df_external.seq2.str.upper().str.strip().str.replace('-','')
df_external['target_len'] = df_external.target.apply(len)
df_external['efficiency'] = df_external.y
df_external['type'] = df_external['off_type']
df_on = df_external[df_external['type'] == 'Y'].copy()
df_on['oneff'] = df_on.efficiency
df_external = df_external.merge(df_on[['source', 'Group', 'oneff']])
df_external['off/on'] = df_external.efficiency/df_external.oneff
df_external['efficiency'] = df_external['off/on']
df_external.drop_duplicates(inplace=True)
df_external.reset_index(drop=True,inplace=True)
df_external_dataset = gRNADataset(df_external)   
exter_iter =  DataLoader(df_external_dataset, batch_size=5, shuffle=False, collate_fn=generate_batch)

In [9]:
def train_model(train_iter, valid_iter, test_iter, patience, k, version=None):
    global my_optim
    global lr_dict
    global lst_testing
    global lst_endo
    N_EPOCHS = 250
    CLIP = 1
    best_valid_loss = float('inf')
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    
    path = safe_dir(version)
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=patience, verbose=True, path=f'{path}/ABE_RNN_{k}_checkpoint.pt')
    early_num = 0
    for epoch in range(N_EPOCHS):

        train_per_epoch = len(train_iter)
        kbar = pkbar.Kbar(target=train_per_epoch, epoch=epoch, num_epochs=N_EPOCHS, width=8, always_stateful=False)

        model.train()
        epoch_loss_ = 0
        
        for i, batch in enumerate(train_iter):
            seq1 = batch[3].to(device)
            seq2 = batch[4].to(device)
            y = batch[2].unsqueeze(1).to(device)
            length = batch[5].to(device)
            wgt_dict = (1/((pd.Series(batch[-1]).value_counts()/(pd.Series(batch[-1]).value_counts().sum()))+0.08)).to_dict()
            wgts_lst = [wgt_dict[x] for x in batch[-1]]
            wgts = torch.FloatTensor(wgts_lst).to(device)
            

            my_optim.zero_grad()
            outputs = model(seq1, seq2)

            #src_y regression
            y_hat = torch.sigmoid(outputs) * 100
            loss = (criterion_mse(y, y_hat) * wgts).mean()

            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
            
            my_optim.step()
            
            epoch_loss_ += loss.item()
            train_losses.append(loss.item())

            kbar.update(i, values=[("lr", my_optim.defaults['lr']), 
                ("loss", epoch_loss_ / (i + 1))])
        
        model.eval()
        epoch_loss_ = 0
        with torch.no_grad():

            for i, batch in enumerate(valid_iter):

                seq1 = batch[3].to(device)
                seq2 = batch[4].to(device)
                y = batch[2].unsqueeze(1).to(device)
                length = batch[5].to(device)
                #wgts = batch[2].to(device)
                wgt_dict = (1/((pd.Series(batch[-1]).value_counts()/(pd.Series(batch[-1]).value_counts().sum()))+0.08)).to_dict()
                wgts_lst = [wgt_dict[x] for x in batch[-1]]
                wgts = torch.FloatTensor(wgts_lst).to(device)
                my_optim.zero_grad()
                outputs = model(seq1, seq2)

                y_hat = torch.sigmoid(outputs) * 100
                loss = (criterion_mse(y, y_hat)).mean()

                valid_losses.append(loss.item())
                epoch_loss_ += loss.item()

            kbar.add(1, values=[("val_loss", epoch_loss_ / len(valid_iter))])

        valid_loss = epoch_loss_ / len(valid_iter)
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            p = f'{path}/ABE_RNN_{k}_{round(valid_loss, 6)}.pt'
            torch.save(model.state_dict(), p)
            logger.info(f'----Testing------,{p}')
            r = get_test(test_iter, model)
            lst_testing.append([p,r])
            logger.info(f'----Endo------')
            r = get_endo(exter_iter, model, df_external)
            lst_endo.append([p,r])

        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)

        # clear lists to track next epoch
        train_losses = []
        valid_losses = []

        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a RNN_checkpoint of the current model
        early_stopping(valid_loss, model, p)
        
        if early_stopping.early_stop:
            early_num = early_num + 1
            ##最后一个，换learning rate
            if early_num == 4:
                print("Early stopping with best_valid_loss:", epoch, best_valid_loss, ". Going next...")
                model.load_state_dict(torch.load(p))
                my_optim = optim.Adam(model.parameters(),lr=lr_dict[0])
                early_stopping = EarlyStopping(patience=patience, verbose=True, 
                    path=f'{path}/ABE_RNN_{k}_checkpoint.pt')
                break
                
            print("Change learning rate..")
            model.load_state_dict(torch.load(p))
            my_optim = optim.Adam(model.parameters(),lr=lr_dict[early_num]) 
            early_stopping.early_stop = False
            early_stopping.counter = 0
            early_stopping = EarlyStopping(patience=patience, verbose=True, best_score = -best_valid_loss,
                path=f'{path}/ABE_RNN_{k}_checkpoint.pt')
                
            
def get_pred(iter_, model):
    model.eval()
    lst_dfs = []
    with torch.no_grad():
        for i, batch in enumerate(iter_):
            seq1 = batch[3].to(device)
            seq2 = batch[4].to(device)
            y = batch[2].unsqueeze(1).to(device)
            length = batch[5].to(device)

            out_eff = model(seq1, seq2)
            y = list(y.view(-1).cpu().numpy() / 100)
            out_eff = torch.sigmoid(out_eff)
            out_eff = list(out_eff.view(-1).cpu().numpy())

            df_gRNA = pd.DataFrame({'source': batch[0],'target': batch[1], 
                'offtype': batch[-1]})
            df_gRNA['y'] = y
            df_gRNA['y_pred'] = out_eff
            lst_dfs.append(df_gRNA)
    df_conc = pd.concat(lst_dfs)
    return df_conc


def init_weights(m):
    for name, param in m.named_parameters():
        if 'rnn.weight_' in name:
            nn.init.orthogonal_(param.data)
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)

In [10]:
df = pd.read_csv('data/ABEdeepoff.txt', sep='\t')
df['efficiency'] = df['efficiency'] * 100

i = 1
df_list = []
for src, sub_df in df.groupby('source', sort=False):
    sub_df['group'] = i
    i += 1
    df_list.append(sub_df)
df = pd.concat(df_list, ignore_index=True)
df_src = df.copy().sample(frac=1,random_state=SEED)

In [11]:
df_on = df_src[df_src['type'] == 'Y'].copy()
df_on['oneff'] = df_on.efficiency
df_src = df_src.merge(df_on[['source','oneff']])
df_src['off/on'] = df_src.efficiency/df_src.oneff
df_src['efficiency'] = df_src['off/on'] * 100

In [12]:
def get_test(iter_data, model):
    df_ = get_pred(iter_data, model)
    grp = df_.groupby('offtype')
    
    all_corr = df_.corr(method='spearman').y[-1]
    logger.info(f"all_corr:{all_corr}")
    lst_r = []
    for k in grp.groups.keys():
        df_grp = grp.get_group(k)
        r = df_grp.corr(method='spearman').y[-1]
        lst_r.append(f'{k},{r}')
        logger.info(f"{k}:{r}")
    return all_corr,lst_r


def get_endo(iter_data, model,df_external):
    df_ = get_pred(iter_data,model)
    df_['group'] = df_external['Group'].values

    grp = df_.groupby('group')
    all_corr = df_.corr(method='spearman').y[1]
    logger.info(f"all_corr:{all_corr}")
    lst_r = []
    for k in grp.groups.keys():
        df_grp = grp.get_group(k)
        r = df_grp.corr(method='spearman').y[1]
        lst_r.append(f'{k},{r}')
        logger.info(f"{k}:{r}")
    return lst_r

In [13]:
ENC_INPUT_DIM = 6
ENC_EMB_DIM = 256
ENC_HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
model = Encoder(ENC_INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM, N_LAYERS, ENC_DROPOUT).to(device)

lst_testing = []
lst_endo = []
lr_dict = {0: 0.001, 1: 0.0001, 2: 0.00001, 3: 0.000005}
criterion_mse = nn.MSELoss(reduction='none')

In [None]:
kf = GroupKFold(n_splits=10)
splits = kf.split(df_src, groups=df_src['group'])

for i, (train_idx, test_idx) in enumerate(splits):
    data_train, data_test = df_src.loc[train_idx], df_src.loc[test_idx]
    test_dataset = gRNADataset(data_test)
    test_iter = DataLoader(test_dataset, batch_size=128, shuffle=False, collate_fn=generate_batch)
    
    data_train.reset_index(drop=True, inplace=True)
    data_train, data_valid = train_test_split(data_train, train_size=0.9, random_state=SEED)
    
    train_dataset = gRNADataset(data_train)
    train_iter = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=generate_batch)

    valid_dataset = gRNADataset(data_valid)
    valid_iter = DataLoader(valid_dataset, batch_size=32, shuffle=True, collate_fn=generate_batch)

    my_optim = optim.AdamW(model.parameters(), lr=lr_dict[0])
    model.apply(init_weights)
    train_model(train_iter, valid_iter, test_iter, 5, i, version='ABEdeepoff_0525')