# Overview

The purpose of this notebook is to perform preliminary sample quality control (QC) in order to identify and/or remove outlier samples in preparation for the batched steps of the GATK-SV pipeline.

This notebook allows the user to analyze sample QC metrics produced by the *EvidenceQC* workflow, such as median coverage and whole genome dosage (WGD). After visualizing the distributions of these metrics, the user can select outlier cutoffs to filter or flag outlier samples. At the end, a file of sample QC metrics for passing samples is produced, which can be used for batching.

**Suggested VM Specifications**:
* Application Configuration: Default
* CPUs: 2
* Memory: 13 GB
* Persistent Disk: 100 GB

**Prerequisites**: GatherSampleEvidence, EvidenceQC.

**Next Steps**: Batching, TrainGCNV.

**Legend**:
<div class="alert alert-block alert-info"> <b>Blue Boxes for One-Time Runs</b>: Uncomment and run the code cell directly below just once, then re-comment the cell to skip it next time. These cells typically save intermediate outputs locally so that the notebook can be paused and picked back up without a significant delay.</div>
<div class="alert alert-block alert-success"> <b>Green Boxes for User Inputs</b>: Edit the inputs provided in the code cell directly below. The inputs that are editable are defined in all capitals, and their descriptions can be found in the section headers. </div>

**Execution Tips**:
* The first time you start this notebook (one time only), you will need to uncomment and run the package installation cell under *Imports*.
* Once the packages are installed, to quickly run all the cells containing helper functions, constants, and imports, skip to the first cell of *Data Ingestion*, click "Cell" in the toolbar at the top of the notebook, and select "Run All Above." Then, starting from *Data Ingestion*, proceed step-by-step through the notebook.
* The keyboard shortcut to run a cell is `Shift`+`Return`.
* The keyboard shortcut to comment or uncomment an enitre cell is `Command`+`/` or `Control`+`/`.

# Imports
This section defines all imports required by this notebook.

<div class="alert alert-block alert-info">Uncomment and run once. It is not necessary to reinstall these packages each time you restart your cloud environment.</div>    

In [None]:
# ! pip install upsetplot

In [None]:
# Package imports
import os
import io
import re
import logging
import shutil
import subprocess
import zipfile
from collections import defaultdict
from logging import INFO

# Aliased imports
import pandas as pd
import numpy as np
import seaborn as sns
import firecloud.api as fapi
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib_inline.backend_inline

# Plotting imports
from upsetplot import UpSet
from upsetplot import from_memberships
from IPython.display import IFrame
from matplotlib.lines import Line2D
from matplotlib.collections import LineCollection
import matplotlib.ticker
from PIL import Image

# Plotting settings 
default_dpi = plt.rcParams['figure.dpi']  
plt.rcParams['figure.dpi'] = 200
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Logger settings
logger = logging.getLogger()
logger.setLevel(INFO)

# Constants
This section declares constants that are used throughout the notebook.

In [None]:
# Workspace-level
TLD_PATH = 'evidence_qc'
PROJECT = os.environ['GOOGLE_PROJECT']
WORKSPACE = os.environ['WORKSPACE_NAME']
WS_BUCKET = os.environ['WORKSPACE_BUCKET']
NAMESPACE = os.environ['WORKSPACE_NAMESPACE']

# Autosomal copy number
cols_autosome = [
    'sample_id', 'chr1_CopyNumber', 'chr2_CopyNumber', 
    'chr3_CopyNumber', 'chr4_CopyNumber', 'chr5_CopyNumber', 
    'chr6_CopyNumber', 'chr7_CopyNumber', 'chr8_CopyNumber', 
    'chr9_CopyNumber', 'chr10_CopyNumber', 'chr11_CopyNumber', 
    'chr12_CopyNumber', 'chr13_CopyNumber', 'chr14_CopyNumber', 
    'chr15_CopyNumber', 'chr16_CopyNumber', 'chr17_CopyNumber', 
    'chr18_CopyNumber', 'chr19_CopyNumber', 'chr20_CopyNumber', 
    'chr21_CopyNumber', 'chr22_CopyNumber'
]

# PED file validation
ID_TYPE_SAMPLE = "sample"
ID_TYPE_FAMILY = "family"
ID_TYPE_PARENT = "parent"
FIELD_NUMBER_ID = 1
FIELD_NUMBER_SEX = 4
ILLEGAL_ID_SUBSTRINGS = ["chr", "name", "DEL", "DUP", "CPX", "CHROM"]

# Sex assignment to numerical entry in PED file
SEX_CODES = {
    "MALE": 1, "FEMALE": 2, "MOSAIC": 0, "TURNER": 0, 
    "TRIPLE X": 0, "KLINEFELTER": 0, "JACOBS": 0, "OTHER": 0
}

# Helper Functions
This section instantiates helper functions used throughout the notebook.

## File System Functions

In [None]:
def copy_local_file_to_gs(ws_bucket, file_path, print_path=True):
    """
    Copies a file saved locally to Google Cloud Storage
    
    Args:
        ws_bucket (str): The bucket in Google Cloud Storage where the file should be saved.
        file_path (str): The path where the file should be saved.
        print_path (bool): Determines whether to print the Google Cloud Storage path of the file created.
    
    Returns:
        None.
    """
    gcs_path = f"{ws_bucket}/{file_path}"
    
    try:
        subprocess.run(
            ["gsutil", "cp", "-r", file_path, gcs_path],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            check=True
        )
    except subprocess.CalledProcessError as e:
        raise Exception(f"Error copying file to GCS: {e}")
    
    if (print_path):
        print(f"GCS file path: {gcs_path}")

In [None]:
def generate_file_path(tld_path, file_type, file_name):
    """
    Calculate the file path of a file to save in either the file system or Google Cloud Storage.

    Args:
        tld_path (str): Top-level directory path.
        file_type (str): Enables generation of the specific sub-directory which a file should live in.
        file_name (str): File name to chain at the end of the path.

    Returns:
        str: Path to file as it should be saved, per file system outline.
    """
    if 'outlier' in file_type:
        full_path = os.path.join(tld_path, 'raw_caller_outliers', file_type, file_name)
    else:
        full_path = os.path.join(tld_path, file_type, file_name)
    
    return full_path

In [None]:
def save_df(ws_bucket, file_path, df, print_path=True, header=True):
    """
    Save a dataframe to the specified file path, creating directories if they don't exist.
    
    Args:
        ws_bucket (str): The bucket in Google Cloud Storage where the file should be saved.
        file_path (str): The path where the figure should be saved.
        df (pandas.DataFrame): The dataframe to save.
        print_path (bool): Determines whether to print the Google Cloud Storage path of the file created.
    
    Returns:
        None.
    """
    dir_path = os.path.dirname(file_path)
    os.makedirs(dir_path, exist_ok=True)
    
    df.to_csv(file_path, sep='\t', index=False, header=header)
    copy_local_file_to_gs(ws_bucket, file_path, print_path)
    
    print(f"Dataframe saved to: {file_path}")

In [None]:
def save_figure(ws_bucket, file_path):
    """
    Saves the current figure to the specified file path, creating directories if they don't exist.
    
    Args:
        ws_bucket (str): The bucket in Google Cloud Storage where the file should be saved.
        file_path (str): The path where the figure should be saved.    
    Returns:
        None.
    """
    dir_path = os.path.dirname(file_path)
    os.makedirs(dir_path, exist_ok=True)
    
    plt.savefig(file_path)
    copy_local_file_to_gs(ws_bucket, file_path)
    
    print(f"Figure saved to: {file_path}")

## Processing Functions

In [None]:
def get_batch_for_sample_id(sample_id):
    """
    Function to retrieve the batch associated with a given sample ID.

    Args:
        sample_id (str): The sample ID to get the batch for.
        
    Returns:
        str: Batch ID corresponding to searched sample.
    """
    for batch, samples in zip(sample_set_tbl['entity:sample_set_id'].values, sample_set_tbl.samples.values):
        if sample_id in samples:
            return batch

In [None]:
# Dictionary to store samples and their associated reasons for removal
samples_to_remove = {}  
    
def remove_samples(sample_ids, filter_type):
    """
    Function to remove samples based on given standard sample IDs and reasons.

    Args:
        sample_ids (list or iterable): A list of standard sample IDs to be removed.
        filter_type (str): The type of filter that caused a sample to be removed.
        
    Returns:
        None.
    """
    # Reset the filter if it has already been applied
    samples_to_remove[filter_type] = set()
    
    # Add each standard sample ID to exclusions from the given filter
    for sample_id in sample_ids:
        samples_to_remove[filter_type].add(sample_id)

In [None]:
def invert_removed_samples(samples_to_remove):
    """
    Inverts the key-value orientation of a sample dictionary.

    Args:
        samples_to_remove (dict): A dictionary with reasons as keys and sets of samples as values.

    Returns:
        dict: A dictionary with samples as keys and sets of reasons as values.
    """    
    inv_samples_to_remove = {}
    
    for reason, samples in samples_to_remove.items():
        for sample in samples:
            if sample not in inv_samples_to_remove:
                inv_samples_to_remove[sample] = set()
            inv_samples_to_remove[sample].add(reason)
        
    return inv_samples_to_remove

In [None]:
def filter_and_save_metadata(samples_qc_table, samples_to_remove, file_path):
    """
    Filter a metadata table based on a list of samples to remove, and then save the filtered table.

    Args:
        samples_qc_table (pandas.DataFrame): DataFrame containing sample metadata.
        samples_to_remove (dict): Dictionary with samples to be removed as keys.
        file_path (str): File path to save metadata at.

    Returns:
        str: Full path to the saved file.
    """
    # Convert the keys of 'samples_to_remove' to a set to remove any duplicates, and then back to a list
    final_list_to_remove = list(set(samples_to_remove.keys()))
    
    # Select relevant columns from the 'samples_qc_table' DataFrame
    all_meta = samples_qc_table[[
        'sample_id', 'mean_insert_size','wgd_score', 
        'median_coverage', 'chrX_CopyNumber_rounded'
    ]]

    # Remove rows corresponding to filtered samples
    pass_meta = all_meta[~all_meta.sample_id.isin(final_list_to_remove)]
    
    # Save metadata file
    save_df(WS_BUCKET, file_path, pass_meta, print_path=True)

## Validation Functions

In [None]:
def validate_numeric_input(*args, log=True):
    """
    Validates user input to check whether it is numeric or not. 
    
    Args:
        *args (numeric): Any number of positional arguments to validate.
    
    Returns:
        None.
    """
    for arg in args:
        if not isinstance(arg, (int, float)):
            raise Exception('Value input must be numeric.')
    
    if log:
        print("Inputs are valid - please proceed to the next cell.")

In [None]:
def validate_table_filter(table, filter_type, log=True):
    """
    Validates user input to check a particular filter is valid for a given table. 
    
    Args:
        table (pd.DataFrame): Sample table to check.
        filter_type: 
    
    Returns:
        None.
    """
    if (filter_type not in table):
        raise Exception(f"'{filter_type}' is not a valid column in the provided dataframe - skip this QC step.")
        
    if (table[filter_type].notna().sum() <= 0 or table[filter_type].empty):
        raise Exception(f"'{filter_type}' has no non-null values in the provided dataframe - skip this QC step.")

In [None]:
def validate_qc_inputs(table, filter_type, line_deviations=None, line_styles=None, method=None, 
                       lower_cutoff=None, upper_cutoff=None, mad_cutoff=None, caller=None, 
                       caller_type=None, log_scale=None):
    """
    Validates user inputs for QC tests.
    
    Args:
        table (pd.DataFrame): Sample table.
        filter_type (str): QC filter being checked.
        line_deviations (list): List of line deviations as input by user.
        line_styles (list): List of line styles as input by user.
        method (str): Method used for filtering as input by user.
        lower_cutoff (numeric): Lower cutoff threshold as input by user.
        upper_cutoff (numeric): Upper cutoff threshold as input by user.
        mad_cutoff (numeric): MAD cutoff threshold as input by user.
        caller (str): Type of caller used as input by user.
        caller_type (str): Determines whether to look for high or low outliers as input by user.
        log_scale (bool): Determines whether to log-scale the plot as input by user.
    
    Returns:
        None.
    """
    validate_table_filter(table, filter_type, log=False)
        
    if log_scale:
        if (not isinstance(log_scale, bool)):
            raise Exception('LOG_SCALE must be a boolean value.')
    
    if line_deviations:
        if (not isinstance(line_deviations, list)):
            raise Exception('LINE_DEVIATIONS must be a list.')
        
        for ld in line_deviations:
            validate_numeric_input(ld, log=False)
    
    if line_styles:
        if (not isinstance(line_styles, list)):
            raise Exception('LINE_STYLES must be a list of strings.')
            
        for ls in line_styles:
            if (ls not in ['solid', 'dotted', 'dashed', 'dashdot']):
                raise Exception(f'The value {ls} is invalid - it must be one of "solid", "dotted", "dashed" or "dashdot".')
    
    if line_deviations and line_styles:
        if (len(line_deviations) != len(line_styles)):
            raise Exception('The number of cutoffs provided should match the number of line styles input.')
        
    if method:
        if (method not in ['MAD', 'hard']):
            raise Exception('The value for the filter method must be one of "MAD" or "hard".')
        
        if (method == 'hard'):
            if (lower_cutoff and not isinstance(lower_cutoff, (int, float))):
                raise Exception('Given that the chosen method is "hard", the value for the lower cutoff should be numeric.')

            if (upper_cutoff and not isinstance(upper_cutoff, (int, float))):
                raise Exception('Given that the chosen method is "hard", the value for the upper cutoff should be numeric.')
                
            if (not lower_cutoff):
                print('[WARNING] Setting LOWER_CUTOFF to None results in no lower cutoff being applied.')
                
            if (not upper_cutoff):
                print('[WARNING] Setting UPPER_CUTOFF to None results in no upper cutoff being applied.')
            
        if (method == 'MAD'):
            if (mad_cutoff and not isinstance(mad_cutoff, (int, float))):
                raise Exception('The value for the MAD cutoff should be numeric.')
                
            if (not mad_cutoff):
                print('[WARNING] Setting MAD_CUTOFF to None results in no lower cutoff being applied.')

    if (caller and caller not in ['overall', 'manta', 'melt', 'scramble', 'scramble', 'wham']):
        raise Exception(f'The value {caller} for category is invalid - it must be one of "overall", "manta", "melt", "scramble" or "wham".')

    if (caller_type and caller_type not in ['high', 'low']):
        raise Exception(f'The value {caller_type} for caller type is invalid - it must be one of "high" or "low".')
        
        
    print("Inputs are valid - please proceed to the next cell.")

In [None]:
def validate_unique_samples(samples_df):
    """
    Validates user input to check whether it is numeric or not. 
    
    Args:
        samples_df (pandas.DataFrame): Contains sample data
    
    Returns:
        None.
    """
    id_counts = samples_df['sample_id'].value_counts()
    duplicates_dict = id_counts[id_counts > 1].to_dict()

    if (len(duplicates_dict) > 0):
        print(f"{len(duplicates_dict)} duplicate samples exist in the dataset.")
        for sample_id, count in duplicates_dict.items():
            print(f"Sample ID: {sample_id}, Count: {count}")
        raise Exception("QC requires unique samples - please resolve duplicates before proceeding.")

    print("No duplicates found - please proceed to the next cell.")

In [None]:
def validate_id(identifier, id_type, source_file):
    """
    Validates sample IDs provided based on a source file of samples.
    
    Args:
        identifier (str): ID for a given sample.
        id_type (str): Type of ID provided.
        source_file (str): File that contains all samples.
    
    Returns:
        None.
    """
    # Check for empty IDs
    if identifier is None or identifier == "":
        raise ValueError(f"Empty {id_type} ID in {source_file}.")

    # Check all characters are alphanumeric or underscore
    if not re.match(r'^\w+$', identifier):
        raise ValueError(f"Invalid {id_type} ID in {source_file}: '{identifier}'." + 
                         "IDs should only contain alphanumeric and underscore characters.")

    # Check for all-numeric IDs, besides maternal & paternal ID (can be 0) and all-numeric family IDs
    if id_type != ID_TYPE_FAMILY and not (id_type == ID_TYPE_PARENT and identifier == "0") and identifier.isdigit():
        raise ValueError(f"Invalid {id_type} ID in {source_file}: {identifier}. " +
                         "IDs should not contain only numeric characters.")

    # Check for illegal substrings
    for sub in ILLEGAL_ID_SUBSTRINGS:
        if sub in identifier:
            raise ValueError(f"Invalid {id_type} ID in {source_file}: {identifier}. " +
                             f"IDs cannot contain the following substrings: {', '.join(ILLEGAL_ID_SUBSTRINGS)}.")

In [None]:
def validate_ped(ped_file, samples):
    """
    Validates structure and data within PED file based on series of samples provided.
    Works with both local and GCS files.
    
    Args:
        ped_file (str): Path to PED file (local or GCS path starting with 'gs://').
        samples (set): Set of sample IDs to validate against.
    
    Returns:
        None
    """
    seen_sex_1 = False
    seen_sex_2 = False
    samples_found = set()

    # Read PED file
    try:
        df = pd.read_table(ped_file, dtype=str, header=None, comment='#', names=[
            'Family_ID', 'Sample_ID', 'Paternal_ID', 'Maternal_ID', 'Sex', 'Phenotype'
        ])
    except Exception as e:
        raise ValueError(f"Error reading PED file: {str(e)}")
    
    # Ensure column count
    if len(df.columns) != 6:
        raise ValueError("PED file must have 6 columns: Family_ID, Sample_ID, " +
                         "Paternal_ID, Maternal_ID, Sex, Phenotype.")

    # Iteratively validate each row
    for _, row in df.iterrows():
        # Validate ID
        for identifier, id_type in zip(row[:FIELD_NUMBER_SEX],
                                       [ID_TYPE_FAMILY, ID_TYPE_SAMPLE, ID_TYPE_PARENT, ID_TYPE_PARENT]):
            validate_id(identifier, id_type, "PED file")

        # Assign main information to variables
        sample_id = row['Sample_ID']
        sex = int(row['Sex'])

        # Check for appearance of each sex
        if sex == 1:
            seen_sex_1 = True
        elif sex == 2:
            seen_sex_2 = True
        elif sex != 0:
            raise ValueError(f"Sample {sample_id} has an invalid value for sex: {sex}. " +
                             "PED file must use the following values for sex: " + 
                             "1 for Male, 2 for Female, 0 for Unknown/Other.")

        # Verify no duplications
        if sample_id in samples_found:
            raise ValueError(f"Duplicate entries for sample {sample_id}.")
        elif sample_id in samples:
            samples_found.add(sample_id)

    # Check if all samples in the sample list are present in PED file
    if len(samples_found) < len(samples):
        missing = samples - samples_found
        raise ValueError(f"PED file is missing sample(s): {','.join(missing)}.")

    # Raise error if at least one of either sex is not found
    if not (seen_sex_2 and seen_sex_1):
        raise ValueError("Did not find existence of multiple sexes in file. "  +
                         "PED file must use the following values for sex: " + 
                         "1 for Male, 2 for Female, 0 for Unknown/Other.")
    
    print("PED file is valid - please proceed to the next cell.")

## QC Functions

### General Functions

In [None]:
def calculate_mad_filter_bounds(filter_type_data, filter_type_cut_off, filter_type):
    """
    Calculate the lower and upper bounds for filtering based on the Median Absolute Deviation (MAD) method.

    Args:
        filter_type_data (numpy.array): Data for which MAD filter bounds will be calculated.
        filter_type_cut_off (float): The cutoff value for the filter. It determines how many MADs away 
                                     from the median to set the filter bounds.
        filter_type (str): The type of the filter or data for which MAD bounds are calculated.

    Returns:
        tuple: Contains the MAD value, median value, lower filter bound, and upper filter bound.
    """
    # Check if the cut_off value is less than 0
    if filter_type_cut_off < 0:
        raise Exception("Invalid cutoff - please ensure that the cutoff is greater than or equal to 0.")
    
    # Calculate the median and the Median Absolute Deviation (MAD) for the filter type data
    median_filter_type = np.median(filter_type_data)
    mad_filter_type = np.median(np.absolute(filter_type_data - median_filter_type))

    # Calculate the lower and upper filter bounds based on the MAD and the cutoff value
    filter_type_lower_cff = median_filter_type - float(filter_type_cut_off) * mad_filter_type
    filter_type_upper_cff = median_filter_type + float(filter_type_cut_off) * mad_filter_type

    # Return the MAD value, median value, and filter bounds as a tuple
    return mad_filter_type, median_filter_type, filter_type_lower_cff, filter_type_upper_cff

In [None]:
def plot_histogram(filter_type_data, filter_type, filter_name, cutoffs=None, line_styles=None, log_scale=False, **kwargs):
    """
    Plots a histogram for a given data array and saves the plot as an image.

    Args:
        filter_type_data (np.array): The data to plot the histogram for.
        filter_type (str): The type of data being plotted (e.g., "age", "height", etc).
        filter_name (str): The name of the metric being plotted.
        cutoffs (list): List of cutoff values for which to plot vertical lines.
        line_styles (list): List of line styles corresponding to each cutoff.
        log_scale (bool): Defines whether to log-scale the plot.
        **kwargs: Additional plotting parameters for matplotlib.

    Returns:
        str: Path to saved file containing plot.
    """    
    # Validation
    if len(filter_type_data) == 0:
        print(f"The {filter_type} data is empty. No plot will be generated.")
        return None

    # Calculate the median and MAD
    median_filter_type = np.median(filter_type_data)
    mad_filter_type = np.median(np.absolute(filter_type_data - median_filter_type))

    # Print summary values
    print(f'Median: {median_filter_type:.5f}')
    print(f'MAD: {mad_filter_type:.5f}')
    print(f'Minimum: {min(filter_type_data):.5f}')
    print(f'Maximum: {max(filter_type_data):.5f}')
    print()

    # Create new figure
    fig = plt.figure(figsize=(8, 6))
    fig.patch.set_facecolor('white')

    # Plot a histogram
    bins = kwargs.pop('bins', 50)
    plt.hist(filter_type_data, edgecolor='black', bins=bins, **kwargs)
    if log_scale:
        plt.gca().set_yscale('log')
        plt.gca().yaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter())
        plt.gca().yaxis.get_major_formatter().set_scientific(False)
        plt.gca().yaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())
    plt.gca().set_facecolor('white')
    plt.grid(False)
    plt.gca().yaxis.grid(True, zorder=0, color='lightgrey')

    # Set the title and axes
    plt.title(f"{filter_name} - {len(filter_type_data)} Samples (Filtered)", fontsize=16)
    plt.xlabel(f"{filter_name}", fontsize=12)
    plt.ylabel("Sample Count (Log Scale)" if log_scale else "Sample Count", fontsize=12)

    # Additional plots
    plt.axvline(x=np.median(filter_type_data), color='black', linestyle='--')
    legend_handles = [Line2D([0], [0], alpha=1, color='black', linestyle='--', label='Median')]
    if cutoffs is not None and line_styles is not None:
        for cutoff, line_style in zip(cutoffs, line_styles):
            plt.axvline(x=median_filter_type + cutoff * mad_filter_type, linestyle=line_style, color='grey', label=f'Median + {cutoff}*MAD')
            plt.axvline(x=median_filter_type - cutoff * mad_filter_type, linestyle=line_style, color='grey', label=f'Median - {cutoff}*MAD')
            legend_handles.append(Line2D([0], [0], alpha=1, linestyle=line_style, color='grey', label=f'Median +- {cutoff}*MAD'))
    plt.legend(handles=legend_handles, framealpha=1, labelspacing=0.5)

    # Save and close figure
    file_name = f"filtered_histogram_{len(filter_type_data)}_samples.png"
    if cutoffs is not None:
        file_name = f"analysis_histogram_{len(filter_type_data)}_samples.png"
    file_path = generate_file_path(TLD_PATH, filter_type, file_name)
    save_figure(WS_BUCKET, file_path)
    plt.close()

    return file_path

In [None]:
def run_analysis(table, filter_type, cutoffs=None, line_styles=None, read_length=None, log_scale=False, **kwargs):
    """
    Analyse data based on a series of parameters provided.

    Args:
        table (pandas.DataFrame): Pandas dataframe from which data is extracted.
        filter_type (str): Type of the filter being plotted.
        cutoffs (list): List of cutoff values for which to plot vertical lines.
        line_styles (list): List of line styles corresponding to each cutoff.
        read_length (numeric): Read length, from which coverage coefficient can be calculated.
        log_scale (bool): Defines whether to log-scale the histogram produced.
        **kwargs: Additional plotting parameters to be passed to the plotting function.

    Returns:
        None.
    """
    # Get data for specific QC metric
    filter_data = table[filter_type].values

    # Multiply median_coverage by coefficient for median coverage metric
    if filter_type == 'median_coverage':
        if not isinstance(read_length, (int, float)) or read_length is None:
            raise Exception(f"A numeric value must be provided for READ_LENGTH.")
        filter_data = filter_data * (read_length / 100)

    # Remove samples that are NaN for outlier metrics
    if 'outlier' in filter_type:
        filter_data = np.nan_to_num(filter_data, nan=0.0)

    # Plot figure
    filter_name = ' '.join(word.capitalize() for word in filter_type.split('_'))
    filter_name = filter_name.replace('Wgd', 'WGD').replace('Nondiploid', 'Non-Diploid')
    file_png = plot_histogram(filter_data, filter_type, filter_name, log_scale=log_scale, cutoffs=cutoffs, line_styles=line_styles, **kwargs)

    # Display plot
    if display:
        plt.figure(figsize=(8, 8))
        img = mpimg.imread(file_png)
        plt.imshow(img)
        plt.axis('off')
        plt.show()


In [None]:
def run_filtering(table, filter_type, filter_method, lower_cutoff=float('-inf'), upper_cutoff=float('inf'), mad_cutoff=None, read_length=None, log_scale=False, **kwargs):
    """
    Filter data based on a given filter type within specified lower and upper limits.

    Args:
        table (pandas.DataFrame): A DataFrame containing the data to be filtered.
        filter_type (str): The type of the filter or data column to be filtered.
        filter_method (str): The filter method to use to filter.
        lower_cutoff (numeric): The lower cutoff for the filter.
        upper_cutoff (numeric): The upper cutoff for the filter.
        mad_cutoff (int): The MAD cutoff for the filter.
        read_length (numeric): Read length, from which coverage coefficient can be calculated.
        log_scale (bool): Defines whether to log-scale the histogram produced.
        **kwargs: Additional plotting parameters to be passed to the plotting function.

    Returns:
        None.
    """
    # Set defaults if None passed
    if not lower_cutoff:
        lower_cutoff = float('-inf')
        
    if not upper_cutoff:
        upper_cutoff = float('inf')
        
    if not mad_cutoff:
        mad_cutoff = float('inf')
    
    # Validation
    if filter_method == 'hard' and lower_cutoff and upper_cutoff and not lower_cutoff <= upper_cutoff:
        raise Exception("Invalid cutoff - please ensure that the lower cutoff is less than or equal to the upper cutoff.")
    
    # Capture specified data being filtered on
    filter_data = table[filter_type].values
    
    # Multiply median_coverage by coefficient for median coverage metric
    if (filter_type == 'median_coverage'):
        if (not isinstance(read_length, (int, float)) or read_length is None):
            raise Exception(f"A numeric value must be provided for READ_LENGTH.")
        filter_data = filter_data * (read_length / 100)
    
    # Calculate filter bounds if MAD selected as cutoff method
    if (filter_method == 'MAD'):
        filter_mad, filter_median, lower_cutoff, upper_cutoff = calculate_mad_filter_bounds(filter_data, mad_cutoff, filter_type)

    # Filter the data based on the lower and upper limits
    filter_pass = table[((filter_data >= lower_cutoff) & (filter_data <= upper_cutoff))]
    failed_sample_ids = table[((filter_data < lower_cutoff) | (filter_data > upper_cutoff))].sample_id.values
        
    # Log intermediete outputs
    print(f'Passing samples: {len(filter_pass)}')
    print(f'Failing samples: {len(failed_sample_ids)}')
    print()
    
    # Remove failed samples
    remove_samples(failed_sample_ids, filter_type)

    # Validation
    if (len(filter_pass) <= 0):
        print(f'No samples passed this QC step, so a figure cannot be displayed.')
        return
    
    # Create figure name
    filter_name = ' '.join(word.capitalize() for word in filter_type.split('_'))
    filter_name = filter_name.replace('Wgd', 'WGD')
    filter_name = filter_name.replace('Nondiploid', 'Non-Diploid')
    
    # Create plot
    filter_pass_data = filter_pass[filter_type]
    if (filter_type == 'median_coverage'):
        filter_pass_data = filter_pass_data * (read_length / 100)
    file_png = plot_histogram(filter_pass_data, filter_type, filter_name, log_scale=log_scale, **kwargs)
    
    # Display plot
    if display:
        plt.figure(figsize=(8, 8))
        img = mpimg.imread(file_png)
        plt.imshow(img)
        plt.axis('off')
        plt.show()

In [None]:
def create_upset_plot(filter_df, num_samples_to_be_removed):
    """
    Create an upset plot based on the provided DataFrame and save it as a .png file.

    Args:
        filter_df (pandas.DataFrame): A DataFrame containing the filtered samples, with columns 
                                      'sample_id' and 'filters_applied'.
        num_samples_to_be_removed (int): The total number of samples to be removed.

    Returns:
        None.
    """
    # Catch failing case - no samples to drop
    if (len(filter_df) == 0):
        print("UpSet plot cannot be generated - no samples are dropped.")
        return
    
    # Group by 'filters_applied' and get the count for each group
    grouped_new = filter_df.groupby("filters_applied").count()
    
    # Plot an upset or bar plot based on number of failing filters
    if len(grouped_new) == 1:
        print("UpSet plot cannot be generated with a single unique filter - generating histogram instead.")
        
        # Create the histogram
        filter_counts = filter_df['filters_applied'].explode().value_counts()
        filter_counts.plot(kind='bar')
        
        # Set the title and axes
        failing_filter = grouped_new.index[0]
        plt.title(f"Histogram - {num_samples_to_be_removed} Failing Samples")
        plt.ylabel("Count")
        plt.xticks(rotation=45)
        
        # Set file name
        file_name = f"single_filter_plot_{str(num_samples_to_be_removed)}_filtered_samples.png"
    else:
        # Create a new DataFrame with an additional 'count' column
        filter_df_new = filter_df.assign(count=filter_df["filters_applied"].map(grouped_new['sample_id']))

        # Create the upset plot
        removed_by_filter = from_memberships(filter_df_new.filters_applied.str.split(','), data=filter_df_new)
        upset = UpSet(removed_by_filter, subset_size='auto', show_percentages=True, show_counts='%d', element_size=60, orientation='horizontal')
        upset.plot()

        # Set the title and margins
        plt.suptitle(f"Upset Plot - {str(num_samples_to_be_removed)} Failing Samples")
        plt.margins(x =0.1, y=0.25)
        
        # Set file name
        file_name = f"upset_plot_{str(num_samples_to_be_removed)}_filtered_samples.png"
  
    # Save the output file
    file_path = generate_file_path(TLD_PATH, 'filtering', file_name)
    save_figure(WS_BUCKET, file_path)

### Autosomal Copy Number Functions

In [None]:
def process_ploidy_data(sample_set_tbl, column_name, dest_path):
    """
    Localize files, extract tarballs, and create directories for ploidy data.
    
    Args:
        sample_set_tbl (pandas.DataFrame): The sample set table.
        column_name (str): The name of the column containing file paths.
        dest_path (str): The destination path for localized files.
    
    Returns:
        list: Paths to the created ploidy directories.
    """
    os.makedirs(dest_path, exist_ok=True)
    ploidy_dirs = []
    
    for i, (batch, file) in enumerate(zip(sample_set_tbl['entity:sample_set_id'], sample_set_tbl[column_name])):
        # Localize file
        local_file = os.path.join(dest_path, os.path.basename(file))
        subprocess.run(["gsutil", "cp", file, local_file], check=True)
        
        # Create and extract to subdirectory
        subdir = f"{batch}_ploidy"
        if os.path.exists(subdir):
            shutil.rmtree(subdir)
        os.mkdir(subdir)
        
        # Extract contents of tarball into new directory
        os.system(f'tar -xf {local_file} -C {subdir}')
        ploidy_dirs.append(subdir)
    
    return ploidy_dirs

In [None]:
def find_samples_outside_threshold(cn_data, upper_cutoff=None, lower_cutoff=None):
    """
    Function to find samples with copy numbers outside the specified thresholds for each chromosome.

    Args:
        cn_data (pandas.DataFrame): Contains the copy number data.
        upper_cutoff (float): The upper threshold value. 
        lower_cutoff (float): The lower threshold value.

    Returns:
        A two-element tuple containing:
            1. pandas.Series: Contains the count of samples outside the threshold(s) for each chromosome.
            2. list: Contains the sample IDs with at least one chromosome outside the threshold(s).
    """
    chr_columns = [col for col in cn_data.columns if 'chr' in col.lower()]
    
    # Create boolean mask for values outside threshold(s)
    if upper_cutoff is not None and lower_cutoff is not None:
        mask = (cn_data[chr_columns] > upper_cutoff) | (cn_data[chr_columns] < lower_cutoff)
    elif upper_cutoff is not None:
        mask = cn_data[chr_columns] > upper_cutoff
    else:
        mask = cn_data[chr_columns] < lower_cutoff
    
    # Create data structures to return
    cn_filter_pairs = mask.sum()
    cn_filter_ids = cn_data.loc[mask.any(axis=1), cn_data.columns[0]].tolist()
    
    # Print summary statistics
    print(f"Sample-chromosome pairs outside threshold(s): {mask.sum().sum()}")
    print(f"Samples outside threshold(s): {len(cn_filter_ids)}")
    
    return cn_filter_pairs, cn_filter_ids

In [None]:
def display_aneuploidy_outside_threshold(cn_data, chr_number, upper_cutoff=None, lower_cutoff=None, display=False):
    """
    Display aneuploidy for batches with samples outside the estimated copy number threshold for a given chromosome.

    Args:
        cn_data (pandas.DataFrame): Contains the copy number data.
        chr_number (int): The chromosome number to plot for.
        upper_cutoff (float): The upper threshold value.
        lower_cutoff (float): The lower threshold value.
        display (bool): Defines whether to display the plot directly.

    Returns:
        dict: Dictionary of batch ID to image file paths for the plot per batch.
    """
    # Validation
    if (upper_cutoff is None and lower_cutoff is None):
        raise Exception("Either one of 'upper_cutoff' or 'lower_cutoff' must be provided.")    
    
    # Find sample IDs with copy number outside the threshold for the given chromosome
    chromosome = f"chr{chr_number}_CopyNumber"
    if (upper_cutoff):
        failed_aneu = cn_data[cn_data[chromosome] > upper_cutoff]['sample_id'].values
    else:
        failed_aneu = cn_data[cn_data[chromosome] < lower_cutoff]['sample_id'].values
    
    # Exit if no samples fall outside the threshold
    if (len(failed_aneu) <= 0):
        print(f"No samples failed this threshold - please modify the cutoff accordingly.")
        return
    
    # Find all batches failing with samples below this threshold
    failed_batches = defaultdict(list)
    for sample_id in failed_aneu:
        batch_id = get_batch_for_sample_id(sample_id)
        failed_batches[batch_id].append(sample_id)
        
    # Log batch-specific information
    image_paths = {}
    for batch_id, sample_ids in failed_batches.items():
        image_file_path = str(batch_id) + "_ploidy/ploidy_est/estimated_CN_per_bin.all_samples." + str(chromosome).split("_")[0] + ".png"
        image_paths[batch_id] = image_file_path
    
    if (display):
        for batch_id, image_path in image_paths.items():
            img = Image.open(image_path)
            plt.figure(figsize=(12, 5))
            plt.imshow(img)
            plt.title(f'Batch: {batch_id} - {len(failed_batches[batch_id])} Failing Sample(s)')
            plt.axis('off')
            plt.show()
        return
    return image_paths

In [None]:
def plot_copy_number_per_autosome(cn_data, display=False):
    """
    Plot the copy number per autosome for samples.

    Args:
        cn_data (pandas.DataFrame): Contains the copy number data.
        display (bool): Defines whether to display the plot directly.

    Returns:
        str: Path to saved file containing plot.
    """
    # Extract contigs by removing "_CopyNumber" from column names (1 to 22) of 'all_cn' DataFrame
    contigs = [x.replace("_CopyNumber", "") for x in cn_data.columns[1:23]]
    num_samples = len(cn_data['sample_id'])

    # Create a box plot of copy number data for contigs (columns 1 to 22) in 'all_cn' DataFrame
    plt.figure(figsize=(10, 3.5))
    bplot = plt.boxplot(
        cn_data.iloc[:, 1:23],
        sym='.',
        labels=contigs,
        whis=6,
        patch_artist=True,
        showfliers=True,
        medianprops=dict(color="xkcd:steel blue")
    )
    plt.tick_params(axis='both', which='major', labelsize=12)
    plt.tick_params(axis='x', rotation=45)
    plt.title(f"Copy Number Per Autosome - {len(cn_data)} Samples", fontsize=15)
    plt.grid(True, zorder=0)
    
    # Save the plot as an image    
    file_name = f"cn_per_autosome_{num_samples}_samples.png"
    file_path = generate_file_path(TLD_PATH, 'autosomal_copy_number', file_name)
    save_figure(WS_BUCKET, file_path)
    
    # Conditionally display the plot
    if (display == True):
        plt.show()
        return
    plot.close()
    return cn_per_autosome_plot

### Sex Analysis Functions

In [None]:
def compute_sex_assignments(samples_qc_table, lower_cutoff_chrX=1.2, upper_cutoff_chrX=1.7, 
                            lower_cutoff_chrY=0.1, upper_cutoff_chrY=0.8):
    """
    Update sex assignments based on copy number data and sex assignment information from EvidenceQC.

    Parameters:
        samples_qc_table (pandas.DataFrame): DataFrame containing sex information with columns 'sample_id' and 'Assignment'.
        lower_cutoff_chrX (numeric): Lower chrX copy number cutoff for mosaic assignments.
        upper_cutoff_chrX (numeric): Upper chrX copy number cutoff for mosaic assignments.
        lower_cutoff_chrY (numeric): Lower chrY copy number cutoff for mosaic assignments.
        upper_cutoff_chrY (numeric): Upper chrY copy number cutoff for mosaic assignments.

    Returns:
         dict: Updated sex assignments
    """
    # Validate
    for cutoff in [lower_cutoff_chrX, upper_cutoff_chrX, lower_cutoff_chrY, upper_cutoff_chrY]:
        validate_numeric_input(cutoff, log=False)
    
    # Create two dictionaries to store updated sex assignments for figures and pedigree
    updated_sex = dict(zip(samples_qc_table['sample_id'], samples_qc_table['sex_assignment']))

    # Use copy number metrics to determine sex assignments
    for sample, cnY, cnX in zip(samples_qc_table['sample_id'], samples_qc_table.chrY_CopyNumber, 
                                samples_qc_table.chrX_CopyNumber):
        if updated_sex[sample] in ("MALE", "TURNER") and (cnY > lower_cutoff_chrY and cnY < upper_cutoff_chrY):
            updated_sex[sample] = 'MALE'
        elif cnY <= lower_cutoff_chrY and cnX > lower_cutoff_chrX and cnX < upper_cutoff_chrX:
            updated_sex[sample] = "MOSAIC"
    
    # Return two dictionaries containing updated sex assignments for figures and pedigree
    return updated_sex

In [None]:
def process_reference_ped(file_path, table):
    """
    Loads and formats a reference sex assignment dataframe from a specified PED file.

    Args:
        file_path (str): The path to the file in Google Cloud Storage containing the reference PED file.
        table (pandas.DataFrame): Contains sample information for comparison.

    Returns:
        pandas.DataFrame: Contains information from the raw sex assignments file.
    """
    # Load the data from the file at 'file_path' into a pandas DataFrame. The 'dtype=str'
    reference_ped = pd.read_table(file_path, dtype=str, header=None, comment="#", names=[
        "Family_ID", "Sample_ID", "Paternal_ID","Maternal_ID", "Sex","Phenotype"
    ])
    
    # Get PED information for only the batches you are working with if the PED file contains more samples
    reference_ped = reference_ped[reference_ped['Sample_ID'].isin(table['sample_id'])]
    
    return reference_ped

In [None]:
def create_ped(sex_assignments, reference_ped, file_path):
    """
    Create a PED file based on the sex assignments provided.

    Args:
        sex_assignments (dict): Maps sample IDs to computed sex assignments.
        reference_ped (pandas.DataFrame): Represents the reference PED file.
        file_path (str): File path to the output PED file.

    Returns:
        None.
    """
    # Ensure valid sex assignments
    valid_sex_assignments = set(SEX_CODES.keys())
    invalid_assignments = set(sex_assignments.values()) - valid_sex_assignments
    if invalid_assignments:
        raise ValueError(f"Invalid sex assignments: {', '.join(invalid_assignments)}")
    
    # Create directories in file path if necessary
    dir_path = os.path.dirname(file_path)
    os.makedirs(dir_path, exist_ok=True)
        
    # Account for case where reference PED file does not exist
    if (reference_ped.empty):
        reference_ped = pd.DataFrame({'Sample_ID': sex_assignments.keys()})
        reference_ped["Family_ID"] = reference_ped["Sample_ID"]
        reference_ped["Paternal_ID"] = "0"
        reference_ped["Maternal_ID"] = "0"
        reference_ped["Sex"] = "0"
        reference_ped["Phenotype"] = "0"
    
    # Assign sexes based on sex assignment dictionary
    computed_ped = reference_ped[['Sample_ID', 'Family_ID', 'Paternal_ID', 'Maternal_ID', 'Phenotype']]
    computed_ped['Sex'] = computed_ped['Sample_ID'].map(sex_assignments).map(SEX_CODES)
    computed_ped = computed_ped[['Sample_ID', 'Family_ID', 'Paternal_ID', 'Maternal_ID', 'Sex', 'Phenotype']]
    computed_ped = computed_ped[computed_ped['Sample_ID'].isin(sex_assignments.keys())]
    
    # Save dataframe
    save_df(WS_BUCKET, file_path, computed_ped, print_path=True, header=False)
    
    # Update cohort_ped_file workspace attribute
    workspace_ped_path = os.path.join(WS_BUCKET, file_path)
    attrs = [fapi._attr_set("cohort_ped_file", workspace_ped_path)]
    r = fapi.update_workspace_attributes(NAMESPACE, WORKSPACE, attrs)
    if r.status_code != 200:
        raise Exception(f"Unable to update workspace attributes: {r}.")

In [None]:
def create_ped_differences(reference_ped, computed_sex, file_path):
    """
    Creates the a PED file with information from the reference and computed_ped assignment files for samples whose
    sex assignment differs across the two.
    
    Args:
        reference_ped (pandas.DataFrame): Contains the information from the reference PED file.
        computed_sex (dict): Maps sample IDs to their computed sexes.
        file_path (str): The file path to the output PED file. 
    
    Returns:
        differences(pandas.DataFrame): Contains the samples whose sex information was modified through the QC process.
    """
    # Create column for computed sex values
    computed_sex = {k: SEX_CODES.get(v, 0) for k, v in computed_sex.items()}
    reference_ped['Sex'] = reference_ped['Sex'].astype(int)
    reference_ped['Computed_Sex'] = reference_ped['Sample_ID'].map(computed_sex)
    
    # Keep only samples with differing sex assignments
    differences = reference_ped[reference_ped['Sex'] != reference_ped['Computed_Sex']]
    differences = differences[['Family_ID', 'Sample_ID', 'Paternal_ID', 'Maternal_ID', 'Sex', 'Phenotype']]
    
    # Save differences file
    differences = differences.reset_index(drop=True)
    save_df(WS_BUCKET, file_path, differences)
    
    return differences

In [None]:
def plot_sex_chromosome_ploidy(all_cn, sex_for_fig, display=False):
    """
    Function to plot sex chromosome ploidy.

    Args:
        all_cn (pandas.DataFrame): Data to be plotted. It should contain columns 'sample_id', 'chrX_CopyNumber' and 'chrY_CopyNumber'.
        sex_for_fig (dict): Maps sample ID to its sex category. The keys should be present in `all_cn['sample_id']`.
        prefix (str): A prefix to use when plotting.
        display (bool): Defines whether to display the plot directly.

    Returns:
        str: Path to saved file containing plot.

    """
    # Dictionary for colors to be used for different sex categories
    color_dict = {"MALE": 'deepskyblue', "FEMALE": 'tab:pink', "TURNER": 'tab:red', 
                  "TRIPLE X": 'darkviolet', "KLINEFELTER": 'darkorange', "JACOBS": 'g', 
                  "MOSAIC": 'aquamarine', "OTHER": 'maroon'}

    # Dictionary for chromosomal representation for different sex categories
    xy_rep = {"MALE": " (XY)", "FEMALE": " (XX)", "TURNER": " (X)", "TRIPLE X": " (XXX)", 
              "KLINEFELTER": " (XXY)", "JACOBS": " (XYY)", "MOSAIC": "", "OTHER": ""}

    # Creating a list of colors corresponding to sex categories for each sample in the data
    color_sex_list = [color_dict[sex_for_fig[sample]] for sample in all_cn['sample_id']]

    # List for the x and y axis ticks
    cn_list = [0,1,2,3]
    
    # Creating the figure and setting its size
    fig = plt.figure(figsize=(7,7))
    
    # Adding a grid to the figure
    plt.grid(True, zorder=0)
    
    # Adding a scatter plot to the figure with data points colored according to sex category
    plt.scatter(all_cn.chrX_CopyNumber, all_cn.chrY_CopyNumber, alpha=1, s=10,
               c=color_sex_list, zorder=2)
    
    # Setting the x and y axis limits
    plt.xlim([-0.1,3.15])
    plt.ylim([-0.1,3.1])
    
    # Setting the x and y axis ticks
    plt.xticks(cn_list)
    plt.yticks(cn_list)

    # Adding text to the plot representing different combinations of X and Y chromosomes
    for nx in cn_list:
        for ny in cn_list:
            if nx + ny >=6:
                continue
            plt.text(nx, ny, nx*"X" + ny*"Y", color='lightgray', ha='center', va='center', zorder=1, weight="bold")
    
    # Adding a legend to the plot
    plt.legend(handles=[Line2D(
        [0],[0], alpha=1, marker='o', color='w', 
        markerfacecolor=color_dict[label], 
        label=label + xy_rep[label]
    ) for label in color_dict.keys()], framealpha=1, labelspacing=0.5)
    
    # Adding a title and labels to the plot
    plt.title(f"Sex Chromosome Ploidy - {len(sex_for_fig)} Samples", fontsize=15)
    plt.xlabel("chrX Copy Number", fontsize=13)
    plt.ylabel("chrY Copy Number", fontsize=13)
        
    # Save the plot as an image    
    file_name = f"sex_chromosome_ploidy_{len(sex_for_fig)}_samples.png"
    file_path = generate_file_path(TLD_PATH, 'sex_analysis', file_name)
    save_figure(WS_BUCKET, file_path)
    
    # Show figure if display = True, else close to free up memory and return file name
    if (display == True):
        plt.show()
        return
    plt.close()
    return output_png

# Data Ingestion
This section fetches the sample QC data.

## Step 1. Sample Sets
This step loads the sample_set data table to find the QC data file paths. You can exclude unnecessary sample_sets at this stage if needed.

<div class="alert alert-block alert-info">Uncomment and run once. Once this step has run, if you need to load the table again, use the following cell.</div>

In [None]:
# sample_set_response = fapi.get_entities_tsv(
#     NAMESPACE, WORKSPACE, "sample_set", 
#     attrs=["ploidy_plots", "qc_table"], model="flexible"
# )

# with open('sample_set.zip', 'wb') as f:
#     f.write(sample_set_response.content)
    
# with zipfile.ZipFile('sample_set.zip', 'r') as zip_ref:
#     # Extract sample set data
#     with zip_ref.open('sample_set_entity.tsv') as file:
#         tsv_file = io.StringIO(file.read().decode('utf-8'))
#         sample_set_tbl = pd.read_csv(tsv_file, sep='\t')
#         sample_set_tbl = sample_set_tbl[sample_set_tbl['ploidy_plots'].notnull() & sample_set_tbl['qc_table'].notnull()]
#         sample_set_tbl = sample_set_tbl.reset_index(drop=True)
    
#     # Extract sample membership data
#     with zip_ref.open('sample_set_membership.tsv') as membership_file:
#         membership_tsv = io.StringIO(membership_file.read().decode('utf-8'))
#         membership_df = pd.read_csv(membership_tsv, sep='\t')
    
#     # Add list of samples to corresponding sample set
#     sample_groups = membership_df.groupby('membership:sample_set_id')['sample'].unique().apply(list)
#     sample_set_tbl['samples'] = sample_set_tbl['entity:sample_set_id'].map(sample_groups)
#     sample_set_tbl['samples'] = sample_set_tbl['samples'].apply(lambda x: x if isinstance(x, list) else [])

# file_path = generate_file_path(TLD_PATH, 'artifacts', 'sample_sets_qc.tsv')
# save_df(WS_BUCKET, file_path, sample_set_tbl)

In [None]:
# Use this block to reload the data saved in the previous cell
file_path = generate_file_path(TLD_PATH, 'artifacts', 'sample_sets_qc.tsv')

sample_set_tbl = pd.read_table(file_path, sep='\t', dtype={'sample_id': str})

sample_set_tbl

<div class="alert alert-block alert-success"><b>Optional</b>: If you wish to only include a subset of batches out of the ones listed above, you can filter <tt>sample_set_tbl</tt> to only include them using the cell below. Ensure that all batches you expect are included.</div>

In [None]:
# Optionally filter to only include a subset of batches from sample_set_tbl and update the saved file. For example:
sample_set_tbl = sample_set_tbl[sample_set_tbl['entity:sample_set_id'].str.contains('KJ_EvidenceQC_Updates')]

In [None]:
file_path = generate_file_path(TLD_PATH, 'artifacts', 'sample_sets_qc.tsv')
save_df(WS_BUCKET, file_path, sample_set_tbl)

# Output batch information
print(f"Sample Set DataFrame Dimensions: {sample_set_tbl.shape}")
print(f"Batch Count: {len(sample_set_tbl)}\n")
print("Batches:")
print(*list(sample_set_tbl['entity:sample_set_id'].values), sep='\n')

sample_set_tbl

## Step 2. Samples
This step aggregates sample QC data across *sample_sets*.

<div class="alert alert-block alert-info">Uncomment and run once. Once this step has run, if you need to load the table again, use the next cell.</div>

In [None]:
# samples_qc_table = pd.concat([pd.read_csv(f, sep='\t') for f in sample_set_tbl['qc_table']], ignore_index = True)

# print(f"Sample DataFrame Dimensions: {samples_qc_table.shape}")
# print(f"Sample Count: {len(samples_qc_table)}\n")

# validate_unique_samples(samples_qc_table)

# file_path = generate_file_path(TLD_PATH, 'artifacts', 'samples_qc.tsv')
# save_df(WS_BUCKET, file_path, samples_qc_table)

In [None]:
# Use this block to reload the data saved in the previous cell
file_path = generate_file_path(TLD_PATH, 'artifacts', 'samples_qc.tsv')

samples_qc_table = pd.read_table(file_path, sep='\t', dtype={'sample_id': str})

samples_qc_table

## Step 3. Ploidy Data
This step localizes files with additional data related to copy number.

<div class="alert alert-block alert-info">Uncomment and run once. Once this step has run, if you need to load the table again, use the next cell.</div>

In [None]:
# dir_path = os.path.join(TLD_PATH, "ploidy")
# ploidy_dirs = process_ploidy_data(sample_set_tbl, 'ploidy_plots', dir_path)

# # Write the directory names to a file
# file_path = os.path.join(TLD_PATH, "ploidy", "ploidy_dirs.list")
# with open(file_path, 'w') as dirs_file:
#     for ploidy_dir in ploidy_dirs:
#         dirs_file.write(ploidy_dir + '\n')

# # Get binwise copy number files
# binwise_cn_files = [os.path.join(ploidy_dir, "ploidy_est", "binwise_estimated_copy_numbers.bed.gz") for ploidy_dir in ploidy_dirs]

In [None]:
# Use this block to reload the data saved in the previous cell
file_path = generate_file_path(TLD_PATH, "ploidy", "ploidy_dirs.list")

ploidy_dirs = [line.strip() for line in open(file_path)]

binwise_cn_files = [ploidy_dir + "/ploidy_est/binwise_estimated_copy_numbers.bed.gz" for ploidy_dir in ploidy_dirs]

# QC
This section involves analyzing and filtering samples based on a series of sample quality metrics. 

**Usage**:

1. Each step should be executed by first using the **Analysis** section, which plots figures to assist with determining appropriate filtering cutoff values. This will involve modifying the following parameters:
    - `LOG_SCALE`: Determines whether to log-scale the resulting plot.
    - `LINE_DEVIATIONS`: A list of integers that defines the cutoff lines to draw on each histogram plot. This is a list of multipliers *x* for the median absolute deviation (MAD) such that lines are drawn at Median &pm; x * MAD. It has been initialized with default values, but set this to `[]` or `None` if you don't wish to include any cutoffs.
    - `LINE_STYLES`: A list of strings that defines the line styles that each cutoff line should use in each histogram plot displayed. It has been initialized with default values, but set this to `[]` or `None` if you don't wish to include any cutoffs.
    - `**kwargs`: Additional arguments passed in to the plotting function (e.g. `bins = 50`).

2. After this is done and suitable cutoffs have been established, the **Filtering** section enables dropping samples using these cutoffs. This will involve modifying the following parameters:
    - `LOG_SCALE`: Determines whether to log-scale the resulting plot.
    - `METHOD`: Must be set to one of `'hard'` or `'MAD'`. It is initialized to the method used for the [All of Us (AoU) CDRv7 off-cycle SV dataset](https://support.researchallofus.org/hc/en-us/articles/27496716922900-All-of-Us-Short-Read-Structural-Variant-Quality-Report).
        - If set to `'hard'`, the filtering will be based on strict thresholds set by `LOWER_CUTOFF` and `UPPER_CUTOFF`. 
            - `LOWER_CUTOFF`: This defines the lower threshold when `METHOD = 'hard'` - any samples whose value for a given metric is less than this will be dropped. It is initialized to `None`, which corresponds to a lower cutoff of negative infinity, but should be set to a numerical value if `METHOD = hard`. For some metrics, a lower cutoff is not relevant, so the parameter is not set and defaults to negative infinity.
            - `UPPER_CUTOFF`: The upper threshold when `METHOD = 'hard'` - any samples whose value for a given metric is greater than this will be dropped. It is initialized to `None`, which corresponds to an upper cutoff of infinity, but should be set to a numerical value if `METHOD = hard`. For some metrics, an upper cutoff is not relevant, so the parameter is not set and defaults to infinity.
        - If set to `'MAD'`, the filtering will be based on MAD-based thresholds set by `MAD_CUTOFF`.
            - `MAD_CUTOFF`: The MAD threshold when `METHOD = 'MAD'` - any samples whose value for a given metric is outside this number of MAD from the median will be dropped. It is initialized to `None`, which corresponds to an infinite cutoff, but should be set to a numerical value if `METHOD = hard`.
    - `**kwargs`: Additional arguments passed in to the plotting function (e.g. `bins = 50`).
    
    
**Outputs**:

Outputs generated in the sections below are saved to the virtual machine's file system (within the analyses edit directory) and Google Cloud Storage (in the bucket attached to your workspace) under the top-level directory `evidence_qc/`. The file system structure within the `evidence_qc/` directory is mirrored across both file systems, and can be seen as follows:
- `median_coverage/`: Contains outputs generated within the **Median Coverage** section.
- `mean_insert_size/`: Contains outputs generated within the **Mean Insert Size** section.
- `wgd_score/`: Contains outputs generated within the **WGD Score** section.
- `nondiploid_bins/`: Contains outputs generated within the **Non-Diploid Bins** section.
- `contamination/`: Contains outputs generated within the **Contamination** section.
- `raw_caller_outliers/`: Contains outputs generated within the **Raw Caller Outliers** section.
- `autosomal_copy_number/`: Contains outputs generated within the **Autosomal Copy Number** section.
- `sex_analysis/`: Contains outputs generated within the **Sex Analysis** section.
- `filtering/`: Contains outputs generated within the **Sample Filtering** section.
- `artifacts/`: Contains temporary artifacts generated throughout the notebook.

    
**Additional Notes**:

- For guidance, cutoffs from the [AoU CDRv7 off-cycle SV dataset](https://support.researchallofus.org/hc/en-us/articles/27496716922900-All-of-Us-Short-Read-Structural-Variant-Quality-Report) will be noted in each section, but users should exercise their judgment to select cutoffs appropriate for their data.
- Some steps dynamically set the value of `filter_type`, which defines the column name within the metadata table that corresponds to the metric that a given step is based on. The code for this should not have to be changed in any cell.
- If you wish to see the samples that have been filtered out at any particular point, run the **Sample Filtering** section - you can always come back and apply more filters at any point.

## Median Coverage
The median coverage measures the typical sequencing depth in a sample. We recommend a minimum coverage of about 30x. Coverage is also important for batching in order to set appropriate genotyping cutoffs among homogeneous samples in a batch.

<div class="alert alert-block alert-success">The read length is used to convert the median binned coverage to median base coverage for easier interpretation. The parameter <tt>READ_LENGTH</tt> is set to 151 by default, which is typical for Illumina data; please edit this input if your data has a different read length.</div>

In [None]:
READ_LENGTH = 151

validate_numeric_input(READ_LENGTH)

### Analysis

<div class="alert alert-block alert-success">Input the analysis parameters.</div>

In [None]:
LOG_SCALE = False # Boolean value that defines whether to log-scale the plot
LINE_DEVIATIONS = None  # List of integers that defines the MAD cutoff lines to draw on each histogram plot
LINE_STYLES = None  # List of strings that defines the line styles of each MAD cutoff line passed above

validate_qc_inputs(samples_qc_table, 'median_coverage', line_deviations=LINE_DEVIATIONS, 
                   line_styles=LINE_STYLES, log_scale=LOG_SCALE)

In [None]:
# Example use of **kwargs to pass in parameter `bins` to the histogram plotting function
run_analysis(samples_qc_table, 'median_coverage', LINE_DEVIATIONS, LINE_STYLES, READ_LENGTH, 
             log_scale=LOG_SCALE, bins=50)

### Filtering

<div class="alert alert-block alert-success">Input the filtering parameters. For the AoU SV data, the lower cutoff was 30.</div>

In [None]:
LOG_SCALE = False # Boolean value that defines whether to log-scale the plot
METHOD = 'hard' # String value that defines the cutoff method to use - either 'MAD' or 'hard'

LOWER_CUTOFF = None # Numeric value that defines the lower threshold if METHOD = 'hard'
UPPER_CUTOFF = None # Numeric value that defines the upper threshold if METHOD = 'hard'
MAD_CUTOFF = None # Numeric value that defines the MAD deviation threshold if METHOD = 'MAD'

validate_qc_inputs(samples_qc_table, 'median_coverage', method=METHOD, lower_cutoff=LOWER_CUTOFF, 
                   upper_cutoff=UPPER_CUTOFF, mad_cutoff=MAD_CUTOFF, log_scale=LOG_SCALE)

In [None]:
# Example use of **kwargs to pass in parameter `bins` to the histogram plotting function
run_filtering(samples_qc_table, 'median_coverage', METHOD, lower_cutoff=LOWER_CUTOFF, upper_cutoff=UPPER_CUTOFF, 
              mad_cutoff=MAD_CUTOFF, read_length=READ_LENGTH, log_scale=LOG_SCALE, bins=50)

## Mean Insert Size
The insert size is the length of the DNA fragment being sequenced, excluding adapters. We estimate insert size for each read pair as the distance between the mapped positions of the mates. The mean insert size is the average insert size for the reads within a sample; refer to the [Picard documentation](https://broadinstitute.github.io/picard/picard-metric-definitions.html#InsertSizeMetrics) for details of how outliers are excluded. Insert size is important for SV calling because discordant read pairs are one type of SV evidence in short read data; this metric is particularly useful for batching to ensure paired end evidence cutoffs for genotyping are well-calibrated within each batch.

Samples processed with Scramble instead of MELT will have missing values for mean insert size. If mean insert size is not available for your data, you can skip this step.

### Analysis

<div class="alert alert-block alert-success">Input the analysis parameters.</div>

In [None]:
LOG_SCALE = False # Boolean value that defines whether to log-scale the plot
LINE_DEVIATIONS = None  # List of integers that defines the MAD cutoff lines to draw on each histogram plot
LINE_STYLES = None  # List of strings that defines the line styles of each MAD cutoff line passed above

validate_qc_inputs(samples_qc_table, 'mean_insert_size', line_deviations=LINE_DEVIATIONS, 
                   line_styles=LINE_STYLES, log_scale=LOG_SCALE)

In [None]:
run_analysis(samples_qc_table, 'mean_insert_size', LINE_DEVIATIONS, LINE_STYLES, log_scale=LOG_SCALE)

### Filtering

<div class="alert alert-block alert-success">Input the filtering parameters. For the AoU SV data, the lower cutoff was 320 and the upper cutoff was 700.</div>

In [None]:
LOG_SCALE = False # Boolean value that defines whether to log-scale the plot
METHOD = 'hard' # String value that defines the cutoff method to use - either 'MAD' or 'hard'

LOWER_CUTOFF = None # Numeric value that defines the lower threshold if METHOD = 'hard'
UPPER_CUTOFF = None # Numeric value that defines the upper threshold if METHOD = 'hard'
MAD_CUTOFF = None # Numeric value that defines the MAD deviation threshold if METHOD = 'MAD'

validate_qc_inputs(samples_qc_table, 'mean_insert_size', method=METHOD, lower_cutoff=LOWER_CUTOFF, 
                   upper_cutoff=UPPER_CUTOFF, mad_cutoff=MAD_CUTOFF, log_scale=LOG_SCALE)

In [None]:
run_filtering(samples_qc_table, 'mean_insert_size', METHOD, lower_cutoff=LOWER_CUTOFF, 
              upper_cutoff=UPPER_CUTOFF, mad_cutoff=MAD_CUTOFF, log_scale=LOG_SCALE)

## WGD Score
The whole genome dosage (WGD), or dosage bias, score quantifies the non-uniformity of sequencing coverage within a sample ([Collins et al., Nature 2020](https://www.nature.com/articles/s41586-020-2287-8)). The distribution of WGD scores should be centered just below 0 for PCR-free genomes and just above 0 for PCR+ genomes. Greater magnitude of WGD indicates more variable coverage, which can result in poor read depth-based CNV calling, so we recommend removing samples which are outliers for WGD. WGD is also useful for creating batches with similar coverage biases, which improves batch-level model training for read depth-based CNV calling.

### Analysis

<div class="alert alert-block alert-success">Input the analysis parameters.</div>

In [None]:
LOG_SCALE = False # Boolean value that defines whether to log-scale the plot
LINE_DEVIATIONS = [2, 4, 6]  # List of integers that defines the MAD cutoff lines to draw on each histogram plot
LINE_STYLES = ['solid', 'dashed', 'dashdot']  # List of strings that defines the line styles of each MAD cutoff line passed above

validate_qc_inputs(samples_qc_table, 'wgd_score', line_deviations=LINE_DEVIATIONS, 
                   line_styles=LINE_STYLES, log_scale=LOG_SCALE)

In [None]:
run_analysis(samples_qc_table, 'wgd_score', LINE_DEVIATIONS, LINE_STYLES, log_scale=LOG_SCALE)

### Filtering

<div class="alert alert-block alert-success">Input the filtering parameters. For the AoU SV data, the MAD cutoff was 6, which represented a range of approximately [-0.162, 0.136].</div>

In [None]:
LOG_SCALE = False # Boolean value that defines whether to log-scale the plot
METHOD = 'MAD' # String value that defines the cutoff method to use - either 'MAD' or 'hard'

LOWER_CUTOFF = None # Numeric value that defines the lower threshold if METHOD = 'hard'
UPPER_CUTOFF = None # Numeric value that defines the upper threshold if METHOD = 'hard'
MAD_CUTOFF = None # Numeric value that defines the MAD deviation threshold if METHOD = 'MAD'

validate_qc_inputs(samples_qc_table, 'wgd_score', method=METHOD, lower_cutoff=LOWER_CUTOFF, 
                   upper_cutoff=UPPER_CUTOFF, mad_cutoff=MAD_CUTOFF, log_scale=LOG_SCALE)

In [None]:
run_filtering(samples_qc_table, 'wgd_score', METHOD, lower_cutoff=LOWER_CUTOFF, 
              upper_cutoff=UPPER_CUTOFF, mad_cutoff=MAD_CUTOFF, log_scale=LOG_SCALE)

## Non-Diploid Bins
The number of non-diploid bins for each sample was calculated by counting the number of 1Mb bins whose sequencing coverage significantly deviated from the expectation for that sample. Similar to WGD, samples with a very high number of non-diploid bins may perform poorly in read depth-based CNV calling.

### Analysis

<div class="alert alert-block alert-success">Input the analysis parameters.</div>

In [None]:
LOG_SCALE = False # Boolean value that defines whether to log-scale the plot
LINE_DEVIATIONS = None  # List of integers that defines the MAD cutoff lines to draw on each histogram plot
LINE_STYLES = None  # List of strings that defines the line styles of each MAD cutoff line passed above

validate_qc_inputs(samples_qc_table, 'nondiploid_bins', line_deviations=LINE_DEVIATIONS, 
                   line_styles=LINE_STYLES, log_scale=LOG_SCALE)

In [None]:
run_analysis(samples_qc_table, 'nondiploid_bins', LINE_DEVIATIONS, LINE_STYLES, log_scale=LOG_SCALE)

### Filtering

<div class="alert alert-block alert-success">Input the filtering parameters. For the AoU SV data, the upper cutoff was 500.</div>

In [None]:
LOG_SCALE = False # Boolean value that defines whether to log-scale the plot
METHOD = 'hard' # String value that defines the cutoff method to use - either 'MAD' or 'hard'

UPPER_CUTOFF = None # Numeric value that defines the upper threshold if METHOD = 'hard'
MAD_CUTOFF = None # Numeric value that defines the MAD deviation threshold if METHOD = 'MAD'

validate_qc_inputs(samples_qc_table, 'nondiploid_bins', method=METHOD, upper_cutoff=UPPER_CUTOFF, 
                   mad_cutoff=MAD_CUTOFF, log_scale=LOG_SCALE)

In [None]:
run_filtering(samples_qc_table, 'nondiploid_bins', METHOD, upper_cutoff=UPPER_CUTOFF, 
              mad_cutoff=MAD_CUTOFF, log_scale=LOG_SCALE)

## Contamination
If you wish to analyze any additional metrics, you may add them to the metrics table and do so here. One additional useful metric, shown here, is cross-sample contamination - this is the fraction of reads in the sample that come from another individual. High rates of contamination can cause artifacts in genotyping. If you do not have access to contamination metrics, you may skip this section.

### Analysis

<div class="alert alert-block alert-success">Input the analysis parameters.</div>

In [None]:
LOG_SCALE = True # Boolean value that defines whether to log-scale the plot
LINE_DEVIATIONS = None  # List of integers that defines the MAD cutoff lines to draw on each histogram plot
LINE_STYLES = None # List of strings that defines the line styles of each MAD cutoff line passed above

validate_qc_inputs(samples_qc_table, 'contamination', line_deviations=LINE_DEVIATIONS, 
                   line_styles=LINE_STYLES, log_scale=LOG_SCALE)

In [None]:
run_analysis(samples_qc_table, 'contamination', LINE_DEVIATIONS, LINE_STYLES, log_scale=LOG_SCALE)

### Filtering

<div class="alert alert-block alert-success">Input the filtering parameters. For the AoU SV data, the upper cutoff was 0.01.</div>

In [None]:
LOG_SCALE = True # Boolean value that defines whether to log-scale the plot
METHOD = 'hard' # String value that defines the cutoff method to use - either 'MAD' or 'hard'

UPPER_CUTOFF = None # Numeric value that defines the upper threshold if METHOD = 'hard'
MAD_CUTOFF = None # Numeric value that defines the MAD deviation threshold if METHOD = 'MAD'


validate_qc_inputs(samples_qc_table, 'contamination', method=METHOD, upper_cutoff=UPPER_CUTOFF, 
                   mad_cutoff=MAD_CUTOFF, log_scale=LOG_SCALE)

In [None]:
run_filtering(samples_qc_table, 'contamination', METHOD, upper_cutoff=UPPER_CUTOFF, 
              mad_cutoff=MAD_CUTOFF, log_scale=LOG_SCALE)

## Raw Caller Outliers
This series of metrics look for samples with an abnormally high or low number of raw SV calls from the three initial algorithms: Manta, Wham, and Scramble (or MELT). Higher than typical SV counts may indicate technical artifacts, while extremely low SV counts may indicate that an algorithm failed to complete. The values represent the number of times the sample was an outlier for SV counts across categories defined by algorithm, SV type, and chromosome. 

**Note**: 
In the sections below, there are two additional parameters that have not been covered as of yet.
- `CALLER`: The caller for which to analyze results. This must be one of `['overall', 'manta', 'melt', 'scramble', 'wham']`, where 'overall' corresponds to the sum of outlier occurrences across the individual callers.
- `TYPE`: The type of outliers for which to analyze results. This must be one of `['high', 'low']`, where 'high' indicates an the number of cases in which the sample had more SVs than typical, while 'low' indicates the number of cases in which the sample had fewer SVs than typical. 

We recommend checking the overall high and low outliers (i.e. `CALLER = 'overall'` and `TYPE = 'high'/'low'`), but you may also examine results for individual algorithms.

### Analysis

<div class="alert alert-block alert-success">Input the analysis parameters.</div>

In [None]:
LOG_SCALE = False # Boolean value that defines whether to log-scale the plot
LINE_DEVIATIONS = None  # List of integers that defines the MAD cutoff lines to draw on each histogram plot
LINE_STYLES = None  # List of strings that defines the line styles of each MAD cutoff line passed above

CALLER = 'overall' # String value that defines the caller - either 'overall', 'manta', 'melt', 'wham' or 'dragen'
TYPE = 'high' # String value that defines the outlier direction - either 'high' or 'low'

validate_qc_inputs(samples_qc_table, f"{CALLER}_{TYPE}_outlier", line_deviations=LINE_DEVIATIONS, 
                   line_styles=LINE_STYLES, caller=CALLER, caller_type=TYPE, log_scale=LOG_SCALE)

In [None]:
run_analysis(samples_qc_table, f"{CALLER}_{TYPE}_outlier", LINE_DEVIATIONS, LINE_STYLES, log_scale=LOG_SCALE)

### Filtering

<div class="alert alert-block alert-success">Input the filtering parameters. For the AoU SV data, the upper cutoff was 30 across all callers - i.e. CALLER = 'overall' and TYPE = 'high'.</div>

In [None]:
LOG_SCALE = False # Boolean value that defines whether to log-scale the plot
METHOD = 'hard' # String value that defines the cutoff method to use - either 'MAD' or 'hard'

CALLER = 'overall' # String value that defines the caller - either 'overall', 'manta', 'melt', 'wham' or 'dragen'
TYPE = 'high' # String value that defines the outlier direction - either 'high' or 'low'

UPPER_CUTOFF = None # Numeric value that defines the upper threshold if METHOD = 'hard'
MAD_CUTOFF = None # Numeric value that defines the MAD deviation threshold if METHOD = 'MAD'

validate_qc_inputs(samples_qc_table, f"{CALLER}_{TYPE}_outlier", method=METHOD, upper_cutoff=UPPER_CUTOFF, 
                   mad_cutoff=MAD_CUTOFF, caller=CALLER, caller_type=TYPE, log_scale=LOG_SCALE)

In [None]:
run_filtering(samples_qc_table, f"{CALLER}_{TYPE}_outlier", METHOD, upper_cutoff=UPPER_CUTOFF, 
              mad_cutoff=MAD_CUTOFF, log_scale=LOG_SCALE)

## Autosomal Copy Number
Samples that are outliers for normalized copy number on one or more autosomes should be filtered to preserve the quality of SV calls for other samples. Explore the plots of copy ratio across chromosomes and binned copy number per chromosome to set an appropriate cutoff. These plots can be used to identify potential germline or mosaic aneuploidies as well.

**Note**: The parameter `CHR_NUMBER` in this step is a numeric value that allows you to select a specific chromosome to view an analysis for.

### Visualize All Autosomes

In [None]:
# Dislay copy ratio across autosomes for all samples
plot_copy_number_per_autosome(samples_qc_table[cols_autosome], display=True)

### Below Threshold

#### Analysis

<div class="alert alert-block alert-success">Adjust the cutoff parameter to get the counts of samples with likely copy number aberrations on the autosome. The copy ratio is defined as the ratio of observed to expected depth per chromosome copy, i.e. a copy ratio of 2.0 corresponds to diploid. A recommended range is 1.2-1.8, but the choice of cutoff will depend on dataset quality.</div>

In [None]:
LOWER_CUTOFF = 1.8

validate_numeric_input(LOWER_CUTOFF)

In [None]:
cn_filter_pairs, cn_filter_ids = find_samples_outside_threshold(samples_qc_table[cols_autosome], 
                                                                lower_cutoff=LOWER_CUTOFF)

cn_filter_pairs

<div class="alert alert-block alert-success">Input the analysis parameters to plot the aneuploidies of samples that fall below the threshold for a specific chromosome.</div>

In [None]:
CHR_NUMBER = 1
LOWER_CUTOFF = 1.8

validate_numeric_input(CHR_NUMBER, LOWER_CUTOFF)

In [None]:
# Dislay copy number plots for batches with samples that don't meet threshold
display_aneuploidy_outside_threshold(samples_qc_table[cols_autosome], CHR_NUMBER, lower_cutoff=LOWER_CUTOFF, 
                                     display=True)

#### Filtering

<div class="alert alert-block alert-success">Input the filtering parameters to filter out the samples whose estimated copy number fall below the threshold for at least one chromosome. For the AoU SV data, the lower cutoff was 1.8.</div>

In [None]:
LOWER_CUTOFF = None

validate_numeric_input(LOWER_CUTOFF)

In [None]:
cn_filter_pairs, cn_filter_ids = find_samples_outside_threshold(samples_qc_table[cols_autosome], 
                                                                lower_cutoff=LOWER_CUTOFF)

remove_samples(cn_filter_ids, 'low_copy_number_outlier')

### Above Threshold

#### Analysis

<div class="alert alert-block alert-success">Input the analysis parameters to get the counts of samples with a copy number exceeding the threshold on each autosome. A recommended range is 2.2-2.8, but the choice of cutoff will depend on dataset quality.</div>

In [None]:
UPPER_CUTOFF = 2.3

validate_numeric_input(UPPER_CUTOFF)

In [None]:
cn_filter_pairs, cn_filter_ids = find_samples_outside_threshold(samples_qc_table[cols_autosome], upper_cutoff=UPPER_CUTOFF)

cn_filter_pairs

<div class="alert alert-block alert-success">Input the analysis parameters to plot the aneuploidies of samples that exceed the threshold for a specific chromosome.</div>

In [None]:
CHR_NUMBER = 1
UPPER_CUTOFF = 2.3

validate_numeric_input(CHR_NUMBER, UPPER_CUTOFF)

In [None]:
# Dislay copy number plots for batches with samples that don't meet threshold
display_aneuploidy_outside_threshold(samples_qc_table[cols_autosome], CHR_NUMBER, upper_cutoff=UPPER_CUTOFF, display=True)

#### Filtering

<div class="alert alert-block alert-success">Input the filtering parameters to filter out the samples whose estimated copy number exceeds the threshold for at least one chromosome. For the AoU SV data, the upper cutoff was 2.3.</div>

In [None]:
UPPER_CUTOFF = None

validate_numeric_input(UPPER_CUTOFF)

In [None]:
cn_filter_pairs, cn_filter_ids = find_samples_outside_threshold(samples_qc_table[cols_autosome], upper_cutoff=UPPER_CUTOFF)

remove_samples(cn_filter_ids, 'high_copy_number_outlier')

## Sex Analysis
In this section, you can examine the sex chromosome ploidy and computed sex labels. You can generate a PED file with the computed sex information, or if you already have a PED file, you can compare and update the sex information as needed. Importantly, samples with sex chromosome ploidy other than XX or XY should have sex set to 0 in the PED file.

<div class="alert alert-block alert-success">Input cutoffs for determining mosaic sex assignments. These cutoffs update the computed sex to account for mosaic loss of allosomes. Males with mosaic loss of chromosome Y between the upper and lower chrY cutoffs will be set to MALE. Females with mosaic loss of chromosome X will be set to MOSAIC and SV calls will not be made on allosomes. Please adjust the cutoffs as needed based on the plot of sex chromosome ploidy below.</div>

In [None]:
LOWER_CUTOFF_CHRX = 1.2
UPPER_CUTOFF_CHRX = 1.7
LOWER_CUTOFF_CHRY = 0.1
UPPER_CUTOFF_CHRY = 0.8

validate_numeric_input(LOWER_CUTOFF_CHRX, UPPER_CUTOFF_CHRX, LOWER_CUTOFF_CHRY, UPPER_CUTOFF_CHRY)

In [None]:
# Create dictionary of sex assignments
updated_sex = compute_sex_assignments(samples_qc_table, 
                                                   lower_cutoff_chrX=LOWER_CUTOFF_CHRX,
                                                   upper_cutoff_chrX=UPPER_CUTOFF_CHRX, 
                                                   lower_cutoff_chrY=LOWER_CUTOFF_CHRY,
                                                   upper_cutoff_chrY=UPPER_CUTOFF_CHRY)

# Calculate sex counts and frequencies
sex_ped_counts = pd.DataFrame.from_dict({'ID': list(updated_sex.keys()), 'SEX': list(updated_sex.values())})['SEX'].value_counts()
sex_ped_counts_normalized = pd.DataFrame.from_dict({'ID': list(updated_sex.keys()), 'SEX': list(updated_sex.values())})['SEX'].value_counts(normalize=True)

# Display sex counts and frequencies
print(f"Count of each unique value in column: {sex_ped_counts}\n")
print(f"Proportion of each unique value in column: {sex_ped_counts_normalized}\n")

plot_sex_chromosome_ploidy(samples_qc_table, updated_sex, display=True)

<div class="alert alert-block alert-success"><b>If available</b>: If you have a PED file for your cohort, provide the Google Cloud Storage path to it via <tt>REFERENCE_ASSIGNMENTS_PED</tt>.</div>

In [None]:
REFERENCE_ASSIGNMENTS_PED = None

if REFERENCE_ASSIGNMENTS_PED:
    validate_ped(REFERENCE_ASSIGNMENTS_PED, set(samples_qc_table['sample_id']))
else:  
    print("Reference PED file is not provided - skipping this step.")

In [None]:
reference_ped = pd.DataFrame()

if REFERENCE_ASSIGNMENTS_PED:
    reference_ped = process_reference_ped(REFERENCE_ASSIGNMENTS_PED, samples_qc_table)
else:
    print("Reference PED file is not provided - skipping this step.")
    
reference_ped

In [None]:
# Create and display differences files using reference and computed sex assignments
differences_ped = pd.DataFrame()

if REFERENCE_ASSIGNMENTS_PED:
    file_path = generate_file_path(TLD_PATH, 'sex_analysis', "differing_sex_assignments.tsv")
    differences_ped = create_ped_differences(reference_ped, updated_sex, file_path)
else:
    print("Reference PED file is not provided - skipping this step.")
    
differences_ped

In [None]:
# Save and display generated PED file, and update 'cohort_ped' workspace variable to this
file_path = generate_file_path(TLD_PATH, 'sex_analysis', 'sample_qc.ped')
create_ped(updated_sex, reference_ped, file_path)

created_ped = pd.read_table(file_path, names=['Sample_ID', 'Family_ID', 'Paternal_ID', 
                                               'Maternal_ID', 'Sex', 'Phenotype'])
created_ped

# Sample Filtering
This section filters samples based on the results of the series of QC steps completed above, creating metadata files for both passing and failed samples.

**Note**: If you ever wish to test see the samples currently being filtered at any point in the notebook's execution, simply run all cells in this section.

In [None]:
# Print summary filtering statistics
inv_samples_to_remove = invert_removed_samples(samples_to_remove)
sample_ids_to_remove = list(inv_samples_to_remove.keys())
sample_ids = samples_qc_table['sample_id'].tolist()

print(f'Total samples: {len(sample_ids)}')
print(f'Passing samples: {len(sample_ids) - len(sample_ids_to_remove)}')
print(f'Failing samples: {len(sample_ids_to_remove)}')

In [None]:
# Save and display table of filtered samples
filter_df = pd.DataFrame({
    'sample_id': list(inv_samples_to_remove.keys()),
    'filters_applied': [','.join(filters) for filters in inv_samples_to_remove.values()]
})

file_path = generate_file_path(TLD_PATH, 'filtering', 'filtered_samples.tsv')
save_df(WS_BUCKET, file_path, filter_df)

filter_df = pd.read_table(file_path, sep='\t')
filter_df

In [None]:
# Save and display upset plot of filtered samples
create_upset_plot(filter_df, len(inv_samples_to_remove))

In [None]:
# Save and display metadata table of passing samples
file_path = generate_file_path(TLD_PATH, 'filtering', 'passing_samples_metadata.tsv')
filter_and_save_metadata(samples_qc_table, inv_samples_to_remove, file_path)

tsv = pd.read_csv(file_path, sep='\t')
tsv