In [1]:
from config import *
from utils import *

import os
import sys
import regex
import copy
import numpy as np
import collections
import multiprocessing
import pickle

import numpy as np
import scipy

# Suppress pandas future warning, which messes tqdm
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pandas as pd

from tqdm.notebook import tqdm

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

inline_rc = dict(mpl.rcParams)
import subprocess

import time
os.environ['KMP_WARNINGS'] = 'off'

from mmsplice.vcf_dataloader import SplicingVCFDataloader
from mmsplice import MMSplice, predict_all_table, predict_save
from mmsplice.utils import max_varEff

The examples.directory rcparam was deprecated in Matplotlib 3.0 and will be removed in 3.2. In the future, examples will be found relative to the 'datapath' directory.
Using TensorFlow backend.


In [2]:
indel_splice_precas_count_map = load_bc_seq(INDEL_SPLICE_PRECAS_COUNT_MAP)
indel_splice_postcas_count_map = load_bc_seq(INDEL_SPLICE_POSTCAS_COUNT_MAP)
gt_splice_count_map = load_bc_seq(GT_SPLICE_COUNT_MAP)
gt_precas_splice_count_map = load_bc_seq(GT_PRECAS_SPLICE_COUNT_MAP)
predicted_gt_indel_dist_map = load_var(PREDICTED_GT_INDEL_DIST_MAP)

# MMSplice Predict
MMSplice requires as input a reference genome FASTA file, a genome annotation file in the standard GTF format, and a variant calling format file (VCF). To represent a query repair genotype as these formats, we considered our lib-SA sequence with the corresponding WT target as the reference genome, and the repair genotype as the variant compared to the WT. That is, the FASTA file contained our lib-SA sequence with the corresponding WT target sequence, the GTF file contained the exonic annotations, and the VCF described the indel of the WT target that resulted in our repair genotype.

In [3]:
EXON_A = 'CAAGATCCGCCACAACATCGAG'
INTRON_TARGETSTART = 'GTAAGTTATCACCTTCGTGGCTACAGAGTTTCCTTATTTGTCTCTGTTGCCGGCTTATATGGACAAGCATATCACAGCCATTTATCGGAGCGCCTCCGTACACGCTATTATCGGACGCCTCGCGAGATCAATACGATTACCAGCTGCCCTCGTCGAC'
TARGETEND_EXON_B = 'TGATTACACATATAGACACGCGAGCAGCCATCTTTTATAGAATGGGTAGAACCCGTCCTAAGGACTCAGATTGAGCATCGTTTGCTTCTCGAGTACTACCTGGTACAGATGTCTCTTCAAACAG'
EXON_C_BCSTART = 'GACGGCAGCGTGCAGCTCGCC'
BCEND_EXON_C = 'GACCACTACCAGCAGAACACCCC'

TARGET_PREFIX = EXON_A + INTRON_TARGETSTART
TARGET_SUFFIX_BC_PREFIX = TARGETEND_EXON_B + EXON_C_BCSTART
BC_SUFFIX = BCEND_EXON_C


def construct_genome(indel_splice_count_map, outputfn=''):
    with open(outputfn, 'w') as f:
        for i, (indel, splice_count_map) in enumerate(indel_iterator(indel_splice_count_map)):
            target = exp_tid_target_map[exp_gid_tid_map[splice_count_map['gid']]]
            genome = TARGET_PREFIX + target + TARGET_SUFFIX_BC_PREFIX
            f.write('>' + str(i) + '\n')
            f.write(genome + '\n')
        
            
def construct_gtf(indel_splice_count_map, outputfn=''):
    with open(outputfn, 'w') as f:
        for i, (indel, splice_count_map) in enumerate(indel_iterator(indel_splice_count_map)):
            target = exp_tid_target_map[exp_gid_tid_map[splice_count_map['gid']]]
            genome_len = len(TARGET_PREFIX) + len(target) + len(TARGET_SUFFIX_BC_PREFIX)
            spliceidx = 37

            # Gene
            f.write('\t'.join([
                str(i), 'artificial', 'gene', '1', str(genome_len), '.', '+', '.',
                'gene_id "g{0}"; transcript_id ""; gene_name "g{0}";\n'.format(i)
            ]))
            # Transcript A_B_C
            f.write('\t'.join([
                str(i), 'artificial', 'transcript', '1', str(genome_len), '.', '+', '.',
                'gene_id "g{0}"; transcript_id "t{0}.{1}"; gene_name "g{0}";\n'.format(i, 'ABC')
            ]))
            # Exon A
            f.write('\t'.join([
                str(i), 'artificial', 'exon', '1', str(len(EXON_A)), '.', '+', '.',
                'gene_id "g{0}"; transcript_id "t{0}.{1}"; gene_name "g{0}"; exon_id "e{0}A";\n'.format(i, 'ABC')
            ]))
            # Exon BC
            f.write('\t'.join([
                str(i), 'artificial', 'exon', str(len(TARGET_PREFIX) + spliceidx + 1), str(genome_len), '.', '+', '.',
                'gene_id "g{0}"; transcript_id "t{0}.{1}"; gene_name "g{0}"; exon_id "e{0}BC";\n'.format(i, 'ABC')
            ]))

            # Transcript A_C
            f.write('\t'.join([
                str(i), 'artificial', 'transcript', '1', str(genome_len), '.', '+', '.',
                'gene_id "g{0}"; transcript_id "t{0}.{1}"; gene_name "g{0}";\n'.format(i, 'AC')
            ]))
            # Exon A
            f.write('\t'.join([
                str(i), 'artificial', 'exon', '1', str(len(EXON_A)), '.', '+', '.',
                'gene_id "g{0}"; transcript_id "t{0}.{1}"; gene_name "g{0}"; exon_id "e{0}A";\n'.format(i, 'AC')
            ]))
            # Exon C
            f.write('\t'.join([
                str(i), 'artificial', 'exon', str(len(TARGET_PREFIX) + len(target) + len(TARGETEND_EXON_B) + 1), str(genome_len), '.', '+', '.',
                'gene_id "g{0}"; transcript_id "t{0}.{1}"; gene_name "g{0}"; exon_id "e{0}C";\n'.format(i, 'AC')
            ]))

            
def construct_vcf(indel_splice_count_map, outputfn=''):        
    indel_splice_count_map_list = list(indel_iterator(indel_splice_count_map))
    with open(outputfn, 'w') as f:
        f.write('##fileformat=VCFv4.0\n')
        for i in range(len(indel_splice_count_map_list)):
            indel, splice_count_map = indel_splice_count_map_list[i]
            target = exp_tid_target_map[exp_gid_tid_map[splice_count_map['gid']]]
            genome_len = len(TARGET_PREFIX) + len(target) + len(TARGET_SUFFIX_BC_PREFIX)
            spliceidx = 37 # Designed splice idx
            f.write('##contig=<ID={0},length={1}>\n'.format(i, genome_len))

        f.write('\t'.join(['#CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO\n']))
        for i, (indel, splice_count_map) in enumerate(indel_splice_count_map_list):
            # Exon B acceptor
            if indel[1] == 'N':
                pos = len(TARGET_PREFIX) + spliceidx + 1
                ref = target[spliceidx]
                alt = target[spliceidx]
            elif indel[1] in DELETION_SIGNATURES:
                _, _, deletion_size, genotype_pos, cutsite = indel
                pos = len(TARGET_PREFIX) + cutsite + genotype_pos - deletion_size
                ref = target[(cutsite + genotype_pos - deletion_size - 1):(cutsite + genotype_pos + 1)]
                alt = target[cutsite + genotype_pos - deletion_size - 1]
            elif indel[1] in INSERTION_SIGNATURES and indel[2] == 1:
                _, _, _, inserted_base, cutsite = indel
                pos = len(TARGET_PREFIX) + cutsite
                ref = target[cutsite - 1]
                alt = target[cutsite - 1] + inserted_base 
            else:
                continue
                
            f.write('\t'.join([
                str(i), str(pos), '.', ref, alt, '.', '.', '.', '\n'
            ]))
            
            
def predict_MMSpliceScores(indel_splice_count_map, outfn='MESC_reporter', progress=False):
    fasta = os.path.join(TEMP_DIR, outfn + '.fa')
    gtf = os.path.join(TEMP_DIR, outfn + '.gtf')
    vcf = os.path.join(TEMP_DIR, outfn + '.vcf')
    
    construct_genome(indel_splice_count_map, fasta)
    construct_gtf(indel_splice_count_map, gtf)
    construct_vcf(indel_splice_count_map, vcf)
    
    start = time.perf_counter()
    dl = SplicingVCFDataloader(gtf, fasta, vcf)
    if progress:
        print(time.perf_counter() - start)
    
    start = time.perf_counter()
    model = MMSplice()
    if progress:
        print(time.perf_counter() - start)
    
    start = time.perf_counter()
    predictions = predict_all_table(model, dl, progress=False, pathogenicity=False, splicing_efficiency=False)
    if progress:
        print(time.perf_counter() - start)
    return predictions

In [4]:
if not os.path.exists(MMSPLICE_PRECAS_DF_PATH):
    mmsplice_precas_df = predict_MMSpliceScores(indel_splice_precas_count_map, 'MESC_reporter_precas', True)
    mmsplice_precas_df.to_csv(MMSPLICE_PRECAS_DF_PATH, index=False)

In [5]:
if not os.path.exists(MMSPLICE_POSTCAS_DF_PATH):
    mmsplice_postcas_df = predict_MMSpliceScores(indel_splice_postcas_count_map, 'MESC_reporter_postcas', True)
    mmsplice_postcas_df.to_csv(MMSPLICE_POSTCAS_DF_PATH, index=False)

In [6]:
def save_mmsplice_predict_gt_indel_dist(pairs):
    for i, p in enumerate(tqdm(pairs)):
        gid = exp_grna_gid_map[p[0]][0]
        cutsite = get_cutsite(*p)
        indel_splice_count_map = collections.defaultdict(lambda: collections.defaultdict(int))
        distribution = predicted_gt_indel_dist_map[p]
        for deletion_size in range(1, MAX_INDEL_LEN + 1):
            for genotype_pos in distribution[-deletion_size]:
                if distribution[-deletion_size][genotype_pos] > 0:
                    indel = ('', 'DS', deletion_size, genotype_pos, cutsite)
                    indel_splice_count_map[indel]['gid'] = gid
                    indel_splice_count_map[indel]['indelphifreq'] = distribution[-deletion_size][genotype_pos]
        for base in 'AGTC':
            if distribution[1][base] > 0:
                indel = ('', 'IS', 1, base, cutsite)
                indel_splice_count_map[indel]['gid'] = gid
                indel_splice_count_map[indel]['indelphifreq'] = distribution[1][base]
        
        mmsplice_df = predict_MMSpliceScores(indel_splice_count_map, 'MESC_agg_postcas')
        mmsplice_df.to_csv(os.path.join(MMSPLICE_GT_DF_DIR, str(gid) + '.csv'))

In [7]:
if not os.path.exists(os.path.join(MMSPLICE_GT_DF_DIR, '626.csv')):
    lib_pairs = [gid_to_gt(gid) for gid in exp_gid_tid_map]
    save_mmsplice_predict_gt_indel_dist(lib_pairs)