In [63]:
cd /home/featurize/work/TargetFM/RNA-FM

/home/featurize/work/TargetFM/RNA-FM


In [64]:
!pip install . --user

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Processing /home/featurize/work/TargetFM/RNA-FM
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Building wheels for collected packages: rna-fm
  Building wheel for rna-fm (setup.py) ... [?25ldone
[?25h  Created wheel for rna-fm: filename=rna_fm-0.1.0-py3-none-any.whl size=28798 sha256=2ca08f9be1fc38e5fd402e3fc584ad688733c31c5c5f09fb21457cdb18d18b29
  Stored in directory: /home/featurize/.cache/pip/wheels/2d/f6/bf/5a73662ac34e4e6d77122ce63144d84a6b927d8c9c9057f669
Successfully built rna-fm
Installing collected packages: rna-fm
  Attempting uninstall: rna-fm
    Found existin

In [65]:
!pip install biopython==1.68 --user

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [66]:
!pip install regex --user

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple


In [67]:
cd /home/featurize/work/TargetFM

/home/featurize/work/TargetFM


# 1 Sequence

In [68]:
import sys
import os

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as F

import pandas as pd
from sklearn import preprocessing
from tqdm import tqdm
import numpy as np
import random
import argparse
import json

import fm
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.Alphabet import generic_rna
import regex

In [69]:
""" 确定mRNA上的CTS片段位置 """
def find_candidate(mirna_sequence, mrna_sequence, seed_match):
    positions = set()

    # 确定seed_match方式
    if seed_match == '10-mer-m6':
        SEED_START = 1
        SEED_END = 10
        SEED_OFFSET = SEED_START - 1
        MIN_MATCH = 6
        TOLERANCE = (SEED_END-SEED_START+1) - MIN_MATCH
    elif seed_match == '10-mer-m7':
        SEED_START = 1
        SEED_END = 10
        SEED_OFFSET = SEED_START - 1
        MIN_MATCH = 7
        TOLERANCE = (SEED_END-SEED_START+1) - MIN_MATCH
    elif seed_match == 'offset-9-mer-m7':
        SEED_START = 2
        SEED_END = 10
        SEED_OFFSET = SEED_START - 1
        MIN_MATCH = 7
        TOLERANCE = (SEED_END-SEED_START+1) - MIN_MATCH
    elif seed_match == 'strict':
        positions = find_strict_candidate(mirna_sequence, mrna_sequence)
        return positions

    else:
        raise ValueError("seed_match expected 'strict', '10-mer-m6', '10-mer-m7', or 'offset-9-mer-m7', got '{}'".format(seed_match))
    
    # 确定mirna上对应seed区域
    seed = mirna_sequence[(SEED_START-1):SEED_END]

    # complement()返回序列的转录序列； rc_seed:seed的配对片段
    rc_seed = str(Seq(seed, generic_rna).complement())

    # 在mrna中找可以与seed匹配的片段； re.finditer(pattern, string, flags=0), Use the finditer() function to match a pattern in a string and return an iterator yielding the Match objects.
    match_iter = regex.finditer("({}){{e<={}}}".format(rc_seed, TOLERANCE), mrna_sequence)

    for match_index in match_iter:
        # positions.add(match_index.start()) # slice-start indicies
        positions.add(match_index.end()+SEED_OFFSET) # slice-stop indicies
    
    # CTS片段的第一个token位置
    positions = list(positions)

    return positions

def find_strict_candidate(mirna_sequence, mrna_sequence):
    positions = set()

    SEED_TYPES = ['8-mer', '7-mer-m8', '7-mer-A1', '6-mer', '6-mer-A1', 'offset-7-mer', 'offset-6-mer']
    for seed_match in SEED_TYPES:
        if seed_match == '8-mer':
            SEED_START = 2
            SEED_END = 8
            SEED_OFFSET = 0
            seed = 'U' + mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '7-mer-m8':
            SEED_START = 1
            SEED_END = 8
            SEED_OFFSET = 0
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '7-mer-A1':
            SEED_START = 2
            SEED_END = 7
            SEED_OFFSET = 0
            seed = 'U' + mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '6-mer':
            SEED_START = 2
            SEED_END = 7
            SEED_OFFSET = 1
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '6mer-A1':
            SEED_START = 2
            SEED_END = 6
            SEED_OFFSET = 0
            seed = 'U' + mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == 'offset-7-mer':
            SEED_START = 3
            SEED_END = 9
            SEED_OFFSET = 0
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == 'offset-6-mer':
            SEED_START = 3
            SEED_END = 8
            SEED_OFFSET = 0
            seed = mirna_sequence[(SEED_START-1):SEED_END]

        rc_seed = str(Seq(seed, generic_rna).complement())
        match_iter = regex.finditer(rc_seed, mrna_sequence)

        for match_index in match_iter:
            # positions.add(match_index.start()) # slice-start indicies
            positions.add(match_index.end()+SEED_OFFSET) # slice-stop indicies

    positions = list(positions)

    return positions

""" 确定CTS片段 """
def get_candidate(mirna_sequence, mrna_sequence, cts_size, seed_match):
    positions = find_candidate(mirna_sequence, mrna_sequence, seed_match)  # CTS片段首位置（list）

    candidates = []
    for i in positions:
        site_sequence = mrna_sequence[max(0, i-cts_size):i]      # mrna上CTS片段
        rev_site_sequence = site_sequence[::-1]                  # [::-1]序列翻转：从左到右->从右到左
        rc_site_sequence = str(Seq(rev_site_sequence, generic_rna).complement())   # 转录序列
        candidates.append(rev_site_sequence)                     # miRNAs: 5'-ends to 3'-ends,  mRNAs: 3'-ends to 5'-ends
        #candidates.append(rc_site_sequence)

    return candidates, positions

""" 生成pairs """
def make_pair(mirna_sequence, mrna_sequence, cts_size, seed_match):
    candidates, positions = get_candidate(mirna_sequence, mrna_sequence, cts_size, seed_match) # candidates:list of seq

    mirna_querys = []     # mirna seed region
    mrna_targets = []     # mirna CTS
    if len(candidates) == 0:   # mrna中无mirna匹配片段
        return (mirna_querys, mrna_targets, positions)
    else: 
        mirna_sequence = mirna_sequence[0:cts_size]   # mirna的seed region
        for i in range(len(candidates)):
            mirna_querys.append(mirna_sequence)
            mrna_targets.append(candidates[i])        # 一个mirna的seed对应多个mrna的CTS

    return mirna_querys, mrna_targets, positions



In [70]:
""" 读取mirna,mrna序列 """
def read_fasta(mirna_fasta_file, mrna_fasta_file):
    mirna_list = list(SeqIO.parse(mirna_fasta_file, 'fasta'))
    mrna_list = list(SeqIO.parse(mrna_fasta_file, 'fasta'))

    mirna_ids = []
    mirna_seqs = []
    mrna_ids = []
    mrna_seqs = []

    for i in range(len(mirna_list)):
        mirna_ids.append(str(mirna_list[i].id))
        mirna_seqs.append(str(mirna_list[i].seq))

    for i in range(len(mrna_list)):
        mrna_ids.append(str(mrna_list[i].id))
        mrna_seqs.append(str(mrna_list[i].seq))

    return mirna_ids, mirna_seqs, mrna_ids, mrna_seqs

""" 读取gt(label) """
def read_ground_truth(ground_truth_file, header=True, train=False):
    # input format: [MIRNA_ID, MRNA_ID, LABEL]
    if header is True:
        records = pd.read_csv(ground_truth_file, header=0, sep='\t')
    else:
        records = pd.read_csv(ground_truth_file, header=None, sep='\t')

    query_ids = np.asarray(records.iloc[:, 0].values)
    target_ids = np.asarray(records.iloc[:, 1].values)
    if train is True:
        labels = np.asarray(records.iloc[:, 2].values)
    else:
        labels = np.full((len(records),), fill_value=-1)

    return query_ids, target_ids, labels


""" 核苷酸转整型 """
# AUUCAAU -> 1442114
def nucleotide_to_int(nucleotides, max_len):
    dictionary = {'A':1, 'C':2, 'G':3, 'T':4, 'U':4}

    chars = []
    nucleotides = nucleotides.upper()   # nucleotides小写转大写
    for c in nucleotides:
        chars.append(c)

    ints_enc = np.full((max_len,), fill_value=0) # to post-pad inputs; np.full(shape, fill_value)返回一个给定大小和类型并且以指定数字全部填充的新数组
    for i in range(len(chars)):
        try:
            ints_enc[i] = dictionary[chars[i]]
        except KeyError:
            continue
        except IndexError:
            break

    return ints_enc

""" 序列转整型 """
# 所有序列转整型
def sequence_to_int(sequences, max_len):
    import itertools

    if type(sequences) is list:
        seqs_enc = np.asarray([nucleotide_to_int(seq, max_len) for seq in sequences])
    else:
        seqs_enc = np.asarray([nucleotide_to_int(seq, max_len) for seq in sequences])
        seqs_enc = list(itertools.chain(*seqs_enc))
        seqs_enc = np.asarray(seqs_enc)

    return seqs_enc

""" 统一序列长度，两种pad方式 """
def pad_sequences(sequences, max_len=None, padding='pre', fill_value='O'):
    n_samples = len(sequences)    # 样本数：序列个数； sequences:list of sequence

    lengths = []
    for seq in sequences:
        try:
            lengths.append(len(seq))  # 记录每一个序列的长度
        except TypeError:
            raise ValueError("sequences expected a list of iterables, got {}".format(seq))
    if max_len is None:
        max_len = np.max(lengths)  # 确定最大序列长度

    # input_shape = np.asarray(sequences[0]).shape[1:]   # ???
    # padded_shape = (n_samples, max_len) + input_shape
    # padded = np.full(padded_shape, fill_value=fill_value)
    
    # import pdb; pdb.set_trace()

    for i, seq in enumerate(sequences):
        if padding == 'pre':
            if max_len > len(seq):
                sequences[i] = [fill_value]*(max_len - len(seq)) + sequences[i]
            else:
                sequences[i] = sequences[i][:max_len]
        elif padding == 'post':
            if max_len > len(seq):
                sequences[i] = sequences[i] + [fill_value]*(max_len - len(seq))
            else:
                sequences[i] = sequences[i][:max_len]
        else:
            raise ValueError("padding expected 'pre' or 'post', got {}".format(truncating))

    return sequences


""" 对label进行编码 （samples, classes）"""
def to_categorical(labels, n_classes=None):
    labels = np.array(labels, dtype='int').reshape(-1)   

    n_samples = labels.shape[0]
    if not n_classes:
        n_classes = np.max(labels) + 1

    categorical = np.zeros((n_samples, n_classes))
    categorical[np.arange(n_samples), labels] = 1

    return categorical


""" 对miran,mrna,y进行one-hot编码 """
def preprocess_data(x_query_seqs, x_target_seqs, y=None, cts_size=None, pre_padding=False):
    if cts_size is not None:
        max_len = cts_size
    else:
        max_len = max(len(max(x_query_seqs, key=len)), len(max(x_target_seqs, key=len)))
    
    # 将mirna,mran转为整型
    # x_mirna = sequence_to_int(x_query_seqs, max_len)
    # x_mrna = sequence_to_int(x_target_seqs, max_len)
 
    # padding, max取max(mirna,mrna)
    # if pre_padding:
    x_query_seqs = [list(i) for i in x_query_seqs]
    x_target_seqs = [list(i) for i in x_target_seqs]
    x_mirna = pad_sequences(x_query_seqs, max_len, padding='pre')
    x_mrna = pad_sequences(x_target_seqs, max_len, padding='pre')
    
    # 对mrna,mirna进行one-hot编码
    # x_mirna_embd = one_hot_enc(x_mirna)
    # x_mrna_embd = one_hot_enc(x_mrna)
    if y is not None:
        y_embd = to_categorical(y, np.unique(y).size)  

        return x_mirna, x_mrna, y_embd
    else:
        return x_mirna, x_mrna


In [71]:
""" 构造dataset(字典) """
def make_input_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file, cts_size=30, seed_match='offset-9-mer-m7', header=True, train=True):
    mirna_ids, mirna_seqs, mrna_ids, mrna_seqs = read_fasta(mirna_fasta_file, mrna_fasta_file)           # mirna_ids, mirna_seqs, mrna_ids, mrna_seqs
    query_ids, target_ids, labels = read_ground_truth(ground_truth_file, header=header, train=train)     # query_ids, target_ids, labels

    dataset = {
        'mirna_fasta_file': mirna_fasta_file,
        'mrna_fasta_file': mrna_fasta_file,
        'ground_truth_file': ground_truth_file,
        'query_ids': [],
        'query_seqs': [],
        'target_ids': [],
        'target_seqs': [],
        'target_locs': [],
        'labels': []
    }

    for i in range(len(query_ids)):
        try:
            j = mirna_ids.index(query_ids[i])    # j:mirna index
        except ValueError:       
            continue
        try:
            k = mrna_ids.index(target_ids[i])    # k:mrna index
        except ValueError:
            continue

        query_seqs, target_seqs, locations = make_pair(mirna_seqs[j], mrna_seqs[k], cts_size=cts_size, seed_match=seed_match)

        n_pairs = len(locations)    # 产生的 mirna-mrna匹配对数
        if n_pairs > 0:
            queries = [query_ids[i] for n in range(n_pairs)]
            dataset['query_ids'].extend(queries)
            dataset['query_seqs'].extend(query_seqs)

            targets = [target_ids[i] for n in range(n_pairs)]
            dataset['target_ids'].extend(targets)
            dataset['target_seqs'].extend(target_seqs)
            dataset['target_locs'].extend(locations)

            dataset['labels'].extend([[labels[i]] for p in range(n_pairs)])

    return dataset

""" 无labels """
def make_brute_force_pair(mirna_fasta_file, mrna_fasta_file, cts_size=30, seed_match='offset-9-mer-m7'):
    mirna_ids, mirna_seqs, mrna_ids, mrna_seqs = read_fasta(mirna_fasta_file, mrna_fasta_file)

    dataset = {
        'query_ids': [],
        'query_seqs': [],
        'target_ids': [],
        'target_seqs': [],
        'target_locs': []
    }

    for i in range(len(mirna_ids)):
        for j in range(len(mrna_ids)):
            query_seqs, target_seqs, positions = make_pair(mirna_seqs[i], mrna_seqs[j], cts_size, seed_match)

            n_pairs = len(positions)
            if n_pairs > 0:
                query_ids = [mirna_ids[i] for k in range(n_pairs)]
                dataset['query_ids'].extend(query_ids)
                dataset['query_seqs'].extend(query_seqs)

                target_ids = [mrna_ids[j] for k in range(n_pairs)]
                dataset['target_ids'].extend(target_ids)
                dataset['target_seqs'].extend(target_seqs)
                dataset['target_locs'].extend(positions)

    return dataset

In [72]:
""" 生成负样本 """
def get_negative_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file=None, cts_size=30, seed_match='offset-9-mer-m7', header=False, predict_mode=True):
    mirna_ids, mirna_seqs, mrna_ids, mrna_seqs = read_fasta(mirna_fasta_file, mrna_fasta_file)   # 读取mrna,mirna文件
  
    dataset = {
        'query_ids': [],
        'target_ids': [],
        'predicts': []
    }

    if ground_truth_file is not None:
        query_ids, target_ids, labels = read_ground_truth(ground_truth_file, header=header)

        for i in range(len(query_ids)):
            try:
                j = mirna_ids.index(query_ids[i])
            except ValueError:
                continue
            try:
                k = mrna_ids.index(target_ids[i])
            except ValueError:
                continue

            query_seqs, target_seqs, locations = make_pair(mirna_seqs[j], mrna_seqs[k], cts_size=cts_size, seed_match=seed_match)

            n_pairs = len(locations)
            if (n_pairs == 0) and (predict_mode is True):
                dataset['query_ids'].append(query_ids[i])
                dataset['target_ids'].append(target_ids[i])
                dataset['predicts'].append(0)
            elif (n_pairs == 0) and (predict_mode is False):
                dataset['query_ids'].append(query_ids[i])
                dataset['target_ids'].append(target_ids[i])
                dataset['predicts'].append(labels[i])
    else:
        for i in range(len(mirna_ids)):
            for j in range(len(mrna_ids)):
                query_seqs, target_seqs, locations = make_pair(mirna_seqs[i], mrna_seqs[j], cts_size=cts_size, seed_match=seed_match)

                n_pairs = len(locations)
                if n_pairs == 0:
                    dataset['query_ids'].append(mirna_ids[i])
                    dataset['target_ids'].append(mrna_ids[j])
                    dataset['predicts'].append(0)

    dataset['target_locs'] = [-1 for i in range(len(dataset['query_ids']))]
    dataset['probabilities'] = [0.0 for i in range(len(dataset['query_ids']))]

    return dataset


""" 结果统计 """
def postprocess_result(dataset, probabilities, predicts, predict_mode=True, output_file=None, cts_size=30, seed_match='offset-9-mer-m7', level='site'):
    neg_pairs = get_negative_pair(dataset['mirna_fasta_file'], dataset['mrna_fasta_file'], dataset['ground_truth_file'], cts_size=cts_size, seed_match=seed_match, predict_mode=predict_mode)   # 负样本对

    # dataset:正样本集  neg_pair:负样本集
    query_ids = np.append(dataset['query_ids'], neg_pairs['query_ids'])         
    target_ids = np.append(dataset['target_ids'], neg_pairs['target_ids'])
    target_locs = np.append(dataset['target_locs'], neg_pairs['target_locs'])
    probabilities = np.append(probabilities, neg_pairs['probabilities'])        # probabilities：正样本训练经模型得到的prob  neg_pairs['probabilities]：构造负样本时设定好的prob=0.0
    predicts = np.append(predicts, neg_pairs['predicts'])                       # predicts:正样本训练经模型得到的预测   neg_pairs['predicts']：构造负样本时设定好的prob=-1

    # output format: [QUERY, TARGET, LOCATION, PROBABILITY]
    records = pd.DataFrame(columns=['MIRNA_ID', 'MRNA_ID', 'LOCATION', 'PROBABILITY'])
    records['MIRNA_ID'] = query_ids
    records['MRNA_ID'] = target_ids
    records['LOCATION'] = np.array(["{},{}".format(max(1, l-cts_size+1), l) if l != -1 else "-1,-1" for l in target_locs])
    records['PROBABILITY'] = probabilities
    if predict_mode is True:                  # 是否在预测
        records['PREDICT'] = predicts
    else:
        records['LABEL'] = predicts

    # site level
    records = records.sort_values(by=['PROBABILITY', 'MIRNA_ID', 'MRNA_ID'], ascending=[False, True, True])  # sort_values()函数原理类似于SQL中的order by，将数据集依照某个字段中的数据进行排序； ascending	是否按指定列的数组升序排列，默认为True，即升序排列
    # gene level
    unique_records = records.sort_values(by=['PROBABILITY', 'MIRNA_ID', 'MRNA_ID'], ascending=[False, True, True]).drop_duplicates(subset=['MIRNA_ID', 'MRNA_ID'], keep='first')

    if level == 'site':
        if output_file is not None:
            records.to_csv(output_file, index=False, sep='\t') 
        return records

    elif level == 'gene':
        if output_file is not None:
            unique_records.to_csv(output_file, index=False, sep='\t')
        return unique_records

    else:
        raise ValueError("level expected 'site' or 'gene', got '{}'".format(mode))

# 2 Datasets

In [73]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, mirna_fasta_file, mrna_fasta_file, ground_truth_file, cts_size=30, seed_match='offset-9-mer-m7', header=True, train=True):
        self.dataset = make_input_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file, cts_size=cts_size, seed_match=seed_match, header=header, train=train)  # 'query_ids': [], 'query_seqs': [], 'target_ids': [], 'target_seqs': [], 'target_locs': [], 'labels': []
        self.mirna, self.mrna = preprocess_data(self.dataset['query_seqs'], self.dataset['target_seqs'])   # x_mirna_embd, x_mrna_embd, y_embd
        self.labels = np.asarray(self.dataset['labels']).reshape(-1,)
        
        # self.mirna = self.mirna.transpose((0, 2, 1))
        # self.mrna = self.mrna.transpose((0, 2, 1))
        
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
            
        mirna = self.mirna[index]
        batch_labels, batch_strs, batch_tokens = batch_converter([('1', ''.join(mirna))])
        mirna = batch_tokens[0][1:-1]
        mrna = self.mrna[index]
        batch_labels, batch_strs, batch_tokens = batch_converter([('1', ''.join(mrna))])
        mrna = batch_tokens[0][1:-1]
        label = self.labels[index]
        
        return (mirna, mrna), label
        
    def __len__(self):
        return len(self.labels)


class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, ground_truth_file, cts_size=30):
        self.records = pd.read_csv(ground_truth_file, header=0, sep='\t')
        mirna_seqs = self.records['MIRNA_SEQ'].values.tolist()
        mrna_seqs = self.records['MRNA_SEQ'].values.tolist()
        self.mirna, self.mrna = preprocess_data(mirna_seqs, mrna_seqs, cts_size=cts_size)
        self.labels = self.records['LABEL'].values.astype(int)
        
        # self.mirna = self.mirna.transpose((0, 2, 1))
        # self.mrna = self.mrna.transpose((0, 2, 1))
        # batch_labels, batch_strs, batch_tokens = batch_converter(data)
    
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        
        mirna = self.mirna[index]
        mrna = self.mrna[index]
        batch_labels, batch_strs, batch_tokens = batch_converter([('1', ''.join(mirna))])
        mirna = batch_tokens[0][1:-1]
        batch_labels, batch_strs, batch_tokens = batch_converter([('1', ''.join(mrna))])
        mrna = batch_tokens[0][1:-1]
        label = self.labels[index]
        
        return (mirna, mrna), label
        
    def __len__(self):
        return len(self.labels)

# 3 Model

## RNA-FM Backbone

In [74]:
mkdir -p /home/featurize/.cache/torch/hub/checkpoints/

In [75]:
cp /home/featurize/work/TargetFM/RNA-FM_pretrained.pth /home/featurize/.cache/torch/hub/checkpoints/

In [76]:
import torch
import fm

# Load RNA-FM model
fm_model, alphabet = fm.pretrained.rna_fm_t12()
batch_converter = alphabet.get_batch_converter()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   
fm_model = fm_model.to(device)
fm_model.eval()  # disables dropout for deterministic results

RNABertModel(
  (embed_tokens): Embedding(25, 640, padding_idx=1)
  (layers): ModuleList(
    (0): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=640, out_features=640, bias=True)
        (v_proj): Linear(in_features=640, out_features=640, bias=True)
        (q_proj): Linear(in_features=640, out_features=640, bias=True)
        (out_proj): Linear(in_features=640, out_features=640, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=640, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=640, bias=True)
      (final_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
    )
    (1): TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=640, out_features=640, bias=True)
        (v_proj): Linear(in_features=640, out_features=640, bias=True)
        (q_proj): Linear(in_features=6

In [77]:
def get_embed(batch_tokens):
    # batch_tokens =batch_tokens.squeeze(0)
    with torch.no_grad():
        batch_tokens = batch_tokens.to(device)
        results = fm_model(batch_tokens, repr_layers=[12])
    token_embeddings = results["representations"][12]
    return token_embeddings

In [78]:
""" 网络超参设置, filters/kernel """
class HyperParam:
    def __init__(self, filters=None, kernels=None, model_json=None):
        self.dictionary = dict()
        self.name_postfix = str()

        if (filters is not None) and (kernels is not None) and (model_json is None):
            for i, (f, k) in enumerate(zip(filters, kernels)):     # get the elements of multiple lists and indexes https://note.nkmk.me/en/python-for-enumerate-zip/
                setattr(self, 'f{}'.format(i+1), f)
                setattr(self, 'k{}'.format(i+1), k)
                self.dictionary.update({'f{}'.format(i+1): f, 'k{}'.format(i+1): k})
            self.len = i+1
                
            for key, value in self.dictionary.items():
                self.name_postfix = "{}_{}-{}".format(self.name_postfix, key, value)
        elif model_json is not None:
            self.dictionary = json.loads(model_json
                                        )
            for i, (key, value) in enumerate(self.dictionary.items()):
                setattr(self, key, value)
                self.name_postfix = "{}_{}-{}".format(self.name_postfix, key, value)
            self.len = (i+1)//2
    
    def __len__(self):
        return self.len


In [79]:
class TargetFM1(nn.Module):
    def __init__(self, hparams=None, hidden_units=30, input_shape=(2, 30), name_prefix="model"):
        super(TargetFM1, self).__init__()

        if hparams is None:
            filters, kernels = [32, 16, 64, 16], [3, 3, 3, 3]
            hparams = HyperParam(filters, kernels)
        self.name = "{}{}".format(name_prefix, hparams.name_postfix)
        self.fc_mi = nn.Linear(640, 32)
        self.fc_mr = nn.Linear(640, 32)
        
        if (isinstance(hparams, HyperParam)) and (len(hparams) == 4):
            self.embd1 = nn.Conv1d(4, hparams.f1, kernel_size=hparams.k1, padding=((hparams.k1 - 1) // 2))
            
            self.conv2 = nn.Conv1d(hparams.f1*2, hparams.f2, kernel_size=hparams.k2)
            self.conv3 = nn.Conv1d(hparams.f2, hparams.f3, kernel_size=hparams.k3)
            self.conv4 = nn.Conv1d(hparams.f3, hparams.f4, kernel_size=hparams.k4)
            
            """ out_features = ((in_length - kernel_size + (2 * padding)) / stride + 1) * out_channels """
            # flat_features = self.forward(torch.randint(1, 5, input_shape).to(device), torch.randint(1, 5, input_shape).to(device), flat_check=False)
            self.fc1 = nn.Linear(384, hidden_units)
            self.fc2 = nn.Linear(hidden_units, 2)
        else:
            raise ValueError("not enough hyperparameters")
    
    def forward(self, x_mirna, x_mrna, flat_check=False):
        mi_out = get_embed(x_mirna)
        mi_out = self.fc_mi(mi_out).transpose(1,2)
        mr_out = get_embed(x_mrna)
        mr_out = self.fc_mr(mr_out).transpose(1,2)
        # import pdb;pdb.set_trace()
        h_mirna = F.relu(mi_out)                   # torch.Size([32, 32, 30])
        # print(h_mirna.shape)
        h_mrna = F.relu(mr_out)                    # torch.Size([32, 32, 30])
        # print(h_mrna.shape)
        h = torch.cat((h_mirna, h_mrna), dim=1)    # torch.Size([32, 64, 30])
        # print(h.shape)
        h = F.relu(self.conv2(h))                  # torch.Size([32, 16, 28])
        # print(h.shape)
        h = F.relu(self.conv3(h))                  # torch.Size([32, 64, 26])
        # print(h.shape)
        h = F.relu(self.conv4(h))                  # torch.Size([32, 16, 24])
        # print(h.shape)
        h = h.view(h.size(0), -1)                  # torch.Size([32, 384])
        # print(h.shape)
        if flat_check:
            return h.size(1)
        h = self.fc1(h)                            # torch.Size([32, 30])
        # print(h.shape)
        # y = F.softmax(self.fc2(h), dim=1)
        y = self.fc2(h)                            # torch.Size([32, 2])
        # print(y.shape)
        
        return y

    def size(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# 4 Train & Inference

## 4.1 Train

In [80]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
from tqdm.auto import tqdm
bar_format = '{desc} |{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]'
from sklearn.utils.class_weight import compute_class_weight
from datetime import datetime

from torch.utils.data import DataLoader

In [81]:
def train_model(mirna_fasta_file, mrna_fasta_file, train_file, model=None, cts_size=30, seed_match='offset-9-mer-m7', level='gene', batch_size=32, epochs=10, save_file=None, device='cpu'):
    """
    if not isinstance(model, deepTarget):
        raise ValueError("'model' expected <nn.Module 'deepTarget'>, got {}".format(type(model)))
    
    print("\n[TRAIN] {}".format(model.name))
    """
    
    if train_file.split('/')[-1] == 'train_set.csv':
        train_set = TrainDataset(train_file)
    else:
        # 实例化
        train_set = Dataset(mirna_fasta_file, mrna_fasta_file, train_file, seed_match=seed_match, header=True, train=True)  # return (mirna, mrna), label
    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
    
    class_weight = torch.Tensor(compute_class_weight('balanced', classes=np.unique(train_set.labels), y=train_set.labels)).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weight)
    optimizer = optim.Adam(model.parameters())
    
    model = model.to(device)
    for epoch in range(epochs):
        epoch_loss, corrects = 0, 0

        with tqdm(train_loader, desc="Epoch {}/{}".format(epoch+1, epochs), bar_format=bar_format) as tqdm_loader:
            for i, ((mirna, mrna), label) in enumerate(tqdm_loader):
                
                mirna, mrna, label = mirna.to(device, dtype=torch.int64), mrna.to(device, dtype=torch.int64), label.to(device)
                
                outputs = model(mirna, mrna)
                loss = criterion(outputs, label)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item() * outputs.size(0)
                corrects += (torch.max(outputs, 1)[1] == label).sum().item()
                
                if (i+1) == len(train_loader):
                    tqdm_loader.set_postfix(dict(loss=(epoch_loss/len(train_set)), acc=(corrects/len(train_set))))
                else:
                    tqdm_loader.set_postfix(loss=loss.item())
    
    if save_file is None:
        time = datetime.now()
        save_file = "{}.pt".format(time.strftime('%Y%m%d_%H%M%S_weights'))
    torch.save(model.state_dict(), save_file)

In [25]:
start_time = datetime.now()
print("\n[START] {}".format(start_time.strftime('%Y-%m-%d @ %H:%M:%S')))
        
mirna_fasta_file = 'Data/mirna.fasta'
mrna_fasta_file  = 'Data/mrna.fasta'
seed_match       = 'offset-9-mer-m7'
level            = 'gene'
train_file       = 'Data/Train/train_set.csv'
weight_file      = 'weights.pt'
batch_size       = 32
epochs           = 10
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')        

model1 = TargetFM1()
# model2 = FMTarget()

torch.save(model1.state_dict(), weight_file)
train_model(mirna_fasta_file, mrna_fasta_file, train_file,
             model=model1,
             seed_match=seed_match, level=level,
             batch_size=batch_size, epochs=epochs,
             save_file=weight_file, device=device)
        
finish_time = datetime.now()
print("\n[FINISH] {} (user time: {})\n".format(finish_time.now().strftime('%Y-%m-%d @ %H:%M:%S'), (finish_time - start_time)))


[START] 2022-09-16 @ 01:16:12


Epoch 1/10 |          | 0/2045 [00:00<?]

Epoch 2/10 |          | 0/2045 [00:00<?]

Epoch 3/10 |          | 0/2045 [00:00<?]

Epoch 4/10 |          | 0/2045 [00:00<?]

Epoch 5/10 |          | 0/2045 [00:00<?]

Epoch 6/10 |          | 0/2045 [00:00<?]

Epoch 7/10 |          | 0/2045 [00:00<?]

Epoch 8/10 |          | 0/2045 [00:00<?]

Epoch 9/10 |          | 0/2045 [00:00<?]

Epoch 10/10 |          | 0/2045 [00:00<?]


[FINISH] 2022-09-16 @ 01:32:02 (user time: 0:15:49.759889)



## 4.2 Inference

In [82]:
def predict_result(mirna_fasta_file, mrna_fasta_file, query_file, model=None, weight_file=None, seed_match='offset-9-mer-m7', level='gene', batch_size=32, output_file=None, device='cpu'):
    """
    if not isinstance(model, deepTarget):
        raise ValueError("'model' expected <nn.Module 'deepTarget'>, got {}".format(type(model)))
    """
    
    if not weight_file.endswith('.pt'):
        raise ValueError("'weight_file' expected '*.pt', got {}".format(weight_file))
    
    model.load_state_dict(torch.load(weight_file))
    
    test_set = Dataset(mirna_fasta_file, mrna_fasta_file, query_file, seed_match=seed_match, header=True, train=True)    
    test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)
    
    y_probs = []
    y_predicts = []
    y_truth = []

    model = model.to(device)
    with torch.no_grad():
        model.eval()
        
        with tqdm(test_loader, bar_format=bar_format) as tqdm_loader:
            for i, ((mirna, mrna), label) in enumerate(tqdm_loader):
                mirna, mrna, label = mirna.to(device, dtype=torch.int64), mrna.to(device, dtype=torch.int64), label.to(device)
                
                outputs = model(mirna, mrna)
                _, predicts = torch.max(outputs.data, 1)
                probabilities = F.softmax(outputs, dim=1)
                
                y_probs.extend(probabilities.cpu().numpy()[:, 1])
                y_predicts.extend(predicts.cpu().numpy())
                y_truth.extend(label.cpu().numpy())

                global correct
                # print(predicts.cpu().numpy())
                # print(label.cpu().numpy())
                correct += (predicts == label).sum().item()
                
        acc = float(correct / len(test_set)) * 100
        print(len(test_set))
        print("acc:", acc, "%")

        if output_file is None:
            time = datetime.now()
            output_file = "{}.csv".format(time.strftime('%Y%m%d_%H%M%S_results'))
        results = postprocess_result(test_set.dataset, y_probs, y_predicts,
                                     seed_match=seed_match, level=level, output_file=output_file)
        
        # print(results)
    return acc

In [83]:
start_time = datetime.now()
print("\n[START] {}".format(start_time.strftime('%Y-%m-%d @ %H:%M:%S')))

mirna_fasta_file = 'Data/mirna.fasta'
mrna_fasta_file  = 'Data/mrna.fasta'
query_file       = 'Data/Test/test_split_'
model = TargetFM1()
weight_file      = 'weights.pt'
seed_match = 'offset-9-mer-m7'
level = 'gene'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  

sum = 0

for i in range(10):
    correct = 0
    query_file1 = query_file + str(i) + '.csv'
    acc = predict_result(mirna_fasta_file, mrna_fasta_file, query_file1,
                             model=model, weight_file=weight_file,
                             seed_match=seed_match, level=level,
                             output_file=None, device=device)
    sum += acc

print("avg = ", sum/10.0, "%")
finish_time = datetime.now()
print("\n[FINISH] {} (user time: {})\n".format(finish_time.now().strftime('%Y-%m-%d @ %H:%M:%S'), (finish_time - start_time)))


[START] 2022-09-16 @ 01:59:03


 |          | 0/156 [00:00<?]

4962
acc: 79.32285368802901 %


 |          | 0/171 [00:00<?]

5457
acc: 84.91845336265347 %


 |          | 0/162 [00:00<?]

5153
acc: 82.8449446924122 %


 |          | 0/155 [00:00<?]

4956
acc: 81.79983857949959 %


 |          | 0/160 [00:00<?]

5101
acc: 81.96432072142717 %


 |          | 0/169 [00:00<?]

5388
acc: 79.28730512249443 %


 |          | 0/160 [00:00<?]

5102
acc: 79.08663269306155 %


 |          | 0/162 [00:00<?]

5167
acc: 82.11728275595122 %


 |          | 0/156 [00:00<?]

4986
acc: 82.49097472924187 %


 |          | 0/162 [00:00<?]

5172
acc: 83.33333333333334 %
avg =  81.71659396781038 %

[FINISH] 2022-09-16 @ 02:01:05 (user time: 0:02:01.979552)



# Py3.9


In [26]:
def metric(TP, FP, TN, FN):
    accuracy = (TP+TN) / (TP+TN+FP+FN)
    sensitivity = TP / (TP+FN)
    specificity = TN / (TN+FP)
    f_measure = 2*TP / (2*TP+FP+FN)
    PPV = TP / (TP+FP)
    NPV = TN / (TN+FN)

    print("accuracy:%.2f" %(accuracy))
    print("sensitivity:%.2f" %(sensitivity))
    print("specifity:%.2f" %(specificity))
    print("F-measure:%.2f" %(F-measure))
    print("PPV:%.2f" %(PPV))
    print("NPV:%.2f" %(NPV))

In [30]:
TP=0
FP=0
TN=0
FN=0

In [42]:
def predict_result2(mirna_fasta_file, mrna_fasta_file, query_file, model=None, weight_file=None, seed_match='offset-9-mer-m7', level='gene', batch_size=32, output_file=None, device='cpu'):
    """
    if not isinstance(model, deepTarget):
        raise ValueError("'model' expected <nn.Module 'deepTarget'>, got {}".format(type(model)))
    """
    
    if not weight_file.endswith('.pt'):
        raise ValueError("'weight_file' expected '*.pt', got {}".format(weight_file))
    
    global TP
    global FP
    global TN
    global FN
    global correct
    
    model.load_state_dict(torch.load(weight_file))
    
    test_set = Dataset(mirna_fasta_file, mrna_fasta_file, query_file, seed_match=seed_match, header=True, train=False)    
    test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)
    
    y_probs = []
    y_predicts = []
    y_truth = []

    model = model.to(device)
    with torch.no_grad():
        model.eval()
        
        with tqdm(test_loader, bar_format=bar_format) as tqdm_loader:
            for i, ((mirna, mrna), label) in enumerate(tqdm_loader):
                mirna, mrna, label = mirna.to(device, dtype=torch.int64), mrna.to(device, dtype=torch.int64), label.to(device)
                
                outputs = model(mirna, mrna)
                _, predicts = torch.max(outputs.data, 1)
                probabilities = F.softmax(outputs, dim=1)
                
                y_probs.extend(probabilities.cpu().numpy()[:, 1])
                y_predicts.extend(predicts.cpu().numpy())
                y_truth.extend(label.cpu().numpy())
                
                correct += (predicts == label).sum().item()

                TP += (label==1 and predicts==1).sum().item()
                FP += (label==1 and predicts==0).sum().item()
                TN += (label==0 and predicts==0).sum().item()
                FN += (label==0 and predicts==1).sum().item()
                print(TP,FP,TN,FN)
                '''
                if(y_truth==1 and y_predicts==1):    TP+=1
                elif(y_truth==1 and y_predicts==0):  FP+=1
                elif(y_truth==0 and y_predicts==0):  TN+=1
                elif(y_truth==0 and y_predicts==1):  FN+=1
                '''
                
            metric(TP, FP, TN, FN)
        
        if output_file is None:
            time = datetime.now()
            output_file = "{}.csv".format(time.strftime('%Y%m%d_%H%M%S_results'))
        results = postprocess_result(test_set.dataset, y_probs, y_predicts,
                                     seed_match=seed_match, level=level, output_file=output_file)
        
        print(results)

In [43]:
start_time = datetime.now()
print("\n[START] {}".format(start_time.strftime('%Y-%m-%d @ %H:%M:%S')))

mirna_fasta_file = 'Data/mirna.fasta'
mrna_fasta_file  = 'Data/mrna.fasta'
query_file       = 'Data/Test/test_split_0.csv'
model = deepTarget()
weight_file      = 'weights.pt'
seed_match = 'offset-9-mer-m7'
level = 'gene'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  

results = predict_result2(mirna_fasta_file, mrna_fasta_file, query_file,
                         model=model, weight_file=weight_file,
                         seed_match=seed_match, level=level,
                         output_file=None, device=device)
        
finish_time = datetime.now()
print("\n[FINISH] {} (user time: {})\n".format(finish_time.now().strftime('%Y-%m-%d @ %H:%M:%S'), (finish_time - start_time)))


[START] 2022-09-14 @ 12:02:33


 |          | 0/156 [00:00<?]

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [47]:
class FMTarget1(nn.Module):
    def __init__(self, hparams=None, hidden_units=30, input_shape=(2, 30), name_prefix="model"):
        super(FMTarget1, self).__init__()

        if hparams is None:
            filters, kernels = [32, 16, 64, 16], [3, 3, 3, 3]
            hparams = HyperParam(filters, kernels)
        self.name = "{}{}".format(name_prefix, hparams.name_postfix)
        self.fc_mi = nn.Linear(640, 32)
        self.fc_mr = nn.Linear(640, 32)
        
        if (isinstance(hparams, HyperParam)) and (len(hparams) == 4):
            self.embd1 = nn.Conv1d(4, hparams.f1, kernel_size=hparams.k1, padding=((hparams.k1 - 1) // 2))
            
            self.conv2 = nn.Conv1d(hparams.f1*2, hparams.f2, kernel_size=hparams.k2)
            self.conv3 = nn.Conv1d(hparams.f2, hparams.f3, kernel_size=hparams.k3)
            self.conv4 = nn.Conv1d(hparams.f3, hparams.f4, kernel_size=hparams.k4)
            
            """ out_features = ((in_length - kernel_size + (2 * padding)) / stride + 1) * out_channels """
            # flat_features = self.forward(torch.randint(1, 5, input_shape).to(device), torch.randint(1, 5, input_shape).to(device), flat_check=False)
            self.fc1 = nn.Linear(1920, hidden_units)
            self.fc2 = nn.Linear(hidden_units, 2)
        else:
            raise ValueError("not enough hyperparameters")
    
    def forward(self, x_mirna, x_mrna, flat_check=False):
        mi_out = get_embed(x_mirna)
        mi_out = self.fc_mi(mi_out).transpose(1,2)
        mr_out = get_embed(x_mrna)
        mr_out = self.fc_mr(mr_out).transpose(1,2)
        # import pdb;pdb.set_trace()
        h_mirna = F.relu(mi_out)                   # torch.Size([32, 32, 30])
        # print(h_mirna.shape)
        h_mrna = F.relu(mr_out)                    # torch.Size([32, 32, 30])
        # print(h_mrna.shape)
        h = torch.cat((h_mirna, h_mrna), dim=1)    # torch.Size([32, 64, 30])
        # print(h.shape)

        h = h.view(h.size(0), -1)                  # torch.Size([32, 1920])
        # print(h.shape)
        if flat_check:
            return h.size(1)
        h = self.fc1(h)                            # torch.Size([32, 30])
        # print(h.shape)
        # y = F.softmax(self.fc2(h), dim=1)
        y = self.fc2(h)                            # torch.Size([32, 2])
        # print(y.shape)
        
        return y

    def size(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [48]:
class FMTarget(nn.Module):
    def __init__(self, hparams=None, hidden_units=30, input_shape=(2, 30), name_prefix="model"):
        super(FMTarget, self).__init__()

        self.fc_mi = nn.Linear(640, 32)
        self.fc_mr = nn.Linear(640, 32)
        """ out_features = ((in_length - kernel_size + (2 * padding)) / stride + 1) * out_channels """
        # flat_features = self.forward(torch.randint(1, 5, input_shape).to(device), torch.randint(1, 5, input_shape).to(device), flat_check=False)
        self.fc1 = nn.Linear(1920, hidden_units)
        self.fc2 = nn.Linear(hidden_units, 2)

    
    def forward(self, x_mirna, x_mrna, flat_check=False):
        mi_out = get_embed(x_mirna)
        mi_out = self.fc_mi(mi_out).transpose(1,2)
        mr_out = get_embed(x_mrna)
        mr_out = self.fc_mr(mr_out).transpose(1,2)
        # import pdb;pdb.set_trace()
        h_mirna = F.relu(mi_out)                   # torch.Size([32, 32, 30])
        # print(h_mirna.shape)
        h_mrna = F.relu(mr_out)                    # torch.Size([32, 32, 30])
        # print(h_mrna.shape)
        h = torch.cat((h_mirna, h_mrna), dim=1)    # torch.Size([32, 64, 30])
        # print(h.shape)

        h = h.view(h.size(0), -1)                  # torch.Size([32, 1920])
        # print(h.shape)
        if flat_check:
            return h.size(1)
        h = self.fc1(h)                            # torch.Size([32, 30])
        # print(h.shape)
        # y = F.softmax(self.fc2(h), dim=1)
        y = self.fc2(h)                            # torch.Size([32, 2])
        # print(y.shape)
        
        return y

    def size(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [None]:
  """
  1. ints_enc[:, 1:5]
  2. site_sequence[::-1]  
  3. reshape(-1)拉成一行； 
  4. reshape(-1, 1)拉成1列 reshape(-1, 2)拉成2列
  5. np.arange(3): array[0,1,2]
  6. categorical[np.arange(n_samples), labels] = 1 :在categorical的前n_samples行按照labels进行索引
  7. np.unique(y)：去除数组中的重复数字，并进行排序之后输出
  8. records.sort_values(by=['PROBABILITY', 'MIRNA_ID', 'MRNA_ID'], ascending=[False, True, True])  # sort_values()函数原理类似于SQL中的order by，将数据集依照某个字段中的数据进行排序； ascending	是否按指定列的数组升序排列，默认为True，即升序排列
  9. records.to_csv(output_file, index=False, sep='\t')  # dt.to_csv('C:/Users/think/Desktop/Result.csv',sep='?')#使用?分隔需要保存的数据，如果不写，默认是,
  10. zip() 函数用于将可迭代的对象作为参数，将对象中对应的元素打包成一个个元组，然后返回由这些元组组成的列表
  11. for i, (f, k) in enumerate(zip(filters, kernels)):     # get the elements of multiple lists and indexes https://note.nkmk.me/en/python-for-enumerate-zip/
  12. getattr(a, 'bar') => 1 # 获取属性 bar 值    /  setattr(a, 'bar', 5) a.bar => 5 # 设置属性 bar 值
  """

In [None]:
!python --version

Python 3.7.13


In [None]:
#install python 3.9
!sudo apt-get update -y
!sudo apt-get install python3.9

#change alternatives
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.7 1
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2

#check python version
!python --version
#3.9.6

In [None]:
!sudo update-alternatives --config python3

In [None]:
!sudo apt-get install python3.9-distutils && wget https://bootstrap.pypa.io/get-pip.py && python get-pip.py

In [None]:
!pip install --upgrade pip

In [None]:
!python -m pip install biopython

In [None]:
!pip install rna-fm