# Deep learning model to identify target mRNA of microRNA sequences

From a review of currently developed Deep learning models in genomics, I realized there is still a long way to go before the full stock of machine learning techniques will be applied to genomics at its maximum potential. In particular, I got fascinated by the deepTarget model that has recently been proposed as a way to identify microRNA targets with 96% accuracy - therefore I decided to focus my project in this field. The repository of the mentioned work is available on GitHib at the following link: https://github.com/ailab-seoultech/deepTarget.

MicroRNAs (miRNAs), which are small non-coding RNA molecules that consist of about 22 nucleotides, are known to regulate more than 60% of protein coding genes of humans and other mammals at the RNA level. As miRNAs control the function of their target messenger RNAs (mRNAs) by regulating the expression of the targets, investigating miRNAs is important to understand various biological processes, including diseases. To predict targets of given miRNAs, numerous computational tools have been proposed. 

## Problem Statement
Two types of computational problems about miRNAs thus naturally arise in bioinformatics: miRNA host identification (i.e., the problem of locating the genes that encode pre-miRNAs) and miRNA target prediction (i.e., the task of finding the mRNA targets of given miRNAs). This project focuses on the target prediction problem, aiming at deploying a model that is able to recognize whether or not a sequence of microRNA belongs to a given target mRNA. The output will be a binary classification of miRNA-mRNA pairs that match/ don’t match (i.e. mRNA is a target of the given miRNA).


## Dataset
I will utilize experimental negative data collected from a 2020 paper ‘Deep Learning-Based microRNA Target Prediction Using Experimental Negative Data’ that used the same dataset to train a deepTarget model.

The dataset is available here: https://github.com/ailab-seoultech/deepTarget 

The dataset is already split into ‘train’ and ‘test’ with 65,425 and 10,930 data points respectively. The train set has a list of labelled pairs of miRNA and mRNA sequences structured as follows:
- MIRNA_ID: identifier of the miRNA sequence (note this is not a unique identifier of the miRNA-mRNA pair because a miRNA can be listed multiple times, paired with multiple mRNA targets)
- ENSEMLBE_GENE and GENE_SYMBOL:  identifiers of the gene encoding such miRNA and mRNA sequences. This enables for gene-level predictions of mRNA-miRNA pairs
- MIRNA_SEQ: list of strings with the sequence of basis of the miRNA (e.g. ‘UGAGGU…’)
- MRNA_SEQ: list of strings with the sequence of basis of the mRNA (e.g. ‘GCCTGC…’)
- LABEL: 0 or 1 if the mRNA is a target of the miRNA (1) or it is not (0)

The dataset also includes 2 'fatsa' files, listing all possible MIRNA and MRNA IDs and their relative nucleotide sequences.

The dataset has been generated using public data obtained from several datasets. It has been balanced by the creators by randomly sampling positive and negative pairs3.

## Setup
Let's first install the right version of the softwares to use and import the needed libraries 

In [3]:
!pip install biopython
!pip install torch
!pip install regex
!pip install nomkl

[31mERROR: Could not find a version that satisfies the requirement nomkl[0m
[31mERROR: No matching distribution found for nomkl[0m


In [4]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from tqdm.auto import tqdm
bar_format = '{desc} |{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]'
import sklearn
from sklearn.utils.class_weight import compute_class_weight
from datetime import datetime

import argparse
import regex
import Bio
from Bio import SeqIO
from Bio.Seq import Seq

## Exploration
In the same location of the Jupyter notebook, I provided a 'data' folder downloaded from the GitHub repository of the paper. It includes 'fasta' files, a training set and 10 test sets
### Fatsa files
.fasta files include IDs and Sequences of all MRNAs and MIRNAs

In [5]:
for i in range (5):
    print(list(SeqIO.parse("data/mirna.fasta", 'fasta'))[i].id)
    print(list(SeqIO.parse("data/mirna.fasta", 'fasta'))[i].seq)

hsa-miR-4777-5p
UUCUAGAUGAGAGAUAUAUAUA
hsa-miR-3908
GAGCAAUGUAGGUAGACUGUUU
hsa-miR-96-3p
AAUCAUGUGCAGUGCCAAUAUG
hsa-miR-3144-5p
AGGGGACCAAAGAGAUAUAUAG
hsa-miR-6509-3p
UUCCACUGCCACUACCUAAUUU


### Training set
train data include pairs of MRNAs and MIRNAs labelled if they are or not a match
Tha dataset is balanced, including half of labels 0 and half of label 1

In [6]:
df=pd.read_csv("data/train/train_set.csv", header=0, sep='\t')
df.head()

Unnamed: 0,MIRNA_ID,ENSEMBL_GENE,GENE_SYMBOL,MIRNA_SEQ,MRNA_SEQ,LABEL
0,hsa-let-7a-5p,ENSG00000114573,ATP6V1A,UGAGGUAGUAGGUUGUAUAGUU,GCCTGCTATTGAGGAAAGGTATTCTTCTATACAACTTGTT,1
1,hsa-let-7a-5p,ENSG00000104497,SNX16,UGAGGUAGUAGGUUGUAUAGUU,ATTTGTTTAGTTTCCACGTAATCTTTATCTCTACCTAGAT,1
2,hsa-let-7a-5p,ENSG00000141682,PMAIP1,UGAGGUAGUAGGUUGUAUAGUU,GCACATTGTATATGATTCGGTTTATACATATTACCTTGTT,1
3,hsa-let-7a-5p,ENSG00000174010,KLHL15,UGAGGUAGUAGGUUGUAUAGUU,GAAGTTAGACACCTTTCTGCTAGACAACTTTGTGCCACTC,1
4,hsa-let-7b-5p,ENSG00000130402,ACTN4,UGAGGUAGUAGGUUGUGUGGUU,GGCCCTCATCTTCGACAACAAGCACACCAACTATACCATG,1


In [7]:
df.hist(column='LABEL')

array([[<AxesSubplot:title={'center':'LABEL'}>]], dtype=object)

## Pre-processing
Below a set of functions that will help me read and encode the .fasta files

In [8]:
def read_fasta(mirna_fasta_file, mrna_fasta_file):
    # function to read the '.fatsa' files and retun sequences and ids of 
    # mRNA and miRNA
    mirna_list = list(SeqIO.parse(mirna_fasta_file, 'fasta'))
    mrna_list = list(SeqIO.parse(mrna_fasta_file, 'fasta'))
    
    mirna_ids = []
    mirna_seqs = []
    mrna_ids = []
    mrna_seqs = []
    
    for i in range(len(mirna_list)):
        mirna_ids.append(str(mirna_list[i].id))
        mirna_seqs.append(str(mirna_list[i].seq))
    
    for i in range(len(mrna_list)):
        mrna_ids.append(str(mrna_list[i].id))
        mrna_seqs.append(str(mrna_list[i].seq))
    
    return mirna_ids, mirna_seqs, mrna_ids, mrna_seqs


mirna_fasta_file = "data/mirna.fasta"
mrna_fasta_file  = "data/mrna.fasta"
mirna_ids, mirna_seqs, mrna_ids, mrna_seqs = read_fasta(mirna_fasta_file,mrna_fasta_file)

for i in range(5):
    print(f"mirna_ID : {mirna_ids[i]}. , mirna_SEQUENCE : {mirna_seqs[i]}")

print("\n")

for i in range(5):
    print(f"mrna_ID : {mrna_ids[i]}. , mrna_SEQUENCE : {mrna_seqs[i][:40]}")


mirna_ID : hsa-miR-4777-5p. , mirna_SEQUENCE : UUCUAGAUGAGAGAUAUAUAUA
mirna_ID : hsa-miR-3908. , mirna_SEQUENCE : GAGCAAUGUAGGUAGACUGUUU
mirna_ID : hsa-miR-96-3p. , mirna_SEQUENCE : AAUCAUGUGCAGUGCCAAUAUG
mirna_ID : hsa-miR-3144-5p. , mirna_SEQUENCE : AGGGGACCAAAGAGAUAUAUAG
mirna_ID : hsa-miR-6509-3p. , mirna_SEQUENCE : UUCCACUGCCACUACCUAAUUU


mrna_ID : NM_003629. , mrna_SEQUENCE : AGAGGAAGUGGGAAGAGAGGUGGUUCUCUGGCAUUUUUUU
mrna_ID : NM_001135041. , mrna_SEQUENCE : GCACUCCUUUCCCCUGCUGUCCCCUUCGACCCUCAGCCCU
mrna_ID : NM_001256461. , mrna_SEQUENCE : AUGUUCUAUACAGUGGACAGCCCUCCAGAAUGGUACUUCA
mrna_ID : NM_005371. , mrna_SEQUENCE : CUGCUUACUCUACCUUAGCUGGACCUCGUCUCCCAGGGAU
mrna_ID : NM_175734. , mrna_SEQUENCE : CCAGCAGGCGGAUGUGGGGUGUGGGGCAGGGCAUGGAGGG


In [9]:
def nucleotide_to_int(nucleotides, max_len):
    # assign a number to the basis in order to translate nucleotides into int
    dictionary = {'A':1, 'C':2, 'G':3, 'T':4, 'U':4}
    
    chars = []
    nucleotides = nucleotides.upper()
    for c in nucleotides:
        chars.append(c)
    
    ints_enc = np.full((max_len,), fill_value=0) # to post-pad inputs
    for i in range(len(chars)):
        try:
            ints_enc[i] = dictionary[chars[i]]
        except KeyError:
            continue
        except IndexError:
            break
        
    return ints_enc


seq = mrna_seqs[1]
max_len = len(seq)

ints_enc = nucleotide_to_int(seq,max_len)

print(ints_enc)

[3 2 1 2 4 2 2 4 4 4 2 2 2 2 4 3 2 4 3 4 2 2 2 2 4 4 2 3 1 2 2 2 4 2 1 3 2
 2 2 4 2 4 3 3 4 3 2 2 3 2 4 2 4 3 2 2 2 3 1 4 3 2 1 2 1 3 2 2 1 2 2 4 2 1
 3 2 2 1 3 2 2 2 2 2 1 3 3 4 1 3 1 1 1 2 3 4 3 3 3 4 4 1 1 3 2 4 2 4 4 2 2
 4 3 2 2 2 2 3 4 4 2 1 3 2 4 4 2 1 2 4 2 2 2 1 2 2 2 4 4 4 2 1 3 2 3 4 2 2
 4 3 2 2 2 2 4 4 2 1 2 2 4 4 3 1 2 2 2 3 3 3 4 4 2 2 2 2 2 1 2 4 2 2 2 1 4
 4 2 2 2 4 3 3 2 2 4 2 4 3 2 2 1 4 1 1 4 4 4 3 4 4 3 4 4 2 1 1 2 4 3 2 4 2
 2 2 4 2 2 4 4 2 2 4 3 1 3 3 3 3 2 2 4 2 1 3 3 3 2 4 4 3 4 3 3 3 3 3 3 4 1
 3 3 2 4 3 1 3 1 2 2 2 2 1 2 2 1 2 2 1 1 1 3 3 4 4 1 1 3 4 3 1 3 3 4 2 2 2
 2 4 4 3 1 4 4 3 1 3 3 1 2 4 4 2 1 2 2 2 2 4 4 3 1 4 4 1 1 1 3 2 1 1 2 4 4
 2 4 3 2 4 4 2 1 3 4 3 2]


In [10]:
def sequence_to_int(sequences, max_len):
    # translate entire RNA sequences into int
    import itertools
    
    if type(sequences) is list:
        seqs_enc = np.asarray([nucleotide_to_int(seq, max_len) for seq in sequences])
    else:
        seqs_enc = np.asarray([nucleotide_to_int(seq, max_len) for seq in sequences])
        seqs_enc = list(itertools.chain(*seqs_enc))
        seqs_enc = np.asarray(seqs_enc)
    
    return seqs_enc

x=sequence_to_int(mrna_seqs, 10)
print (x[1])

[3 2 1 2 4 2 2 4 4 4]


In [11]:
def pad_sequences(sequences, max_len=None, padding='pre', fill_value=0):
    # padding sequences with fill_value to achieve seme lenght
    n_samples = len(sequences)
    
    lengths = []
    for seq in sequences:
        try:
            lengths.append(len(seq))
        except TypeError:
            raise ValueError("sequences expected a list of iterables, got {}".format(seq))
    if max_len is None:
        max_len = np.max(lengths)
    
    input_shape = np.asarray(sequences[0]).shape[1:]
    padded_shape = (n_samples, max_len) + input_shape
    padded = np.full(padded_shape, fill_value=fill_value)
    
    for i, seq in enumerate(sequences):
        if padding == 'pre':
            truncated = seq[-max_len:]
            padded[i, -len(truncated):] = truncated
        elif padding == 'post':
            truncated = seq[:max_len]
            padded[i, :len(truncated)] = truncated
        else:
            raise ValueError("padding expected 'pre' or 'post', got {}".format(truncating))
    
    return padded

x=pad_sequences(x, max_len=100)
print (x[1])

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 2 1 2 4 2 2 4 4 4]


In [12]:
def one_hot(ints):
    #one hot encoding for nucleotides
    dictionary_k = 5 # maximum number of nucleotides
    ints_len = len(ints)
    ints_enc = np.zeros((ints_len, dictionary_k))
    ints_enc[np.arange(ints_len), [k for k in ints]] = 1
    ints_enc = ints_enc[:, 1:5] # to handle zero-padded values
    ints_enc = ints_enc.tolist()
    
    return (ints_enc)

def one_hot_enc(seqs_enc):
    #one hot encoding for sequences    
    one_hot_encs = []
    
    for i in range(len(seqs_enc)):
        one_hot_encs.append(one_hot(seqs_enc[i]))
    
    one_hot_encs = np.array(one_hot_encs)
    
    return one_hot_encs

b=one_hot(x[1])
print (b[95:])

[[0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]]


In [13]:
def to_categorical(labels, n_classes=None):
    #matrix assigning labels to a number of classes
    labels = np.array(labels, dtype='int').reshape(-1)

    n_samples = labels.shape[0]
    if not n_classes:
        n_classes = np.max(labels) + 1

    categorical = np.zeros((n_samples, n_classes))
    categorical[np.arange(n_samples), labels] = 1
    
    return categorical

labels=[0,1,1,1,1,0,2,3,4,0,1,2,3]
labels=to_categorical(labels)
print(labels)

[[1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]]


In [14]:
def preprocess_data(x_query_seqs, x_target_seqs, y=None, cts_size=None, pre_padding=False):
    #getting encoded data form miRNA and mRNA sequences
    if cts_size is not None:
        max_len = cts_size
    else:
        max_len = max(len(max(x_query_seqs, key=len)), len(max(x_target_seqs, key=len)))
    
    x_mirna = sequence_to_int(x_query_seqs, max_len)
    x_mrna = sequence_to_int(x_target_seqs, max_len)
    
    if pre_padding:
        x_mirna = pad_sequences(x_mirna, max_len, padding='pre')
        x_mrna = pad_sequences(x_mrna, max_len, padding='pre')
    
    x_mirna_embd = one_hot_enc(x_mirna)
    x_mrna_embd = one_hot_enc(x_mrna)
    if y is not None:
        y_embd = to_categorical(y, np.unique(y).size)
        
        return x_mirna_embd, x_mrna_embd, y_embd
    else:
        return x_mirna_embd, x_mrna_embd

In [15]:
x_mirna_embd, x_mrna_embd=preprocess_data(mirna_seqs[:10], mrna_seqs[:10])

print (x_mirna_embd, x_mrna_embd)

[[[0. 0. 0. 1.]
  [0. 0. 0. 1.]
  [0. 1. 0. 0.]
  ...
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 1. 0.]
  [1. 0. 0. 0.]
  [0. 0. 1. 0.]
  ...
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[1. 0. 0. 0.]
  [1. 0. 0. 0.]
  [0. 0. 0. 1.]
  ...
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 ...

 [[0. 1. 0. 0.]
  [1. 0. 0. 0.]
  [1. 0. 0. 0.]
  ...
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[0. 0. 1. 0.]
  [0. 0. 1. 0.]
  [0. 0. 1. 0.]
  ...
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[1. 0. 0. 0.]
  [1. 0. 0. 0.]
  [1. 0. 0. 0.]
  ...
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]] [[[1. 0. 0. 0.]
  [0. 0. 1. 0.]
  [1. 0. 0. 0.]
  ...
  [0. 0. 0. 1.]
  [1. 0. 0. 0.]
  [1. 0. 0. 0.]]

 [[0. 0. 1. 0.]
  [0. 1. 0. 0.]
  [1. 0. 0. 0.]
  ...
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[1. 0. 0. 0.]
  [0. 0. 0. 1.]
  [0. 0. 1. 0.]
  ...
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 ...

 [[0. 0. 1. 0.]
  [1. 0. 0. 0.]
  [0. 0. 1. 0.]
  ..

Let's now define how matching can happen.
Following the paper, I utilize relaxed site patterns, which covers most of the canonical site types (CSTs), non-canonical site types (NSTs), and context-dependent non-canonical site types (CDNSTs, to define the candidate target site (CTS). The details used are as follows:
* 10-mer-m6: six WC pairings from the miRNA nucleotides 1–10
* 10-mer-m7: seven WC pairings from the miRNA nucleotides 1–10
* Offset 9-mer-m7: seven WC pairings from the miRNA nucleotides 2–10



In [16]:
def find_candidate(mirna_sequence, mrna_sequence, seed_match):
    positions = set()
    
    if seed_match == '10-mer-m6':
        SEED_START = 1
        SEED_END = 10
        SEED_OFFSET = SEED_START - 1
        MIN_MATCH = 6
        TOLERANCE = (SEED_END-SEED_START+1) - MIN_MATCH
    elif seed_match == '10-mer-m7':
        SEED_START = 1
        SEED_END = 10
        SEED_OFFSET = SEED_START - 1
        MIN_MATCH = 7
        TOLERANCE = (SEED_END-SEED_START+1) - MIN_MATCH
    elif seed_match == 'offset-9-mer-m7':
        SEED_START = 2
        SEED_END = 10
        SEED_OFFSET = SEED_START - 1
        MIN_MATCH = 7
        TOLERANCE = (SEED_END-SEED_START+1) - MIN_MATCH
    elif seed_match == 'strict':
        positions = find_strict_candidate(mirna_sequence, mrna_sequence)
        
        return positions
    else:
        raise ValueError("seed_match expected 'strict', '10-mer-m6', '10-mer-m7', or 'offset-9-mer-m7', got '{}'".format(seed_match))
    
    seed = mirna_sequence[(SEED_START-1):SEED_END]
    rc_seed = str(Seq(seed).complement())
    match_iter = regex.finditer("({}){{e<={}}}".format(rc_seed, TOLERANCE), mrna_sequence)
    
    for match_index in match_iter:
        #positions.add(match_index.start()) # slice-start indicies
        positions.add(match_index.end()+SEED_OFFSET) # slice-stop indicies
    
    positions = list(positions)
    
    return positions



In [17]:
positions=find_candidate(mirna_seqs[1], mrna_seqs[1], seed_match = '10-mer-m6')
print(positions)

[64, 161, 130, 185, 225, 322, 232, 202, 13, 302, 143, 112, 53, 342, 216, 25, 123]


In [18]:
def find_strict_candidate(mirna_sequence, mrna_sequence):
    #find position of matches without tolerance

    positions = set()
    
    SEED_TYPES = ['8-mer', '7-mer-m8', '7-mer-A1', '6-mer', '6-mer-A1', 'offset-7-mer', 'offset-6-mer']
    for seed_match in SEED_TYPES:
        if seed_match == '8-mer':
            SEED_START = 2
            SEED_END = 8
            SEED_OFFSET = 0
            seed = 'U' + mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '7-mer-m8':
            SEED_START = 1
            SEED_END = 8
            SEED_OFFSET = 0
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '7-mer-A1':
            SEED_START = 2
            SEED_END = 7
            SEED_OFFSET = 0
            seed = 'U' + mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '6-mer':
            SEED_START = 2
            SEED_END = 7
            SEED_OFFSET = 1
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '6mer-A1':
            SEED_START = 2
            SEED_END = 6
            SEED_OFFSET = 0
            seed = 'U' + mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == 'offset-7-mer':
            SEED_START = 3
            SEED_END = 9
            SEED_OFFSET = 0
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == 'offset-6-mer':
            SEED_START = 3
            SEED_END = 8
            SEED_OFFSET = 0
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        
        rc_seed = str(Seq(seed).complement())
        match_iter = regex.finditer(rc_seed, mrna_sequence)
        
        for match_index in match_iter:
            #positions.add(match_index.start()) # slice-start indicies
            positions.add(match_index.end()+SEED_OFFSET) # slice-stop indicies
    
    positions = list(positions)
    
    return positions

In [19]:
positions=find_strict_candidate(mirna_seqs[6], mrna_seqs[1])
print(positions)

[209]


In [20]:
def get_candidate(mirna_sequence, mrna_sequence, cts_size, seed_match):
    #using the find_candidate function we can find actual candidates and positions
    positions = find_candidate(mirna_sequence, mrna_sequence, seed_match)
    
    candidates = []
    for i in positions:
        site_sequence = mrna_sequence[max(0, i-cts_size):i]
        rev_site_sequence = site_sequence[::-1]
        rc_site_sequence = str(Seq(rev_site_sequence).complement())
        candidates.append(rev_site_sequence) # miRNAs: 5'-ends to 3'-ends,  mRNAs: 3'-ends to 5'-ends
        #candidates.append(rc_site_sequence)
    
    return candidates, positions

In [21]:
candidates, positions=get_candidate(mirna_seqs[1], mrna_seqs[1], cts_size=1, seed_match = '10-mer-m6')
for i in range(len(candidates)):
    print(candidates[i], positions [i])

C 64
U 161
U 130
U 185
U 225
U 322
U 232
U 202
C 13
U 302
G 143
U 112
U 53
G 342
A 216
U 25
G 123


In [22]:
def make_pair(mirna_sequence, mrna_sequence, cts_size, seed_match):
    #and finally identify mirna_querys and mrna_targets
    candidates, positions = get_candidate(mirna_sequence, mrna_sequence, cts_size, seed_match)
    
    mirna_querys = []
    mrna_targets = []
    if len(candidates) == 0:
        return (mirna_querys, mrna_targets, positions)
    else:
        mirna_sequence = mirna_sequence[0:cts_size]
        for i in range(len(candidates)):
            mirna_querys.append(mirna_sequence)
            mrna_targets.append(candidates[i])
        
    return mirna_querys, mrna_targets, positions

In [23]:
mirna_querys, mrna_targets, positions = make_pair(mirna_seqs[1], mrna_seqs[1], cts_size=1, seed_match = '10-mer-m6')
for i in range(len(candidates)):
    print(mirna_querys[i], mrna_targets[i], positions [i])

G C 64
G U 161
G U 130
G U 185
G U 225
G U 322
G U 232
G U 202
G C 13
G U 302
G G 143
G U 112
G U 53
G G 342
G A 216
G U 25
G G 123


We also need some functions to read the training and test files 

In [24]:
def read_ground_truth(ground_truth_file, header=True, train=True):
    # read the trainign and test files containing pairs of miRNA-mRNA 
    # input format: [miRNA_ID, mRNA_ID, LABEL]
    if header is True:
        records = pd.read_csv(ground_truth_file, header=0, sep='\t')
    else:
        records = pd.read_csv(ground_truth_file, header=None, sep='\t')
    
    query_ids = np.asarray(records.iloc[:, 0].values)
    target_ids = np.asarray(records.iloc[:, 1].values)
    if train is True:
        labels = np.asarray(records.iloc[:, 2].values)
    else:
        labels = np.full((len(records),), fill_value=-1)
    
    return query_ids, target_ids, labels

ground_truth_file = "data/train/train_set.csv"
query_ids, target_ids, labels=read_ground_truth(ground_truth_file)
for i in range(10):
    print(query_ids[i], target_ids[i], labels[i])

hsa-let-7a-5p ENSG00000114573 ATP6V1A
hsa-let-7a-5p ENSG00000104497 SNX16
hsa-let-7a-5p ENSG00000141682 PMAIP1
hsa-let-7a-5p ENSG00000174010 KLHL15
hsa-let-7b-5p ENSG00000130402 ACTN4
hsa-let-7c-5p ENSG00000130402 ACTN4
hsa-let-7a-5p ENSG00000130402 ACTN4
hsa-let-7d ENSG00000130402 ACTN4
hsa-let-7b-5p ENSG00000105202 FBL
hsa-let-7c-5p ENSG00000105202 FBL


In [25]:
def make_input_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file, cts_size=30, seed_match='offset-9-mer-m7', header=True, train=True):
    #from sequences, ids and ground truth we generate the dataset
    mirna_ids, mirna_seqs, mrna_ids, mrna_seqs = read_fasta(mirna_fasta_file, mrna_fasta_file)
    query_ids, target_ids, labels = read_ground_truth(ground_truth_file, header=header, train=train)
    
    dataset = {
        'mirna_fasta_file': mirna_fasta_file,
        'mrna_fasta_file': mrna_fasta_file,
        'ground_truth_file': ground_truth_file,
        'query_ids': [],
        'query_seqs': [],
        'target_ids': [],
        'target_seqs': [],
        'target_locs': [],
        'labels': []
    }
    
    for i in range(len(query_ids)):
        try:
            j = mirna_ids.index(query_ids[i])
        except ValueError:
            continue
        try:
            k = mrna_ids.index(target_ids[i])
        except ValueError:
            continue
        
        query_seqs, target_seqs, locations = make_pair(mirna_seqs[j], mrna_seqs[k], cts_size=cts_size, seed_match=seed_match)
        
        n_pairs = len(locations)
        if n_pairs > 0:
            queries = [query_ids[i] for n in range(n_pairs)]
            dataset['query_ids'].extend(queries)
            dataset['query_seqs'].extend(query_seqs)
            
            targets = [target_ids[i] for n in range(n_pairs)]
            dataset['target_ids'].extend(targets)
            dataset['target_seqs'].extend(target_seqs)
            dataset['target_locs'].extend(locations)
            
            dataset['labels'].extend([[labels[i]] for p in range(n_pairs)])
        
    return dataset

In [26]:
dataset=make_input_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file, seed_match='offset-9-mer-m7')

Let's now prepare the dataset using the above basic functions

In [27]:
def get_negative_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file=None, cts_size=30, seed_match='offset-9-mer-m7', header=False, predict_mode=True):
    mirna_ids, mirna_seqs, mrna_ids, mrna_seqs = read_fasta(mirna_fasta_file, mrna_fasta_file)
    
    dataset = {
        'query_ids': [],
        'target_ids': [],
        'predicts': []
    }
    
    if ground_truth_file is not None:
        query_ids, target_ids, labels = read_ground_truth(ground_truth_file, header=header)
        
        for i in range(len(query_ids)):
            try:
                j = mirna_ids.index(query_ids[i])
            except ValueError:
                continue
            try:
                k = mrna_ids.index(target_ids[i])
            except ValueError:
                continue

            query_seqs, target_seqs, locations = make_pair(mirna_seqs[j], mrna_seqs[k], cts_size=cts_size, seed_match=seed_match)

            n_pairs = len(locations)
            if (n_pairs == 0) and (predict_mode is True):
                dataset['query_ids'].append(query_ids[i])
                dataset['target_ids'].append(target_ids[i])
                dataset['predicts'].append(0)
            elif (n_pairs == 0) and (predict_mode is False):
                dataset['query_ids'].append(query_ids[i])
                dataset['target_ids'].append(target_ids[i])
                dataset['predicts'].append(labels[i])
    else:
        for i in range(len(mirna_ids)):
            for j in range(len(mrna_ids)):
                query_seqs, target_seqs, locations = make_pair(mirna_seqs[i], mrna_seqs[j], cts_size=cts_size, seed_match=seed_match)

                n_pairs = len(locations)
                if n_pairs == 0:
                    dataset['query_ids'].append(mirna_ids[i])                
                    dataset['target_ids'].append(mrna_ids[j])
                    dataset['predicts'].append(0)
    
    dataset['target_locs'] = [-1 for i in range(len(dataset['query_ids']))]
    dataset['probabilities'] = [0.0 for i in range(len(dataset['query_ids']))]
    
    return dataset

In [28]:
dataset_2=get_negative_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file)

In [29]:
dataset_2

{'query_ids': [],
 'target_ids': [],
 'predicts': [],
 'target_locs': [],
 'probabilities': []}

In [30]:
def postprocess_result(dataset, probabilities, predicts, predict_mode=True, output_file=None, cts_size=30, seed_match='offset-9-mer-m7', level='site'):
    neg_pairs = get_negative_pair(dataset['mirna_fasta_file'], dataset['mrna_fasta_file'], dataset['ground_truth_file'], cts_size=cts_size, seed_match=seed_match, predict_mode=predict_mode)
    
    query_ids = np.append(dataset['query_ids'], neg_pairs['query_ids'])
    target_ids = np.append(dataset['target_ids'], neg_pairs['target_ids'])
    target_locs = np.append(dataset['target_locs'], neg_pairs['target_locs'])
    probabilities = np.append(probabilities, neg_pairs['probabilities'])
    predicts = np.append(predicts, neg_pairs['predicts'])
    
    # output format: [QUERY, TARGET, LOCATION, PROBABILITY]
    records = pd.DataFrame(columns=['MIRNA_ID', 'MRNA_ID', 'LOCATION', 'PROBABILITY'])
    records['MIRNA_ID'] = query_ids
    records['MRNA_ID'] = target_ids
    records['LOCATION'] = np.array(["{},{}".format(max(1, l-cts_size+1), l) if l != -1 else "-1,-1" for l in target_locs])
    records['PROBABILITY'] = probabilities
    if predict_mode is True:
        records['PREDICT'] = predicts
    else:
        records['LABEL'] = predicts
    
    #records = records.sort_values(by=['PROBABILITY', 'MIRNA_ID', 'MRNA_ID'], ascending=[False, True, True])
    unique_records = records.drop_duplicates(subset=['MIRNA_ID', 'MRNA_ID'], keep='first')
    
    if level == 'site':
        if output_file is not None:
            records.to_csv(output_file, index=False, sep='\t')
        
        return records
    elif level == 'gene':
        if output_file is not None:
            unique_records.to_csv(output_file, index=False, sep='\t')
        
        return unique_records
    else:
        raise ValueError("level expected 'site' or 'gene', got '{}'".format(mode))


In [31]:
records=postprocess_result(dataset, probabilities=None, predicts=None)

In [32]:
records

Unnamed: 0,MIRNA_ID,MRNA_ID,LOCATION,PROBABILITY,PREDICT
0,,,,,


Time to setup the dataset

In [33]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, mirna_fasta_file, mrna_fasta_file, ground_truth_file, cts_size=30, seed_match='offset-9-mer-m7', header=True, train=True):
        self.dataset = make_input_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file, cts_size=cts_size, seed_match=seed_match, header=header, train=train)
        self.mirna, self.mrna = preprocess_data(self.dataset['query_seqs'], self.dataset['target_seqs'])
        self.labels = np.asarray(self.dataset['labels']).reshape(-1,)
        
        self.mirna = self.mirna.transpose((0, 2, 1))
        self.mrna = self.mrna.transpose((0, 2, 1))
        
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
            
        mirna = self.mirna[index]
        mrna = self.mrna[index]
        label = self.labels[index]
        
        return (mirna, mrna), label
        
    def __len__(self):
        return len(self.labels)


class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, ground_truth_file, cts_size=30):
        self.records = pd.read_csv(ground_truth_file, header=0, sep='\t')
        mirna_seqs = self.records['MIRNA_SEQ'].values.tolist()
        mrna_seqs = self.records['MRNA_SEQ'].values.tolist()
        self.mirna, self.mrna = preprocess_data(mirna_seqs, mrna_seqs, cts_size=cts_size)
        self.labels = self.records['LABEL'].values.astype(int)
        
        self.mirna = self.mirna.transpose((0, 2, 1))
        self.mrna = self.mrna.transpose((0, 2, 1))
    
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        
        mirna = self.mirna[index]
        mrna = self.mrna[index]
        label = self.labels[index]
        
        return (mirna, mrna), label
        
    def __len__(self):
        return len(self.labels)

# Model design
I'll now build the neural network model and define how it will be trained

The proposed approach exploits one-dimensional convolutional neural networks (CNNs) based on sequence-to-sequence interaction learning framework.

The one-hot encoded inputs xmi and xm share the same L1 layer, then the concatenated sequence representation h from the L1 layer are fed to the consecutive L2 –L4 layers. The feature maps from the L4 layer will be flattened and fed to L5 layer, then the L6 layer gets logits of the given miRNA-mRNA pair.

For training, I will optimize the weighted cross-entropy loss function using Adam optimizer.

<img src='deepTarget_model_layout.png' style='width:800px' align='left' />



In [34]:
class HyperParam:
    def __init__(self, filters=None, kernels=None, model_json=None):
        self.dictionary = dict()
        self.name_postfix = str()

        if (filters is not None) and (kernels is not None) and (model_json is None):
            for i, (f, k) in enumerate(zip(filters, kernels)):
                setattr(self, 'f{}'.format(i+1), f)
                setattr(self, 'k{}'.format(i+1), k)
                self.dictionary.update({'f{}'.format(i+1): f, 'k{}'.format(i+1): k})
            self.len = i+1
                
            for key, value in self.dictionary.items():
                self.name_postfix = "{}_{}-{}".format(self.name_postfix, key, value)
        elif model_json is not None:
            self.dictionary = json.loads(model_json)
            for i, (key, value) in enumerate(self.dictionary.items()):
                setattr(self, key, value)
                self.name_postfix = "{}_{}-{}".format(self.name_postfix, key, value)
            self.len = (i+1)//2
    
    def __len__(self):
        return self.len


class deepTarget(nn.Module):
    #defining the architecture of the convonutional model
    def __init__(self, hparams=None, hidden_units=30, input_shape=(1, 4, 30), name_prefix="model"):
        super(deepTarget, self).__init__()
        
        if hparams is None:
            filters, kernels = [32, 16, 64, 16], [3, 3, 3, 3]
            hparams = HyperParam(filters, kernels)
        self.name = "{}{}".format(name_prefix, hparams.name_postfix)
        
        if (isinstance(hparams, HyperParam)) and (len(hparams) == 4):
            self.embd1 = nn.Conv1d(4, hparams.f1, kernel_size=hparams.k1, padding=((hparams.k1 - 1) // 2))
            self.conv2 = nn.Conv1d(hparams.f1*2, hparams.f2, kernel_size=hparams.k2)
            self.conv3 = nn.Conv1d(hparams.f2, hparams.f3, kernel_size=hparams.k3)
            self.conv4 = nn.Conv1d(hparams.f3, hparams.f4, kernel_size=hparams.k4)
            
            """ out_features = ((in_length - kernel_size + (2 * padding)) / stride + 1) * out_channels """
            flat_features = self.forward(torch.rand(input_shape), torch.rand(input_shape), flat_check=True)
            self.fc1 = nn.Linear(flat_features, hidden_units)
            self.fc2 = nn.Linear(hidden_units, 2)
        else:
            raise ValueError("not enough hyperparameters")
    
    def forward(self, x_mirna, x_mrna, flat_check=False):
        h_mirna = F.relu(self.embd1(x_mirna))
        h_mrna = F.relu(self.embd1(x_mrna))
        
        h = torch.cat((h_mirna, h_mrna), dim=1)
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        
        h = h.view(h.size(0), -1)
        if flat_check:
            return h.size(1)
        h = self.fc1(h)
        y = self.fc2(h) #y = F.softmax(self.fc2(h), dim=1)
        
        return y
    
    def size(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [35]:
def train_model(mirna_fasta_file, mrna_fasta_file, train_file, model=None, cts_size=30, seed_match='offset-9-mer-m7', level='gene', batch_size=32, epochs=10, save_file=None, device='cpu'):
    if not isinstance(model, deepTarget):
        raise ValueError("'model' expected <nn.Module 'deepTarget'>, got {}".format(type(model)))
    
    print("\n[TRAIN] {}".format(model.name))
    
    if train_file.split('/')[-1] == 'train_set.csv':
        train_set = TrainDataset(train_file)
    else:
        train_set = Dataset(mirna_fasta_file, mrna_fasta_file, train_file, seed_match=seed_match, header=True, train=True)
    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
    
    class_weight = torch.Tensor(compute_class_weight('balanced', classes=np.unique(train_set.labels), y=train_set.labels)).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weight)
    optimizer = optim.Adam(model.parameters())
    
    model = model.to(device)
    for epoch in range(epochs):
        epoch_loss, corrects = 0, 0

        with tqdm(train_loader, desc="Epoch {}/{}".format(epoch+1, epochs), bar_format=bar_format) as tqdm_loader:
            for i, ((mirna, mrna), label) in enumerate(tqdm_loader):
                mirna, mrna, label = mirna.to(device, dtype=torch.float), mrna.to(device, dtype=torch.float), label.to(device)
                
                outputs = model(mirna, mrna)
                loss = criterion(outputs, label)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item() * outputs.size(0)
                corrects += (torch.max(outputs, 1)[1] == label).sum().item()
                
                if (i+1) == len(train_loader):
                    tqdm_loader.set_postfix(dict(loss=(epoch_loss/len(train_set)), acc=(corrects/len(train_set))))
                else:
                    tqdm_loader.set_postfix(loss=loss.item())
    
    if save_file is None:
        time = datetime.now()
        save_file = "{}.pt".format(time.strftime('%Y%m%d_%H%M%S_weights'))
    torch.save(model.state_dict(), save_file)
    
    
def predict_result(mirna_fasta_file, mrna_fasta_file, query_file, model=None, weight_file=None, seed_match='offset-9-mer-m7', level='gene', output_file=None, device='cpu'):
    if not isinstance(model, deepTarget):
        raise ValueError("'model' expected <nn.Module 'deepTarget'>, got {}".format(type(model)))
    
    if not weight_file.endswith('.pt'):
        raise ValueError("'weight_file' expected '*.pt', got {}".format(weight_file))
    
    model.load_state_dict(torch.load(weight_file))
    
    test_set = Dataset(mirna_fasta_file, mrna_fasta_file, query_file, seed_match=seed_match, header=True, train=False)
    
    model = model.to(device)
    with torch.no_grad():
        model.eval()
        
        mirna = torch.from_numpy(test_set.mirna).to(device, dtype=torch.float)
        mrna = torch.from_numpy(test_set.mrna).to(device, dtype=torch.float)
        label = torch.from_numpy(test_set.labels).to(device)
        
        outputs = model(mirna, mrna)
        _, predicts = torch.max(outputs.data, 1)
        probabilities = F.softmax(outputs, dim=1)
        
        y_probs = probabilities.cpu().numpy()[:, 1]
        y_predicts = predicts.cpu().numpy()
        y_truth = label.cpu().numpy()
        
        if output_file is None:
            time = datetime.now()
            output_file = "{}.csv".format(time.strftime('%Y%m%d_%H%M%S_results'))
        results = postprocess_result(test_set.dataset, y_probs, y_predicts,
                                     seed_match=seed_match, level=level, output_file=output_file)
        
        print(results)

# Training
Feeding the model with the dataset to train

I have trained the model with multiple batch_sizes and number of epochs - turned out that 3 epochs are enough for this problem and avoid overfitting. However, increasing the batch size to 64 showed a better accuracy. I also tested different seed-matches and 10-mer-m6 resulted to be the best one to predict miRNA-mRNA targets.

In [124]:
mirna_fasta_file = "data/mirna.fasta"
mrna_fasta_file  = "data/mrna.fasta"
level            = "gene"
train_file       = "data/train/train_set.csv"
save_file        = "weights.pt"
seed_match       = "10-mer-m6"
batch_size       = 64
epochs           = 3        
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

start_time = datetime.now()
print("\n[START] {}".format(start_time.strftime('%Y-%m-%d @ %H:%M:%S')))

model = deepTarget()
train_model(mirna_fasta_file, mrna_fasta_file, train_file,
            model=model,
            seed_match=seed_match, level=level,
            batch_size=batch_size, epochs=epochs,
            save_file=save_file, device=device)

finish_time = datetime.now()
print("\n[FINISH] {} (user time: {})\n".format(finish_time.now().strftime('%Y-%m-%d @ %H:%M:%S'), (finish_time - start_time)))

    


[START] 2021-03-09 @ 18:44:47

[TRAIN] model_f1-32_k1-3_f2-16_k2-3_f3-64_k3-3_f4-16_k4-3


HBox(children=(FloatProgress(value=0.0, description='Epoch 1/3', max=1023.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Epoch 2/3', max=1023.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Epoch 3/3', max=1023.0, style=ProgressStyle(description_w…



[FINISH] 2021-03-09 @ 18:45:31 (user time: 0:00:43.989682)



# Predict
Using the test set provided, I feed the model with new data to verify predictions

In [125]:
query_file       = "data/test/test_split_2.csv"
weight_file      = 'weights.pt'
output_file      = "results.csv"

start_time = datetime.now()
print("\n[START] {}".format(start_time.strftime('%Y-%m-%d @ %H:%M:%S')))

model = deepTarget()
results = predict_result(mirna_fasta_file, mrna_fasta_file, query_file,
                         model=model, weight_file=weight_file,
                         seed_match=seed_match, level=level,
                         output_file=output_file, device=device)

finish_time = datetime.now()
print("\n[FINISH] {} (user time: {})\n".format(finish_time.now().strftime('%Y-%m-%d @ %H:%M:%S'), (finish_time - start_time)))


[START] 2021-03-09 @ 18:45:31
              MIRNA_ID       MRNA_ID   LOCATION  PROBABILITY  PREDICT
0        hsa-let-7c-5p     NM_005373  1511,1540     0.960640        1
100    hsa-miR-106b-3p  NM_001099285    484,513     0.362340        0
113        hsa-miR-107     NM_002970    131,160     0.997281        1
126     hsa-miR-10a-5p  NM_001105541    227,256     0.192497        0
154     hsa-miR-10a-5p  NM_001166687    356,385     0.551413        1
...                ...           ...        ...          ...      ...
60205   hsa-miR-214-3p     NM_018254      -1,-1     0.000000        0
60206     hsa-miR-4523  NM_001164391      -1,-1     0.000000        0
60207  hsa-miR-4690-5p     NM_012265      -1,-1     0.000000        0
60208   hsa-miR-497-5p     NM_004662      -1,-1     0.000000        0
60209      hsa-miR-922     NM_006923      -1,-1     0.000000        0

[803 rows x 5 columns]

[FINISH] 2021-03-09 @ 18:46:01 (user time: 0:00:30.092920)



# Evaluate
Since the dataset is balanced, the performance of the model against the benchmark will be evaluated using common binary classification performance metrics including accuracy [(TP + TN)/(TP + TN + FP + FN), where TP, FP, FN, and TN represent the numbers of true positives, false positives, false negatives, and true negatives, respectively], sensitivity [TP/(TP + FN)], specificity [TN/(TN + FP)], positive predictive value (PPV)[TP/(TP + FP)] and negative predictive value (NPV) [TN/(TN + FN)]. I will finally benchmark against the paper.

In [126]:
#Creating pandas dataframes for the predictions and the actual test set
predictions=pd.read_csv("results.csv", header=0, sep='\t')
test_set=pd.read_csv("data/test/test_split_2.csv", header=None, sep='\t')

In [127]:
#addign a key 'miRNA-mRNA' to merge the df
key=[]
for i in range (len (predictions['MIRNA_ID'])):
    key.append(predictions['MIRNA_ID'][i]+predictions['MRNA_ID'][i])
predictions['key']=key

key2=[]
for i in range (len (test_set[0])):
    key2.append(test_set[0][i]+test_set[1][i])
test_set['key']=key2
test_set.index=key2

In [128]:
predictions

Unnamed: 0,MIRNA_ID,MRNA_ID,LOCATION,PROBABILITY,PREDICT,key
0,hsa-let-7c-5p,NM_005373,15111540,0.960640,1,hsa-let-7c-5pNM_005373
1,hsa-miR-106b-3p,NM_001099285,484513,0.362340,0,hsa-miR-106b-3pNM_001099285
2,hsa-miR-107,NM_002970,131160,0.997281,1,hsa-miR-107NM_002970
3,hsa-miR-10a-5p,NM_001105541,227256,0.192497,0,hsa-miR-10a-5pNM_001105541
4,hsa-miR-10a-5p,NM_001166687,356385,0.551413,1,hsa-miR-10a-5pNM_001166687
...,...,...,...,...,...,...
798,hsa-miR-214-3p,NM_018254,"-1,-1",0.000000,0,hsa-miR-214-3pNM_018254
799,hsa-miR-4523,NM_001164391,"-1,-1",0.000000,0,hsa-miR-4523NM_001164391
800,hsa-miR-4690-5p,NM_012265,"-1,-1",0.000000,0,hsa-miR-4690-5pNM_012265
801,hsa-miR-497-5p,NM_004662,"-1,-1",0.000000,0,hsa-miR-497-5pNM_004662


In [129]:
test_set

Unnamed: 0,0,1,2,key
hsa-let-7a-5pNM_005373,hsa-let-7a-5p,NM_005373,0,hsa-let-7a-5pNM_005373
hsa-let-7b-5pNM_001291429,hsa-let-7b-5p,NM_001291429,0,hsa-let-7b-5pNM_001291429
hsa-let-7c-5pNM_005373,hsa-let-7c-5p,NM_005373,0,hsa-let-7c-5pNM_005373
hsa-miR-1-3pNM_001292001,hsa-miR-1-3p,NM_001292001,0,hsa-miR-1-3pNM_001292001
hsa-miR-1-3pNM_001354579,hsa-miR-1-3p,NM_001354579,0,hsa-miR-1-3pNM_001354579
...,...,...,...,...
hsa-miR-98-5pNM_001144999,hsa-miR-98-5p,NM_001144999,1,hsa-miR-98-5pNM_001144999
hsa-miR-98-5pNM_001256310,hsa-miR-98-5p,NM_001256310,1,hsa-miR-98-5pNM_001256310
hsa-miR-98-5pNM_001284388,hsa-miR-98-5p,NM_001284388,1,hsa-miR-98-5pNM_001284388
hsa-miR-98-5pNM_006662,hsa-miR-98-5p,NM_006662,1,hsa-miR-98-5pNM_006662


In [130]:
#adding the actual label to the predictions dataset
labels=[]
for i in predictions['key']:
    labels.append(test_set.loc[i, 2])
predictions['labels']=labels

In [131]:
#givign a label to correct and incorrect results
result=[]
for i in range(len(predictions['key'])):
    if predictions['PREDICT'][i]==predictions['labels'][i]==1:
        result.append("TP")
    elif predictions['PREDICT'][i]==predictions['labels'][i]==0:
        result.append("TN")
    elif predictions['labels'][i]==1:
        result.append("FN")
    else:
        result.append("FP")
predictions['result']=result

In [132]:
#a look at the final dataset
predictions

Unnamed: 0,MIRNA_ID,MRNA_ID,LOCATION,PROBABILITY,PREDICT,key,labels,result
0,hsa-let-7c-5p,NM_005373,15111540,0.960640,1,hsa-let-7c-5pNM_005373,0,FP
1,hsa-miR-106b-3p,NM_001099285,484513,0.362340,0,hsa-miR-106b-3pNM_001099285,0,TN
2,hsa-miR-107,NM_002970,131160,0.997281,1,hsa-miR-107NM_002970,0,FP
3,hsa-miR-10a-5p,NM_001105541,227256,0.192497,0,hsa-miR-10a-5pNM_001105541,0,TN
4,hsa-miR-10a-5p,NM_001166687,356385,0.551413,1,hsa-miR-10a-5pNM_001166687,0,FP
...,...,...,...,...,...,...,...,...
798,hsa-miR-214-3p,NM_018254,"-1,-1",0.000000,0,hsa-miR-214-3pNM_018254,1,FN
799,hsa-miR-4523,NM_001164391,"-1,-1",0.000000,0,hsa-miR-4523NM_001164391,1,FN
800,hsa-miR-4690-5p,NM_012265,"-1,-1",0.000000,0,hsa-miR-4690-5pNM_012265,1,FN
801,hsa-miR-497-5p,NM_004662,"-1,-1",0.000000,0,hsa-miR-497-5pNM_004662,1,FN


To evaluate the metrics, let's see how many true positives, true negatives, false positives and false negatives we got

In [133]:
TP=predictions['result'].value_counts()['TP']
TN=predictions['result'].value_counts()['TN']
FP=predictions['result'].value_counts()['FP']
FN=predictions['result'].value_counts()['FN']
print('true positives={}, true negatives={}, false positives={}, false negatives={}'.format(TP, TN, FP, FN))

true positives=307, true negatives=329, false positives=88, false negatives=79


We can now calculate the key metrics

In [134]:
print('accuracy = ',(TP + TN)/(TP + TN + FP + FN))
print('sensitivity =', TP/(TP + FN))
print('specificity =', TN/(TN + FP))
print('F-measure =', 2*TP/(2*TP + FP + FN))
print('PPV =', TP/(TP + FP))
print('NPV =', TN/(TN + FN))

accuracy =  0.7920298879202988
sensitivity = 0.7953367875647669
specificity = 0.7889688249400479
F-measure = 0.7861715749039693
PPV = 0.7772151898734178
NPV = 0.8063725490196079


Let's now run the same process to all the 10 test sets to generate the final score matrix and bechmark with the paper

In [135]:
test_queries=["data/test/test_split_0.csv", 
              "data/test/test_split_1.csv", 
              "data/test/test_split_2.csv", 
              "data/test/test_split_3.csv", 
              "data/test/test_split_4.csv", 
              "data/test/test_split_5.csv", 
              "data/test/test_split_6.csv", 
              "data/test/test_split_7.csv", 
              "data/test/test_split_8.csv", 
              "data/test/test_split_9.csv"]

In [136]:
matrix=pd.DataFrame()
matrix['set']=['accuracy','sensitivity','specificity','F-measure', 'PPV','NPV']
x=0
for i in test_queries:
    x+=1
    output_file = "results.csv"
    query_file = i
    model = deepTarget()
    results = predict_result(mirna_fasta_file, mrna_fasta_file, query_file,
                         model=model, weight_file=weight_file,
                         seed_match=seed_match, level=level,
                         output_file=output_file, device=device)
    predictions=pd.read_csv("results.csv", header=0, sep='\t')
    test_set=pd.read_csv(i, header=None, sep='\t')
    key=[]
    for i in range (len (predictions['MIRNA_ID'])):
        key.append(predictions['MIRNA_ID'][i]+predictions['MRNA_ID'][i])
    predictions['key']=key

    key2=[]
    for i in range (len (test_set[0])):
        key2.append(test_set[0][i]+test_set[1][i])
    test_set['key']=key2
    test_set.index=key2
    
    labels=[]
    for i in predictions['key']:
        labels.append(test_set.loc[i, 2])
    predictions['labels']=labels

    result=[]
    for i in range(len(predictions['key'])):
        if predictions['PREDICT'][i]==predictions['labels'][i]==1:
            result.append("TP")
        elif predictions['PREDICT'][i]==predictions['labels'][i]==0:
            result.append("TN")
        elif predictions['labels'][i]==1:
            result.append("FN")
        else:
            result.append("FP")
    
    predictions['result']=result
    TP=predictions['result'].value_counts()['TP']
    TN=predictions['result'].value_counts()['TN']
    FP=predictions['result'].value_counts()['FP']
    FN=predictions['result'].value_counts()['FN']
    matrix['set{}'.format(x)]=[(TP + TN)/(TP + TN + FP + FN),
              TP/(TP + FN),
              TN/(TN + FP),
              2*TP/(2*TP + FP + FN),
              TP/(TP + FP),
              TN/(TN + FN)]
    print(x)
matrix=matrix.transpose()
matrix.columns = matrix.iloc[0]
score_matrix=matrix.drop('set')

              MIRNA_ID       MRNA_ID   LOCATION  PROBABILITY  PREDICT
0        hsa-let-7c-5p     NM_005373  1511,1540     0.960640        1
100    hsa-miR-106b-3p  NM_001099285    484,513     0.362340        0
113        hsa-miR-107     NM_002970    131,160     0.997281        1
126     hsa-miR-10a-5p  NM_001105541    227,256     0.192497        0
154     hsa-miR-10a-5p  NM_001166687    356,385     0.551413        1
...                ...           ...        ...          ...      ...
59645   hsa-miR-497-5p  NM_001242825      -1,-1     0.000000        0
59646  hsa-miR-548d-3p  NM_001144755      -1,-1     0.000000        0
59647  hsa-miR-548d-3p     NM_021127      -1,-1     0.000000        0
59648  hsa-miR-548d-3p     NM_024524      -1,-1     0.000000        0
59649      hsa-miR-922  NM_001244604      -1,-1     0.000000        0

[796 rows x 5 columns]
1
              MIRNA_ID       MRNA_ID   LOCATION  PROBABILITY  PREDICT
0        hsa-let-7c-5p     NM_005373  1511,1540     0.960640    

Now let's calculate the average accuracy score ofr the 10 test datasets:

In [137]:
#calculating averages
averages=[]
for i in score_matrix.columns:
    averages.append(score_matrix[i].mean())
averages=pd.DataFrame(averages)
averages.columns=['average']
averages['set']=['accuracy','sensitivity','specificity','F-measure', 'PPV','NPV']
averages=averages.transpose()
averages.columns = averages.iloc[1]
averages=averages.drop('set')
score_matrix.append(averages)

set,accuracy,sensitivity,specificity,F-measure,PPV,NPV
set1,0.792714,0.796834,0.788969,0.785436,0.774359,0.810345
set2,0.80661,0.825,0.788969,0.806846,0.789474,0.824561
set3,0.79203,0.795337,0.788969,0.786172,0.777215,0.806373
set4,0.789668,0.790404,0.788969,0.785445,0.780549,0.798544
set5,0.792593,0.796438,0.788969,0.788413,0.780549,0.804401
set6,0.783042,0.776623,0.788969,0.774611,0.77261,0.792771
set7,0.79703,0.805627,0.788969,0.793451,0.781638,0.812346
set8,0.789474,0.790026,0.788969,0.781818,0.773779,0.804401
set9,0.79375,0.798956,0.788969,0.787645,0.77665,0.810345
set10,0.801471,0.814536,0.788969,0.800493,0.786925,0.816377


Let's compare to the model performance as published in the paper:
<img src='evaluation_metrics_deepTarget.png' style='width:800px' align='left'/>

To conclude, this model achives a slightly higher accuracy compared to the deepTarget paper, while reducing the computational effort and time required for training: it gets to +79% accuracy in just 3 epochs, compared to the 10 epochs in the paper.