In [1]:
import sys
sys.path.insert(0, '/private/groups/brookslab/gabai/tools/seqUtils/src/')
import time
import numpy as np
from seqUtil import *
from bamUtil import *
from nanoUtil import *
from nntUtil import *
from modPredict import *
import matplotlib.pyplot as plt
import seaborn as sns

Device type:  cpu


In [2]:
nuc_regions = {
    'PHO5': 'chrII:429000-435000',
    'CLN2': 'chrXVI:66000-67550',
    'HMR': 'chrIII:290000-299000'
}

models = {
    'resnet1D':resnet1D
}
myregion = nuc_regions['PHO5']
mymodel = models['resnet1D']
myweights =  '/private/groups/brookslab/gabai/tools/seqUtils/src/nanopore_classification/best_models/addseq_resnet1d.pt'

reg = myregion.split(':')
chrom, pStart, pEnd = reg[0], int(reg[1].split('-')[0]), int(reg[1].split('-')[1])

In [3]:
genome = '/private/groups/brookslab/gabai/projects/Add-seq/data/ref/sacCer3.fa'
chrom_bam = '/private/groups/brookslab/gabai/projects/Add-seq/data/chrom/mapping/all_read.bam'
pos_bam = '/private/groups/brookslab/gabai/projects/Add-seq/data/ctrl/mapping/unique.500.pass.sorted.bam'
neg_bam = '/private/groups/brookslab/gabai/projects/Add-seq/data/ctrl/mapping/unique.0.pass.sorted.bam'
chrom_evt = '/private/groups/brookslab/gabai/projects/Add-seq/data/chrom/eventalign/all_read.eventalign.txt'
pos_evt = '/private/groups/brookslab/gabai/projects/Add-seq/data/ctrl/eventalign/unique.500.eventalign.tsv'
neg_evt = '/private/groups/brookslab/gabai/projects/Add-seq/data/ctrl/eventalign/unique.0.eventalign.tsv'

In [112]:
def parseEventAlign(eventAlign = '', outfile = '', readname = '', chr_region = '', print_sequence = False, n_rname = 0, header = True):
    '''
    This function reads nanopolish eventalign file, aggregates signals and the number of 
    signals correspinding to one base movement for read in readname list.
    
    input:
        eventAlign: nanopolish eventalign output file.
        readname: A list containing readnames.
        chr_region: chromosome number that region of interest falls in.
    optional:
        print_sequence: if True, kmer sequence will be included in outfile.
        n_rname: number of readnames can be skipped in the readname list (default: 0).
                 Searching all the readnames from the eventalign file takes longer time.
    output: 
        outfile: siganlAlign.tsv with format: readname\tchrom\teventStart(reference)\tsigList\tsigLenLsit

    E.g.    read1  ACGTGGCTGA
            events ACGTG
                    CGTGG
                     GTGGC
                      TGGCT
                       GGCTG
                        GCTGA
            sigLen  23
                     45
                      61
                       78
                        101
    '''
    readname = set(readname)
    if outfile:
        outf = open(outfile, 'w')
    tag = ''
    read = ''
    sequence = ''
    c = 0
    readok = False
    
    with open(eventAlign, 'r') as inFile:
        if header:
            header = inFile.readline()
        for line in inFile:
            line = line.strip().split('\t')
            thisread = line[3]
            thischrom = line[0]
            c+=1
            if c%10000000 == 0:
                print(c/1000000, ' M lines have passed.')

            if thischrom != chr_region:
                continue
#             print(line[1], line[2])
            # all line passed the chromosome and readname check should start here
            if thisread != read:
                if sequence:
                    # Set variables back to initial state
                    if print_sequence:
                        out = "{}\t{}\t{}\t{}\t{}\t{}\n".format(read, chrom, eventStart, sequence, ','.join(str(i) for i in sigList), ','.join(str(i) for i in sigLenList))
                    else:
                        out = "{}\t{}\t{}\t{}\t{}\n".format(read, chrom, eventStart, ','.join(str(i) for i in sigList), ','.join(str(i) for i in sigLenList))
                    if outfile:
                        outf.write(out)
                    if len(readname) <= n_rname:
                        sequence = ''
                        break
                    read = ''
                    sequence = ''
                    sigList = []
                    sigLenList = []
                
                if thisread in readname:
                    readname.remove(thisread)
                    print(len(readname), ' reads left in readname list')
                else:
                    continue

                # start new read here
                read = thisread
                chrom = thischrom
                eventStart = line[1]
                start = line[1]
                kmer = line[2]

                # signals are stored in column 13/15 and are separated my comma
                sigList = [float(i) for i in line[-1].split(',')]
                sigLen = len(sigList)
                sigLenList = [sigLen]
                sequence = kmer
            # next kmer within the same read
            else:
                signals = [float(i) for i in line[-1].split(',')]
                # or signalList += signals
                sigList.extend(signals)
                # signalLength records the number of signals for one base movement
                sigLen += len(signals)

                # If different kmer
                if (line[1], line[2]) != (start, kmer):
                    deletion = int(line[1]) - int(start) - 1
                    # id there is a deletion in eventalign file
                    if deletion > 0:
                        sequence += deletion*'D'
                        for i in range(deletion):
                            sigLenList.append(sigLenList[-1])
                    start = line[1]
                    kmer = line[2]
                    sequence += kmer[-1]
                    sigLenList.append(sigLen)
                # If same kmer
                else:
                    # Update the number of signals matched to previous kmer
                    sigLenList[-1]=sigLen
        if sequence:
            if print_sequence:
                out = "{}\t{}\t{}\t{}\t{}\t{}\n".format(read, chrom, eventStart, sequence, ','.join(str(i) for i in sigList), ','.join(str(i) for i in sigLenList))
            else:
                out = "{}\t{}\t{}\t{}\t{}\n".format(read, chrom, eventStart, ','.join(str(i) for i in sigList), ','.join(str(i) for i in sigLenList))
            if outfile:
                outf.write(out)
    outf.close()

In [113]:
inFile ='/private/groups/brookslab/gabai/projects/Add-seq/data/chrom/eventalign/dcc687c3-af48-4c9c-9a6b-06df1112187d.eventalign.txt'
outFile = '/private/groups/brookslab/gabai/projects/Add-seq/data/chrom/eventalign/dcc687c3-af48-4c9c-9a6b-06df1112187d.sigAlign.txt'
parseEventAlign(eventAlign = inFile, outfile = outFile, 
                readname = ['dcc687c3-af48-4c9c-9a6b-06df1112187d'], 
                chr_region = 'chrII', header = False)

0  reads left in readname list


In [114]:
sigAlign_output = '/private/groups/brookslab/gabai/projects/Add-seq/data/ctrl/eventalign/231012_PHO5_chrom_chrII:429000-435000siganlAlign.tsv'

In [115]:
alignment = getAlignedReads(sam = chrom_bam, region = myregion, genome=genome, print_name=False)

In [141]:
def assign_scores(readID, sigList, siglenList, sigLenList_init, modbase, alignemnt, 
                  weights, model, device, 
                  tune = False, method = 'median', kmerWindow=80, signalWindow=400):
    
    refSeq = alignemnt['ref']
    # Position of As, relative to the reference
    modPositions = basePos(refSeq, base = modbase)
    modScores = {i:[] for i in modPositions}
    
    for pos in range(len(refSeq)):
        if pos % 500 ==0:
            print('Predicting at position:', pos)
        
        # 1. Fetch sequences with kmer window size, this step is optional
        # seq = refSeq[pos:pos+kmerWindow]
        # 2. Fetch signals with signal window size 
        print('sigLenList_init:', sigLenList_init)
        pos_sigLenList_start = int(sigLenList_init)+pos
        pos_sigLenList_end = pos_sigLenList_start+1
        if pos_sigLenList_start<0: 
            start=0
        else:
            start = int(siglenList[pos_sigLenList_start])
#         if len(sigList)-start<400:
#             print('Reached the end of the signal.')
#             break
        end = int(siglenList[pos_sigLenList_end])
        # If no signals aligned to this position. E.g. chrII 429016 is missed is eventalign output.
        if start == end:
            print('No signal captured at position: ', pos)
            continue
        signals = [float(s) for s in sigList[start:end+signalWindow]]
        # 3. Get predicted probability score from machine learning model
        prob = nntPredict(signals,device = device, model = model, weights_path = weights)
        print('sequence length: ', len(seq), ', signal length: ', len(signals), ', prob: ', prob)
        if len(signals) == 400:
            print(start, end)
            break
        # 4. Assign predicted scores to each modPosition
        # modifiable positions [1,3,4,5,7,10,15,16,21,40]
        # kmer position is 2: [2:2+22]
        # modbase_left = 0
        # modbase_right = 9
        # modifiable position within kmer window [3,4,5,7,10,15,16,21]
        modbase_left = bisect.bisect_left(modPositions, pos)
        modbase_right = bisect.bisect_right(modPositions, pos+kmerWindow)
        modbase_count = modbase_right - modbase_left
        
        for p in range(modbase_left, modbase_right):
            modPosition = modPositions[p]
            # 4.1 Tune signals based on position of A and A content:
            if tune:
                strand = alignemnt[readID][1]
                prob = model_scores(prob, modPosition, pos, modbase_count, strand)
            modScores[modPosition].append(prob)
    
    for mod in modScores:
        modScores[mod] = aggregate_scors(modScores[mod], method = method)
    return modScores

In [132]:
sigAlign_output = '/private/groups/brookslab/gabai/projects/Add-seq/data/chrom/eventalign/dcc687c3-af48-4c9c-9a6b-06df1112187d.sigAlign.txt'

In [133]:
thisread = {}
for readID, eventStart, sigList, siglenList in parseSigAlign(sigAlign=sigAlign_output):
    start_time = time.time()
    sigLenList_init = pStart-eventStart-1
    thisread[readID] = [sigList, siglenList, sigLenList_init]

In [None]:
read = 'dcc687c3-af48-4c9c-9a6b-06df1112187d'
thismodscores = assign_scores(readID=read, sigList=thisread[read][0], siglenList=thisread[read][1], 
                              sigLenList_init=thisread[read][2], 
                              modbase='A', alignemnt=alignment, weights=myweights, model=mymodel, device='cpu', 
                              tune = False, method = 'median', kmerWindow=80, signalWindow=400)

Predicting at position: 0
sigLenList_init: 22929
sequence length:  80 , signal length:  403 , prob:  0.3807646731535594
sigLenList_init: 22929
sequence length:  80 , signal length:  407 , prob:  0.39741207872118267
sigLenList_init: 22929
sequence length:  80 , signal length:  412 , prob:  0.3993050654729207
sigLenList_init: 22929
sequence length:  80 , signal length:  407 , prob:  0.39528991494859966
sigLenList_init: 22929
sequence length:  80 , signal length:  403 , prob:  0.39726178844769794
sigLenList_init: 22929
sequence length:  80 , signal length:  411 , prob:  0.4116072492165999
sigLenList_init: 22929
sequence length:  80 , signal length:  407 , prob:  0.4167107045650482
sigLenList_init: 22929
sequence length:  80 , signal length:  404 , prob:  0.411438912153244
sigLenList_init: 22929
sequence length:  80 , signal length:  404 , prob:  0.40501508116722107
sigLenList_init: 22929
sequence length:  80 , signal length:  412 , prob:  0.39113952467838925
sigLenList_init: 22929
sequenc

sequence length:  80 , signal length:  413 , prob:  0.4689975426747249
sigLenList_init: 22929
sequence length:  80 , signal length:  405 , prob:  0.4645633339881897
sigLenList_init: 22929
sequence length:  80 , signal length:  405 , prob:  0.44993193745613097
sigLenList_init: 22929
sequence length:  80 , signal length:  416 , prob:  0.46073607355356216
sigLenList_init: 22929
sequence length:  80 , signal length:  408 , prob:  0.46128929033875465
sigLenList_init: 22929
sequence length:  80 , signal length:  408 , prob:  0.4689956270158291
sigLenList_init: 22929
sequence length:  80 , signal length:  415 , prob:  0.4548739989598592
sigLenList_init: 22929
sequence length:  80 , signal length:  412 , prob:  0.4443569133679072
sigLenList_init: 22929
sequence length:  80 , signal length:  410 , prob:  0.4703119874000549
sigLenList_init: 22929
sequence length:  80 , signal length:  406 , prob:  0.4757387389739354
sigLenList_init: 22929
sequence length:  80 , signal length:  403 , prob:  0.488

In [144]:
pStart+1171

430171

In [127]:
thisread[read][1][22945:22945+5]

['225876', '225876', '225880', '225883', '225888']