In [1]:
import sys
sys.path.insert(0, '../rna-editing-downstream')

In [2]:
import os
import re
import subprocess
import uuid

In [3]:
import bam_utils

In [4]:
INPUT_BAM_FP = '../tests/data/C3N-00435.chr1.filtered.bam'
INPUT_ANNOTATED_VAF = '../tests/data/C3N-00435.chr1.tsv'
REFERENCE_FASTA = '/gscmnt/gc2686/rna_editing/data/reference/gdc/GRCh38.d1.vd1.fa'

In [5]:
def get_position_to_data_dict(input_rna_editing_vaf, has_header=False):
    f = open(input_rna_editing_vaf)
    
    if has_header:
        f.readline()
    
    input_types = ['dna.a', 'dna.t', 'rna.a', 'rna.t']
    fields = ['depth', 'ref_vaf', 'minor_vaf', 'a_vaf', 'c_vaf', 'g_vaf', 't_vaf', 'n_vaf']
    annotations = ['primary_transcript', 'gene', 'strand', 'region', 'non_verbose_region', 'info', 
                   'repeat_name', 'repeat_class', 'repeat_family', 'blat_percent_passing']
    
    position_to_data_dict = {}
   
    for line in f:
        pieces = line.strip().split('\t')
        chrom, pos, ref_base = pieces[0], pieces[1], pieces[2]
        
        data_dict = {'ref_base': ref_base}
        data_dict.update({k:{} for k in input_types})
        for i, input_type in enumerate(input_types):
            for j, field in enumerate(fields):
                index = 3 + (len(fields) * i) + j
                val = pieces[index]
                if field == 'depth':
                    val = int(val)
                elif val.isdecimal():
                    val = float(val)
                
                data_dict[input_type][field] = val
    
        for i, annotation in enumerate(annotations):
            index = 3 + (len(fields) * len(input_types)) + i
            val = pieces[index]
            if val.isdecimal():
                val = float(val)
            elif val.isdigit():
                val = int(val)

            data_dict[annotation] = val
        
        position_to_data_dict[(chrom, pos)] = data_dict
    
    return position_to_data_dict

In [6]:
position_to_data_dict = get_position_to_data_dict(INPUT_ANNOTATED_VAF, has_header=True)

In [7]:
def write_positions_bed(chrom_pos_tups, output_fp):
    f = open(output_fp, 'w')
    for chrom, pos in chrom_pos_tups:
        f.write(f'{chrom}\t{pos}\t{pos}\n')
    f.close()

def write_regions_file(chrom_start_stop_tups, output_fp):
    f = open(output_fp, 'w')
    for chrom, start, stop in chrom_start_stop_tups:
        f.write(f'{chrom}:{start}-{stop}\n')
    f.close()

In [8]:
chrom_pos_tups = sorted(set(position_to_data_dict.keys()))
u_id = str(uuid.uuid4())
temp_positions_fp = f'temp.positions.{u_id}.bed'
write_positions_bed(chrom_pos_tups, temp_positions_fp)

In [9]:
read_tups = bam_utils.get_chrom_start_cigar_seq_read_tups(INPUT_BAM_FP, temp_positions_fp)

In [10]:
len(chrom_pos_tups), len(read_tups)

(236, 13079)

In [11]:
u_id = str(uuid.uuid4())
temp_regions_fp = f'temp.regions.{u_id}.txt'
chrom_start_stop_tups = [(chrom, *bam_utils.get_covering_reference_coords(int(pos), cigar, seq))
                        for chrom, pos, cigar, seq in read_tups]
write_regions_file(chrom_start_stop_tups, temp_regions_fp)

In [12]:
def get_read_to_reference_seq(read_tups, regions_fp, reference_fasta):
    tool_args = ['samtools', 'faidx', '-r', regions_fp, reference_fasta]
    output = subprocess.check_output(tool_args).decode('utf-8')
    read_to_reference_sequence = bam_utils.get_reads_to_sequences_from_fasta_stream(output)
    
    read_tups_to_reference_seqs = {}
    for chrom, pos, cigar, seq in read_tups:
        end = bam_utils.get_covering_reference_coords(int(pos), cigar, seq)[1]
        reference_seq = read_to_reference_sequence[f'{chrom}:{pos}-{end}']
        read_tups_to_reference_seqs[(chrom, pos, cigar, seq)] = reference_seq
    
    return read_tups_to_reference_seqs

In [13]:
read_tups_to_reference_seq = get_read_to_reference_seq(read_tups, temp_regions_fp, REFERENCE_FASTA)

In [14]:
len(read_tups_to_reference_seq)

6241

In [15]:
rc = bam_utils.ReadCollection(chrom_pos_tups)
for chrom, pos, cigar, seq in read_tups:
    rc.put_read(chrom, pos, cigar, seq,
                reference_sequence=read_tups_to_reference_seq[(chrom, pos, cigar, seq)])

In [25]:
def get_average_base_changes_on_changed_reads(chrom, pos, data_dict, read_collection_reads):
    totals = []
    for read_chrom, read_start, read_cigar, read_seq, read_dict in read_collection_reads:
        reference_seq = read_dict['reference_sequence']
        
        if not bam_utils.is_match(int(read_start), int(pos), read_cigar, read_seq, reference_seq):
            totals.append(bam_utils.count_mismatches(read_cigar, read_seq, reference_seq))

    if totals:
        return sum(totals) / len(totals)
    return '.'

def get_average_base_changes_on_all_reads(chrom, pos, data_dict, read_collection_reads):
    totals = []
    for read_chrom, read_start, read_cigar, read_seq, read_dict in read_collection_reads:
        reference_seq = read_dict['reference_sequence']
        
        totals.append(bam_utils.count_mismatches(read_cigar, read_seq, reference_seq))

    if totals:
        return sum(totals) / len(totals)
    return '.'

def get_average_rna_editing_base_changes_on_changed_reads(chrom, pos, data_dict, read_collection_reads):
    totals = []
    for read_chrom, read_start, read_cigar, read_seq, read_dict in read_collection_reads:
        reference_seq = read_dict['reference_sequence']
        strand = data_dict['strand']
        
        if strand != '+' and strand != '-':
            return '.'
        
        if not bam_utils.is_match(int(read_start), int(pos), read_cigar, read_seq, reference_seq):
            totals.append(bam_utils.count_valid_rna_editing_mismatches(
                    read_cigar, read_seq, reference_seq, strand))

    if totals:
        return sum(totals) / len(totals)
    return '.'

def get_average_rna_editing_base_changes_on_all_reads(chrom, pos, data_dict, read_collection_reads):
    totals = []
    for read_chrom, read_start, read_cigar, read_seq, read_dict in read_collection_reads:
        reference_seq = read_dict['reference_sequence']
        strand = data_dict['strand']
        
        if strand != '+' and strand != '-':
            return '.'
        
        totals.append(bam_utils.count_valid_rna_editing_mismatches(
                read_cigar, read_seq, reference_seq, strand))

    if totals:
        return sum(totals) / len(totals)
    return '.'

def get_percent_altered_reads_valid_editing_changes(chrom, pos, data_dict, read_collection_reads):
    count = 0
    total = 0
    for read_chrom, read_start, read_cigar, read_seq, read_dict in read_collection_reads:
        reference_seq = read_dict['reference_sequence']
        strand = data_dict['strand']
        
        if strand != '+' and strand != '-':
            return '.'
        
        if not bam_utils.is_match(int(read_start), int(pos), read_cigar, read_seq, reference_seq):
            if bam_utils.is_valid_rna_editing_site(int(read_start), int(pos),
                                                   read_cigar, read_seq, reference_seq, strand):
                count += 1
            total += 1

    if total:
        return count / total
    return '.'


def get_percent_reads_valid_editing_changes(chrom, pos, data_dict, read_collection_reads):
    count = 0
    total = 0
    for read_chrom, read_start, read_cigar, read_seq, read_dict in read_collection_reads:
        reference_seq = read_dict['reference_sequence']
        strand = data_dict['strand']
        
        if strand != '+' and strand != '-':
            return '.'
        
        count += bam_utils.count_valid_rna_editing_mismatches(
                    read_cigar, read_seq, reference_seq, strand)
        total += bam_utils.count_mismatches(read_cigar, read_seq, reference_seq)

    if total:
        return count / total
    return '.'

def get_annotations(read_collection, position_to_data_dict):
    position_to_annotations = {}
    for (chrom, pos), data_dict in position_to_data_dict.items():
        reads = read_collection.get_reads(chrom, int(pos))
        annotations = {}
        
        annotations['AVG_CHANGES_PER_ALTERED_READ'] = get_average_base_changes_on_changed_reads(chrom,
                pos, data_dict, reads)
        
        annotations['AVG_CHANGES_PER_READ'] = get_average_base_changes_on_all_reads(chrom,
                pos, data_dict, reads)
        
        annotations['AVG_EDITING_CHANGES_PER_ALTERED_READ'] = get_average_rna_editing_base_changes_on_changed_reads(
                chrom, pos, data_dict, reads)
                
        annotations['AVG_EDITING_CHANGES_PER_READ'] = get_average_rna_editing_base_changes_on_all_reads(
          chrom, pos, data_dict, reads)
        
        annotations['ALTERED_SITE_%_VALID_EDITING'] = get_percent_altered_reads_valid_editing_changes(
          chrom, pos, data_dict, reads)

        annotations['ALL_CHANGES_%_VALID_EDITING'] = get_percent_reads_valid_editing_changes(
          chrom, pos, data_dict, reads)
        
        position_to_annotations[(chrom, pos)] = annotations

    return position_to_annotations
    

In [26]:
annotations = get_annotations(rc, position_to_data_dict)

In [27]:
for k, v in annotations.items():
    print(k, v)

('chr1', '100393745') {'AVG_CHANGES_PER_ALTERED_READ': 1.8, 'AVG_CHANGES_PER_READ': 0.45, 'AVG_EDITING_CHANGES_PER_ALTERED_READ': 1.6, 'AVG_EDITING_CHANGES_PER_READ': 0.4, 'ALTERED_SITE_%_VALID_EDITING': 1.0, 'ALL_CHANGES_%_VALID_EDITING': 0.8888888888888888}
('chr1', '110381504') {'AVG_CHANGES_PER_ALTERED_READ': 2.8, 'AVG_CHANGES_PER_READ': 2.0, 'AVG_EDITING_CHANGES_PER_ALTERED_READ': 0.0, 'AVG_EDITING_CHANGES_PER_READ': 0.0, 'ALTERED_SITE_%_VALID_EDITING': 0.0, 'ALL_CHANGES_%_VALID_EDITING': 0.0}
('chr1', '11083033') {'AVG_CHANGES_PER_ALTERED_READ': 1.5, 'AVG_CHANGES_PER_READ': 0.43478260869565216, 'AVG_EDITING_CHANGES_PER_ALTERED_READ': 1.5, 'AVG_EDITING_CHANGES_PER_READ': 0.43478260869565216, 'ALTERED_SITE_%_VALID_EDITING': 1.0, 'ALL_CHANGES_%_VALID_EDITING': 1.0}
('chr1', '114397580') {'AVG_CHANGES_PER_ALTERED_READ': 9.5, 'AVG_CHANGES_PER_READ': 1.4482758620689655, 'AVG_EDITING_CHANGES_PER_ALTERED_READ': 0.0, 'AVG_EDITING_CHANGES_PER_READ': 0.034482758620689655, 'ALTERED_SITE_%_VA

In [None]:
rc.get_reads('chr1', 147178361)

In [None]:
read_tups_to_reference_seq.keys()

In [None]:
read_tups_to_reference_seq[('chr1', '1483003', '75M',
                            'GAAAAAAATGGCCAGGCGGTAGTGGCTCAGGCCTGTAATCCCAGCATTTTCGGAGGCGGAGGTGGGCGGATCGCG')]

In [None]:
len('GAAAAAAATGGCCAGGCGGTAGTGGCTCAGGCCTGTAATCCCAGCATTTTCGGAGGCGGAGGTGGGCGGATCACGA')

In [None]:
len(rc.get_reads('chr1', 100393745))

In [None]:
rc.get_reads('chr1', 100393745)

In [None]:
annotations[('chr1', '100393745')]

In [None]:
position_to_data_dict[('chr1', '100393745')]

In [None]:
def count_mismatches(s1, s2):
    total = 0
    for a, b in zip(s1, s2):
        if a != b:
            total += 1
    
    return total

In [None]:
a.difference(b)

In [None]:
a = '50S20=20S'

In [None]:
a.split(re.compile(r'M|X|=|N|D|I|S|H|P'))

In [None]:
re.split(re.compile(r'M|X|=|N|D|I|S|H|P'), a)

In [None]:
re.split(re.compile(r'[0-9]+'), a)