In [None]:
import pandas as pd
import numpy as np
from random import seed
from random import sample
import pyfastx
import pysam
import os.path
import re
import csv
from Bio.Align import MultipleSeqAlignment
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
import mappy as mp
import uuid as uuid_gen

In [None]:
chrms = ['chr2L', 'chr2R', 'chr3L', 'chr3R', 'chrX', 'chr4', 'chrM']
dm6 = {}
fastafile = 'dm6.fa'
total_length = 0
for seq_record in SeqIO.parse(fastafile, 'fasta'):
    chrm = seq_record.id
    dm6[chrm] = seq_record.seq
    total_length = total_length + len(seq_record.seq)

In [None]:
nt_complement = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C', 'N': 'N'}
def get_reverse_complement(seq):
    revcompl = ''
    for nt in reversed(seq):
        revcompl = revcompl + nt_complement[nt]
    return revcompl

In [None]:
nt_alphabet = ['A', 'C', 'G', 'T']
twomers = [x+y for x in nt_alphabet for y in nt_alphabet]
def check_for_low_complexity(string):
    length = 0.7 * len(string)
    flag = 0
    for tmer in twomers:
        if len(string.replace(tmer, '')) < length:
            flag = 1
            break
    return flag

In [None]:
germline_active = ['hobo', 'BS', 'roo', 'Tabor', 'opus', 'mdg1', 'FB', 'jockey', 
            'F-element', '1360', 'HeT-A', 'Doc', 'Tirant', 'flea', 'Stalker']

In [None]:
tesreformat = []
for order in tes.keys():
    for superfamily in tes[order]:
        for subfamily in tes[order][superfamily]:
            tesreformat.append([order, superfamily, subfamily, len(str(tes[order][superfamily][subfamily]))])
tedf = pd.DataFrame(tesreformat, columns=['order', 'superfamily', 'subfamily', 'len'])

In [None]:
teTELR = '/D_mel_transposon_sequence_set.fa' # from https://github.com/bergmanlab/drosophila-transposons/blob/master/current/D_mel_transposon_sequence_set.fa

def load_ref_te(teTELR):
    tes = {}
    tes_by_subfamily = {}
    tes_superfamily_order = {}
    
    for seq_record in SeqIO.parse(teTELR, 'fasta'):
        seqid = seq_record.id
        ln = seqid.split('#')
        subfamily = ln[0]
        order = ln[1].split('/')[0]
        superfamily = ln[1].split('/')[1]
        if order not in tes:
            tes[order] = {}
        if superfamily not in tes[order]:
            tes[order][superfamily] = {}
        
        if superfamily not in tes_superfamily_order:
            tes_superfamily_order[superfamily] = order
        tes[order][superfamily][subfamily] = seq_record.seq
        tes_by_subfamily[subfamily] = seq_record.seq

    ref_te_len = {}
    for order in tes.keys():
        for superfamily in tes[order]:
            for subfamily in tes[order][superfamily]:            
                ref_te_len[subfamily] = len(str(tes[order][superfamily][subfamily]))
    
    return tes, tes_by_subfamily, ref_te_len, tes_superfamily_order

tes, tes_by_subfamily, ref_te_len, tes_superfamily_order = load_ref_te(teTELR)

In [None]:
def preprocess_tldr_file(tldr_infile, outfile, sample, ref_te_len):
    with open(tldr_infile, 'r') as indata, open(outfile, 'w') as outdata:
        for i, l in enumerate(indata):

            if i == 0:
                outdata.write(l.rstrip() + '\t'.join(['\tSample', 'Left_flank', 'TE', 
                                                      'Right_flank', 'Ref_length']) + '\n')
                continue
            ln = l.rstrip().split('\t')
            subfamily = ln[6]
            if subfamily =='NA':
                continue
            
            consensus = ln[21]
            right_flank = consensus.split('[0m[91m')[1].replace('[0m','')
            left_flank = consensus.split('[0m[91m')[0].split('[0m[33m')[0].replace('[91m','')
            te = consensus.split('[0m[91m')[0].split('[0m[33m')[1].replace('[34m','').replace('[33m','')
            outdata.write(l.rstrip() + '\t'.join(['\t' + sample, left_flank, te, right_flank, str(ref_te_len[subfamily])]) + '\n')

In [None]:
samples = ['PGFP_5d_guts_P2', 'PGFP_5d_heads_P2', 
           'PGFP_5d_guts_P3', 'PGFP_5d_heads_P3',
           'PGFP_25d_guts_P1', 'PGFP_25d_guts_P2', 
           'PGFP_25d_heads_P1', 'PGFP_25d_heads_P2', 
           'PGFP_50d_guts_P1', 'PGFP_50d_heads_P1',
           'PGFP_50d_guts_P2', 'PGFP_50d_heads_P2']

sample_order = ['PGFP_5d_guts_P2', 'PGFP_5d_heads_P2', \
                'PGFP_5d_guts_P3', 'PGFP_5d_heads_P3',
                'PGFP_25d_guts_P1', 'PGFP_25d_heads_P1', \
                'PGFP_25d_guts_P2', 'PGFP_25d_heads_P2', \
                'PGFP_50d_guts_P1', 'PGFP_50d_heads_P1', \
               'PGFP_50d_guts_P2', 'PGFP_50d_heads_P2']

sample_map = {'PGFP_5d_guts_P2': 'PGFP_5d_heads_P2', 'PGFP_5d_heads_P2': 'PGFP_5d_guts_P2',
            'PGFP_5d_guts_P3':  'PGFP_5d_heads_P3',  'PGFP_5d_heads_P3': 'PGFP_5d_guts_P3',
            'PGFP_25d_guts_P1': 'PGFP_25d_heads_P1', 'PGFP_25d_heads_P1': 'PGFP_25d_guts_P1',
            'PGFP_25d_guts_P2': 'PGFP_25d_heads_P2', 'PGFP_25d_heads_P2': 'PGFP_25d_guts_P2',
            'PGFP_50d_guts_P1': 'PGFP_50d_heads_P1', 'PGFP_50d_heads_P1': 'PGFP_50d_guts_P1',
            'PGFP_50d_guts_P2': 'PGFP_50d_heads_P2', 'PGFP_50d_heads_P2': 'PGFP_50d_guts_P2'}

In [None]:
for sample in samples:
    tldr_infile = '~/' + sample + '.dm6.porechop.sorted.flt.table.txt'
    tldr_outfile = '~/' + sample + '.dm6.porechop.sorted.flt.table.reformat.txt'
    preprocess_tldr_file(tldr_infile, tldr_outfile, sample, ref_te_len)

In [None]:
tldr = {}
for sample in samples:
    tldrfile = '~/' + sample + '.dm6.porechop.sorted.flt.table.reformat.txt'
    tldr[sample] = pd.read_csv(tldrfile, sep='\t', header=0)
        
tldr_not_flt = pd.concat(tldr.values())
tldr_not_flt.reset_index(inplace=True, drop=True)

In [None]:
tldr = tldr_not_flt[(tldr_not_flt['Subfamily'].notnull()) & 
                    (tldr_not_flt['UnmapCover'] > 0.5 )  &
                    (tldr_not_flt['LengthIns'] > 500) &
                   (tldr_not_flt['MedianMapQ'] >= 30) &
                   (tldr_not_flt['Chrom'] != 'chrY')]
tldr.reset_index(inplace=True, drop=True)

In [None]:
tldr[['tmp','EmptyReads']] = tldr['EmptyReads'].str.split('|',expand=True)
tldr['EmptyReads'] = tldr['EmptyReads'].fillna(0)
tldr['EmptyReads'] = tldr['EmptyReads'].astype(int)
tldr['UsedReads'] = tldr['UsedReads'].astype(int)
tldr['Coverage'] = tldr['UsedReads'] + tldr['EmptyReads']
tldr['AF'] = tldr['UsedReads'] / tldr['Coverage']
tldr = tldr.drop('tmp', axis=1)
tldr = tldr[tldr['Coverage'] >= 10]
tldr.reset_index(inplace=True, drop=True)

In [None]:
supp_data = {}
row_data = {}
hash_data = {}
coords_data = {}

for index, row in tldr.iterrows():
    uuid = row['UUID']
    chrm = row['Chrom']
    start = row['Start']
    end = row['End']
    strand = row['Strand']
    subfamily = row['Subfamily']
    
    unmapcov = row['UnmapCover']
    ref_length = row['Ref_length']
    length_ins = row['LengthIns'] 
    sample = row['SampleReads'].split('.')[0]
    
    found = 0
    imprecise = 0
    
    conservation_val = round(length_ins * unmapcov / ref_length, 3)
        
    if length_ins * unmapcov / ref_length >= 0.8:
            conservation = 'FL'
    elif (length_ins * unmapcov / ref_length >= 0.7) and (length_ins * unmapcov / ref_length < 0.8):
            conservation = 'AFL'
    else:
            conservation = 'TR'
                                    
    if (chrm, strand, start, end, subfamily) in hash_data:
        key = hash_data[(chrm, strand, start, end, subfamily)] 
        row_data[key].append(row)
        supp_data[key].append((sample, conservation))
        coords_data[key].append((chrm, strand, start, end, subfamily,
                                 conservation, conservation_val, uuid, key))
    else:
        for i in range(start - 100, start + 100):
            for j in range(end - 100, end + 100):
                if (chrm, strand, i, j, subfamily) in hash_data:
                    key = hash_data[(chrm, strand, i, j, subfamily)]
                    hash_data[(chrm, strand, start, end, subfamily)] = key
                    row_data[key].append(row)
                    supp_data[key].append((sample, conservation))
                    coords_data[key].append((chrm, strand, start, end, subfamily,
                                             conservation, conservation_val, uuid, key))
                    found = 1
                    break
            if found == 1:
                break
        if found == 0:
            key = uuid_gen.uuid4().hex
            hash_data[(chrm, strand, start, end, subfamily)] = key
            row_data[key] = []
            supp_data[key] = []
            coords_data[key] = []
            row_data[key].append(row)
            supp_data[key].append((sample, conservation))
            coords_data[key].append((chrm, strand, start, end, subfamily, 
                                     conservation, conservation_val, uuid, key))

In [None]:
clusters_by_chrom = {}
skey = uuid_gen.uuid4().hex

for key in coords_data:
    coords_data[key].sort(key=lambda x: x[2])

for key in coords_data:
    chrm = coords_data[key][0][0]
    strand = coords_data[key][0][1]
    subfamily = coords_data[key][0][4]
    if (chrm, strand, subfamily) not in clusters_by_chrom:
        clusters_by_chrom[(chrm, strand, subfamily)] = []
    
    clusters_by_chrom[(chrm, strand, subfamily)].append(coords_data[key])
    
for (chrm, strand, subfamily) in clusters_by_chrom:
    clusters_by_chrom[(chrm, strand, subfamily)].sort(key=lambda x: x[0][2])

In [None]:
sclusters_merged = {}
wiggly_breakpoints = []
normal_breakpoints = []

for (chrm, strand, subfamily) in clusters_by_chrom:
    skey = uuid_gen.uuid4().hex
    sclusters_merged[skey] = []
    prev_cluster_start = [0]
    prev_key = skey
    for cluster in clusters_by_chrom[(chrm, strand, subfamily)]:
        if (abs(cluster[0][2] - np.mean(prev_cluster_start)) <= 300) :
            sclusters_merged[skey].append(cluster)
            prev_cluster_start.append(cluster[0][2])
            if (abs(cluster[0][2] - np.mean(prev_cluster_start)) > 200 ):
                if (skey not in wiggy_breakpoints):
                    wiggly_breakpoints.append(skey)
            else:
                if (skey not in normal_breakpoints):
                    normal_breakpoints.append(skey)
        else:
            skey = uuid_gen.uuid4().hex
            sclusters_merged[skey] = []
            prev_cluster_start = [cluster[0][2]]
            prev_key = skey
            sclusters_merged[skey].append(cluster)

In [None]:
tldr_merged_list = []
skip_coverage = 0
skip_coverage_deep = 0
skip_rr = 0
skip_remappable = 0 
index = 0
count_rr_outlier = 0
keys_to_check = []

for skey in sclusters_merged:
    keys = []
    preserved = []
    chrm = ''
    starts = []
    ends = []
    strand = ''
    family = ''
    subfamily = ''
    starttes = []
    endtes = []
    reflength = 0
    samples = []
    telengths = {}
    mapqs = {}
    uuids = {}
    usedreads = {}
    spanreads = {}
    coverage = {}
    TSDs = {}
    remappable_flag = {}
    filter_flag = {}
    
    for cluster in sclusters_merged[skey]:
        for insertion in cluster:
            preserved.append(insertion[5])
            key = insertion[8]
            if key not in keys:
                keys.append(key)
        
    if ('FL' not in preserved) and ('AFL' not in preserved):
        continue
        
    for sample in sample_order:
        telengths[sample] = []
        mapqs[sample] = []
        uuids[sample] = []
        usedreads[sample] = []
        spanreads[sample] = []
        coverage[sample] = []
        TSDs[sample] = []
        remappable_flag[sample] = []
        filter_flag[sample] = []

    for key in keys:
        for row in row_data[key]:
            tsd = ''
            
            sample = row['SampleReads'].split('.')[0]
            if sample not in samples:
                samples.append(sample)
            chrm = row['Chrom']

            strand =row['Strand']
            family = row['Family']
            order = tes_superfamily_order[family]
            subfamily = row['Subfamily']
            reflength = row['Ref_length']
            umapcov = row['UnmapCover']
            starts.append(row['Start'])
            ends.append(row['End'])
            
            starttes.append(row['StartTE'])
            endtes.append(row['EndTE'])
            
            uuids[sample].append(row['UUID'])
            telengths[sample].append(str(round(row['LengthIns'] * umapcov)))            
            mapqs[sample].append(str(row['MedianMapQ']))
            usedreads[sample].append(row['UsedReads'])
            spanreads[sample].append(row['SpanReads'])
            coverage[sample].append(row['Coverage'])
            
            if (not isinstance(row['TSD'], str)) and np.isnan(row['TSD']):
                tsd = 'noTSD'
            elif len(row['TSD']) > 50:
                tsd = 'longTSD'
            else:
                tsd = row['TSD']
            TSDs[sample].append(tsd)
            remappable_flag[sample].append(str(row['Remappable']))
            filter_flag[sample].append(row['Filter'])
            
    tissue = ''
    
    tmpvar = ''.join(samples)
    if ('guts' in tmpvar) and ('heads' in tmpvar):
        tissue = 'sboth'
    elif ('guts' in tmpvar):
        tissue = 'gut'
    elif ('heads' in tmpvar):
        tissue = 'head'
    else:
        tissue = 'unknown'
        
    new_row = []
    
    tmp_uuid = []
    tmp_sample = []
    tmp_usedreads = []
    tmp_spanreads = []
    tmp_coverage = []
    tmp_TSD = []
    tmp_remappable = []
    tmp_filter = []
    tmp_telength = []
    tmp_mapq = []
    
    coverage_filter = 0
    coverage_threshold = 15
    coverage_max_threshold = 200
    remappable_filter = 0
    coverage_deep_filter = 0
    
    for sample in sample_order:
        if sample in samples:
            tmp_uuid.append('|'.join(uuids[sample]))
        else:
            tmp_uuid.append('-')
            
    for sample in sample_order:
        if sample in samples:
            tmp_sample.append(sample)
        else:
            tmp_sample.append('-')
                        
    for sample in sample_order:
        if sample in samples:
            tmp_usedreads.append(str(sum(usedreads[sample])))  
            if len(usedreads[sample]) > 1:
                if skey not in keys_to_check:
                    keys_to_check.append(skey)
        else:
            tmp_usedreads.append('0')
                        
    for sample in sample_order:
        if sample in samples:
            tmp_spanreads.append(str(sum(spanreads[sample]))) 
        else:
            tmp_spanreads.append('0')
                        
    for sample in sample_order:
        if sample in samples:
            tmp_coverage.append(str(max(coverage[sample])))
            if max(coverage[sample]) >= coverage_threshold:
                coverage_filter = 1 
        else:
            tmp_coverage.append('0')
    
    for sample in sample_order:
        if sample in samples:
            tmp_TSD.append('|'.join(TSDs[sample]))
        else:
            tmp_TSD.append('-')
                        
    for sample in sample_order:
        if sample in samples:
            tmp_remappable.append('|'.join(remappable_flag[sample]))
            if 'True' in '|'.join(remappable_flag[sample]):
                remappable_filter = 1
        else:
            tmp_remappable.append('-')
                        
    for sample in sample_order:
        if sample in samples:
            tmp_filter.append('|'.join(filter_flag[sample]))
        else:
            tmp_filter.append('-')
            
    for sample in sample_order:
        if sample in samples:
            tmp_telength.append('|'.join(telengths[sample]))
        else:
            tmp_telength.append('-')
            
    for sample in sample_order:
        if sample in samples:
            tmp_mapq.append('|'.join(mapqs[sample]))
        else:
            tmp_mapq.append('-')
    
    if coverage_filter == 0:
        skip_coverage +=1
        continue
        
    total_usedreads = sum([int(x) for x in tmp_usedreads])
    total_coverage = sum([int(x) for x in tmp_coverage])
    
    for x in tmp_coverage:
        if int(x) > coverage_max_threshold:
            coverage_deep_filter = 1
            break
            
    tmp_rrs = []
    
    for ind in range(len(sample_order)):
        if int(tmp_coverage[ind]) < coverage_threshold:
            continue
        tmp_rrs.append(int(tmp_usedreads[ind]) / int(tmp_coverage[ind]) )
        
    rr = total_usedreads / total_coverage
    rr_median = statistics.median(tmp_rrs)
    
    if (len(tmp_rrs) > 1) and ((max(rr, rr_median) / min(rr, rr_median))>= 3)  :
        skip_rr += 1
        continue
        
    rr = rr_median
    
    nsamples = len(samples)
    
    if coverage_deep_filter == 1:
        skip_coverage_deep +=1
        continue
        
    if total_usedreads == 1:
        genotype = 'Singleton'
        if remappable_filter == 0:
            skip_remappable +=1
            continue
    elif (rr < 0.1) and (tissue == 'sboth'):
        genotype = 'Rare'
        if remappable_filter == 0:
            skip_remappable +=1
            continue
    elif rr >=0.1 :
        genotype = 'Fixed'

    else:
        genotype = 'Ungenotyped'
                
    instype = 'TBD'
    
    if ('FL' in preserved):
        inspreserv = 'FL'
    elif ('AFL' in preserved):
        inspreserv = 'AFL'
            
    new_row = [';'.join(tmp_uuid), ';'.join(tmp_sample), \
               chrm, round(np.median(starts)), round(np.median(ends)), strand, \
              order, family, subfamily,\
               ';'.join(tmp_usedreads), ';'.join(tmp_spanreads), ';'.join(tmp_coverage), \
               ';'.join(tmp_TSD), ';'.join(tmp_remappable), ';'.join(tmp_filter), \
               ';'.join(tmp_telength), ';'.join(tmp_mapq), reflength, 
               nsamples, rr, genotype, instype, tissue, inspreserv, \
               np.min(starts), np.max(starts), 0.0, 0.0
              ]        
    tldr_merged_list.append(new_row)
    index +=1
    
print('low coverage', skip_coverage)
print('aberrant ontRR', skip_rr)
print('remappable filter', skip_remappable)
print('very high coverage', skip_coverage_deep)
print('RR outlier ', count_rr_outlier)

In [None]:
columns = ['UUID', 'Sample', 
           'Chrom', 'Start', 'End', 'Strand', 
           'Order', 'Family', 'Subfamily',
          'UsedReads', 'SpanReads', 'Coverage', 
           'TSD', 'Remappable', 'Filter',  
           'TELength', 'MapQ', 'RefLength',
           'nSamples',
           'ontRR', 'Genotype', 'Type', 'Tissue', 'Preserved', 
           'minRefStart', 'maxRefStart', 'illAF', 'illPF']
tldr_merged = pd.DataFrame(tldr_merged_list, columns=columns)

In [None]:
ref_te = pd.read_csv('~/PGFP_refTE_dm6.csv', index_col=0, sep='\t')

ref_te['illAF'] = pd.Series(dtype='int')
ref_te['illPF'] = pd.Series(dtype='int')

In [None]:
rows_to_remove = []
count_genotype = {}

for indx, row in tldr_merged.iterrows():    
    uuid = row['UUID']
    chrm = row['Chrom']
    start = row['Start']
    end = row['End']
    strand = row['Strand']
    genotype = row['Genotype']
    remap = row['Remappable']
    fltr = row['Filter']
    subfamily = row['Subfamily']
    
    if genotype not in count_genotype:
        count_genotype[genotype] = {}

    tmp = tldr_merged[(tldr_merged['Subfamily'] == subfamily) 
                 & (start >= (tldr_merged['Start'] - 1000)) & (start <= (tldr_merged['End'] + 1000))
                 & (tldr_merged['Chrom'] == chrm) ]
    
    if len(tmp.index) > 1:
        to_keep = []
        to_remove = []
        for indx2, row2 in tmp.iterrows():
            pass_filter = row2['Filter']
            tissue = row2['Tissue']
            
            if 'PASS' in pass_filter:
                to_keep.append([indx2, tissue])
            else:
                to_remove.append([indx2, tissue])
                            
            if len(to_keep) > 1:
                max_tissue = ''
                good_indx = 0
                for [j, tissue] in to_keep:
                    if tissue == 'sboth':
                        max_tissue = 'sboth'
                        good_indx = j
                
                if max_tissue != 'sboth':
                    pass # keeping all
                else:
                    for [j, tissue] in to_keep:
                        if j!=good_indx:
                            if j not in rows_to_remove:
                                rows_to_remove.append(j)

                                if subfamily not in count_genotype[genotype]:
                                    count_genotype[genotype][subfamily] = 0
                                count_genotype[genotype][subfamily] +=1
            if len(to_keep) == 1:
                for [j, tissue] in to_remove:
                    if j not in rows_to_remove:
                        rows_to_remove.append(j)
                        
                        if subfamily not in count_genotype[genotype]:
                            count_genotype[genotype][subfamily] = 0
                        count_genotype[genotype][subfamily] +=1
            if len(to_keep) == 0:
                good_indx = tmp['nSamples'].idxmax()
                for [j, tissue] in to_remove:
                    if j != good_indx:
                        
                        if j not in rows_to_remove:
                            rows_to_remove.append(j)

                            if subfamily not in count_genotype[genotype]:
                                count_genotype[genotype][subfamily] = 0
                            count_genotype[genotype][subfamily] +=1
        
tldr_merged.drop(rows_to_remove, inplace = True)
tldr_merged.reset_index(drop=True, inplace=True)

In [None]:
result = tldr_merged.append(ref_te)
result.reset_index(drop=True, inplace=True)

In [None]:
coords = {'Rare' : [], 'Singleton' : [], 'Fixed': [], 'Ungenotyped' : []}
for index, row in result.iterrows():
    genotype = row['Genotype']
    samples = row['Sample']

    if genotype in ['Singleton', 'Rare', 'Fixed', 'Ungenotyped'] :
        include = ((('gut' in samples) & ('head' not in samples)) | \
                   (('head' in samples) & ('gut' not in samples)) \
                  ) 
        if include :
            chrm = row['Chrom']
            minstart = row['minRefStart']
            maxstart = row['maxRefStart']
            strand = row['Strand']
            name = '_'.join([row['Order'], row['Family'], row['Subfamily']]) 
            uuid = row['UUID']            
            coords[genotype].append([chrm, minstart, maxstart, strand, name, samples, uuid])

In [None]:
for genotype in ['Rare', 'Singleton', 'Unknown', 'Fixed']:
    print(genotype)
    outfiles = {}
    
    if not os.path.isdir(opath):
        os.makedirs(opath)
        
    for coord in coords[genotype]:
        print(coord)
        chrm = coord[0]
        minstart = coord[1]
        maxstart = coord[2]
        strand = coord[3]
        name = coord[4]
        ins_samples = coord[5]
        flank = 200
        records = []
        outfile = '_'.join([chrm, str(minstart), str(maxstart), name, strand])
        for sample in samples:
            readcount = 200
            if sample in ins_samples:
                continue
            bamfile = '~/' + sample + '.dm6.porechop.sorted.flt.bam'
            bam = pysam.AlignmentFile(bamfile, 'rb')
            read_set = bam.fetch(chrm, minstart - flank, maxstart + flank)
        
            for read in read_set:
                if readcount < 0:
                    break
                readcount -= 1
                if read.is_supplementary or read.is_unmapped or (read.mapping_quality < 5):
                    continue
                if not read.is_reverse:

                    ref_start = read.get_reference_positions()[0]
                    ref_end = read.get_reference_positions()[-1]

                    if (ref_start >= minstart) and (ref_start <= maxstart + flank):
                        for i, cigar in enumerate(read.cigartuples):
                            if i == 0 and (cigar[0] in [4,5]): 
                                left_clipped_pos = cigar[1]
                                seq = str(read.query_sequence[0:left_clipped_pos ] )
                                if len(seq) < 100:
                                    continue
                                record = SeqRecord(
                                    Seq(seq),
                                    id=':'.join([sample, read.query_name, '+']),
                                    name='',
                                    description='')
                                records.append(record)
                                break

                    if (ref_start < minstart) and (ref_start >= minstart - flank) :
                        for i, cigar in enumerate(read.cigartuples):
                            if i == 0 and (cigar[0] in [4,5]): 
                                left_clipped_pos = cigar[1]
                                seq = str(read.query_sequence[0:left_clipped_pos] )
                                if len(seq) < 100:
                                    continue
                                record = SeqRecord(
                                    Seq(seq),
                                    id=':'.join([sample, read.query_name, '+']),
                                    name='',
                                    description='')
                                records.append(record)
                                break
                                
                    if (ref_end < minstart) and (ref_end >= minstart - flank) :
                        for i, cigar in reversed(list(enumerate(read.cigartuples))):
                            if (i == (len(read.cigartuples) -1)) and (cigar[0] in [4,5]): 
                                right_clipped_pos = cigar[1]
                                seq = str(read.query_sequence[(len(read.query_sequence) - right_clipped_pos):-1] )
                                if len(seq) < 100:
                                    continue
                                record = SeqRecord(
                                    Seq(seq),
                                    id=':'.join([sample, read.query_name, '+']),
                                    name='',
                                    description='')
                                records.append(record)
                                break

                    if (ref_end >= minstart) and (ref_end <= maxstart + flank):
                        for i, cigar in reversed(list(enumerate(read.cigartuples))):
                            if (i == (len(read.cigartuples) -1))  and (cigar[0] in [4,5]): 
                                right_clipped_pos = cigar[1]
                                seq = str(read.query_sequence[(len(read.query_sequence) - right_clipped_pos ):-1] )
                                if len(seq) < 100:
                                    continue
                                record = SeqRecord(
                                    Seq(seq),
                                    id=':'.join([sample, read.query_name, '+']),
                                    name='',
                                    description='')
                                records.append(record)
                                break

                if read.is_reverse:
                    ref_start = read.get_reference_positions()[0]
                    ref_end = read.get_reference_positions()[-1]
                    if (ref_start >= minstart) and (ref_start <= maxstart + flank):
                        for i, cigar in enumerate(read.cigartuples):
                            if (i == 0) and (cigar[0] in [4,5]): 
                                left_clipped_pos = cigar[1]
                                seq = str(Seq(read.query_sequence[0:left_clipped_pos]).reverse_complement() )
                                if len(seq) < 100:
                                    continue
                                record = SeqRecord(
                                    Seq(seq),
                                    id=':'.join([sample, read.query_name, '-']),
                                    name='',
                                    description='')
                                records.append(record)
                                break

                    if (ref_start < minstart) and (ref_start >= minstart - flank) :
                        for i, cigar in enumerate(read.cigartuples):
                            if (i == 0) and (cigar[0] in [4,5]): 
                                left_clipped_pos = cigar[1]
                                seq = str(Seq(read.query_sequence[0:left_clipped_pos]).reverse_complement() )
                                if len(seq) < 100:
                                    continue
                                record = SeqRecord(
                                    Seq(seq),
                                    id=':'.join([sample, read.query_name, '-']),
                                    name='',
                                    description='')
                                records.append(record)
                                break

                    if (ref_end < minstart) and (ref_end >= minstart - flank):
                        for i, cigar in reversed(list(enumerate(read.cigartuples))):
                            if (i == (len(read.cigartuples) -1)) and (cigar[0] in [4,5]): 
                                right_clipped_pos = cigar[1]
                                seq = str(Seq(read.query_sequence[(len(read.query_sequence) - right_clipped_pos):-1]).reverse_complement() )
                                if len(seq) < 100:
                                    continue
                                record = SeqRecord(
                                    Seq(seq),
                                    id=':'.join([sample, read.query_name, '-']),
                                    name='',
                                    description='')
                                records.append(record)
                                break

                    if (ref_end >= minstart) and (ref_end <= maxstart + flank):
                        for i, cigar in reversed(list(enumerate(read.cigartuples))):
                            if (i == (len(read.cigartuples) -1) ) and (cigar[0] in [4,5]): 
                                right_clipped_pos = cigar[1]
                                seq = str(Seq(read.query_sequence[(len(read.query_sequence) - right_clipped_pos):-1]).reverse_complement() )
                                if len(seq) < 100:
                                    continue
                                record = SeqRecord(
                                    Seq(seq),
                                    id=':'.join([sample, read.query_name, '-']),
                                    name='',
                                    description='')
                                records.append(record)
                                break   

        _ = SeqIO.write(records, outfile + '.fa', 'fasta')

In [None]:
updated_genotypes = {}
a = mp.Aligner(teTELR)  # load or build index
families_updated = {}
families_notupdated = {}
notupdated_coords= {}
count_strand_filter = 0
genotype_map = {'Rare':'Rare', 'Singleton':'Singleton', 'Ungenotyped':'Unknown', 'Fixed':'Fixed'}

for genotype in ['Rare', 'Singleton', 'Ungenotyped', 'Fixed']:    
    print(genotype)
    updated_bed = open('updated_coords_' + genotype + '.bed', 'w')
    notupdated_bed = open('notupdated_coords_' + genotype + '.bed', 'w')
    
    notupdated_coords[genotype] = []
    families_updated[genotype] = {}
    families_notupdated[genotype] = {}
    
    single_sample = 0
    total = 0
    total_updated = 0
    
    to_check_upd = []
    to_check_notupd = []
    
    for coord in coords[genotype]:
        chrm = coord[0]
        minstart = coord[1]
        maxstart = coord[2]
        te_strand = coord[3]
        name = coord[4]
        ins_samples = coord[5]
        flank = 200
        uuid = coord[-1]
        samples_to_add = []
        outfile = opath + '_'.join([chrm, str(minstart), str(maxstart), name, te_strand])
        te_query = name.split('_')[2]
        
        total +=1

        with open(apath + '_'.join([chrm, str(minstart), str(maxstart), name, genotype, te_strand]) + '.alg', 'w') as outdata:
            
            for read_name, seq, qual in mp.fastx_read(outfile  + '.fa'): # read a fasta/q sequence
                    for hit in a.map(seq): # traverse alignments
                        strand_filter = -1
                        if hit.is_primary and hit.mapq >= 15:
                            sample = read_name.split(':')[0]
                            if sample in ins_samples:
                                break
                            te_read = hit.ctg.split('#')[0]
                            read_strand = read_name.split(':')[2]
                            hit_strand = hit.strand
                            _ = outdata.write('\t'.join([read_name, te_query, te_strand, read_strand,
                                                 seq[hit.q_st:hit.q_en], hit.ctg, str(hit.r_st), str(hit.r_en), 
                                                 str(hit.mapq), hit.cigar_str, str(hit.strand)]) + '\n')
                        
                            if te_strand == '+':
                                if read_strand == '+':
                                    if hit_strand == 1:
                                        strand_filter = 1
                                    elif hit_strand == -1:
                                        strand_filter = 0
                                elif read_strand == '-':
                                    if hit_strand == 1:
                                        strand_filter = 0
                                    elif hit_strand == -1:
                                        strand_filter = 1
                            
                            elif te_strand == '-':
                                if read_strand == '+':
                                    if hit_strand == 1:
                                        strand_filter = 0
                                    elif hit_strand == -1:
                                        strand_filter = 1
                                elif read_strand == '-':
                                    if hit_strand == 1:
                                        strand_filter = 1
                                    elif hit_strand == -1:
                                        strand_filter = 0
                            
                            if (te_read == te_query) and (sample not in samples_to_add) \
                                and (strand_filter == 1) and (not check_for_low_complexity(seq[hit.q_st:hit.q_en])):
                                samples_to_add.append(sample)
                                
                            if strand_filter == 0:
                                count_strand_filter+=1
                            if strand_filter == -1:
                                print('Error: strand_filter')
                                break
                                                                                             
        if len(samples_to_add) == 0:
            if genotype=='Fixed':
                to_check_notupd.append(uuid)
                
            notupdated_coords[genotype].append(coord)
            _ = notupdated_bed.write('\t'.join([chrm, str(minstart), str(maxstart), '.', te_query, te_strand]) + '\n')  

            if te_query not in families_notupdated[genotype]:
                families_notupdated[genotype][te_query] = 1
            else:
                families_notupdated[genotype][te_query] +=1

            if len(ins_samples.split(';')) == 1:
                single_sample += 1
        else:
            if genotype=='Fixed':
                to_check_upd.append(uuid)
            if uuid in updated_genotypes:
                print('Error: uuid is already in updated_genotypes')
                break
            updated_genotypes[uuid]= 'add_samples:' + ';'.join(samples_to_add) 
            _ = updated_bed.write('\t'.join([chrm, str(minstart), str(maxstart), '.', te_query, te_strand]) + '\n')  
            total_updated +=1
            if te_query not in families_updated[genotype]:
                families_updated[genotype][te_query] = 1
            else:
                families_updated[genotype][te_query] +=1
    print(total, total_updated, total - total_updated, single_sample)

    families_notupdated[genotype]
    families_updated[genotype]
    notupdated_bed.close()
    updated_bed.close()

In [None]:
count_keep = 0
count_remove = 0
count_somatic= 0 
count_fail_arm = 0
count_fail_ms = 0
count_fail_other = 0
               
for coord in notupdated_coords['Ungenotyped']:
    print(coord)
    reads = []
    families = []
    uuids_unsplit = coord[-1]
    uuids = coord[-1].split(';')
    samples = coord[-2].split(';')
    uuid_list = []
    samples_list= []
    pass_filter = []
    
    for i in range(len(samples)):
        if samples[i]!= '-':
            if '|' not in uuids[i]:
                uuid_list.append(uuids[i])
                samples_list.append(samples[i])
            else:
                for uuid in uuids[i].split('|'):
                    uuid_list.append(uuid)
                    samples_list.append(samples[i])
    
    for ind, uuid in enumerate(uuid_list):
        read_count = 0
        with open(uuid + '.detail.out') as indata:
            for i, l in enumerate(indata):
                if i == 0:
                    continue
                ln = l.rstrip().split('\t')
                if (ln[-3] == 'True') and (float(ln[-4]) >= 0.8):
                    reads.append(ln[2])
                    families.append(ln[4].split(',')[1])
                read_count += 1
            if uuids_unsplit in updated_genotypes:
                print('WARNING: uuid is already in updated_genotypes')

            print(read_count, len(reads), len(list(set(reads))), len(list(set(families))), end=' ')
            
            if (len(list(set(reads))) == 1) and (len(list(set(families))) == 1):
                pass_filter.append('FAIL_ambigous_read_mapping')

            elif (len(list(set(reads))) >= 1) and (len(list(set(families))) > 1):            
                pass_filter.append('FAIL_multiple_subfamilies')

            elif (len(list(set(reads))) >= 2) and (len(list(set(families))) == 1):            
                pass_filter.append('PASS')
            else:
                pass_filter.append('FAIL_other')
    
    if 'PASS' in pass_filter:
        updated_genotypes[uuids_unsplit] = 'keep'
        count_keep +=1
    else:    
        count_remove+=1
        updated_genotypes[uuids_unsplit] = 'remove'
        
        if 'FAIL_ambigous_read_mapping' in pass_filter:
            count_fail_arm+=1
        if 'FAIL_multiple_subfamilies' in pass_filter:
            count_fail_ms+=1
        if 'FAIL_other' in pass_filter:
            count_fail_other+=1
        
print('count_keep', count_keep)
print('count_remove', count_remove)
print('FAIL_ambigous_read_mapping, probably singleton', count_fail_arm)
print('FAIL_multiple_subfamilies', count_fail_ms)
print('FAIL_other', count_fail_other)

In [None]:
rows_to_remove = []
updt_rows = []
count_changed_gt_fixed = 0
count_changed_single_germvar = 0
count_changed_single_noTSD = 0
count_changed_single_longTSD = 0
count_singleton_to_rare = 0
count_singleton_to_ungenotyped = 0
count_ungenotyped_to_rare = 0
header = list(result.columns.values)

for index, row in result.iterrows():
    uuid = row['UUID']
    genotype = row['Genotype']
    samples = row['Sample']
    tissue = row['Tissue']
    subfamily = row['Subfamily']
    nsamples = int(row['nSamples'])
    vals = row.values.flatten().tolist()
    flag_to_drop = 0
    
    if uuid in updated_genotypes:
        action = updated_genotypes[uuid]
        if (action=='remove') or (action=='remove_somatic'):
            if index in rows_to_remove:
                print('Warning: duplicated index in rows_to_remove')
            rows_to_remove.append(index)
            flag_to_drop = 1
            
        if action=='rare_to_somatic':
            vals[header.index('Genotype')] = 'Somatic'
            result.at[index, 'Genotype'] = 'Somatic'
                        
        if action == 'keep':
            pass
        if action == 'keep_somatic':
            pass
        
        if 'add_samples:' in action:
            samples_list = samples.split(';')
            samples_toadd_list = action.split(':')[1].split(';')
            
            for i, sample in enumerate(sample_order):
                if sample in samples_toadd_list:
                    samples_list[i] = sample
            
            vals[header.index('Sample')] = ';'.join(samples_list)
            vals[header.index('nSamples')] = nsamples + len(samples_toadd_list)
            
            result.at[index, 'Sample'] = ';'.join(samples_list)
            result.at[index, 'nSamples'] = nsamples + len(samples_toadd_list)
            
            samples = ';'.join(samples_list)
            
            if ('gut' in samples.lower()) and ('head' in samples.lower()):
                vals[header.index('Tissue')] = 'sboth'
                result.at[index, 'Tissue'] = 'sboth'
                tissue = 'sboth'
            elif 'gut' in samples.lower():
                vals[header.index('Tissue')] = 'gut'
                result.at[index, 'Tissue'] = 'gut'
                tissue = 'gut'
            elif 'head' in samples.lower():
                vals[header.index('Tissue')] = 'head'
                result.at[index, 'Tissue'] = 'head'
                tissue = 'head'
            
            if (genotype == 'Fixed') :
                if (tissue == 'gut') or (tissue == 'head'):
                    count_changed_gt_fixed+=1
                    
                    vals[header.index('Genotype')] = 'Ungenotyped'
                    result.at[index, 'Genotype'] = 'Ungenotyped'
                
            if (genotype == 'Singleton'):
                if tissue == 'sboth':
                    vals[header.index('Genotype')] = 'Rare'
                    result.at[index, 'Genotype'] = 'Rare'
                    count_singleton_to_rare+=1
                    
                if (tissue == 'gut') or (tissue == 'head'):
                    vals[header.index('Genotype')] = 'Ungenotyped'
                    result.at[index, 'Genotype'] = 'Ungenotyped'
                    count_singleton_to_ungenotyped+=1
                    
            if (genotype == 'Ungenotyped') :
                if tissue == 'sboth':
                    vals[header.index('Genotype')] = 'Rare'
                    result.at[index, 'Genotype'] = 'Rare'
                    count_ungenotyped_to_rare+=1
                    
                if (tissue == 'gut') or (tissue == 'head'):
                    vals[header.index('Genotype')] = 'Ungenotyped'
                    result.at[index, 'Genotype'] = 'Ungenotyped'
    else:
        
        if (genotype == 'Singleton') and (subfamily in germline_active):
            vals[header.index('Genotype')] = 'Ungenotyped'
            result.at[index, 'Genotype'] = 'Ungenotyped'
            count_changed_single_germvar+=1
            
        elif (genotype == 'Singleton') and (('noTSD' in row['TSD'])):
            vals[header.index('Genotype')] = 'Ungenotyped'
            result.at[index, 'Genotype'] = 'Ungenotyped'
            count_changed_single_noTSD+=1
        
        elif (genotype == 'Singleton') and (('longTSD' in row['TSD'])  ):
            vals[header.index('Genotype')] = 'Ungenotyped'
            result.at[index, 'Genotype'] = 'Ungenotyped'
            count_changed_single_longTSD+=1
                        
        elif (genotype == 'Fixed') :
            if (tissue == 'gut') or (tissue == 'head'):
                count_changed_gt_fixed+=1
                
                vals[header.index('Genotype')] = 'Ungenotyped'
                result.at[index, 'Genotype'] = 'Ungenotyped'
                
    if not flag_to_drop:
        updt_rows.append(vals)

print('count_changed_gt_fixed=', count_changed_gt_fixed)
print('count_changed_single_germvar=', count_changed_single_germvar)
print('count_changed_single_noTSD=', count_changed_single_noTSD)
print('count_changed_single_longTSD=', count_changed_single_longTSD)
print('count_singleton_to_rare', count_singleton_to_rare)
print('count_singleton_to_ungenotyped', count_singleton_to_ungenotyped)
print('count_ungenotyped_to_rare', count_ungenotyped_to_rare)
print('number of rows to remove=', len(rows_to_remove))

result.drop(index=rows_to_remove, inplace = True)       
result.reset_index(drop=True, inplace=True)

In [None]:
count_singleton_noTSD = 0
count_singleton_longTSD = 0
count_singleton_gv = 0
count_pot_clonal = 0
count_amb_samples = 0
count_other = 0

for i, row in result[(result['Genotype']=='Ungenotyped')].iterrows():
    genotype = row['Genotype']
    tissue = row['Tissue']
    subfamily = row['Subfamily']
    tissue = row['Tissue']
    nSamples = int(row['nSamples'])
    tsd = row['TSD']
    usedReads = row['UsedReads'].replace('0', '').replace(';', '')
    
    if (nSamples==1) and  (int(usedReads)==1) and (subfamily in germline_active):
        count_singleton_gv+=1
    
    elif (nSamples==1) and  (int(usedReads)==1) and ('noTSD' in tsd):
        count_singleton_noTSD+=1
        
    elif (nSamples==1) and  (int(usedReads)==1) and ('longTSD' in tsd):
        count_singleton_longTSD+=1
        
    elif (nSamples==1) and (int(usedReads)>1):
        count_pot_clonal +=1
    
    elif (nSamples>1) and (tissue!='sboth'):
        count_amb_samples+=1
    else:
        count_other+=1
        
print('count_singleton_noTSD', count_singleton_noTSD)
print('count_singleton_longTSD', count_singleton_longTSD)
print('count_singleton_gv', count_singleton_gv)
print('count_pot_clonal', count_pot_clonal)
print('count_amb_samples', count_amb_samples)
print('count_other', count_other)

In [None]:
illumina = pd.read_csv('PGFP_Illumina.csv', index_col=0, header=0, sep='\t')

In [None]:
result['illGT'] = pd.Series(dtype='int')
result['illGT'] = result['illGT'].astype(str)

In [None]:
for indx, row in result.iterrows():
    uuid = row['UUID']
    genotype = row['Genotype']
    samples = row['Sample']
    tissue = row['Tissue']
    nsamples = int(row['nSamples'])
    
    chrm = row['Chrom']
    start = row['Start']
    strand = row['Strand']
    subfamily = row['Subfamily']
    tmp = illumina[(illumina['Chrom'] == chrm) & 
                   (start >= illumina['Start'] - 100 ) & 
                   (start <= illumina['Start']  + 100 ) &
                   (illumina['Subfamily'] == subfamily ) &
                  ((illumina['Strand'] == strand ) | (illumina['Strand'] == np.nan))
                  ]
    
    if len(tmp.index) >1:
        tmp
    
    if len(tmp.index) == 1:
        result.at[indx, 'illAF'] = tmp['meanAF'].iloc[0]
        result.at[indx, 'illPF'] = tmp['nPairs'].iloc[0] / 31
        result.at[indx, 'illGT'] = tmp['Genotype'].iloc[0] 

In [None]:
result.to_csv('~/final_table.csv', sep='\t')