# Deep learning model to identify target mRNA of microRNA sequences

From a review of currently developed Deep learning models in genomics, I realized there is still a long way to go before the full stock of machine learning techniques will be applied to genomics at its maximum potential. In particular, I got fascinated by the deepTarget model that has recently been proposed as a way to identify microRNA targets with 96% accuracy - therefore I decided to focus my project in this field. The repository of the mentioned work is available on GitHib at the following link: https://github.com/ailab-seoultech/deepTarget.

MicroRNAs (miRNAs), which are small non-coding RNA molecules that consist of about 22 nucleotides, are known to regulate more than 60% of protein coding genes of humans and other mammals at the RNA level. As miRNAs control the function of their target messenger RNAs (mRNAs) by regulating the expression of the targets, investigating miRNAs is important to understand various biological processes, including diseases. To predict targets of given miRNAs, numerous computational tools have been proposed. 

## Problem Statement
Two types of computational problems about miRNAs thus naturally arise in bioinformatics: miRNA host identification (i.e., the problem of locating the genes that encode pre-miRNAs) and miRNA target prediction (i.e., the task of finding the mRNA targets of given miRNAs). This project focuses on the target prediction problem, aiming at deploying a model that is able to recognize whether or not a sequence of microRNA belongs to a given target mRNA. The output will be a binary classification of miRNA-mRNA pairs that match/ don’t match (i.e. mRNA is a target of the given miRNA).




## Getting ready
Copyin the dataset from GitHub repository of the author of the paper and importing the right libraries
Le'ts first install the right version of the softwares to use and import the needed libraries 

In [1]:
! git clone https://github.com/ailab-seoultech/deepTarget.git

fatal: destination path 'deepTarget' already exists and is not an empty directory.


In [2]:
! sh deepTarget/download.sh

Downloading data.tar.gz...
Extracting data.tar.gz...


In [3]:
pip install biopython

Collecting biopython
  Downloading biopython-1.78-cp36-cp36m-manylinux1_x86_64.whl (2.3 MB)
[K     |████████████████████████████████| 2.3 MB 17.1 MB/s eta 0:00:01
Installing collected packages: biopython
Successfully installed biopython-1.78
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/python3/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [4]:
pip install torch

Collecting torch
  Downloading torch-1.7.1-cp36-cp36m-manylinux1_x86_64.whl (776.8 MB)
[K     |████████████████████████████████| 776.8 MB 5.3 kB/s  eta 0:00:01     |████▊                           | 116.0 MB 76.5 MB/s eta 0:00:09     |███████▉                        | 189.7 MB 82.5 MB/s eta 0:00:08MB/s eta 0:00:04
Collecting dataclasses
  Downloading dataclasses-0.8-py3-none-any.whl (19 kB)
Installing collected packages: dataclasses, torch
Successfully installed dataclasses-0.8 torch-1.7.1
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/python3/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [5]:
pip install regex

Collecting regex
  Downloading regex-2020.11.13-cp36-cp36m-manylinux2014_x86_64.whl (723 kB)
[K     |████████████████████████████████| 723 kB 24.7 MB/s eta 0:00:01
[?25hInstalling collected packages: regex
Successfully installed regex-2020.11.13
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/python3/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


In [6]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from tqdm.auto import tqdm
bar_format = '{desc} |{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]'
import sklearn
from sklearn.utils.class_weight import compute_class_weight
from datetime import datetime

import argparse
import regex
import Bio
from Bio import SeqIO
from Bio.Seq import Seq


In [7]:
pip install nomkl

[31mERROR: Could not find a version that satisfies the requirement nomkl[0m
[31mERROR: No matching distribution found for nomkl[0m
You should consider upgrading via the '/home/ec2-user/anaconda3/envs/python3/bin/python -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.


# Pre-processing
I will now define a set of functions that will help me convert a csv formatted dataset into an input to feed the model



##Dataset
The dataset includes: 
* 2 '.fasta' files that encode the sequence of mRNA and miRNA
* a training set with pairs of miRNA and mRNA sequences
* a test set with only miRNA and mRNA endoded versions

Let's first define a range of basic functions to read the '.fasta' files


In [8]:
def read_fasta(mirna_fasta_file, mrna_fasta_file):
    # function to read the '.fatsa' files and retun sequences and ids of 
    # mRNA and miRNA
    mirna_list = list(SeqIO.parse(mirna_fasta_file, 'fasta'))
    mrna_list = list(SeqIO.parse(mrna_fasta_file, 'fasta'))
    
    mirna_ids = []
    mirna_seqs = []
    mrna_ids = []
    mrna_seqs = []
    
    for i in range(len(mirna_list)):
        mirna_ids.append(str(mirna_list[i].id))
        mirna_seqs.append(str(mirna_list[i].seq))
    
    for i in range(len(mrna_list)):
        mrna_ids.append(str(mrna_list[i].id))
        mrna_seqs.append(str(mrna_list[i].seq))
    
    return mirna_ids, mirna_seqs, mrna_ids, mrna_seqs


mirna_fasta_file = "data/mirna.fasta"
mrna_fasta_file  = "data/mrna.fasta"
mirna_ids, mirna_seqs, mrna_ids, mrna_seqs = read_fasta(mirna_fasta_file,mrna_fasta_file)

for i in range(5):
    print(f"mirna_ID : {mirna_ids[i]}. , mirna_SEQUENCE : {mirna_seqs[i]}")

print("\n")

for i in range(5):
    print(f"mrna_ID : {mrna_ids[i]}. , mrna_SEQUENCE : {mrna_seqs[i][:10]}")


mirna_ID : hsa-miR-4777-5p. , mirna_SEQUENCE : UUCUAGAUGAGAGAUAUAUAUA
mirna_ID : hsa-miR-3908. , mirna_SEQUENCE : GAGCAAUGUAGGUAGACUGUUU
mirna_ID : hsa-miR-96-3p. , mirna_SEQUENCE : AAUCAUGUGCAGUGCCAAUAUG
mirna_ID : hsa-miR-3144-5p. , mirna_SEQUENCE : AGGGGACCAAAGAGAUAUAUAG
mirna_ID : hsa-miR-6509-3p. , mirna_SEQUENCE : UUCCACUGCCACUACCUAAUUU


mrna_ID : NM_003629. , mrna_SEQUENCE : AGAGGAAGUG
mrna_ID : NM_001135041. , mrna_SEQUENCE : GCACUCCUUU
mrna_ID : NM_001256461. , mrna_SEQUENCE : AUGUUCUAUA
mrna_ID : NM_005371. , mrna_SEQUENCE : CUGCUUACUC
mrna_ID : NM_175734. , mrna_SEQUENCE : CCAGCAGGCG


In [9]:
def nucleotide_to_int(nucleotides, max_len):
    # assign a number to the basis in order to translate nucleotides into int
    dictionary = {'A':1, 'C':2, 'G':3, 'T':4, 'U':4}
    
    chars = []
    nucleotides = nucleotides.upper()
    for c in nucleotides:
        chars.append(c)
    
    ints_enc = np.full((max_len,), fill_value=0) # to post-pad inputs
    for i in range(len(chars)):
        try:
            ints_enc[i] = dictionary[chars[i]]
        except KeyError:
            continue
        except IndexError:
            break
        
    return ints_enc


seq = mrna_seqs[1]
max_len = len(seq)

ints_enc = nucleotide_to_int(seq,max_len)

print(ints_enc)

[3 2 1 2 4 2 2 4 4 4 2 2 2 2 4 3 2 4 3 4 2 2 2 2 4 4 2 3 1 2 2 2 4 2 1 3 2
 2 2 4 2 4 3 3 4 3 2 2 3 2 4 2 4 3 2 2 2 3 1 4 3 2 1 2 1 3 2 2 1 2 2 4 2 1
 3 2 2 1 3 2 2 2 2 2 1 3 3 4 1 3 1 1 1 2 3 4 3 3 3 4 4 1 1 3 2 4 2 4 4 2 2
 4 3 2 2 2 2 3 4 4 2 1 3 2 4 4 2 1 2 4 2 2 2 1 2 2 2 4 4 4 2 1 3 2 3 4 2 2
 4 3 2 2 2 2 4 4 2 1 2 2 4 4 3 1 2 2 2 3 3 3 4 4 2 2 2 2 2 1 2 4 2 2 2 1 4
 4 2 2 2 4 3 3 2 2 4 2 4 3 2 2 1 4 1 1 4 4 4 3 4 4 3 4 4 2 1 1 2 4 3 2 4 2
 2 2 4 2 2 4 4 2 2 4 3 1 3 3 3 3 2 2 4 2 1 3 3 3 2 4 4 3 4 3 3 3 3 3 3 4 1
 3 3 2 4 3 1 3 1 2 2 2 2 1 2 2 1 2 2 1 1 1 3 3 4 4 1 1 3 4 3 1 3 3 4 2 2 2
 2 4 4 3 1 4 4 3 1 3 3 1 2 4 4 2 1 2 2 2 2 4 4 3 1 4 4 1 1 1 3 2 1 1 2 4 4
 2 4 3 2 4 4 2 1 3 4 3 2]


In [10]:
def sequence_to_int(sequences, max_len):
    # translate entire RNA sequences into int
    import itertools
    
    if type(sequences) is list:
        seqs_enc = np.asarray([nucleotide_to_int(seq, max_len) for seq in sequences])
    else:
        seqs_enc = np.asarray([nucleotide_to_int(seq, max_len) for seq in sequences])
        seqs_enc = list(itertools.chain(*seqs_enc))
        seqs_enc = np.asarray(seqs_enc)
    
    return seqs_enc

x=sequence_to_int(mrna_seqs, 10)
print (x[1])

[3 2 1 2 4 2 2 4 4 4]


In [11]:
def pad_sequences(sequences, max_len=None, padding='pre', fill_value=0):
    # padding sequences with fill_value to achieve seme lenght
    n_samples = len(sequences)
    
    lengths = []
    for seq in sequences:
        try:
            lengths.append(len(seq))
        except TypeError:
            raise ValueError("sequences expected a list of iterables, got {}".format(seq))
    if max_len is None:
        max_len = np.max(lengths)
    
    input_shape = np.asarray(sequences[0]).shape[1:]
    padded_shape = (n_samples, max_len) + input_shape
    padded = np.full(padded_shape, fill_value=fill_value)
    
    for i, seq in enumerate(sequences):
        if padding == 'pre':
            truncated = seq[-max_len:]
            padded[i, -len(truncated):] = truncated
        elif padding == 'post':
            truncated = seq[:max_len]
            padded[i, :len(truncated)] = truncated
        else:
            raise ValueError("padding expected 'pre' or 'post', got {}".format(truncating))
    
    return padded

x=pad_sequences(x, max_len=100)
print (x[1])

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 2 1 2 4 2 2 4 4 4]


In [12]:
def one_hot(ints):
    #one hot encoding for nucleotides
    dictionary_k = 5 # maximum number of nucleotides
    ints_len = len(ints)
    ints_enc = np.zeros((ints_len, dictionary_k))
    ints_enc[np.arange(ints_len), [k for k in ints]] = 1
    ints_enc = ints_enc[:, 1:5] # to handle zero-padded values
    ints_enc = ints_enc.tolist()
    
    return (ints_enc)

def one_hot_enc(seqs_enc):
    #one hot encoding for sequences    
    one_hot_encs = []
    
    for i in range(len(seqs_enc)):
        one_hot_encs.append(one_hot(seqs_enc[i]))
    
    one_hot_encs = np.array(one_hot_encs)
    
    return one_hot_encs

b=one_hot(x[1])
print (b[95:])

[[0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0]]


In [13]:
def to_categorical(labels, n_classes=None):
    #matrix assigning labels to a number of classes
    labels = np.array(labels, dtype='int').reshape(-1)

    n_samples = labels.shape[0]
    if not n_classes:
        n_classes = np.max(labels) + 1

    categorical = np.zeros((n_samples, n_classes))
    categorical[np.arange(n_samples), labels] = 1
    
    return categorical

labels=[0,1,1,1,1,0,2,3,4,0,1,2,3]
labels=to_categorical(labels)
print(labels)

[[1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1.]
 [1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]]


In [14]:
def preprocess_data(x_query_seqs, x_target_seqs, y=None, cts_size=None, pre_padding=False):
    #getting encoded data form miRNA and mRNA sequences
    if cts_size is not None:
        max_len = cts_size
    else:
        max_len = max(len(max(x_query_seqs, key=len)), len(max(x_target_seqs, key=len)))
    
    x_mirna = sequence_to_int(x_query_seqs, max_len)
    x_mrna = sequence_to_int(x_target_seqs, max_len)
    
    if pre_padding:
        x_mirna = pad_sequences(x_mirna, max_len, padding='pre')
        x_mrna = pad_sequences(x_mrna, max_len, padding='pre')
    
    x_mirna_embd = one_hot_enc(x_mirna)
    x_mrna_embd = one_hot_enc(x_mrna)
    if y is not None:
        y_embd = to_categorical(y, np.unique(y).size)
        
        return x_mirna_embd, x_mrna_embd, y_embd
    else:
        return x_mirna_embd, x_mrna_embd

In [None]:
print(2+2)
x_mirna_embd, x_mrna_embd=preprocess_data(mirna_seqs, mrna_seqs)
print(2+2)

print (x_mirna_embd, x_mrna_embd)

4


Let's now define how matching can happen
Following the paper, I utilize relaxed site patterns, which covers most of the canonical site types (CSTs), non-canonical site types (NSTs), and context-dependent non-canonical site types (CDNSTs, to define the candidate target site (CTS). The details used are as follows:
* 10-mer-m6: six WC pairings from the miRNA nucleotides 1–10
* 10-mer-m7: seven WC pairings from the miRNA nucleotides 1–10
* Offset 9-mer-m7: seven WC pairings from the miRNA nucleotides 2–10



In [11]:
def find_candidate(mirna_sequence, mrna_sequence, seed_match):
    #find potential matched with tolerance
    positions = set()
    
    if seed_match == '10-mer-m6':
        SEED_START = 1
        SEED_END = 10
        SEED_OFFSET = SEED_START - 1
        MIN_MATCH = 6
        TOLERANCE = (SEED_END-SEED_START+1) - MIN_MATCH
    elif seed_match == '10-mer-m7':
        SEED_START = 1
        SEED_END = 10
        SEED_OFFSET = SEED_START - 1
        MIN_MATCH = 7
        TOLERANCE = (SEED_END-SEED_START+1) - MIN_MATCH
    elif seed_match == 'offset-9-mer-m7':
        SEED_START = 2
        SEED_END = 10
        SEED_OFFSET = SEED_START - 1
        MIN_MATCH = 7
        TOLERANCE = (SEED_END-SEED_START+1) - MIN_MATCH
    elif seed_match == 'strict':
        positions = find_strict_candidate(mirna_sequence, mrna_sequence)
        
        return positions
    else:
        raise ValueError("seed_match expected 'strict', '10-mer-m6', '10-mer-m7', or 'offset-9-mer-m7', got '{}'".format(seed_match))
    
    seed = mirna_sequence[(SEED_START-1):SEED_END]
    rc_seed = str(Seq(seed).complement())
    match_iter = regex.finditer("({}){{e<={}}}".format(rc_seed, TOLERANCE), mrna_sequence)
    
    for match_index in match_iter:
        #positions.add(match_index.start()) # slice-start indicies
        positions.add(match_index.end()+SEED_OFFSET) # slice-stop indicies
    
    positions = list(positions)
    
    return positions



In [None]:
positions=find_candidate(mirna_seqs[1], mrna_seqs[1], seed_match = '10-mer-m6')
print(positions)

In [None]:


def find_strict_candidate(mirna_sequence, mrna_sequence):
    #find potential matched without tolerance

    positions = set()
    
    SEED_TYPES = ['8-mer', '7-mer-m8', '7-mer-A1', '6-mer', '6-mer-A1', 'offset-7-mer', 'offset-6-mer']
    for seed_match in SEED_TYPES:
        if seed_match == '8-mer':
            SEED_START = 2
            SEED_END = 8
            SEED_OFFSET = 0
            seed = 'U' + mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '7-mer-m8':
            SEED_START = 1
            SEED_END = 8
            SEED_OFFSET = 0
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '7-mer-A1':
            SEED_START = 2
            SEED_END = 7
            SEED_OFFSET = 0
            seed = 'U' + mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '6-mer':
            SEED_START = 2
            SEED_END = 7
            SEED_OFFSET = 1
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == '6mer-A1':
            SEED_START = 2
            SEED_END = 6
            SEED_OFFSET = 0
            seed = 'U' + mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == 'offset-7-mer':
            SEED_START = 3
            SEED_END = 9
            SEED_OFFSET = 0
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        elif seed_match == 'offset-6-mer':
            SEED_START = 3
            SEED_END = 8
            SEED_OFFSET = 0
            seed = mirna_sequence[(SEED_START-1):SEED_END]
        
        rc_seed = str(Seq(seed).complement())
        match_iter = regex.finditer(rc_seed, mrna_sequence)
        
        for match_index in match_iter:
            #positions.add(match_index.start()) # slice-start indicies
            positions.add(match_index.end()+SEED_OFFSET) # slice-stop indicies
    
    positions = list(positions)
    
    return positions


def get_candidate(mirna_sequence, mrna_sequence, cts_size, seed_match):
    #using the find_candidate function we can find actual candidates and positions
    positions = find_candidate(mirna_sequence, mrna_sequence, seed_match)
    
    candidates = []
    for i in positions:
        site_sequence = mrna_sequence[max(0, i-cts_size):i]
        rev_site_sequence = site_sequence[::-1]
        rc_site_sequence = str(Seq(rev_site_sequence).complement())
        candidates.append(rev_site_sequence) # miRNAs: 5'-ends to 3'-ends,  mRNAs: 3'-ends to 5'-ends
        #candidates.append(rc_site_sequence)
    
    return candidates, positions


def make_pair(mirna_sequence, mrna_sequence, cts_size, seed_match):
    #and finally identify mirna_querys and mrna_targets
    candidates, positions = get_candidate(mirna_sequence, mrna_sequence, cts_size, seed_match)
    
    mirna_querys = []
    mrna_targets = []
    if len(candidates) == 0:
        return (mirna_querys, mrna_targets, positions)
    else:
        mirna_sequence = mirna_sequence[0:cts_size]
        for i in range(len(candidates)):
            mirna_querys.append(mirna_sequence)
            mrna_targets.append(candidates[i])
        
    return mirna_querys, mrna_targets, positions

We also need some functions to read the training and test files 

In [None]:
def read_ground_truth(ground_truth_file, header=True, train=True):
    # read the trainign and test files containing pairs of miRNA-mRNA 
    # input format: [miRNA_ID, mRNA_ID, LABEL]
    if header is True:
        records = pd.read_csv(ground_truth_file, header=0, sep='\t')
    else:
        records = pd.read_csv(ground_truth_file, header=None, sep='\t')
    
    query_ids = np.asarray(records.iloc[:, 0].values)
    target_ids = np.asarray(records.iloc[:, 1].values)
    if train is True:
        labels = np.asarray(records.iloc[:, 2].values)
    else:
        labels = np.full((len(records),), fill_value=-1)
    
    return query_ids, target_ids, labels


def make_input_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file, cts_size=30, seed_match='offset-9-mer-m7', header=True, train=True):
    #from sequences, ids and ground truth we generate the dataset
    mirna_ids, mirna_seqs, mrna_ids, mrna_seqs = read_fasta(mirna_fasta_file, mrna_fasta_file)
    query_ids, target_ids, labels = read_ground_truth(ground_truth_file, header=header, train=train)
    
    dataset = {
        'mirna_fasta_file': mirna_fasta_file,
        'mrna_fasta_file': mrna_fasta_file,
        'ground_truth_file': ground_truth_file,
        'query_ids': [],
        'query_seqs': [],
        'target_ids': [],
        'target_seqs': [],
        'target_locs': [],
        'labels': []
    }
    
    for i in range(len(query_ids)):
        try:
            j = mirna_ids.index(query_ids[i])
        except ValueError:
            continue
        try:
            k = mrna_ids.index(target_ids[i])
        except ValueError:
            continue
        
        query_seqs, target_seqs, locations = make_pair(mirna_seqs[j], mrna_seqs[k], cts_size=cts_size, seed_match=seed_match)
        
        n_pairs = len(locations)
        if n_pairs > 0:
            queries = [query_ids[i] for n in range(n_pairs)]
            dataset['query_ids'].extend(queries)
            dataset['query_seqs'].extend(query_seqs)
            
            targets = [target_ids[i] for n in range(n_pairs)]
            dataset['target_ids'].extend(targets)
            dataset['target_seqs'].extend(target_seqs)
            dataset['target_locs'].extend(locations)
            
            dataset['labels'].extend([[labels[i]] for p in range(n_pairs)])
        
    return dataset


def make_brute_force_pair(mirna_fasta_file, mrna_fasta_file, cts_size=30, seed_match='offset-9-mer-m7'):
    mirna_ids, mirna_seqs, mrna_ids, mrna_seqs = read_fasta(mirna_fasta_file, mrna_fasta_file)
    
    dataset = {
        'query_ids': [],
        'query_seqs': [],
        'target_ids': [],
        'target_seqs': [],
        'target_locs': []
    }
    
    for i in range(len(mirna_ids)):
        for j in range(len(mrna_ids)):
            query_seqs, target_seqs, positions = make_pair(mirna_seqs[i], mrna_seqs[j], cts_size, seed_match)
            
            n_pairs = len(positions)
            if n_pairs > 0:
                query_ids = [mirna_ids[i] for k in range(n_pairs)]
                dataset['query_ids'].extend(query_ids)
                dataset['query_seqs'].extend(query_seqs)
                
                target_ids = [mrna_ids[j] for k in range(n_pairs)]
                dataset['target_ids'].extend(target_ids)
                dataset['target_seqs'].extend(target_seqs)
                dataset['target_locs'].extend(positions)
    
    return dataset

Let's now prepare the dataset using the above basic functions

In [None]:
def get_negative_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file=None, cts_size=30, seed_match='offset-9-mer-m7', header=False, predict_mode=True):
    #prepare dataset with query, target and predictions
    mirna_ids, mirna_seqs, mrna_ids, mrna_seqs = read_fasta(mirna_fasta_file, mrna_fasta_file)
    
    dataset = {
        'query_ids': [],
        'target_ids': [],
        'predicts': []
    }
    
    if ground_truth_file is not None:
        query_ids, target_ids, labels = read_ground_truth(ground_truth_file, header=header)
        
        for i in range(len(query_ids)):
            try:
                j = mirna_ids.index(query_ids[i])
            except ValueError:
                continue
            try:
                k = mrna_ids.index(target_ids[i])
            except ValueError:
                continue

            query_seqs, target_seqs, locations = make_pair(mirna_seqs[j], mrna_seqs[k], cts_size=cts_size, seed_match=seed_match)

            n_pairs = len(locations)
            if (n_pairs == 0) and (predict_mode is True):
                dataset['query_ids'].append(query_ids[i])
                dataset['target_ids'].append(target_ids[i])
                dataset['predicts'].append(0)
            elif (n_pairs == 0) and (predict_mode is False):
                dataset['query_ids'].append(query_ids[i])
                dataset['target_ids'].append(target_ids[i])
                dataset['predicts'].append(labels[i])
    else:
        for i in range(len(mirna_ids)):
            for j in range(len(mrna_ids)):
                query_seqs, target_seqs, locations = make_pair(mirna_seqs[i], mrna_seqs[j], cts_size=cts_size, seed_match=seed_match)

                n_pairs = len(locations)
                if n_pairs == 0:
                    dataset['query_ids'].append(mirna_ids[i])                
                    dataset['target_ids'].append(mrna_ids[j])
                    dataset['predicts'].append(0)
    
    dataset['target_locs'] = [-1 for i in range(len(dataset['query_ids']))]
    dataset['probabilities'] = [0.0 for i in range(len(dataset['query_ids']))]
    
    return dataset


def postprocess_result(dataset, probabilities, predicts, predict_mode=True, output_file=None, cts_size=30, seed_match='offset-9-mer-m7', level='site'):
    neg_pairs = get_negative_pair(dataset['mirna_fasta_file'], dataset['mrna_fasta_file'], dataset['ground_truth_file'], cts_size=cts_size, seed_match=seed_match, predict_mode=predict_mode)
    
    query_ids = np.append(dataset['query_ids'], neg_pairs['query_ids'])
    target_ids = np.append(dataset['target_ids'], neg_pairs['target_ids'])
    target_locs = np.append(dataset['target_locs'], neg_pairs['target_locs'])
    probabilities = np.append(probabilities, neg_pairs['probabilities'])
    predicts = np.append(predicts, neg_pairs['predicts'])
    
    # output format: [QUERY, TARGET, LOCATION, PROBABILITY]
    records = pd.DataFrame(columns=['MIRNA_ID', 'MRNA_ID', 'LOCATION', 'PROBABILITY'])
    records['MIRNA_ID'] = query_ids
    records['MRNA_ID'] = target_ids
    records['LOCATION'] = np.array(["{},{}".format(max(1, l-cts_size+1), l) if l != -1 else "-1,-1" for l in target_locs])
    records['PROBABILITY'] = probabilities
    if predict_mode is True:
        records['PREDICT'] = predicts
    else:
        records['LABEL'] = predicts
    
    records = records.sort_values(by=['PROBABILITY', 'MIRNA_ID', 'MRNA_ID'], ascending=[False, True, True])
    unique_records = records.sort_values(by=['PROBABILITY', 'MIRNA_ID', 'MRNA_ID'], ascending=[False, True, True]).drop_duplicates(subset=['MIRNA_ID', 'MRNA_ID'], keep='first')
    
    if level == 'site':
        if output_file is not None:
            records.to_csv(output_file, index=False, sep='\t')
        
        return records
    elif level == 'gene':
        if output_file is not None:
            unique_records.to_csv(output_file, index=False, sep='\t')
        
        return unique_records
    else:
        raise ValueError("level expected 'site' or 'gene', got '{}'".format(mode))

Time to setup the dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
    #herediting the torch class for the dataset???
    def __init__(self, mirna_fasta_file, mrna_fasta_file, ground_truth_file, cts_size=30, seed_match='offset-9-mer-m7', header=True, train=True):
        #setup with fuctions above
        self.dataset = make_input_pair(mirna_fasta_file, mrna_fasta_file, ground_truth_file, cts_size=cts_size, seed_match=seed_match, header=header, train=train)
        self.mirna, self.mrna = preprocess_data(self.dataset['query_seqs'], self.dataset['target_seqs'])
        self.labels = np.asarray(self.dataset['labels']).reshape(-1,)
        
        self.mirna = self.mirna.transpose((0, 2, 1))
        self.mrna = self.mrna.transpose((0, 2, 1))
        
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
            
        mirna = self.mirna[index]
        mrna = self.mrna[index]
        label = self.labels[index]
        
        return (mirna, mrna), label
        
    def __len__(self):
        return len(self.labels)


class TrainDataset(torch.utils.data.Dataset):
    #herediting the torch class for the training dataset???
    def __init__(self, ground_truth_file, cts_size=30):
        #initialize with train set (ground_truth)
        self.records = pd.read_csv(ground_truth_file, header=0, sep='\t')
        mirna_seqs = self.records['MIRNA_SEQ'].values.tolist()
        mrna_seqs = self.records['MRNA_SEQ'].values.tolist()
        self.mirna, self.mrna = preprocess_data(mirna_seqs, mrna_seqs, cts_size=cts_size)
        self.labels = self.records['LABEL'].values.astype(int)
        
        self.mirna = self.mirna.transpose((0, 2, 1))
        self.mrna = self.mrna.transpose((0, 2, 1))
    
    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        
        mirna = self.mirna[index]
        mrna = self.mrna[index]
        label = self.labels[index]
        
        return (mirna, mrna), label
        
    def __len__(self):
        return len(self.labels)

# Model

In [None]:
class HyperParam:
    def __init__(self, filters=None, kernels=None, model_json=None):
        self.dictionary = dict()
        self.name_postfix = str()

        if (filters is not None) and (kernels is not None) and (model_json is None):
            for i, (f, k) in enumerate(zip(filters, kernels)):
                setattr(self, 'f{}'.format(i+1), f)
                setattr(self, 'k{}'.format(i+1), k)
                self.dictionary.update({'f{}'.format(i+1): f, 'k{}'.format(i+1): k})
            self.len = i+1
                
            for key, value in self.dictionary.items():
                self.name_postfix = "{}_{}-{}".format(self.name_postfix, key, value)
        elif model_json is not None:
            self.dictionary = json.loads(model_json)
            for i, (key, value) in enumerate(self.dictionary.items()):
                setattr(self, key, value)
                self.name_postfix = "{}_{}-{}".format(self.name_postfix, key, value)
            self.len = (i+1)//2
    
    def __len__(self):
        return self.len


class deepTarget(nn.Module):
    def __init__(self, hparams=None, hidden_units=30, input_shape=(1, 4, 30), name_prefix="model"):
        super(deepTarget, self).__init__()
        
        if hparams is None:
            filters, kernels = [32, 16, 64, 16], [3, 3, 3, 3]
            hparams = HyperParam(filters, kernels)
        self.name = "{}{}".format(name_prefix, hparams.name_postfix)
        
        if (isinstance(hparams, HyperParam)) and (len(hparams) == 4):
            self.embd1 = nn.Conv1d(4, hparams.f1, kernel_size=hparams.k1, padding=((hparams.k1 - 1) // 2))
            self.conv2 = nn.Conv1d(hparams.f1*2, hparams.f2, kernel_size=hparams.k2)
            self.conv3 = nn.Conv1d(hparams.f2, hparams.f3, kernel_size=hparams.k3)
            self.conv4 = nn.Conv1d(hparams.f3, hparams.f4, kernel_size=hparams.k4)
            
            """ out_features = ((in_length - kernel_size + (2 * padding)) / stride + 1) * out_channels """
            flat_features = self.forward(torch.rand(input_shape), torch.rand(input_shape), flat_check=True)
            self.fc1 = nn.Linear(flat_features, hidden_units)
            self.fc2 = nn.Linear(hidden_units, 2)
        else:
            raise ValueError("not enough hyperparameters")
    
    def forward(self, x_mirna, x_mrna, flat_check=False):
        h_mirna = F.relu(self.embd1(x_mirna))
        h_mrna = F.relu(self.embd1(x_mrna))
        
        h = torch.cat((h_mirna, h_mrna), dim=1)
        h = F.relu(self.conv2(h))
        h = F.relu(self.conv3(h))
        h = F.relu(self.conv4(h))
        
        h = h.view(h.size(0), -1)
        if flat_check:
            return h.size(1)
        h = self.fc1(h)
        y = self.fc2(h) #y = F.softmax(self.fc2(h), dim=1)
        
        return y
    
    def size(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

# Training

In [None]:
def train_model(mirna_fasta_file, mrna_fasta_file, train_file, model=None, cts_size=30, seed_match='offset-9-mer-m7', level='gene', batch_size=32, epochs=10, save_file=None, device='cpu'):
    if not isinstance(model, deepTarget):
        raise ValueError("'model' expected <nn.Module 'deepTarget'>, got {}".format(type(model)))
    
    print("\n[TRAIN] {}".format(model.name))
    
    if train_file.split('/')[-1] == 'train_set.csv':
        train_set = TrainDataset(train_file)
    else:
        train_set = Dataset(mirna_fasta_file, mrna_fasta_file, train_file, seed_match=seed_match, header=True, train=True)

    print("BBB")
    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)


    
    class_weight = torch.Tensor(compute_class_weight('balanced', classes=np.unique(train_set.labels), y=train_set.labels)).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weight)
    optimizer = optim.Adam(model.parameters())
    
    model = model.to(device)
    for epoch in range(epochs):
        epoch_loss, corrects = 0, 0

        with tqdm(train_loader, desc="Epoch {}/{}".format(epoch+1, epochs), bar_format=bar_format) as tqdm_loader:
            for i, ((mirna, mrna), label) in enumerate(tqdm_loader):
                print("AAA")
                mirna, mrna, label = mirna.to(device, dtype=torch.float), mrna.to(device, dtype=torch.float), label.to(device)
                
                outputs = model(mirna, mrna)
                loss = criterion(outputs, label)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item() * outputs.size(0)
                corrects += (torch.max(outputs, 1)[1] == label).sum().item()
                
                if (i+1) == len(train_loader):
                    tqdm_loader.set_postfix(dict(loss=(epoch_loss/len(train_set)), acc=(corrects/len(train_set))))
                else:
                    tqdm_loader.set_postfix(loss=loss.item())
    
    if save_file is None:
        time = datetime.now()
        save_file = "{}.pt".format(time.strftime('%Y%m%d_%H%M%S_weights'))
    torch.save(model.state_dict(), save_file)
    
    
def predict_result(mirna_fasta_file, mrna_fasta_file, query_file, model=None, weight_file=None, seed_match='offset-9-mer-m7', level='gene', output_file=None, device='cpu'):
    if not isinstance(model, deepTarget):
        raise ValueError("'model' expected <nn.Module 'deepTarget'>, got {}".format(type(model)))
    
    if not weight_file.endswith('.pt'):
        raise ValueError("'weight_file' expected '*.pt', got {}".format(weight_file))
    
    model.load_state_dict(torch.load(weight_file))
    
    test_set = Dataset(mirna_fasta_file, mrna_fasta_file, query_file, seed_match=seed_match, header=True, train=False)
    
    model = model.to(device)
    with torch.no_grad():
        model.eval()
        
        mirna = torch.from_numpy(test_set.mirna).to(device, dtype=torch.float)
        mrna = torch.from_numpy(test_set.mrna).to(device, dtype=torch.float)
        label = torch.from_numpy(test_set.labels).to(device)
        
        outputs = model(mirna, mrna)
        _, predicts = torch.max(outputs.data, 1)
        probabilities = F.softmax(outputs, dim=1)
        
        y_probs = probabilities.cpu().numpy()[:, 1]
        y_predicts = predicts.cpu().numpy()
        y_truth = label.cpu().numpy()
        
        if output_file is None:
            time = datetime.now()
            output_file = "{}.csv".format(time.strftime('%Y%m%d_%H%M%S_results'))
        results = postprocess_result(test_set.dataset, y_probs, y_predicts,
                                     seed_match=seed_match, level=level, output_file=output_file)
        
        print(results)

In [None]:
def foo(mode):
    mirna_fasta_file = "data/mirna.fasta"
    mrna_fasta_file  = "data/mrna.fasta"
    seed_match       = "offset-9-mer-m7"
    level            = "gene"
    train_file       = "data/train/train_set.csv"
    save_file        = "weights.pt"
    query_file       = "templates/query_set.csv"
    weight_file      = 'model/weights.pt'
    output_file      = "results.csv"
    batch_size       = 32
    epochs           = 10
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    if mode == 'train':

        start_time = datetime.now()
        print("\n[START] {}".format(start_time.strftime('%Y-%m-%d @ %H:%M:%S')))
        
        model = deepTarget()
        train_model(mirna_fasta_file, mrna_fasta_file, train_file,
                    model=model,
                    seed_match=seed_match, level=level,
                    batch_size=batch_size, epochs=epochs,
                    save_file=save_file, device=device)
        
        finish_time = datetime.now()
        print("\n[FINISH] {} (user time: {})\n".format(finish_time.now().strftime('%Y-%m-%d @ %H:%M:%S'), (finish_time - start_time)))

    elif mode == 'predict':
        if query_file is None:
            raise ValueError("'--query_file' expected '*.csv', got '{}'".format(configs.QUERY_FILE))
            
        start_time = datetime.now()
        print("\n[START] {}".format(start_time.strftime('%Y-%m-%d @ %H:%M:%S')))
        
        model = deepTarget()
        results = predict_result(mirna_fasta_file, mrna_fasta_file, query_file,
                                 model=model, weight_file=weight_file,
                                 seed_match=seed_match, level=level,
                                 output_file=output_file, device=device)
        
        finish_time = datetime.now()
        print("\n[FINISH] {} (user time: {})\n".format(finish_time.now().strftime('%Y-%m-%d @ %H:%M:%S'), (finish_time - start_time)))


def main():
    foo("train")
    foo("predict")

def parse_arguments():
    # parser = argparse.ArgumentParser()
    
    # parser.add_argument('--mode', dest='MODE', type=str, required=True,
    #                     help="run mode: [train|predict]")
    
    # parser.add_argument('--mirna_file', dest='MIRNA_FASTA_FILE', type=str,
    #                     help="miRNA fasta file (default: data/miRNA.fasta)")
    # parser.add_argument('--mrna_file', dest='MRNA_FASTA_FILE', type=str,
    #                     help="mRNA fasta file (default: data/mRNA.fasta)")
    # parser.add_argument('--seed_match', dest='SEED_MATCH', type=str,
    #                     help="seed match type: [offset-9-mer-m7|10-mer-m7|10-mer-m6] (default: offset-9-mer-m7)")
    # parser.add_argument('--level', dest='LEVEL', type=str,
    #                     help="prediction level: [gene|site] (default: gene)")
    
    # parser.add_argument('--train_file', dest='TRAIN_FILE', type=str,
    #                     help="training file to be used in 'train' mode (sample: data/train_set.csv)")
    # parser.add_argument('--save_file', dest='SAVE_FILE', type=str,
    #                     help="state_dict file to be saved in 'train' mode (default: yyyyMMdd_HHmmss_weights.pt)")
    # parser.add_argument('--query_file', dest='QUERY_FILE', type=str,
    #                     help="query file to be queried in 'predict' mode (sample: templates/query_set.csv)")
    # parser.add_argument('--weight_file', dest='WEIGHT_FILE', type=str,
    #                     help="state_dict file to be loaded in 'predict' mode (default: model/weights.pt)")
    # parser.add_argument('--output_file', dest='OUTPUT_FILE', type=str,
    #                     help="output file to be saved in 'predict' mode (default: yyyyMMdd_HHmmss_results.csv)")
    
    # parser.add_argument('--batch_size', dest='BATCH_SIZE', type=int,
    #                     help="batch size to be used in 'train' mode (default: 32)")
    # parser.add_argument('--epochs', dest='EPOCHS', type=int,
    #                     help="epochs to be used in 'train' mode (default: 10)")

    
    return args

In [None]:

main()

In [None]:
ls

In [None]:
class Pippo():
    def __init__(self,a):
        self.a = a
    def __call__(self):
        return self.a

In [None]:
asd = Pippo(3)

In [None]:
asd()

In [None]:
ls