# Experimental Counting Optimization

In [67]:
bamfile = "/mnt/workspace2/jdetlef/data/public_data/sorted_heart_left_ventricle_194.bam"
fragments_file = "/mnt/workspace2/jdetlef/data/public_data/fragments_heart_left_ventricle_194_sorted.bed"
h5ad_file = "/mnt/workspace2/jdetlef/data/public_data/heart_lv_SM-JF1NY.h5ad"

In [68]:
import peakqc.general as general
import peakqc.insertsizes as insertsizes

In [69]:
import pandas as pd
import numpy as np
import gzip
import datetime
from multiprocessing import Manager, Lock, Pool
from tqdm import tqdm
import time


from beartype import beartype
import numpy.typing as npt
from beartype.typing import Any, Optional, Literal

In [70]:
import scanpy as sc

In [71]:
adata = sc.read_h5ad(h5ad_file)
adata

In [72]:
adata_barcodes = adata.obs.index.tolist()

In [73]:
%%time
# split index for barcodes CBs
barcodes = []
for entry in adata_barcodes:
    barcode = entry.split('+')[1]
    barcodes.append(barcode)

In [75]:
table_from_bam = insertsizes.insertsize_from_bam(bamfile=bamfile,
                        barcodes=barcodes,
                        barcode_tag='CB',
                        chunk_size=100000,
                        regions=None)

In [74]:
table_from_fragments = insertsizes.insertsize_from_fragments(fragments=fragments_file,
                              barcodes=barcodes,
                              n_threads=8)

In [77]:
table_from_fragments

In [88]:
table_from_fragments.to_hdf('count_table_heart_lv.h5',
                            key='df', mode='w')

In [78]:
table_from_fragments.to_csv('count_table_heart_lv.csv')

In [82]:
read_table = pd.read_csv('count_table_heart_lv.csv', index_col=0)

In [85]:
read_table['dist'][0]

In [61]:
def store_list_to_file(str_list, file_path):
    """
    Stores a list of strings to a file, with each string on a new line.

    Args:
    str_list (list of str): The list of strings to store.
    file_path (str): The path to the file where the list should be stored.
    """
    with open(file_path, 'w') as file:
        for item in str_list:
            file.write(f"{item}\n")


In [62]:
store_list_to_file(barcodes, 'barcodes.txt')

In [63]:
def read_list_from_file(file_path):
    """
    Reads a list of strings from a file, assuming each line in the file is a separate string.

    Args:
    file_path (str): The path to the file to read.

    Returns:
    list of str: The list of strings read from the file.
    """
    with open(file_path, 'r') as file:
        return [line.strip() for line in file]


In [65]:
len(barcodes)

In [64]:
read_list_from_file('barcodes.txt')

In [60]:
barcodes

In [11]:
table_from_fragments

In [11]:
lock = Lock()
type(lock)

In [12]:
another_lock = Lock()

In [16]:
type(lock) == type(Lock())

In [25]:
chunk = insertsize_from_fragments(fragments=fragments_file, barcodes=barcodes,
                              n_threads=8)

In [36]:
# Create a mock lock object
lock_instance = Lock()

# Call the function with the mock lock
insertsizes.init_pool_processes(lock_instance)

In [41]:
insertsizes._count_fragments_worker(chunk, managed_dict=managed_dict)

In [59]:
managed_dict['output']['AGGGATAAACCACCGAAGGTCA']['dist'][:10]

In [53]:
round(managed_dict['output']['AGGGATAAACCACCGAAGGTCA']['mean_insertsize']) == 140

In [40]:
managed_dict = {'output': {}}

In [26]:
chunk.to_csv('example_chunk.csv', index=False)

In [27]:
read_chunk = pd.read_csv('example_chunk.csv')

In [28]:
(chunk == read_chunk).all()

In [31]:
import os

In [34]:
os.getcwd()

In [24]:
@beartype
def _is_gz_file(filepath: str) -> bool:
    """
    Check wheather file is a compressed .gz file.

    Parameters
    ----------
    filepath : str
        Path to file.

    Returns
    -------
    bool
        True if the file is a compressed .gz file.
    """

    with open(filepath, 'rb') as test_f:
        return test_f.read(2) == b'\x1f\x8b'


@beartype
def init_pool_processes(the_lock: Any) -> None:
    '''
    Initialize each process with a global variable lock.

    Parameters
    ----------
    the_lock : Any
        Lock object to be used by the processes.

    Returns
    -------
    None
    '''
    global lock
    lock = the_lock


@beartype
def _check_in_list(element: Any, alist: list[Any] | set[Any]) -> bool:
    """
    Check if element is in list.

    Parameters
    ----------
    element : Any
        Element that is checked for.
    alist : list[Any] | set[Any]
        List or set in which the element is searched for.

    Returns
    -------
    bool
        True if element is in list else False
    """

    return element in alist


@beartype
def _check_true(element: Any, alist: Optional[list[Any]] = None) -> bool:  # true regardless of input
    """
    Return True regardless of input

    Parameters
    ----------
    element : Any
        Element that is checked for.
    alist: Optional[list[Any]]
        List or set in which the element is searched for.

    Returns
    -------
    bool
        True if element is in list else False
    """

    return True


@beartype
def _custom_callback(error: Exception) -> None:
    """
    Error callback function for multiprocessing.

    Parameters
    ----------
    error : Exception
        Error that is raised.

    Returns
    -------
    None
    """
    print(error, flush=True)


@beartype
def insertsize_from_fragments(fragments: str,
                              barcodes: Optional[list[str]] = None,
                              n_threads: int = 8) -> pd.DataFrame:
    """
    Count the insertsizes of fragments in a fragments file and get basic statistics (mean and total count) per barcode.

    Parameters
    ----------
    fragments : str
        Path to fragments file.
    barcodes : list[str], optional
        List of barcodes to count. If None, all barcodes are counted.
    n_threads : int, default 8
        Number of threads to use for multiprocessing.

    Returns
    -------
    pd.DataFrame
        Dataframe containing the mean insertsizes and total counts per barcode.
    """
    print('Count insertsizes from fragments...')
    # Open fragments file
    if _is_gz_file(fragments):
        f = gzip.open(fragments, "rt")
    else:
        f = open(fragments, "r")

    # Prepare function for checking against barcodes list
    if barcodes is not None:
        barcodes = set(barcodes)
        check_in = _check_in_list
    else:
        check_in = _check_true

    # Initialize iterator
    iterator = pd.read_csv(fragments,
                           delimiter='\t',
                           header=None,
                           names=['chr', 'start', 'stop', 'barcode', 'count'],
                           iterator=True,
                           chunksize=1000000)

    # start timer
    start_time = datetime.datetime.now()

    # Initialize multiprocessing
    m = Manager() # initialize manager
    lock = Lock() # initialize lock
    managed_dict = m.dict() # initialize managed dict
    managed_dict['output'] = {}
    # initialize pool
    pool = Pool(processes=n_threads,
                initializer=init_pool_processes,
                initargs=(lock,),
                maxtasksperchild=48)
    jobs = []
    print('Starting counting fragments...')
    # split fragments into chunks
    for chunk in tqdm(iterator, desc="Processing Chunks"):
        return chunk
        # apply async job wit callback function
        job = pool.apply_async(_count_fragments_worker,
                               args=(chunk, barcodes, check_in, managed_dict),
                               error_callback=_custom_callback)
        jobs.append(job)

    # close pool
    pool.close()
    # wait for all jobs to finish
    pool.join()
    # reset settings
    count_dict = managed_dict['output']

    # Close file and print elapsed time
    end_time = datetime.datetime.now()
    f.close()
    elapsed = end_time - start_time
    print("Done reading file - elapsed time: {0}".format(str(elapsed).split(".")[0]))

    # Convert dict to pandas dataframe
    print("Converting counts to dataframe...")
    table = pd.DataFrame.from_dict(count_dict, orient="index")
    # round mean_insertsize to 2 decimals
    table["mean_insertsize"] = table["mean_insertsize"].round(2)

    print("Done getting insertsizes from fragments!")

    return table


def _count_fragments_worker(chunk: pd.DataFrame,
                            barcodes: Optional[list[str]] = None,
                            check_in: Any = _check_true,
                            managed_dict: dict = {'output': {}}) -> None:
    """
    Worker function for counting fragments.

    Parameters
    ----------
    chunk : pd.DataFrame
        Chunk of fragments file.
    barcodes : list[str], optional
        List of barcodes to count. If None, all barcodes are counted.
    check_in : Any, default _check_true
        Function for checking if barcode is in barcodes list.
    managed_dict : dict, default None
        Dictionary for multiprocessing.

    Returns
    -------
    None

    """

    # Initialize count_dict
    count_dict = {}
    # Iterate over chunk
    for row in chunk.itertuples():
        start = int(row[2])
        end = int(row[3])
        barcode = row[4]
        count = int(row[5])
        size = end - start - 9  # length of insertion (-9 due to to shifted cutting of Tn5)

        # Only add fragment if check is true
        if check_in(barcode, barcodes) is True:
            count_dict = _add_fragment(count_dict, barcode, size, count) # add fragment to count_dict

    # Update managed_dict
    lock.acquire() # acquire lock
    latest = managed_dict['output']
    managed_dict['output'] = _update_count_dict(latest, count_dict) # update managed dict
    lock.release() # release lock


@beartype
def _add_fragment(count_dict: dict[str, int],
                  barcode: str,
                  size: int,
                  count: int = 1,
                  max_size: int=1000) -> dict:
    """
    Add fragment of size 'size' to count_dict.

    Parameters
    ----------
    count_dict : dict[str, int]
        Dictionary containing the counts per insertsize.
    barcode : str
        Barcode of the read.
    size : int
        Insertsize to add to count_dict.
    count : int, default 1
        Number of reads to add to count_dict.

    Returns
    -------
    dict
        Updated count_dict.
    """

    # Initialize if barcode is seen for the first time
    if barcode not in count_dict:
        count_dict[barcode] = {"mean_insertsize": 0, "insertsize_count": 0}

    # Add read to dict
    if size > 0 and size <= max_size:  # do not save negative insertsize, and set a cap on the maximum insertsize to limit outlier effects

        count_dict[barcode]["insertsize_count"] += count

        # Update mean
        mu = count_dict[barcode]["mean_insertsize"]
        total_count = count_dict[barcode]["insertsize_count"]
        diff = (size - mu) / total_count
        count_dict[barcode]["mean_insertsize"] = mu + diff

        # Save to distribution
        if 'dist' not in count_dict[barcode]:  # initialize distribution
            count_dict[barcode]['dist'] = np.zeros(max_size + 1)
        count_dict[barcode]['dist'][size] += count # add count to distribution

    return count_dict


@beartype
def _update_count_dict(count_dict_1: dict, count_dict_2: dict) -> dict:
    """
    Updates the managed dict with the new counts.

    Parameters
    ----------
    count_dict_1 : dict
        Dictionary containing the counts per insertsize.
    count_dict_2 : dict
        Dictionary containing the counts per insertsize.

    Returns
    -------
    dict
        Updated count_dict.
    """
    # Check if count_dict_1 is empty:
    if len(count_dict_1) == 0:
        return count_dict_2
    # Check if count_dict_2 is empty
    if len(count_dict_2) == 0:
        return count_dict_1

    # make Dataframes for computation
    df1 = pd.DataFrame(count_dict_1).T
    df2 = pd.DataFrame(count_dict_2).T

    # merge distributions
    combined_dists = df1['dist'].combine(df2['dist'], func=_update_dist)
    # merge counts
    merged_counts = pd.merge(df1["insertsize_count"], df2["insertsize_count"], left_index=True, right_index=True,
                             how='outer').fillna(0)
    # sum total counts/barcode
    updated_counts = merged_counts.sum(axis=1)

    # calculate scaling factors
    x_scaling_factor = merged_counts["insertsize_count_x"] / updated_counts
    y_scaling_factor = merged_counts["insertsize_count_y"] / updated_counts

    # merge mean insertsizes
    merged_mean_insertsizes = pd.merge(df1["mean_insertsize"], df2["mean_insertsize"], left_index=True,
                                       right_index=True, how='outer').fillna(0)

    # scale mean insertsizes
    merged_mean_insertsizes["mean_insertsize_x"] = merged_mean_insertsizes["mean_insertsize_x"] * x_scaling_factor
    merged_mean_insertsizes["mean_insertsize_y"] = merged_mean_insertsizes["mean_insertsize_y"] * y_scaling_factor

    # sum the scaled means
    updated_means = merged_mean_insertsizes.sum(axis=1)

    # build the updated dictionary
    updated_dict = pd.DataFrame(
        {'mean_insertsize': updated_means, 'insertsize_count': updated_counts, 'dist': combined_dists}).T.to_dict()

    return updated_dict


@beartype
def _update_dist(dist_1: npt.ArrayLike, dist_2: npt.ArrayLike) -> npt.ArrayLike:
    """
    Updates the Insertsize Distributions.

    Parameters
    ----------
    dist_1 : npt.ArrayLike
        Insertsize distribution 1.
    dist_2 : npt.ArrayLike
        Insertsize distribution 2.

    Returns
    -------
    npt.ArrayLike
        Updated insertsize distribution.
    """
    # check if both distributions are not empty
    if not np.isnan(dist_1).any() and not np.isnan(dist_2).any():
        updated_dist = dist_1 + dist_2 # add distributions
        return updated_dist.astype(int)
    # if one of the distributions is empty, return the other one
    elif np.isnan(dist_1).any():
        return dist_2.astype(int)
    elif np.isnan(dist_2).any():
        return dist_1.astype(int)

In [None]:
insertsize_from_fragments(fragments=fragments_file,
                              barcodes=barcodes,
                              n_threads = 8)

In [None]:
import sctoolbox.tools as tools

In [None]:
tools._insertsize_from_fragments(fragments=fragments_file,
                              barcodes=barcodes)

In [None]:
import peakqc.general as utils
import os
import re

In [None]:
@beartype
def open_bam(file: str,
             mode: str,
             verbosity: Literal[0, 1, 2, 3] = 3, **kwargs: Any) -> "pysam.AlignmentFile":
    """
    Open bam file with pysam.AlignmentFile. On a specific verbosity level.

    Parameters
    ----------
    file : str
        Path to bam file.
    mode : str
        Mode to open the file in. See pysam.AlignmentFile
    verbosity : Literal[0, 1, 2, 3], default 3
        Set verbosity level. Verbosity level 0 for no messages.
    **kwargs : Any
        Forwarded to pysam.AlignmentFile

    Returns
    -------
    pysam.AlignmentFile
        Object to work on SAM/BAM files.
    """

    # check then load modules
    utils.check_module("pysam")
    import pysam

    # save verbosity, then set temporary one
    former_verbosity = pysam.get_verbosity()
    pysam.set_verbosity(verbosity)

    # open file
    handle = pysam.AlignmentFile(file, mode, **kwargs)

    # return to former verbosity
    pysam.set_verbosity(former_verbosity)

    return handle

In [None]:
start_run = time.time()

regions=None
bam = bamfile
chunk_size = 100000
n_threads=10
barcode_tag = 'CB'

utils.check_module("pysam")
import pysam

if isinstance(regions, str):
    regions = [regions]

# Prepare function for checking against barcodes list
if barcodes is not None:
    barcodes = set(barcodes)
    check_in = _check_in_list
else:
    check_in = _check_true
    
# Open bamfile
print("Opening bam file...")
if not os.path.exists(bam + ".bai"):
    print("Bamfile has no index - trying to index with pysam...")
    pysam.index(bam)

bam_obj = open_bam(bam, "rb", require_index=True)
chromosome_lengths = dict(zip(bam_obj.references, bam_obj.lengths))

# Create chunked genome regions:
print(f"Creating chunks of size {chunk_size}bp...")

if regions is None:
    regions = [f"{chrom}:0-{length}" for chrom, length in chromosome_lengths.items()]
elif isinstance(regions, str):
    regions = [regions]

# Create chunks from larger regions
regions_split = []
for region in regions:
    chromosome, start, end = re.split("[:-]", region)
    start = int(start)
    end = int(end)
    for chunk_start in range(start, end, chunk_size):
        chunk_end = chunk_start + chunk_size
        if chunk_end > end:
            chunk_end = end
        regions_split.append(f"{chromosome}:{chunk_start}-{chunk_end}")
        
# start timer
start_time = datetime.datetime.now()

# Count insertsize per chunk using multiprocessing
print(f"Counting insertsizes across {len(regions_split)} chunks...")
count_dict = {}
read_count = 0
#pbar = tqdm(total=len(regions_split), desc="Progress: ", unit="chunks")
for region in tqdm(regions_split):
    chrom, start, end = re.split("[:-]", region)
    for read in bam_obj.fetch(chrom, int(start), int(end)):
        read_count += 1
        try:
            barcode = read.get_tag(barcode_tag)
        except Exception:  # tag was not found
            barcode = "NA"

        # Add read to dict
        if check_in(barcode, barcodes) is True:
            size = abs(read.template_length) - 9  # length of insertion
            count_dict = _add_fragment(count_dict, barcode, size)     

        # Update progress
#        pbar.update(1)
#    pbar.close()  # close progress bar

            
# Close file and print elapsed time
end_time = datetime.datetime.now()
bam_obj.close()
elapsed = end_time - start_time
print("Done reading file - elapsed time: {0}".format(str(elapsed).split(".")[0]))

# Convert dict to pandas dataframe
print("Converting counts to dataframe...")
table = pd.DataFrame.from_dict(count_dict, orient="index")
# round mean_insertsize to 2 decimals
table["mean_insertsize"] = table["mean_insertsize"].round(2)

print("Done getting insertsizes from fragments!")

finish_run = time.time()

print(f'Run finished in: {finish_run - start_run}')

In [None]:
@beartype
def _insertsize_from_bam(bam: str,
                         barcode_tag: str = "CB",
                         barcodes: Optional[list[str]] = None,
                         regions: Optional[str | list[str]] = 'chr1:1-2000000',
                         chunk_size: int = 100000) -> pd.DataFrame:
    """
    Get insertsize distributions per barcode from bam file.

    Parameters
    ----------
    bam : str
        Path to bam file
    barcode_tag : str, default "CB"
        The read tag representing the barcode.
    barcodes : Optional[list[str]], default None
        List of barcodes to include in the analysis. If None, all barcodes are included.
    regions : Optional[str | list[str]], default 'chr1:1-2000000'
        Regions to include in the analysis. If None, all reads are included.
    chunk_size : int, default 500000
        Size of bp chunks to read from bam file.

    Returns
    -------
    pd.DataFrame
        DataFrame with insertsize distributions per barcode.

    Raises
    ------
    ValueError:
        1. No reads found in bam-file.
        2. If no reads in bam-file overlap with barcodes.
    """

    # Load modules
    try:
        import pysam
    except:
        print('Check')

    if utils._is_notebook() is True:
        from tqdm import tqdm_notebook as tqdm
    else:
        from tqdm import tqdm

    if isinstance(regions, str):
        regions = [regions]

    # Prepare function for checking against barcodes list
    if barcodes is not None:
        barcodes = set(barcodes)
        check_in = _check_in_list
    else:
        check_in = _check_true

    # Open bamfile
    logger.info("Opening bam file...")
    if not os.path.exists(bam + ".bai"):
        logger.warning("Bamfile has no index - trying to index with pysam...")
        pysam.index(bam)

    bam_obj = sctoolbox.tools.bam.open_bam(bam, "rb", require_index=True)
    chromosome_lengths = dict(zip(bam_obj.references, bam_obj.lengths))

    # Create chunked genome regions:
    logger.info(f"Creating chunks of size {chunk_size}bp...")

    if regions is None:
        regions = [f"{chrom}:0-{length}" for chrom, length in chromosome_lengths.items()]
    elif isinstance(regions, str):
        regions = [regions]

    # Create chunks from larger regions
    regions_split = []
    for region in regions:
        chromosome, start, end = re.split("[:-]", region)
        start = int(start)
        end = int(end)
        for chunk_start in range(start, end, chunk_size):
            chunk_end = chunk_start + chunk_size
            if chunk_end > end:
                chunk_end = end
            regions_split.append(f"{chromosome}:{chunk_start}-{chunk_end}")

    # Count insertsize per chunk using multiprocessing
    logger.info(f"Counting insertsizes across {len(regions_split)} chunks...")
    count_dict = {}
    read_count = 0
    pbar = tqdm(total=len(regions_split), desc="Progress: ", unit="chunks")
    for region in regions_split:
        chrom, start, end = re.split("[:-]", region)
        for read in bam_obj.fetch(chrom, int(start), int(end)):
            read_count += 1
            try:
                barcode = read.get_tag(barcode_tag)
            except Exception:  # tag was not found
                barcode = "NA"

            # Add read to dict
            if check_in(barcode, barcodes) is True:
                size = abs(read.template_length) - 9  # length of insertion
                count_dict = _add_fragment(count_dict, barcode, size)

        # Update progress
        pbar.update(1)
    pbar.close()  # close progress bar

    bam_obj.close()

In [None]:
import peakqc.general as utils
import os
import re

In [None]:
@beartype
def open_bam(file: str,
             mode: str,
             verbosity: Literal[0, 1, 2, 3] = 3, **kwargs: Any) -> "pysam.AlignmentFile":
    """
    Open bam file with pysam.AlignmentFile. On a specific verbosity level.

    Parameters
    ----------
    file : str
        Path to bam file.
    mode : str
        Mode to open the file in. See pysam.AlignmentFile
    verbosity : Literal[0, 1, 2, 3], default 3
        Set verbosity level. Verbosity level 0 for no messages.
    **kwargs : Any
        Forwarded to pysam.AlignmentFile

    Returns
    -------
    pysam.AlignmentFile
        Object to work on SAM/BAM files.
    """

    # check then load modules
    utils.check_module("pysam")
    import pysam

    # save verbosity, then set temporary one
    former_verbosity = pysam.get_verbosity()
    pysam.set_verbosity(verbosity)

    # open file
    handle = pysam.AlignmentFile(file, mode, **kwargs)

    # return to former verbosity
    pysam.set_verbosity(former_verbosity)

    return handle

In [None]:
start_run = time.time()

regions=None
bam = bamfile
chunk_size = 10000000
n_threads=10
cb_tag = 'CB'

utils.check_module("pysam")
import pysam

if isinstance(regions, str):
    regions = [regions]

# Prepare function for checking against barcodes list
if barcodes is not None:
    barcodes = set(barcodes)
    check_in = _check_in_list
else:
    check_in = _check_true
    
# Open bamfile
print("Opening bam file...")
if not os.path.exists(bam + ".bai"):
    print("Bamfile has no index - trying to index with pysam...")
    pysam.index(bam)

bam_obj = open_bam(bam, "rb", require_index=True)
chromosome_lengths = dict(zip(bam_obj.references, bam_obj.lengths))

# Create chunked genome regions:
print(f"Creating chunks of size {chunk_size}bp...")

if regions is None:
    regions = [f"{chrom}:0-{length}" for chrom, length in chromosome_lengths.items()]
elif isinstance(regions, str):
    regions = [regions]

# Create chunks from larger regions
regions_split = []
for region in regions:
    chromosome, start, end = re.split("[:-]", region)
    start = int(start)
    end = int(end)
    for chunk_start in range(start, end, chunk_size):
        chunk_end = chunk_start + chunk_size
        if chunk_end > end:
            chunk_end = end
        regions_split.append(f"{chromosome}:{chunk_start}-{chunk_end}")
        
# start timer
start_time = datetime.datetime.now()

# Initialize multiprocessing
m = Manager() # initialize manager
lock = Lock() # initialize lock
managed_dict = m.dict() # initialize managed dict
managed_dict['output'] = {}
# initialize pool
pool = Pool(processes=n_threads,
            initializer=init_pool_processes,
            initargs=(lock,),
            maxtasksperchild=48)
jobs = []
print('Starting counting fragments...')
     
tag_idx = None
# Count insertsize per chunk using multiprocessing TODO here goes MP
for region in tqdm(regions_split):
    chrom, start, end = re.split("[:-]", region)
    #start_time = time.time()
    chunk = list(bam_obj.fetch(chrom, int(start), int(end)))
    chunk = [prep_reads(read) for read in chunk]
    stop_time = time.time()
    #print(f'reading: {stop_time - start_time}')
    
    job = pool.apply_async(_count_fragments_from_bam_worker,
                       args=(chunk, barcodes, check_in, managed_dict),
                       error_callback=_custom_callback)
    jobs.append(job)

# close pool
pool.close()
# wait for all jobs to finish
pool.join()
# reset settings
count_dict = managed_dict['output']

# Close file and print elapsed time
end_time = datetime.datetime.now()
bam_obj.close()
elapsed = end_time - start_time
print("Done reading file - elapsed time: {0}".format(str(elapsed).split(".")[0]))

# Convert dict to pandas dataframe
print("Converting counts to dataframe...")
table = pd.DataFrame.from_dict(count_dict, orient="index")
# round mean_insertsize to 2 decimals
table["mean_insertsize"] = table["mean_insertsize"].round(2)

print("Done getting insertsizes from fragments!")

finish_run = time.time()

print(f'Run finished in: {finish_run - start_run}')

In [None]:
import time
start = time.time()

In [None]:
def _count_fragments_from_bam_worker(chunk: list,
                            barcodes: Optional[list[str]] = None,
                            check_in: Any = _check_true,
                            managed_dict: dict = {'output': {}}) -> None:
    """
    Worker function for counting fragments.

    Parameters
    ----------
    chunk : pd.DataFrame
        Chunk of fragments file.
    barcodes : list[str], optional
        List of barcodes to count. If None, all barcodes are counted.
    check_in : Any, default _check_true
        Function for checking if barcode is in barcodes list.
    managed_dict : dict, default None
        Dictionary for multiprocessing.

    Returns
    -------
    None

    """
    # Initialize count_dict
    count_dict = {}
    
    # define helper
    for [barcode, size] in chunk:
    
        #barcode = pair[0]
        #size = pair[1]
        
        if check_in(barcode, barcodes) is True:
                count_dict = _add_fragment(count_dict, barcode, size) # add fragment to count_dict
    
    # process
    #[process_reads(pair, count_dict) for pair in chunk]
    # Update managed_dict
    lock.acquire() # acquire lock
    latest = managed_dict['output']
    try:
        managed_dict['output'] = _update_count_dict(latest, count_dict) # update managed dict
    except Exception as e:
        
        print(f'Exception: {e}')
    lock.release() # release lock

In [None]:
def process_reads(pair):
    
    barcode = pair[0]
    size = pair[1]
    
    if check_in(barcode, barcodes) is True:
            count_dict = _add_fragment(count_dict, barcode, size) # add fragment to count_dict
    

In [None]:
def prep_reads(read):

    barcode = read.get_tag(cb_tag)
    size = read.template_length - 9
    
    return [barcode, size]

In [None]:
import sctoolbox.tools as sctools

In [None]:
start = time.time()
count_table = sctools._insertsize_from_bam(bam=bamfile,
                         barcode_tag="CB",
                         barcodes=list(barcodes),
                         regions=None,
                         chunk_size= 100000)

stop = time.time()

print(f'original implementation: {stop-start}')

In [None]:
count_table.loc['AAATCCGCATAAATGCTACGGG'][np.arange(0,50)]

In [None]:
table.loc['AAATCCGCATAAATGCTACGGG']

In [None]:
table

In [None]:
# individual imports
import episcanpy as epi
import pandas as pd
import numpy as np
import gzip
import datetime
from multiprocessing import Manager, Lock, Pool
from tqdm import tqdm

from beartype import beartype
from beartype.typing import Any, Optional

@beartype
def _is_gz_file(filepath: str) -> bool:
    """
    Check wheather file is a compressed .gz file.

    Parameters
    ----------
    filepath : str
        Path to file.

    Returns
    -------
    bool
        True if the file is a compressed .gz file.
    """

    with open(filepath, 'rb') as test_f:
        return test_f.read(2) == b'\x1f\x8b'

class MPFragmentCounter():
    """
    """

    def __init__(self):
        """Init class variables."""
        pass

    def init_pool_processes(self, the_lock):
        '''
        Initialize each process with a global variable lock.
        '''
        global lock
        lock = the_lock

    def _check_in_list(self, element: Any, alist: list[Any] | set[Any]) -> bool:
        """
        Check if element is in list.

        TODO Do we need this function?

        Parameters
        ----------
        element : Any
            Element that is checked for.
        alist : list[Any] | set[Any]
            List or set in which the element is searched for.

        Returns
        -------
        bool
            True if element is in list else False
        """

        return element in alist

    def _check_true(element: Any, alist: Optional[list[Any]] = None) -> bool:  # true regardless of input
        """
        Return True regardless of input

        Parameters
        ----------
        element : Any
            Element that is checked for.
        alist: Optional[list[Any]]
            List or set in which the element is searched for.

        Returns
        -------
        bool
            True if element is in list else False
        """

        return True

    def custom_callback(self, error):
    	print(error, flush=True)
        

    def insertsize_from_fragments(self, fragments: str,
                                  barcodes: Optional[list[str]] = None,
                                  n_threads: int = 8) -> pd.DataFrame:

        print('Count insertsizes from fragments...')
        # Open fragments file
        if _is_gz_file(fragments):
            f = gzip.open(fragments, "rt")
        else:
            f = open(fragments, "r")

        # Prepare function for checking against barcodes list
        if barcodes is not None:
            barcodes = set(barcodes)
            check_in = self._check_in_list
        else:
            check_in = self._check_true

        iterator = pd.read_csv(fragments,
                               delimiter='\t',
                               header=None,
                               names=['chr', 'start', 'stop', 'barcode', 'count'],
                               iterator=True,
                               chunksize=5000000)

        # start timer
        start_time = datetime.datetime.now()

        # Initialize multiprocessing
        m = Manager()
        lock = Lock()
        managed_dict = m.dict()
        managed_dict['output'] = {}
        pool = Pool(processes=n_threads, initializer=self.init_pool_processes, initargs=(lock,), maxtasksperchild=48)
        jobs = []
        print('Starting counting fragments...')
        # split fragments into chunks
        for chunk in tqdm(iterator, desc="Processing Chunks"):
            # apply async job wit callback function
            job = pool.apply_async(self._count_fragments_worker, args=(chunk, barcodes, check_in, managed_dict), error_callback=self.custom_callback)
            jobs.append(job)
        
        # close pool
        pool.close()
        # wait for all jobs to finish
        pool.join()
        # reset settings
        count_dict = managed_dict['output']

        # Close file and print elapsed time
        end_time = datetime.datetime.now()
        f.close()

        elapsed = end_time - start_time
        print("Done reading file - elapsed time: {0}".format(str(elapsed).split(".")[0]))

        # Convert dict to pandas dataframe
        print("Converting counts to dataframe...")
        table = pd.DataFrame.from_dict(count_dict, orient="index")
        #table = table[["insertsize_count", "mean_insertsize"] + sorted(table.columns[2:])]
        table["mean_insertsize"] = table["mean_insertsize"].round(2)

        print("Done getting insertsizes from fragments!")

        return table

    def _count_fragments_worker(self, chunk, barcodes, check_in, managed_dict):
        """
        Worker function for counting fragments.
        Parameters
        ----------
        chunk
        barcodes
        check_in
        managed_dict

        Returns
        -------

        """

        # Initialize count_dict
        count_dict = {}
        for row in chunk.itertuples():
            start = int(row[2])
            end = int(row[3])
            barcode = row[4]
            count = int(row[5])
            size = end - start - 9  # length of insertion (-9 due to to shifted cutting of Tn5)

            # Only add fragment if check is true
            if check_in(barcode, barcodes) is True:
                count_dict = self._add_fragment(count_dict, barcode, size, count)

        lock.acquire()
        latest = managed_dict['output']
        managed_dict['output'] = self._update_count_dict(latest, count_dict)
        lock.release()

    def _add_fragment(self, count_dict: dict[str, int],
                      barcode: str,
                      size: int,
                      count: int = 1,
                      max_size=1000):
        """
        Add fragment of size 'size' to count_dict.

        Parameters
        ----------
        count_dict : dict[str, int]
            Dictionary containing the counts per insertsize.
        barcode : str
            Barcode of the read.
        size : int
            Insertsize to add to count_dict.
        count : int, default 1
            Number of reads to add to count_dict.
        """

        # Initialize if barcode is seen for the first time
        if barcode not in count_dict:
            count_dict[barcode] = {"mean_insertsize": 0, "insertsize_count": 0}

        # Add read to dict
        if size > 0 and size <= max_size:  # do not save negative insertsize, and set a cap on the maximum insertsize to limit outlier effects

            count_dict[barcode]["insertsize_count"] += count

            # Update mean
            mu = count_dict[barcode]["mean_insertsize"]
            total_count = count_dict[barcode]["insertsize_count"]
            diff = (size - mu) / total_count
            count_dict[barcode]["mean_insertsize"] = mu + diff

            # Save to distribution
            if 'dist' not in count_dict[barcode]:  # first time size is seen
                count_dict[barcode]['dist'] = np.zeros(max_size+1)
            count_dict[barcode]['dist'][size] += count

        return count_dict

    def _update_count_dict(self, count_dict_1, count_dict_2):
        """
        updates
        """
        # Check if count_dict_1 is empty:
        if len(count_dict_1) == 0:
            return count_dict_2

        # make Dataframes for computation
        df1 = pd.DataFrame(count_dict_1).T
        df2 = pd.DataFrame(count_dict_2).T

        # merge distributions
        combined_dists = df1['dist'].combine(df2['dist'], func=self._update_dist)
        # merge counts
        merged_counts = pd.merge(df1["insertsize_count"], df2["insertsize_count"], left_index=True, right_index=True,
                                 how='outer').fillna(0)
        # sum total counts/barcode
        updated_counts = merged_counts.sum(axis=1)

        # calculate scaling factors
        x_scaling_factor = merged_counts["insertsize_count_x"] / updated_counts
        y_scaling_factor = merged_counts["insertsize_count_y"] / updated_counts

        # merge mean insertsizes
        merged_mean_insertsizes = pd.merge(df1["mean_insertsize"], df2["mean_insertsize"], left_index=True,
                                           right_index=True, how='outer').fillna(0)

        # scale mean insertsizes
        merged_mean_insertsizes["mean_insertsize_x"] = merged_mean_insertsizes["mean_insertsize_x"] * x_scaling_factor
        merged_mean_insertsizes["mean_insertsize_y"] = merged_mean_insertsizes["mean_insertsize_y"] * y_scaling_factor

        # sum the scaled means
        updated_means = merged_mean_insertsizes.sum(axis=1)

        # build the updated dictionary
        updated_dict = pd.DataFrame(
            {'mean_insertsize': updated_means, 'insertsize_count': updated_counts, 'dist': combined_dists}).T.to_dict()

        return updated_dict


    def _update_dist(self, dist_1, dist_2):
        """Updates the Insertsize Distributions"""
        if not np.isnan(dist_1).any() and not np.isnan(dist_2).any():
            updated_dist = dist_1 + dist_2
            return updated_dist.astype(int)
        elif np.isnan(dist_1).any():
            return dist_2.astype(int)
        elif np.isnan(dist_2).any():
            return dist_1.astype(int)

In [None]:
def get_dist_df(dist):
    
    table_dict = {}
    for row in dist.iterrows():
        barcode = str(row[0])
        table_dict[barcode] = {}

        for i, counts in enumerate(row[1]['dist']):
            table_dict[barcode][i] = counts
    
    dist_df = pd.DataFrame(table_dict).T
    
    return dist_df

In [None]:
%%time
adata_barcodes = adata.obs.index.tolist()
# split index for barcodes CBs
barcodes = []
for entry in adata_barcodes:
    barcode = entry.split('+')[1]
    barcodes.append(barcode)

In [None]:
%%time
counter = MPFragmentCounter()
table_mp = counter.insertsize_from_fragments(fragments_file, barcodes, n_threads=10)
print(table_mp)

In [None]:
table_mp.loc['AAATCCGCATAAACGTCCCGTT']['dist'].sum()

In [None]:
%%time
table_sctoolbox = tools._insertsize_from_fragments(fragments_file, barcodes)
print(table_sctoolbox)

In [None]:
table_sctoolbox.loc['AAATCCGCATAAACGTCCCGTT']

In [None]:
table_sctoolbox.loc['AAATCCGCATAAACGTCCCGTT'][[c for c in table_sctoolbox.columns if isinstance(c, int)]].sum()

In [None]:
table_sctoolbox.loc['AAATCCGCATAAACGTCCCGTT'][[c for c in table_sctoolbox.columns if isinstance(c, int)]][0:50]

In [None]:
table_sctoolbox = table_sctoolbox[[c for c in table_sctoolbox.columns if isinstance(c, int)]]

In [None]:
table_mp = get_dist_df(table_mp)

In [None]:
table_sctoolbox.shape

In [None]:
table_mp.shape

In [None]:
sorted_table_mp = table_mp.sort_index()

In [None]:
sorted_table_mp

In [None]:
sorted_table_sctoolbox = table_sctoolbox.sort_index()

In [None]:
sorted_table_sctoolbox

In [None]:
sorted_table_mp.equals(sorted_table_sctoolbox)

In [None]:
sorted_table_mp == table_mp

In [None]:
def count_lines(filename):
    with open(filename, 'r') as file:
        return sum(1 for line in file)

In [None]:
%%time
# Replace 'yourfile.txt' with the path to your file
number_of_lines = count_lines(fragments_file)
print(f"Total number of lines: {number_of_lines}")

In [None]:
#small_fragments = '/mnt/workspace2/jdetlef/data/public_data/cropped_heart_fragments.bed'

In [None]:
small_fragments = '/home/jan/Workspace/bio_data/small_fragments.bed'

In [None]:

class MPFragmentCounter():
    """
    """

    def __init__(self):
        """Init class variables."""
        
        self.m = Manager()
        self.d = self.m.dict()
        self.d['output'] = {}
        self.lock = Lock()


        
    def _check_in_list(element: Any, alist: list[Any] | set[Any]) -> bool:
        """
        Check if element is in list.

        TODO Do we need this function?

        Parameters
        ----------
        element : Any
            Element that is checked for.
        alist : list[Any] | set[Any]
            List or set in which the element is searched for.

        Returns
        -------
        bool
            True if element is in list else False
        """

        return element in alist


    
    def _check_true(element: Any, alist: Optional[list[Any]] = None) -> bool:  # true regardless of input
        """
        Return True regardless of input

        Parameters
        ----------
        element : Any
            Element that is checked for.
        alist: Optional[list[Any]]
            List or set in which the element is searched for.

        Returns
        -------
        bool
            True if element is in list else False
        """

        return True

    
    def insertsize_from_fragments(self, fragments: str,
                                  barcodes: Optional[list[str]] = None,
                                  n_threads: int = 8) -> pd.DataFrame:
        # Open fragments file
        if _is_gz_file(fragments):
            f = gzip.open(fragments, "rt")
        else:
            f = open(fragments, "r")

        # Prepare function for checking against barcodes list
        if barcodes is not None:
            barcodes = set(barcodes)
            check_in = self._check_in_list
        else:
            check_in = self._check_true

        iterator = pd.read_csv(fragments,
                               delimiter='\t',
                               header=None,
                               names=['chr', 'start', 'stop', 'barcode', 'count'],
                               iterator=True,
                               chunksize=1000)

        # start timer
        start_time = datetime.datetime.now()

        pool = Pool(n_threads, maxtasksperchild=48)
        jobs = []
        # split fragments into chunks
        for chunk in iterator:
            # apply async job wit callback function
            job = pool.apply_async(self._count_fragments_worker, args=(chunk, barcodes, check_in))
            jobs.append(job)
        # monitor progress
        # utils.monitor_jobs(jobs, description="Progress")
        # close pool
        pool.close()
        # wait for all jobs to finish
        pool.join()
        # reset settings
        count_dict = self.d
        print('what is going on')
        print(count_dict)
        # Fill missing sizes with 0
        max_fragment_size = 1001

        for barcode in count_dict:
            for size in range(max_fragment_size):
                if size not in count_dict[barcode]:
                    count_dict[barcode][size] = 0

        # Close file and print elapsed time
        end_time = datetime.datetime.now()
        f.close()

        elapsed = end_time - start_time
        print("Done reading file - elapsed time: {0}".format(str(elapsed).split(".")[0]))

        # Convert dict to pandas dataframe
        print("Converting counts to dataframe...")
        table = pd.DataFrame.from_dict(count_dict, orient="index")
        table = table[["insertsize_count", "mean_insertsize"] + sorted(table.columns[2:])]
        table["mean_insertsize"] = table["mean_insertsize"].round(2)

        print("Done getting insertsizes from fragments!")

        return table

    
    def _count_fragments_worker(self, chunk, barcodes, check_in):
        
        count_dict = {}
        
        for i in range(len(chunk)):
            row = chunk.iloc[i]
            start = int(row['start'])
            end = int(row['stop'])
            barcode = row['barcode']
            count = int(row['count'])
            size = end - start - 9  # length of insertion (-9 due to to shifted cutting of Tn5)

            # Only add fragment if check is true
            if check_in(barcode, barcodes) is True:
                count_dict = self._add_fragment(count_dict, barcode, size, count)
                
        with self.lock:
            self.d['output'] = update_count_dict(self.d['output'], count_dict)


    def _add_fragment(count_dict: dict[str, int],
                      barcode: str,
                      size: int,
                      count: int = 1):
        """
        Add fragment of size 'size' to count_dict.

        Parameters
        ----------
        count_dict : dict[str, int]
            Dictionary containing the counts per insertsize.
        barcode : str
            Barcode of the read.
        size : int
            Insertsize to add to count_dict.
        count : int, default 1
            Number of reads to add to count_dict.
        """

        # Initialize if barcode is seen for the first time
        if barcode not in count_dict:
            count_dict[barcode] = {"mean_insertsize": 0, "insertsize_count": 0}

        # Add read to dict
        if size >= 0 and size <= 1000:  # do not save negative insertsize, and set a cap on the maximum insertsize to limit outlier effects

            count_dict[barcode]["insertsize_count"] += count

            # Update mean
            mu = count_dict[barcode]["mean_insertsize"]
            total_count = count_dict[barcode]["insertsize_count"]
            diff = (size - mu) / total_count
            count_dict[barcode]["mean_insertsize"] = mu + diff

            # Save to distribution
            if size not in count_dict[barcode]:  # first time size is seen
                count_dict[barcode][size] = 0
            count_dict[barcode][size] += count
            
        return count_dict
    

    def _log_result(self, result: Any) -> None:
        """Log results from mp_counter."""

        if self.merged_dict:
            self.merged_dict = dict(Counter(self.merged_dict) + Counter(result))
            # print('merging')
        else:
            self.merged_dict = result

In [None]:
 mpc = MPFragmentCounter()

In [None]:
%%time
counts = mpc.insertsize_from_fragments(small_fragments, barcodes)

In [None]:
some_dict = {}

In [None]:
some_dict['another'] = {'test': 'Hallo'}

In [None]:
some_dict['another']

In [None]:
count_dict={}

In [None]:
count_dict_1={}
count_dict_1['ACGTT'] = {"mean_insertsize": 10, "insertsize_count": 5, 'dist': np.array([0,1,0,2,1,1,0])}
count_dict_1['GTCCT'] = {"mean_insertsize": 10, "insertsize_count": 20, 'dist': np.array([0,0,0,1,2,2,1])}
count_dict_1['GCGCG'] = {"mean_insertsize": 10, "insertsize_count": 20, 'dist': np.array([0,0,0,1,2,2,1])}

count_dict_2={}
count_dict_2['ACGTT'] = {"mean_insertsize": 20, "insertsize_count": 20, 'dist': np.array([2,1,1,0,1,1,0])}
count_dict_2['GTCCT'] = {"mean_insertsize": 20, "insertsize_count": 5, 'dist': np.array([1,0,2,2,1,1,0])}
count_dict_2['TTTAA'] = {"mean_insertsize": 20, "insertsize_count": 5, 'dist': np.array([1,0,2,2,1,1,0])}

In [None]:
# make Dataframes for computation
df1 = pd.DataFrame(count_dict_1).T
df2 = pd.DataFrame(count_dict_2).T

# merge counts
combined_dists = df1['dist'].combine(df2['dist'], func=update_dist)

In [None]:
merged_counts = pd.merge(df1["insertsize_count"], df2["insertsize_count"], left_index=True, right_index=True, how='outer').fillna(0)


In [None]:
merged_counts

In [None]:
    # merge counts
    merged_counts = pd.merge(df1["insertsize_count"], df2["insertsize_count"], left_index=True, right_index=True, how='outer').fillna(0)
    # sum total counts/barcode
    updated_counts = merged_counts.sum(axis=1)

In [None]:
df_dists= pd.DataFrame({'combined_dists' : combined_dists})

In [None]:
df_counts = pd.DataFrame({'insertsize_counts' : updated_counts})

In [None]:
df_counts = pd.DataFrame({'insertsize_counts': {'TTTAA':20, 'ACGTT':25, 'GCGCG': 25, 'GTCCT': 5}})

In [None]:
df_counts

In [None]:
some_dict = {}

In [None]:
len(some_dict)

In [None]:
def update_count_dict(count_dict_1, count_dict_2):
    """
    updates
    """
    # Check if count_dict_1 is empty:
    if len(count_dict_1) == 0:
        return count_dict_2
        
    # make Dataframes for computation
    df1 = pd.DataFrame(count_dict_1).T
    df2 = pd.DataFrame(count_dict_2).T

    # merge distributions
    combined_dists = df1['dist'].combine(df2['dist'], func=update_dist)
    
    # merge counts
    merged_counts = pd.merge(df1["insertsize_count"], df2["insertsize_count"], left_index=True, right_index=True, how='outer').fillna(0)
    # sum total counts/barcode
    updated_counts = merged_counts.sum(axis=1)
    

    # calculate scaling factors
    x_scaling_factor = merged_counts["insertsize_count_x"] / updated_counts
    y_scaling_factor = merged_counts["insertsize_count_y"] / updated_counts

    # merge mean insertsizes
    merged_mean_insertsizes = pd.merge(df1["mean_insertsize"], df2["mean_insertsize"], left_index=True, right_index=True, how='outer').fillna(0)

    # scale mean insertsizes
    merged_mean_insertsizes["mean_insertsize_x"] = merged_mean_insertsizes["mean_insertsize_x"] * x_scaling_factor
    merged_mean_insertsizes["mean_insertsize_y"] = merged_mean_insertsizes["mean_insertsize_y"] * y_scaling_factor

    # sum the scaled means
    updated_means = merged_mean_insertsizes.sum(axis=1)

    # build the updated dictionary
    updated_dict = pd.DataFrame({'mean_insertsize': updated_means, 'insertsize_count' : updated_counts, 'dist': combined_dists}).T.to_dict()
    
    
    return updated_dict


def update_dist(dist_1, dist_2):
    """Updates the Insertsize Distributions"""
    if not np.isnan(dist_1).any() and not np.isnan(dist_2).any():
        updated_dist = dist_1 + dist_2
        return updated_dist
    elif np.isnan(dist_1).any():
        return dist_2
    elif np.isnan(dist_2).any():
        return dist_1

In [None]:
pd.DataFrame({'mean_insertsizes': updated_means, 'insertsize_counts' : updated_counts})

In [None]:
np.array([1,3,21,0]) / 10

In [None]:
merged_insertsizes = pd.merge(df1["insertsize_count"], df2["insertsize_count"], left_index=True, right_index=True)
merged

In [None]:
x_scaling_factor = merged_insertsizes["insertsize_count_x"] / merged_insertsizes.sum(axis=1)
y_scaling_factor = merged_insertsizes["insertsize_count_y"] / merged_insertsizes.sum(axis=1)

In [None]:
merged_mean_insertsizes = pd.merge(df1["mean_insertsize"], df2["mean_insertsize"], left_index=True, right_index=True)
merged_mean_insertsizes

In [None]:
merged_mean_insertsizes["mean_insertsize_x"] = merged_mean_insertsizes["mean_insertsize_x"] * x_scaling_factor
merged_mean_insertsizes["mean_insertsize_y"] = merged_mean_insertsizes["mean_insertsize_y"] * y_scaling_factor

In [None]:
merged_mean_insertsizes.sum(axis=1)

In [None]:
merged_mean_insertsizes * 

In [None]:
import pandas as pd

# Erstellen Sie zwei Beispieldatenframes
df1 = pd.DataFrame({'Werte1': [1, 2, 3]}, index=['a', 'b', 'c'])
df2 = pd.DataFrame({'Werte1': [4, 5, 6]}, index=['a', 'b', 'c'])

# Mergen Sie die DataFrames am Index
merged_df = pd.merge(df1, df2, left_index=True, right_index=True)

# Summieren Sie die Werte
summed_df = merged_df.sum(axis=1)

print(summed_df)


In [None]:
merged_df

In [None]:
%%time
count_table = tools._insertsize_from_fragments(small_fragments, barcodes)

In [None]:
@beartype
def _insertsize_from_fragments(fragments: str,
                               barcodes: Optional[list[str]] = None) -> pd.DataFrame:
    """
    Get fragment insertsize distributions per barcode from fragments file.

    Parameters
    ----------
    fragments : str
        Path to fragments.bed(.gz) file.
    barcodes : Optional[list[str]], default None
        Only collect fragment sizes for the barcodes in barcodes

    Returns
    -------
    pd.DataFrame
        DataFrame with insertsize distributions per barcode.
    """

    # Open fragments file
    if utils._is_gz_file(fragments):
        f = gzip.open(fragments, "rt")
    else:
        f = open(fragments, "r")

    # Prepare function for checking against barcodes list
    if barcodes is not None:
        barcodes = set(barcodes)
        check_in = _check_in_list
    else:
        check_in = _check_true

    # Read fragments file and add to dict
    print("Counting fragment lengths from fragments file...")
    start_time = datetime.datetime.now()
    count_dict = {}
    for line in f:
        columns = line.rstrip().split("\t")
        start = int(columns[1])
        end = int(columns[2])
        barcode = columns[3]
        count = int(columns[4])
        size = end - start - 9  # length of insertion (-9 due to to shifted cutting of Tn5)

        # Only add fragment if check is true
        if check_in(barcode, barcodes) is True:
            count_dict = _add_fragment(count_dict, barcode, size, count)

    # Fill missing sizes with 0
    max_fragment_size = 1001

    for barcode in count_dict:
        for size in range(max_fragment_size):
            if size not in count_dict[barcode]:
                count_dict[barcode][size] = 0

    # Close file and print elapsed time
    end_time = datetime.datetime.now()
    elapsed = end_time - start_time
    f.close()
    print("Done reading file - elapsed time: {0}".format(str(elapsed).split(".")[0]))

    # Convert dict to pandas dataframe
    print("Converting counts to dataframe...")
    table = pd.DataFrame.from_dict(count_dict, orient="index")
    table = table[["insertsize_count", "mean_insertsize"] + sorted(table.columns[2:])]
    table["mean_insertsize"] = table["mean_insertsize"].round(2)

    print("Done getting insertsizes from fragments!")

    return table

In [None]:
@beartype
def _add_fragment(count_dict: dict[str, int],
                  barcode: str,
                  size: int,
                  count: int = 1) -> dict[str, int]:
    """
    Add fragment of size 'size' to count_dict.

    Parameters
    ----------
    count_dict : dict[str, int]
        Dictionary containing the counts per insertsize.
    barcode : str
        Barcode of the read.
    size : int
        Insertsize to add to count_dict.
    count : int, default 1
        Number of reads to add to count_dict.

    Returns
    -------
    dict[str, int]
        Updated count_dict
    """

    # Initialize if barcode is seen for the first time
    if barcode not in count_dict:
        count_dict[barcode] = {"mean_insertsize": 0, "insertsize_count": 0}

    # Add read to dict
    if size >= 0 and size <= 1000:  # do not save negative insertsize, and set a cap on the maximum insertsize to limit outlier effects

        count_dict[barcode]["insertsize_count"] += count

        # Update mean
        mu = count_dict[barcode]["mean_insertsize"]
        total_count = count_dict[barcode]["insertsize_count"]
        diff = (size - mu) / total_count
        count_dict[barcode]["mean_insertsize"] = mu + diff

        # Save to distribution
        if size not in count_dict[barcode]:  # first time size is seen
            count_dict[barcode][size] = 0
        count_dict[barcode][size] += count

    return count_dict

# HELPERS

In [None]:
@beartype
def _is_gz_file(filepath: str) -> bool:
    """
    Check wheather file is a compressed .gz file.

    Parameters
    ----------
    filepath : str
        Path to file.

    Returns
    -------
    bool
        True if the file is a compressed .gz file.
    """

    with open(filepath, 'rb') as test_f:
        return test_f.read(2) == b'\x1f\x8b'

In [None]:
@beartype
def gunzip_file(f_in: str, f_out: str) -> None:
    """
    Decompress file.

    Parameters
    ----------
    f_in : str
        Path to compressed input file.
    f_out : str
        Destination to decompressed output file.
    """

    with gzip.open(f_in, 'rb') as h_in:
        with open(f_out, 'wb') as h_out:
            shutil.copyfileobj(h_in, h_out)

In [None]:
iterator = pd.read_csv(fragments_file,
                       delimiter='\t',
                       header=None,
                       names=['chr', 'start', 'stop', 'barcode', 'count'],
                       iterator=True,
                       chunksize=100000)
updated = {}

In [None]:
chunk = next(iterator)

In [None]:
chunk

In [None]:
%%time
check_in = _check_true

count_dict = {}

for row in chunk.itertuples():
    start = int(row[2])
    end = int(row[3])
    barcode = row[4]
    count = int(row[5])
    size = end - start - 9  # length of insertion (-9 due to to shifted cutting of Tn5)

    # Only add fragment if check is true
    if check_in(barcode, barcodes) is True:
        count_dict = _add_fragment(count_dict, barcode, size, count)
        

updated = update_count_dict(updated, count_dict)

In [None]:
def wrap_add_fragments(row, count_dict):
    start = int(row[1])
    end = int(row[2])
    barcode = str(row[3])
    count = int(row[4])
    size = end - start - 9 

    if check_in(barcode, barcodes) is True:
        result = _add_fragment(count_dict, barcode, size, count)

In [None]:
count_dict = {}

In [None]:
%%time
_ = chunk.apply(lambda row: wrap_add_fragments(row, count_dict), axis=1)

In [None]:
count_dict

In [None]:
%%time
check_in = _check_true

count_dict = {}

for i in range(len(chunk)):
    row = chunk.iloc[i]
    start = int(row['start'])
    end = int(row['stop'])
    barcode = row['barcode']
    count = int(row['count'])
    size = end - start - 9  # length of insertion (-9 due to to shifted cutting of Tn5)

    # Only add fragment if check is true
    if check_in(barcode, barcodes) is True:
        count_dict = _add_fragment(count_dict, barcode, size, count)
        

updated = update_count_dict(updated, count_dict)

In [None]:
len(count_dict)

In [None]:
len(updated)

In [None]:
pd.DataFrame(updated)

In [None]:
updated = {}

In [None]:
pd.DataFrame(count_dict)

In [None]:
df = pd.DataFrame(count_dict).T

In [None]:
df['dist']

In [None]:
    def _count_fragments_worker(self, chunk, barcodes, check_in):
        
        count_dict = {}
        
        for i in range(len(chunk)):
            row = chunk.iloc[i]
            start = int(row['start'])
            end = int(row['stop'])
            barcode = row['barcode']
            count = int(row['count'])
            size = end - start - 9  # length of insertion (-9 due to to shifted cutting of Tn5)

            # Only add fragment if check is true
            if check_in(barcode, barcodes) is True:
                count_dict = self._add_fragment(count_dict, barcode, size, count)
                
        with self.lock:
            self.d = update_count_dict(self.d, count_dict)
            
    def _check_true(element: Any, alist: Optional[list[Any]] = None) -> bool:  # true regardless of input

        return True

    def _add_fragment(count_dict: dict[str, int],
                      barcode: str,
                      size: int,
                      count: int = 1,
                      max_size=1000):

        # Initialize if barcode is seen for the first time
        if barcode not in count_dict:
            count_dict[barcode] = {"mean_insertsize": 0, "insertsize_count": 0}

        # Add read to dict
        if size >= 0 and size <= max_size:  # do not save negative insertsize, and set a cap on the maximum insertsize to limit outlier effects

            count_dict[barcode]["insertsize_count"] += count

            # Update mean
            mu = count_dict[barcode]["mean_insertsize"]
            total_count = count_dict[barcode]["insertsize_count"]
            diff = (size - mu) / total_count
            count_dict[barcode]["mean_insertsize"] = mu + diff

            # Save to distribution
            if size not in count_dict[barcode]:  # first time size is seen
                sizes = np.arange(0,max_size+1)
                count_dict[barcode]['dist'] = np.zeros(max_size)
            count_dict[barcode]['dist'][size] += count
            
        return count_dict

In [None]:
np.arange(0,1001) + np.arange(0,1001)

In [None]:
import numpy as np

In [None]:
count_dict[']