## Loading data

In [1]:
import pysam

In [2]:
sequence_reads = pysam.AlignmentFile("SRR_final_sorted.bam", "rb")

In [3]:
sequence_reads.count() # how many reads are in the file

4656238

In [5]:
bamfile_contigs = []
for contig in sequence_reads.header.to_dict()['SQ']:
    bamfile_contigs.append(contig['SN'])

In [6]:
len(bamfile_contigs)

3366

In [7]:
example_read = next(sequence_reads.head(1))

In [8]:
print(example_read)

SRR590764.636676	99	#0	10354	9	43M1D33M	#0	10374	89	CCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCC	array('B', [31, 30, 28, 31, 29, 32, 34, 32, 32, 31, 30, 32, 34, 31, 31, 34, 18, 28, 32, 31, 34, 33, 28, 31, 30, 32, 32, 34, 19, 27, 27, 20, 32, 32, 31, 27, 32, 20, 27, 31, 32, 30, 30, 33, 9, 18, 33, 11, 32, 27, 23, 22, 31, 26, 14, 22, 22, 24, 26, 28, 31, 31, 22, 28, 16, 29, 32, 33, 28, 32, 2, 2, 2, 2, 2, 2])	[('MC', '69M7S'), ('BD', 'MMLNONNOLNNMMNLNNLMMLNMLLMKMMLLMLNNMMNLLNNMMNLNNMMNLNNMMNLNNMMNLNONNOMNNMMMM'), ('MD', '43^C33'), ('RG', 'SRR'), ('BI', 'PPNQPNPOMQPNOOMPONOOMPOOOOMPONOPMQPOOPNNQPOPPNQPOPPNQPOPPNQPOPPNQPPQPNQQPPPP'), ('NM', 1), ('MQ', 9), ('AS', 69), ('XS', 76)]


## Algorithm

In [9]:
import numpy as np
np.random.seed(12)

In [10]:
def write_VCF(filename, reference, variants):
    write_header(filename, reference, list(variants.keys()), [v["contig_len"] for v in variants.values()])
    i = 1
    with open(f"{filename}.vcf", "a") as f:
        for key in variants.keys():
            for n, q, b, r, v in zip(variants[key]["num_reads"], \
                                     variants[key]["read_qualities"], \
                                     variants[key]["bin_bounds"], \
                                     variants[key]["refs"], \
                                     variants[key]["variants"]
                                    ):
                if v == "deletion":
                    f.write(f"{key}\t{b[0]+1}\tDEL{str(i).zfill(10)}\t{r.upper()}\t<DEL>\t100\tPASS\tSVTYPE=DEL;SVMETHOD=depthRead;END={b[1]};MAPQ={q}\n")
                else:
                    f.write(f"{key}\t{b[0]+1}\tDUP{str(i).zfill(10)}\t{r.upper()}\t<DUP>\t100\tPASS\tSVTYPE=DUP;SVMETHOD=depthRead;END={b[1]};MAPQ={q};SVLEN={b[1]-b[0]}\n")
                i += 1
            
from datetime import datetime

def write_header(filename, reference, contigs, contig_lengths):
    with open(f"{filename}.VCF", "w") as f:
        f.write('##fileformat=VCFv4.2\n')
        f.write(f'##fileDate={datetime.today().strftime("%Y%m%d")}\n')
        f.write(f'##reference={reference}\n')
        f.write('##ALT=<ID=DUP,Description="Duplication">\n')
        f.write('##ALT=<ID=DEL,Description="Deletion">\n')
        f.write('##FILTER=<ID=PASS,Description="All filters passed">\n')
        f.write('##INFO=<ID=SVTYPE,Number=1,Type=String,Description="Type of structural variant">\n')
        f.write('##INFO=<ID=SVMETHOD,Number=1,Type=String,Description="Type of approach used to detect SV">\n')
        f.write('##INFO=<ID=END,Number=1,Type=Integer,Description="1-based end position of the structural variant">\n')
        f.write('##INFO=<ID=MAPQ,Number=1,Type=Integer,Description="Median mapping quality of reads">\n')
        f.write('##INFO=<ID=SVLEN,Number=1,Type=Integer,Description="Duplication length for SVTYPE=DUP.">\n')
        for contig, ln in zip(contigs, contig_lengths):
            f.write(f'##contig=<ID={contig},length={ln}>\n')
        f.write('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\n')

In [11]:
class CBS:
    # Based on instructions and implementation from this blog: https://jeremy9959.net/Blog/cbs-fixed/
    def __init__(self, shuffles=100, p=0.05, validation_p=0.01, min_size=5):
        self.shuffles = shuffles
        self.p = p
        self.validation_p=validation_p
        self.min_size = min_size
        
    @staticmethod
    def t_statistic(x0, x1):
        denom = np.sqrt((np.std(x0)**2)/len(x0) + (np.std(x1)**2)/len(x1)) + 1e-10
        if len(x0) == 0:
            mu0 = 0
        else:
            mu0 = np.mean(x0)
        if len(x1) == 0:
            mu1 = 0
        else:
            mu1 = np.mean(x1)
        return ((mu0-mu1)/denom)**2

    def _best_interval(self, x):
        # shortcut for selecting best "inside" interval, taken from https://jeremy9959.net/Blog/cbs-fixed/
        n = len(x)
        x0 = x - np.mean(x)
        y = np.cumsum(x0)
        e0, e1 = np.argmin(y), np.argmax(y)
        i0, i1 = min(e0, e1), max(e0, e1)
        t = (y[i1]-y[i0])**2*n/(i1-i0+1)/(n+1-i1+i0)
        return t, i0, i1+1

    def _main_alg(self, x):
        initial_t, initial_start, initial_end = self._best_interval(x)
        if initial_end-initial_start == len(x):
            return False, initial_t, initial_start, initial_end
        if initial_start < self.min_size:
            initial_start = 0
        if len(x)-initial_end < self.min_size:
            initial_end = len(x)
        interval_significance = self._check_if_interval_significant(x, initial_t)
        return interval_significance, initial_t, initial_start, initial_end

    def _check_if_interval_significant(self, x, initial_threshold, seg=None):
        x_copy = x.copy()
        t_c = 0
        if seg is not None:
            thresh_tolerance = self.shuffles*self.validation_p
        else:
            thresh_tolerance = self.shuffles*self.p
            
        for _ in range(self.shuffles):
            np.random.shuffle(x_copy)
            if seg is None:
                t, a, b = self._best_interval(x_copy)
            else:
                t = CBS.t_statistic(x_copy[:seg], x_copy[seg:])
            if t >= initial_threshold:
                t_c += 1
            if t_c > thresh_tolerance:
                return False
        return True
    
    
    def _recurrent_segment(self, x, start, end, segments):
        threshold, t, s, e = self._main_alg(x[start:end])
        if  not threshold or e-s < self.min_size or e-s == end-start :
            segments.append((start, end))
        else:
            if s > 0:
                self._recurrent_segment(x, start, start+s, segments)
            if e-s > 0:
                self._recurrent_segment(x, start+s, start+e, segments)
            if start+e < end:
                self._recurrent_segment(x, start+e, end, segments)
        return segments

    def validate(self, x, L):
        S = [x[0] for x in L]+[len(x)]
        SV = [0]
        left = 0
        for test, s in enumerate(S[1:-1]):
            x0 = x[S[left]:S[test+2]]
            i = S[test+1]-S[left]
            t = CBS.t_statistic(x0[:i], x0[i:])
            if self._check_if_interval_significant(x0, t, i):
                SV.append(S[test+1])
                left += 1
        SV.append(S[-1])
        return SV

    def run_algorithm(self, counts, contig):
        print(f"Segmentation started for contig: {contig}\n")
        seg = {contig:[]}
        if len(counts) <= 1:
            return seg
        self._recurrent_segment(counts, 0, len(counts), seg[contig])
        print(f"Starting breakpoint validation for contig {contig}...\n")
        seg[contig] = self.validate(counts, seg[contig])
        return seg
    

    

    

In [12]:
from tqdm.notebook import tqdm
import statistics
import multiprocessing as mp
from Bio import SeqIO

class ReadDepthAlgorithm:
    def __init__(self, bin_size, file, reference, output_filename=None, average=True, dup_threshold=None, del_threshold=None, 
                 check_validity = False, cbs=None, num_cpus=None, **kwargs):
        """
        Parameters:
        bin_size := int, required, the size of the bin
        file := string, required, .bam file name with reads for variant detection
        reference := string, required, .fasta file name with reference genome sequences
        output_filename := string or None, the filename of the output .vcf file
        average := boolean, whether to average the number of reads across segments
        dup_threshold := float or None, the thershold for duplication
        del_threshold := float or None, the thershold for deletions
        check_validity := bool, whether to check the validity of the reads (if True, unmapped and duplicated reads will not be considered for the algorithm)
        cbs := CBS or None, cbs algorithm class
        num_cpus := int or None, the number of cpus to use during parallelization of CBS algorithm, if None uses all available
        """
        self.bin_size = bin_size
        self.dup_threshold = dup_threshold # by default the thresholds will be the 95 and 5 percentiles of the number of reads per bin
        self.del_threshold = del_threshold
        self.bins = dict()
        self.check_validity = check_validity
        self.average = average
        self.file = file
        self.reference = reference
        self.reference_dict = self._create_reference_dict()
        self.output_filename = output_filename
        if num_cpus is None:
            self.num_cpus = mp.cpu_count()
        else:
            self.num_cpus = num_cpus
        if cbs is None:
            self.cbs = CBS(**kwargs)
        else:
            self.cbs = cbs

    def _create_reference_dict(self):
        # creating a dictionary containing reference sequences for each contig
        fasta_file = SeqIO.parse(open(self.reference), 'fasta')
        reference_dict = dict()
        for fasta in fasta_file:
            reference_dict[fasta.id] = fasta.seq
        return reference_dict
        
    @staticmethod
    def get_read_start_len(read):
        return read.reference_start, read.reference_length

    @staticmethod
    def read_in_bin(bin_start, bin_end, read): #check if read falls inside a given bin. 
        # We assume that read falls inside a bin if at least half of it belongs to that bin
        read_start, read_len = ReadDepthAlgorithm.get_read_start_len(read)
        if read_len is None:
            return False
        read_end = read_start + read_len
        return bin_start - read_start <= read_len/2 and read_end - bin_end <= read_len/2  

    def valid_read(self, read):
        # if we check validity, we take into account only mapped and non-duplicate reads <- as far as I understand, 
        #we take all the reads, no matter the mappability (authors of ReadDepth multiplied by the inverse of mappability to account for unmapped reads)
        if not self.check_validity:
            return True # if we dont check validity, all reads are passed as valid
        return not read.is_unmapped and not read.is_duplicate # and read.is_read1 <- we will read both reads and their mates
    

    def _set_thresholds(self):
        # setting duplication and deletion thresholds if None
        all_bin_nums = []
        for contig in self.bins:
            all_bin_nums.extend(self.bins[contig]['num_reads'])
        if self.dup_threshold is None:
            self.dup_threshold = np.quantile(all_bin_nums, 0.95)
        if self.del_threshold is None:
            self.del_threshold = np.quantile(all_bin_nums, 0.05)
    
    def _handle_segmented_vals(self, result):
        # function for assigning the results of the CBS algorithm to appropriate contigs
        contig = next(iter(result))
        self.bins[contig]["segments"] = result[contig]
    

    def _run_segmentation(self, contig_list):
        # function for running the CBS algorithm. It uses multiprocessing for algorithm parallelization
        print("SEGMENTATION PROCESS...")
        args = []
        for contig in contig_list:
            reads = self.bins[contig]["num_reads"]
            args.append((reads, contig))
        with mp.Pool(self.num_cpus) as p: 
            for result in p.starmap(self.cbs.run_algorithm, args):
                self._handle_segmented_vals(result)
                

    def _divide_reads_into_bins(self, contig_list):
        # compute number of reads in bins for all contigs
        print("BINNING PROCESS...")
        if self.reference_dict is None:
            self.reference_dict = self._create_reference_dict()
        for contig in tqdm(contig_list):
            self._divide_reads_into_bins_one_contig(contig)
        self.reference_dict = None
    
    def _divide_reads_into_bins_one_contig(self, contig):
        # compute number of reads in bins for a given contig
        contig_len = self.file.get_reference_length(contig)
        num_bins = contig_len//self.bin_size
        self.bins[contig] = dict()
        self.bins[contig]["contig_len"] = contig_len
        self.bins[contig]["num_reads"] = []
        self.bins[contig]["read_qualities"] = []
        self.bins[contig]["bin_bounds"] = []
        self.bins[contig]["refs"] = []
        seq = None
        if contig in self.reference_dict:
            seq = self.reference_dict[contig]

        for i in tqdm(range(1, num_bins+1), leave=False):
            bin_start = (i - 1) * self.bin_size
            bin_end = i * self.bin_size
            num_reads = 0
            read_qualities = []
            for read in self.file.fetch(contig, bin_start, bin_end):
                if ReadDepthAlgorithm.read_in_bin(bin_start, bin_end, read) and self.valid_read(read):
                    num_reads += 1
                    read_qualities.append(read.mapping_quality)
            self.bins[contig]["num_reads"].append(num_reads)
            self.bins[contig]["read_qualities"].append(read_qualities)
            self.bins[contig]["bin_bounds"].append((bin_start, bin_end))
            if seq is not None:
                self.bins[contig]["refs"].append(seq[bin_start])
        del seq
            
    def get_bins(self):
        # equivalent to get_SVs()
        return self.bins

    def _filter_binned_variances(self, contig_list):
        # applies the filtering to all the contigs
        print("FILTERING BINS...")
        for contig in contig_list:
            self._filter_binned_variances_one_contig(contig)

    def _filter_binned_variances_one_contig(self, contig):
        # function for filtering out the non-variant segments in a particular contig
        variants = []
        idxes = []
        for i, read in enumerate(self.bins[contig]["num_reads"]):
            if read > self.dup_threshold:
                idxes.append(i)
                variants.append("duplication")
            if read <= self.del_threshold:
                idxes.append(i)
                variants.append("deletion")
        self.bins[contig]["variants"] = variants
        if len(idxes) > 0:
            self.bins[contig]["num_reads"]      = np.array(self.bins[contig]["num_reads"])[idxes]
            self.bins[contig]["read_qualities"] = np.array(self.bins[contig]["read_qualities"])[idxes]
            self.bins[contig]["bin_bounds"]     = np.array(self.bins[contig]["bin_bounds"])[idxes]
        if len(self.bins[contig]["refs"]) > 0:
            self.bins[contig]["refs"]           = np.array(self.bins[contig]["refs"])[idxes]
        del self.bins[contig]["segments"]

    def get_SVs(self):
        # returns the found SVs
        return self.bins

    def _consolidate_variants(self, contig_bins):
        # function for consolidating bins after CBS segmentation
        bins      = contig_bins["num_reads"]
        segments  = contig_bins["segments"]
        qualities = contig_bins["read_qualities"]
        if len(bins) == 0 or len(segments) == 0:
            if len(contig_bins["read_qualities"]) != 0:
                new_qualities = []
                for q in contig_bins["read_qualities"]:
                    if len(q) == 0:
                        new_qualities.append(0.0)
                    else:
                        new_qualities.append(statistics.median(q))
                contig_bins["read_qualities"] = new_qualities
            else:
                contig_bins["read_qualities"] = [0.0]
            return 
        new_variants = []
        new_bounds = []
        new_qualities = []
        prev_segment = segments[0]
        for segment in segments[1:]:
            if self.average:
                new_variants.append(round(np.sum(bins[prev_segment:segment])/(segment-prev_segment),1)) # we set segment read number as average number of reads from bins
            else:
                new_variants.append((np.sum(bins[prev_segment:segment])))
            new_bounds.append((prev_segment*self.bin_size, segment*self.bin_size))
            new_qualities.append(np.concatenate(qualities[prev_segment:segment]))
            if len(new_qualities[-1]) == 0:
                new_qualities[-1] = 0.0
            else:
                new_qualities[-1] = float(statistics.median(new_qualities[-1]))
            prev_segment = segment
        contig_bins["num_reads"]      = new_variants
        contig_bins["bin_bounds"]     = new_bounds
        contig_bins["read_qualities"] = new_qualities
        

        # MAIN FUNCTION FOR RUNNING ALGORITHM
    def run_algorithm(self, contig_list):
        self._divide_reads_into_bins(contig_list)
        self._run_segmentation(contig_list)
        print("CONSOLIDATING BINS...")
        for contig in contig_list:
            self._consolidate_variants(self.bins[contig])
        self._set_thresholds()
        self._filter_binned_variances(contig_list)
        if self.output_filename:
            write_VCF(self.output_filename, self.reference, self.bins)
                    



In [13]:
REFERENCE_FILE='GRCh38_full_analysis_set_plus_decoy_hla.fa'

In [14]:
rda = ReadDepthAlgorithm(1000, file=sequence_reads, reference=REFERENCE_FILE, average=False)

In [15]:
rda.run_algorithm(bamfile_contigs)

BINNING PROCESS...


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/248956 [00:00<?, ?it/s]

SEGMENTATION PROCESS...
Segmentation started for contig: chr1

Starting breakpoint validation for contig chr1...

CONSOLIDATING BINS...
FILTERING BINS...


In [17]:
def count_variants(bins): # for counting the number of detected variants
    count_dict = {
        "deletion":0,
        "duplication":0
    }
    for contig in bins:
        for var in bins[contig]["variants"]:
            count_dict[var] += 1
    print(f'Number of deletions={count_dict["deletion"]}, number of duplications={count_dict["duplication"]}')

def count_variants_smaller(bins): # for counting the number of detected variants on only first 24 chromosomes
    i = 0
    count_dict = {
        "deletion":0,
        "duplication":0
    }
    for contig in bins:
        for var in bins[contig]["variants"]:
            count_dict[var] += 1
        i += 1
        if i == 23:
            break
    print(f'Number of deletions={count_dict["deletion"]}, number of duplications={count_dict["duplication"]}')

def average_segment_length(bins): # for computing average segment length
    lengths = []
    for contig in bins:
        for bb in bins[contig]["bin_bounds"]:
            lengths.append(bb[1]-bb[0])
    print(f"Average segment length={np.mean(lengths)}")