### 0. Instruction
This file instructs you to create annotation file in the following format

filename: annotations.npy

shape: (number of sequences, number of region types, sequence length)

        where the default number of region types is five, and the default sequence length is 101
        Indicies correspond to each region type as: 0=5'UTR, 1=3'UTR, 2=exon, 3=intron, 4=CDS

As long as your resulting file has the above format and corresponding dev.tsv file, analysis can be conducted without any problems.

### 1. Create fasta files for RBPs from the benchmark file
The files generated using generate_datasets.py are in tsv formats. This codes here will generate those files in the fasta format.

In [2]:
import argparse
import os
import re
import sys
sys.path.append('./motif')
from motif_utils import seq2kmer, kmer2seq
import pandas as pd
import numpy as np

In [4]:
# where the benchmark file is
BENCHMARK_PATH = "THIS_IS_THE_PATH_TO_YOUR_BENCHMARK_FILE"
OUTPUT_PATH = "./datasets"

parser = argparse.ArgumentParser()
parser.add_argument(
    "--path_to_benchmark",
    default=None,
    type=str,
    required=True,
    help="path to the benchmark file",
)
parser.add_argument(
    "--path_to_output",
    default=None,
    type=str,
    required=True,
    help="path to the output directory",
)
parser.add_argument(
    "--kmer",
    default=3,
    type=int,
    required=False,
    help="kmer of the output file",
)
parser.add_argument(
    "--max_num",
    default=15000,
    type=int,
    required=False,
    help="maximum number of samples to retrieve",
)
parser.add_argument(
    "--test_ratio",
    default=0.2,
    type=float,
    required=False,
    help="ratio of test data",
)
parser.add_argument(
    "--random_seed",
    default=0,
    type=int,
    required=False,
    help="seed number for random sampling",
)

args = parser.parse_args(args = ["--path_to_benchmark", BENCHMARK_PATH,\
                                                 "--path_to_output", OUTPUT_PATH])

In [None]:
OUTPUT_FILE = 'original.fasta'

def createfasta(args):
    with open(args.input_file) as f:
        lines = f.readlines()
    
    filename = 'EMPTY'
    i = 0
    rbplist = []
    while i < len(lines):
        line = lines[i]
        if 'train_dir' in line:
            pattern = '([A-Z0-9]+.negative)|([A-Z0-9]+.positive)'
            filename = re.sub('[.]', '_', re.search(pattern, line).group())
            # print(filename)
            value = 0
            if 'positive' in filename:
                value = 1
            filename = re.search('[A-Z0-9]+', filename).group()
            filepath = os.path.join(args.output_dir, filename, OUTPUT_FILE)
            
            
            if os.path.exists(filepath):
                
                if filename in rbplist:
                    with open(filepath, mode='a') as f:
                        flg = 0
                        while flg == 0:
                            found_header = re.findall(">chr[0-9XYM]+:[0-9]+\-[0-9]+\([\+\-]\)", lines[i])
                            found_sequence = re.findall("[AUGCTN]+", lines[i])
                            if len(found_header)==1:
                                f.write(found_header[0] + " rbp_bound: " + str(value) + '\n')
                            elif len(found_sequence)==1:
                                f.write(re.sub('U', 'T', found_sequence[0]) + '\n')
                            else:
                                if i<len(lines):
                                    print("ERROR: rbp {} line #{}".format(filename, i))
                                    print(lines[i])
                            i += 1
                            if i >= len(lines):
                                break
                            elif 'train_dir' in lines[i]:
                                break
                else:
                    with open(filepath, mode='w') as f:
                        flg = 0
                        while flg == 0:
                            found_header = re.findall(">chr[0-9XYM]+:[0-9]+\-[0-9]+\([\+\-]\)", lines[i])
                            found_sequence = re.findall("[AUGCTN]+", lines[i])
                            if len(found_header)==1:
                                f.write(found_header[0] + " rbp_bound: " + str(value) + '\n')
                            elif len(found_sequence)==1:
                                f.write(re.sub('U', 'T', found_sequence[0]) + '\n')
                            else:
                                #found_nnn = re.findall("[N]+", lines[i])
                                #if found_nnn:
                                #    print("found NNN: rbp {} line #{}".format(filename, i))
                                if i<len(lines):
                                    print("ERROR: rbp {} line #{}".format(filename, i))
                                    print(lines[i])
                            i += 1
                            if i >= len(lines):
                                break
                            elif 'train_dir' in lines[i]:
                                break
                
                rbplist.append(filename)

            else:
                print("ERROR")
        else:
            print('ERROR')
    return

### 2. Create bed files from original.fasta

In [None]:
MASTER_DIR = "./datasets"
FASTA_NAME = "original.fasta"
BED_OUTPUT_DIR = ""
BED_NAME = "original.bed"

RBPS = ("AARS",)

for rbp in RBPS:
    fasta_file = os.path.join(MASTER_DIR, rbp, FASTA_NAME)
    if len(BED_OUTPUT_DIR)==0:
        BED_OUTPUT_DIR = os.path.join(MASTER_DIR, rbp)
    bed_file = os.path.join(BED_OUTPUT_DIR, BED_NAME)

    lines_bed = []

    count = 0
    with open(fasta_file, 'r') as f:
            for line in f:
                if ">chr" in line:
                    #lines_bed.append(line[:-])
                    found = re.findall(">chr[0-9XYM]+:[0-9]+\-[0-9]+", line)
                    if not len(found)==1:
                        print('ERROR: found {} match at line #{}'.format(len(found), count))
                        print(line)
                    else:
                        lines_bed.append(found[0][1:])
                count += 1

    with open(bed_file, 'w') as f:
            f.write("\n".join(lines_bed))

    # print(count)
    # print(len(lines_bed))

### 3. Convert bed(hg19) to bed(hg38)
3-1. upload the generated bed file to the UCSC liftover page (https://genome.ucsc.edu/cgi-bin/hgLiftOver)

3-2. convert the version from "Feb. 2009 (GRCh37/hg19)" to "Dec. 2013 (GRCh38/hg38)" by setting the parameters as follows:

    Minimum ratio of bases that must remap: 1
    
3-3. submit your bed file

3-4. download the converted file and rename it as "original_hg38.bed"

3-5. browse the failed list of sequences and copy and paste them into the file named "fail.fasta"

### 4. Create original_hg38.fasta
Create fasta file with hg38 versions.

In [None]:
MASTER_DIR = "./datasets"
FASTA_NAME = "original.fasta"
FASTA_OUTPUT_NAME = "original_hg38.fasta"
FASTA_FAIL_NAME = "fail.fasta"
BED_OUTPUT_DIR = ""
BED_NAME = "original_hg38.bed"

RBPS = ("AARS",)

for rbp in RBPS:
    fasta_file = os.path.join(MASTER_DIR, rbp, FASTA_NAME)
    fasta_out_file = os.path.join(MASTER_DIR, rbp, FASTA_OUTPUT_NAME)
    fasta_fail = os.path.join(MASTER_DIR, rbp, FASTA_FAIL_NAME)
    if len(BED_OUTPUT_DIR)==0:
        BED_OUTPUT_DIR = os.path.join(MASTER_DIR, rbp)
    bed_file = os.path.join(BED_OUTPUT_DIR, BED_NAME)

    lines_bed = []
    with open(bed_file, 'r') as f:
        lines_bed = f.readlines()

    fail_list = []
    with open(fasta_fail, 'r') as f:
        for line in f:
            if re.match("chr[0-9XYM]+:[0-9]+\-[0-9]+", line):
                fail_list.append(line.strip())

    lines_fasta = []
    with open(fasta_file, 'r') as f:
        lines_fasta = f.readlines()

    lines_fasta_new = []
    idx_bed = 0
    flg = 0
    for i, line in enumerate(lines_fasta):
        if ">chr" in line and flg == 0:
            flg = 1
            bed_info = re.findall("chr[0-9XYM]+:[0-9]+\-[0-9]+", line)[0]
            bind_info = re.findall("\([\-\+]\) rbp_bound: [0-1]", line)[0]
            #if i < 20:
            #    print(bed_info)
            #    print(bind_info)
            if bed_info in fail_list:
                flg = 0
            else:
                #if not re.findall("chr[0-9XYM]+", bed_info)[0] == re.findall("chr[0-9XYM]+", lines_bed[idx_bed])[0]:
                #    print(bed_info, re.findall("chr[0-9XYM]+", bed_info)[0])
                #    print(idx_bed, lines_bed[idx_bed], re.findall("chr[0-9XYM]+", lines_bed[idx_bed])[0])
                #assert re.findall("chr[0-9XYM]+", bed_info)[0] == re.findall("chr[0-9XYM]+", lines_bed[idx_bed])[0]
                new_line = ">" + lines_bed[idx_bed].strip() + bind_info
                lines_fasta_new.append(new_line)
                idx_bed += 1
        elif flg==1:
            lines_fasta_new.append(line.strip())
            flg = 0

    with open(fasta_out_file, "w") as f:
        f.write("\n".join(lines_fasta_new))

### 5. Create fasta from dev and original_hg38.fasta
convert dev.tsv file in the nontraining_finetune directory to the fasta format

In [None]:
MASTER_DIR = "./datasets"
SUBDIR1 = "nontraining_sample_finetune"
TSV_FILE = "dev.tsv"
FASTA_OUTPUT_FILE = "dev.fasta"
FASTA_REFERENCE_FILE = "original_hg38.fasta"

RBPS = ("AARS",)
for rbp in RBPS:
#rbp = "HNRNPA1"

    fasta_output_file = os.path.join(MASTER_DIR, rbp, SUBDIR1)
    if not os.path.isdier(fasta_output_file):
        os.makedirs(fasta_output_file)
    fasta_output_file = os.path.join(MASTER_DIR, rbp, SUBDIR1, FASTA_OUTPUT_FILE)
    tsv_orig_file = os.path.join(MASTER_DIR, rbp, SUBDIR1, TSV_FILE)
    fasta_ref_file = os.path.join(MASTER_DIR, rbp, FASTA_REFERENCE_FILE)

    df_seqs = pd.read_csv(tsv_orig_file, sep="\t")
    print("len(df_seqs)", len(df_seqs))

    lines_fasta_header = []
    lines_fasta_seq = []
    with open(fasta_ref_file, 'r') as f:
        flg = 0
        for line in f:
            if ">chr" in line:
                chromosome = re.findall("chr[0-9XYM]+:", line)
                if chromosome:
                    lines_fasta_header.append(line.strip())
                    flg = 1
                else:
                    flg = 0
            elif re.match("[ATGC]+", line) and flg==1:
                lines_fasta_seq.append(line.strip())
                flg = 0
            elif re.match("[ATGC]+", line) and flg==0:
                flg = 0
            else:
                print(line)

    assert len(lines_fasta_header) == len(lines_fasta_seq)

    fasta_new = []
    keep_track = int(len(df_seqs)*0.2)
    for i in range(len(df_seqs)):
        seq = df_seqs["sequence"][i]
        seq = kmer2seq(seq)
        if seq in lines_fasta_seq[i:]:
            idx = i + lines_fasta_seq[i:].index(seq)
            fasta_new.append(lines_fasta_header[idx])
            fasta_new.append(lines_fasta_seq[idx])
        if (i%keep_track==0 and not i==0) or i==len(df_seqs)-1:
            print("finished converting {}/{} sequences".format(keep_track, len(df_seqs)))

    print("len(fasta_new)", len(fasta_new))

    with open(fasta_output_file, 'w') as f:
        f.write("\n".join(fasta_new))

### 6. Download transcript annotations
6-1. download transcript annotations from the Ensembl website (https://www.ensembl.org/biomart/martview/41609dbda8c4191fad2eab93679401f8)

    6-1-1. Choose "Ensemble Gene 103" and "Human genes (GRCh38.p13)"
    6-1-2. In "Filters/REGION", select the chromosome 1
    6-1-3. In "Attributes", select "Features", and then check the following boxes in the same order:
        "Gene stable ID", "Transcript stable ID", "APPRIS annotation",
        "Transcript support level(TSL)", "Transcript length (including UTRs and CDS)"
    6-1-4. Download the result and save it as "chr1_info.txt"
    6-1-5. In "Attributes", select "Structures", and then check the following boxes in the same order:
        "Gene stable ID", "Transcript stable ID", "Transcript start (bp)",  "Transcript end (bp)",
         "5'UTR start",  "5'UTR end",  "3'UTR start",  "3'UTR end",  "Exon region start (bp)",  "Exon region end (bp)",
    6-1-6. Download the result and save it as "chr1_value.txt"
    6-1-7. In "Attributes", select "Sequences", and then check the following boxes in the same order:
        "Unspliced (Transcript)", "Gene stable ID", "Transcript stable ID"
    6-1-8. Download the results and save it as "chr1_seq.txt"
    6-1-9. Repeat 6-1-2 through 6-1-8 for chromosome 2-22, MT, X, and Y

6-2. convert the information into chrOO.csv by running the following command

In [None]:
import pandas as pd
import numpy as np
import os
from Bio import SeqIO
import re

def load_and_format_chr_info(info_path):
    df = pd.read_csv(info_path)

    # assign APPRIS rank to each transcript
    tmp_array = np.array(df['APPRIS annotation'])
    new_array = []
    for appris in tmp_array:
        if type(appris) == str:
            if re.fullmatch('principal[0-9]+', appris):
                new_array.append(int(re.search('[0-9]+', appris).group()))
            elif re.fullmatch('alternative[0-9]+', appris):
                new_array.append(100 + int(re.search('[0-9]+', appris).group()))
            else:
                print(appris)
                raise Exception("str type object did not match with APPRIS format")
        elif np.isnan(appris):
            new_array.append(appris)
        else:
            print(appris)
            raise Exception("found exception for APPRIS annotation")
    # print(len(df), len(tmp_array), len(new_array))
    assert len(tmp_array)==len(new_array)
    df['APPRIS rank'] = new_array
    
    # assign TSL rank to each transcript
    tmp_array = np.array(df['Transcript support level (TSL)'])
    new_array = []
    new_array2 = []
    for tsl in tmp_array:
        if type(tsl) == str:
            if 'tslNA' in tsl:
                new_array.append('tslNA')
                new_array2.append(np.nan)
            elif re.fullmatch('tsl[0-9]+', tsl):
                tsl_level = re.match('tsl[0-9]+', tsl).group()
                new_array.append(tsl_level)
                new_array2.append(0)
            elif re.fullmatch('tsl[0-9]+ \(assigned to previous version [0-9]+\)', tsl):
                tsl_level = re.match('tsl[0-9]+', tsl).group()
                new_array.append(tsl_level)
                prev_version = re.search('version [0-9]+', tsl).group()
                prev_version = int(re.search('[0-9]+', prev_version).group())
                new_array2.append(prev_version)
            else:
                print(tsl)
                raise Exception("str type object did not match with TSL format")
        elif np.isnan(tsl):
            new_array.append(tsl)
            new_array2.append(np.nan)
        else:
            print(appris)
            raise Exception("found exception for TSL")
    # print(len(df), len(tmp_array), len(new_array))
    assert len(tmp_array)==len(new_array)
    df['TSL rank'] = new_array
    df['TSL version'] = new_array2
    
    return df

In [None]:
ANNOT_DIR = 'DIRECTORY_WHERE_THE_ANNOTATION_FILES_EXIST'
CHROMOSOMES = ['chrX', 'chrY', 'chrM']
nums = np.arange(1,22)
for num in nums:
    CHROMOSOMES.append('chr{}'.format(num))
print(CHROMOSOMES)

COLUMN_NAMES = ["Gene_stable_ID", "Transcript_stable_ID", "Transcript_sequence",
           "Transcript_start_(bp)", "Transcript_end_(bp)",
           "5'_UTR_start", "5'_UTR_end", "3'_UTR_start", "3'_UTR_end",
           "Exon_region_start_(bp)", "Exon_region_end_(bp)", "Strand"]

for chromosome in CHROMOSOMES:
    print('processing {} ...'.format(chromosome))
    
    # read seq data
    seq_path = os.path.join(ANNOT_DIR, '{}_seq.txt'.format(chromosome))
    seq_data = []
    for record in SeqIO.parse(seq_path, 'fasta'):
        tmp_data = record.id.split('|')
        tmp_data.append(str(record.seq))
        seq_data.append(tmp_data)
    df_seq = pd.DataFrame(seq_data, columns=['Gene_stable_ID','Transcript_stable_ID','Transcript_sequence'])
    print('\tnumber of transcripts in seq.txt:', len(df_seq))
    
    # read value data
    df_value = pd.read_csv(os.path.join(ANNOT_DIR, '{}_value.txt'.format(chromosome)))
    
    # read info data
    df_info = load_and_format_chr_info(os.path.join(ANNOT_DIR, '{}_info.txt'.format(chromosome)))
    unique_gene_ids = np.unique(np.array(df_info['Gene stable ID']))
    print('\tnumber of transcripts in info.txt:', len(df_info))
    print('\tnumber of unique gene ids in info.txt:', len(unique_gene_ids))
    save_info = []
    for gene_id in unique_gene_ids:
        tmp_save = []
        # record gene id
        tmp_save.append(gene_id)
        
        # record the best transcript
        df_tmp = df_info[df_info['Gene stable ID']==gene_id].copy()
        df_tmp = df_tmp.sort_values(['APPRIS rank', 'TSL rank', 'Transcript length (including UTRs and CDS)', 'TSL version'], ascending=[True, True, False, False])
        transcript_id = df_tmp.iloc[0,1]
        tmp_save.append(transcript_id)
        
        # record sequence
        sequence = df_seq.query('Gene_stable_ID=="{}" & Transcript_stable_ID=="{}"'.format(gene_id, transcript_id))['Transcript_sequence'].values[0]
        assert type(sequence)==str
        assert re.fullmatch('[ATGCN]+', sequence)
        tmp_save.append(sequence)
        
        df_tmp = df_value[df_value['Transcript stable ID']==transcript_id]
        # record transcript info
        ts_start = np.unique(df_tmp['Transcript start (bp)'].values).astype(int)
        ts_end = np.unique(df_tmp['Transcript end (bp)'].values).astype(int)
        if not len(ts_start)==1:
            print(gene_id, transcript_id)
        assert len(ts_start)==1
        assert len(ts_end)==1
        tmp_save.append(ts_start[0])
        tmp_save.append(ts_end[0])
        
        # record regiontype info
        for tmp_column in ["5' UTR start", "5' UTR end", "3' UTR start", "3' UTR end", "Exon region start (bp)", "Exon region end (bp)"]:
            tmp_array = np.array(df_tmp[tmp_column])
            tmp_array = list(tmp_array[~np.isnan(tmp_array)].astype(int))
            tmp_save.append(tmp_array)
        assert len(tmp_save[-1])==len(tmp_save[-2])
        assert len(tmp_save[-3])==len(tmp_save[-4])
        assert len(tmp_save[-5])==len(tmp_save[-6])

        # record strand info
        strand = np.unique(df_tmp['Strand'].values).astype(int)
        assert len(strand)==1
        tmp_save.append(strand[0])
        save_info.append(tmp_save)

    df_save = pd.DataFrame(save_info, columns=COLUMN_NAMES)
    save_path = os.path.join(ANNOT_DIR, '{}.csv'.format(chromosome))
    df_save.to_csv(save_path, index=False)
    print('\tsaved into {}'.format(save_path))

### 7. Generate annotation.npy


In [6]:
import os
import pandas as pd
import numpy as np
import argparse
import sys
sys.path.append('./motif')
from motif_utils import seq2kmer, kmer2seq
from Bio import Align
from Bio.Seq import reverse_complement
import argparse
import re

# this takes 1-based genomic positions (overlap_start, overlap_end, region_start, region_end)
# this returns 0-based indeces within the overlap sequence (index 0-100 for 101-length overlap sequence)
def assign_region(overlap_start, overlap_end, region_start, region_end, seq):
    assigned_start = -1
    assigned_end = 0
    if overlap_start > region_end  or overlap_end < region_start:
        pass
    else:
        if overlap_start >= region_start:
            assigned_start = 0
        else:
            assigned_start = region_start - overlap_start
        if overlap_end <= region_end:
            assigned_end = len(seq)
        else:
            assigned_end = region_end - overlap_start + 1
    return assigned_start, assigned_end

# this takes 1-based genomic positions (transcript_start, transcript_end, query_start, query_end)
# this returns 1-based genomic positions
def assign_transcript(transcript_start, transcript_end, query_start, query_end):
    transcript_sub_start = 0
    transcript_sub_end = 0
    if transcript_start > query_end  or transcript_end < query_start:
        assert False
    else:
        if transcript_start >= query_start:
            transcript_sub_start = transcript_start
        else:
            transcript_sub_start = query_start
        if transcript_end <= query_end:
            transcript_sub_end = transcript_end
        else:
            transcript_sub_end = query_end
            
    return transcript_sub_start, transcript_sub_end

def get_annotation(annot_dir, chromosome, query_start, query_end, strand, query_seq):
    path = os.path.join(annot_dir, (chromosome + ".csv"))
    df_chr = pd.read_csv(path)
    allow_mismatch = 0.05 #ratio
    thresh_for_partial = 0.5
    
    # 5UTR, 3UTR, exon, intron, CDS
    annotations = np.zeros([5, len(query_seq)])
    for i in range(len(df_chr)):
        transcript_start = int(df_chr["Transcript_start_(bp)"][i])
        transcript_end = int(df_chr["Transcript_end_(bp)"][i])
        if query_start > transcript_end or query_end < transcript_start: # no overlap
            pass
        else:
            # retrieve sequences
            transcript_seq = df_chr["Transcript_sequence"][i]
            if df_chr["Strand"][i] == -1:
                transcript_seq = reverse_complement(transcript_seq)
            if strand == "-":
                query_seq = reverse_complement(query_seq)
                annotations = np.flip(annotations, axis=1)
            tmp_start, tmp_end = assign_region(transcript_start, transcript_end, query_start, query_end, transcript_seq)
            transcript_seq = transcript_seq[tmp_start:tmp_end]
            transcript_sub_start, transcript_sub_end = assign_transcript(transcript_start, transcript_end, query_start, query_end)
            
            # retrieve 5'UTR positions
            utr5_starts = re.findall("[0-9]+", df_chr["5'_UTR_start"][i])
            utr5_ends = re.findall("[0-9]+", df_chr["5'_UTR_end"][i])
            if len(utr5_starts) > 0:
                utr5_starts = list(map(int, utr5_starts))
            if len(utr5_ends) > 0:
                utr5_ends = list(map(int, utr5_ends))
            # retrieve 3'UTR positions
            utr3_starts = re.findall("[0-9]+", df_chr["3'_UTR_start"][i])
            utr3_ends = re.findall("[0-9]+", df_chr["3'_UTR_end"][i])
            if len(utr3_starts) > 0:
                utr3_starts = list(map(int, utr3_starts))
            if len(utr3_ends) > 0:
                utr3_ends = list(map(int, utr3_ends))
            # retrieve exon positions
            exon_starts = re.findall("[0-9]+", df_chr["Exon_region_start_(bp)"][i])
            exon_ends = re.findall("[0-9]+", df_chr["Exon_region_end_(bp)"][i])
            if len(exon_starts) > 0:
                exon_starts = list(map(int, exon_starts))
            if len(exon_ends) > 0:
                exon_ends = list(map(int, exon_ends))
            
            if query_start >= transcript_start and query_end <= transcript_end: # perfect overlap
                # detect overlapping region
                overlap_starts = []
                overlap_ends = []
                aligner = Align.PairwiseAligner()
                aligner.gap_score = -10000
                aligner.internal_gap_score = -1
                aligner.mode = "local"
                alignment = aligner.align(transcript_seq, query_seq)[0]
                if alignment.score > len(query_seq)*(1-allow_mismatch):
                    match_starts = alignment.aligned[0][0]
                    start_offsets = alignment.aligned[1][0]
                    for match_start, start_offset in zip(match_starts, start_offsets):
                        overlap_start = transcript_sub_start + match_start - start_offset
                        overlap_starts.append(overlap_start)
                else:
                    query_seq = reverse_complement(query_seq)
                    annotations = np.flip(annotations, axis=1)
                    alignment = aligner.align(transcript_seq, query_seq)[0]
                    if alignment.score > len(query_seq)*(1-allow_mismatch):
                        print("query strand(+/-) is wrong")
                        print("transcriptID: {}, {}, query_start {}, query_end {}, strand {}, query sequence:\n {}".format(\
                                    df_chr["Transcript_stable_ID"][i], chromosome, query_start, query_end, strand, query_seq))
                        match_starts = alignment.aligned[0][0]
                        start_offsets = alignment.aligned[1][0]
                        for match_start, start_offset in zip(match_starts, start_offsets):
                            overlap_start = transcript_sub_start + match_start - start_offset
                            overlap_starts.append(overlap_start)
                    else:
                        print("ERROR: query overlapped, but no match was found")
                        print("transcript #{}, ID: {}, chr {}, query_start {}, query_end {}, strand {}, query sequence:\n {}".format(\
                                    i, df_chr["Transcript_stable_ID"][i], chromosome, query_start, query_end, strand, query_seq))
                        break
                        #assert False
                        
                for overlap_start in overlap_starts:
                    overlap_end = overlap_start + len(query_seq) - 1
                    overlap_ends.append(overlap_end)
                # print("transcriptID: {}, strand: {}\n{}:{}-{}".format(df_chr["Transcript_stable_ID"][i], df_chr["Strand"][i], chromosome, overlap_start, overlap_end))
                assert overlap_start >=query_start and overlap_end <=query_end

                # assign 5'UTR(index=0)
                for utr5_start, utr5_end in zip(utr5_starts, utr5_ends):
                    for overlap_start, overlap_end in zip(overlap_starts, overlap_ends):
                        assigned_start, assigned_end = assign_region(overlap_start, overlap_end, utr5_start, utr5_end, query_seq)
                        annotations[0, assigned_start:assigned_end] = 1
                # assign 3'UTR(index=1)
                for utr3_start, utr3_end in zip(utr3_starts, utr3_ends):
                    for overlap_start, overlap_end in zip(overlap_starts, overlap_ends):
                        assigned_start, assigned_end = assign_region(overlap_start, overlap_end, utr3_start, utr3_end, query_seq)
                        annotations[1, assigned_start:assigned_end] = 1
                # assign exon(index=2)
                for exon_start, exon_end in zip(exon_starts, exon_ends):
                    for overlap_start, overlap_end in zip(overlap_starts, overlap_ends):
                        assigned_start, assigned_end = assign_region(overlap_start, overlap_end, exon_start, exon_end, query_seq)
                        annotations[2, assigned_start:assigned_end] = 1

                # assign cds(index=5) and intron(index=4)
                for pos in range(annotations.shape[1]):
                    if annotations[2,pos]==1 and np.sum(annotations[0:2, pos])==0:
                        annotations[4,pos] = 1
                    if np.sum(annotations[:,pos])==0:
                        annotations[3,pos] = 1
                
                if strand == "-":
                    query_seq = reverse_complement(query_seq)
                    annotations = np.flip(annotations, axis=1)
                
                break
            else: # partial overlap
                # detect overlapping region
                overlap_start = -1
                overlap_end = -1
                query_match_start = -1
                query_match_end = 0
                aligner = Align.PairwiseAligner()
                aligner.gap_score = -10000
                aligner.mode = "local"
                alignment = aligner.align(transcript_seq, query_seq)[0]
                if alignment.score > len(transcript_seq)*thresh_for_partial or alignment.score > 10:
                    match_start = alignment.aligned[0][0][0]
                    match_end = alignment.aligned[0][0][1]
                    query_match_start = alignment.aligned[1][0][0]
                    query_match_end = alignment.aligned[1][0][1]
                    overlap_start = transcript_sub_start + match_start - query_match_start
                    overlap_end = transcript_sub_start + match_end - 1
                else:
                    query_seq = reverse_complement(query_seq)
                    annotations = np.flip(annotations, axis=1)
                    alignment = aligner.align(transcript_seq, query_seq)[0]
                    if alignment.score > len(transcript_seq)*thresh_for_partial or alignment.score > 10:
                        print("query strand(+/-) is wrong")
                        print("transcriptID: {}, {}, query_start {}, query_end {}, strand {}, query sequence:\n {}".format(\
                                    df_chr["Transcript_stable_ID"][i], chromosome, query_start, query_end, strand, query_seq))
                        match_start = alignment.aligned[0][0][0]
                        match_end = alignment.aligned[0][0][1]
                        query_match_start = alignment.aligned[1][0][0]
                        query_match_end = alignment.aligned[1][0][1]
                        overlap_start = transcript_sub_start + match_start - query_match_start
                        overlap_end = transcript_sub_start + match_end - 1
                    else:
                        if strand == "+":
                            query_seq = reverse_complement(query_seq)
                            annotations = np.flip(annotations, axis=1)
                        print("alignment score is too low: score {}, len {}".format(alignment.score, len(transcript_seq)))
                        print("transcriptID: {}, {}, query_start {}, query_end {}, strand {}, query sequence:\n {}".format(\
                                    df_chr["Transcript_stable_ID"][i], chromosome, query_start, query_end, strand, query_seq))
                        continue
                
                # assign 5'UTR(index=0)
                for utr5_start, utr5_end in zip(utr5_starts, utr5_ends):
                    assigned_start, assigned_end = assign_region(overlap_start, overlap_end, utr5_start, utr5_end, query_seq)
                    annotations[0, assigned_start:assigned_end] = 1
                # assign 3'UTR(index=1)
                for utr3_start, utr3_end in zip(utr3_starts, utr3_ends):
                    assigned_start, assigned_end = assign_region(overlap_start, overlap_end, utr3_start, utr3_end, query_seq)
                    annotations[1, assigned_start:assigned_end] = 1
                # assign exon(index=2)
                for exon_start, exon_end in zip(exon_starts, exon_ends):
                    assigned_start, assigned_end = assign_region(overlap_start, overlap_end, exon_start, exon_end, query_seq)
                    annotations[2, assigned_start:assigned_end] = 1

                # assign cds(index=5) and intron(index=4)
                for pos in range(query_match_start, query_match_end):
                    if annotations[2,pos]==1 and np.sum(annotations[0:2, pos])==0:
                        annotations[4,pos] = 1
                    if np.sum(annotations[:,pos])==0:
                        annotations[3,pos] = 1
                
                if strand == "-":
                    query_seq = reverse_complement(query_seq)
                    annotations = np.flip(annotations, axis=1)

    return annotations

In [None]:
MASTER_DIR = "./datasets"
ANNOT_DIR = "DIRECTORY_WHERE_THE_ANNOTATION_FILES_EXIST"
RBPS = ("AARS",)
for rbp in RBPS:
    FASTA_PATH = os.path.join(MASTER_DIR, rbp, "nontraining_sample_finetune/dev.fasta")
    OUTPUT_PATH = os.path.join(MASTER_DIR, rbp, "nontraining_sample_finetune/annotation.npy")

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--fasta_path",
        default=None,
        type=str,
        required=True,
        help="path to the hg38 fasta file.",
    )
    parser.add_argument(
        "--annotation_dir",
        default=None,
        type=str,
        required=True,
        help="directory where the reference annotations exist",
    )
    parser.add_argument(
        "--output_path",
        default=None,
        type=str,
        required=True,
        help="path to the output npy file.",
    )
    args = parser.parse_args(args = ["--fasta_path", FASTA_PATH,\
                                                     "--annotation_dir", ANNOT_DIR,\
                                                     "--output_path",OUTPUT_PATH])

    fasta_file, output_file = args.fasta_path, args.output_path
    if not ".fasta" in fasta_file:
        fasta_file = os.path.join(args.fasta_path, "dev.fasta")
    if not ".npy" in output_file:
        output_file = os.path.join(args.output_path, "annotations.npy")

    lines_fasta_header = []
    lines_fasta_seq = []
    with open(fasta_file, 'r') as f:
        flg = 0
        for line in f:
            if ">chr" in line:
                lines_fasta_header.append(line.strip())
            elif re.match("[ATGC]+", line):
                lines_fasta_seq.append(line.strip())
            else:
                print(line)

    assert len(lines_fasta_header) == len(lines_fasta_seq)

    annotations_all = np.array([])
    keep_track = int(len(lines_fasta_header)*0.2)
    print("start annotating")
    for i, line in enumerate(lines_fasta_header):
        if ">chr" in line:
            chromosome = re.findall("chr[0-9XYM]+:", line)
            if chromosome:
                chromosome = chromosome[0][:-1]
                query_start = int(re.findall("[0-9]+\-", line)[0][:-1])
                query_end = int(re.findall("\-[0-9]+", line)[0][1:])
                strand = re.findall("\([\-\+]\)", line)[0][1]
                rbp_bound = re.findall("rbp_bound: [0-1]", line)[0][-1]
                query_seq = lines_fasta_seq[i]

                annotations = get_annotation(args.annotation_dir, chromosome, query_start, query_end, strand, query_seq)
                annotations = annotations[np.newaxis, :, :]
                if len(annotations_all)==0:
                    annotations_all = annotations
                else:
                    annotations_all = np.concatenate([annotations_all, annotations], axis=0)
                if (i%keep_track==0 and not i==0) or i==len(lines_fasta_header)-1:
                    print("finished annotating {}/{} sequences".format(keep_track, len(lines_fasta_header)))

    np.save(output_file, annotations_all)