In [None]:
import os
import glob
import re
import numpy as np
import pandas as pd
import string
import math

from cell_cycle_gating import manual_gating as mg
from cell_cycle_gating import dead_cell_filter_ldrint as dcf_int

import patchworklib as pw
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
from matplotlib.gridspec import GridSpec

import time
import multiprocessing

import plotnine
from plotnine import *
import seaborn as sns

from scipy.signal import find_peaks
from cell_cycle_gating import findpeaks as fp
from pomegranate.gmm import GeneralMixtureModel
from pomegranate.distributions import *
import torch

In [None]:
## Get the names of wells, valid wells (excluding outer two rows/columns), etc.
## input:
##    var: either "all_wells", "valid_wells", "all_well_rows", or "all_well_cols"
## output: a list with the corresponding info, e.g. ["A01", "A02", ... , "P24"] for all_wells
def get_well_names(var):
    all_well_rows = string.ascii_uppercase[0:16]
    all_well_cols = [str(num).zfill(2) for num in range(1,25)]
    valid_well_rows = all_well_rows[2:14]
    valid_well_cols = all_well_cols[2:22]
    ### get all wells
    all_wells = [row + col for row in all_well_rows for col in all_well_cols]
    all_wells.sort()
    ### get valid wells
    valid_wells = [row + col for row in valid_well_rows for col in valid_well_cols]
    valid_wells.sort()
    out = {
        'all_wells':all_wells,
        'valid_wells':valid_wells,
        'all_well_rows':all_well_rows,
        'all_well_cols':all_well_cols,
        'valid_well_rows':valid_well_rows,
        'valid_well_cols':valid_well_cols
    }
    return(out[var])

## Define a dictionary with the local folder name for experiments on each date
def define_folder_dict(name='folder_dict'):
    folder_dict = {
            '2020-11-17':'rep3',
            '2021-02-19':'rep4',
            '2021-02-26':'rep5',
            '2021-03-02':'rep6',
            '2021-04-06':'rep7',
            '2021-04-23':'rep8',
            '2021-05-18':'rep9',
            '2021-05-21':'rep10',
            '2021-06-11':'rep11',
            '2021-07-27':'2_rep1/210727-combo-rep1',
            '2021-07-30':'2_rep2/210730_combo_rep2',
            '2021-08-06':'2_rep3/210806_combo_rep3',
            '2021-10-05':'redo_rep1_and_2',
            '2021-10-15':'redo_rep1_and_2/redo_rep2',
            '2021-10-29':'redo_rep3'
        }
    globals()[name] = folder_dict
    return(None)

def get_base_dir():
    ### note: using linux/unix folder conventions -- would need to re-write for Windows
    base_dir = "/mnt/y/lsp-analysis/LINCS-combinations"
    if os.path.exists(base_dir):
        return(base_dir)
    elif os.path.exists("/Volumes/hits/lsp-analysis/LINCS-combinations"):
        return("/Volumes/hits/lsp-analysis/LINCS-combinations")
    else:
        raise Exception("Base directory not found -- need to mount research.files and supply its path, e.g. '/mnt/y/lsp-analysis/LINCS-combinations'")

def get_plates_to_regate():
    base_dir = get_base_dir()
    ### time_zero plates
    time_zero_file = os.path.join(base_dir, "re_gating", "plates_to_regate", "time_zero_regate_plates.csv")
    well_file = os.path.join(base_dir, "re_gating", "plates_to_regate", "time_zero_regate_wells.csv")
    df_plates1 = pd.read_csv(time_zero_file)
    df_wells1 = pd.read_csv(well_file)
    df_wells1['time'] = "time_zero"
    df_plates1['time'] = "time_zero"
    ### end-time control plates
    ctrl_end_file = os.path.join(base_dir, "re_gating", "plates_to_regate", "ctrl_end_time_regate_plates.csv")
    well_file = os.path.join(base_dir, "re_gating", "plates_to_regate", "ctrl_end_time_regate_wells.csv")
    df_plates2 = pd.read_csv(ctrl_end_file)
    df_wells2 = pd.read_csv(well_file)
    df_wells2['time'] = "end_time_control"
    df_plates2['time'] = "end_time_control"
    ### combine data frames
    df_plates = pd.concat([df_plates1, df_plates2])
    df_wells = pd.concat([df_wells1, df_wells2])
    return(df_plates, df_wells)

## Get a list of plate barcodes for a given date
## input:
##    date: e.g. '2021-10-15'
## output: list of barcodes e.g. '211015_combo_173'
def get_barcodes(date):
    date_formatted = date_format_switch(date)
    main_dir = get_data_dir(date = date)
    dirs = [ x for x in os.listdir(main_dir) if os.path.isdir( os.path.join(main_dir, x) )]
    ### match date at the start of the sub-directory
    dirs_barcodes = [ x for x in dirs if bool(re.match(date_formatted+"_combo", x)) ]
    return( dirs_barcodes )

## Switch the format of a date from YYYY-MM-DD to YYMMDD
## input:
##    date: e.g. '2021-02-19'
## output: e.g. '210219'
def date_format_switch(date):
    new_str = date[2:4] + date[5:7] + date[8:10]
    return(new_str)

## Get the date in YYYY-MM-DD format from a plate barcode
## input:
##    barcode: '210406_combo_71'
## output: e.g. '2021-04-06'
def date_from_barcode(barcode):
    date = '20' + barcode[0:2] + '-' + barcode[2:4] + '-' + barcode[4:6]
    return(date)

## Get the well-level data directory for a given date or barcode
## input:
##    barcode: a plate barcode, e.g. '210406_combo_71'
##    date: a date in YYYY-MM-DD format, e.g. '2021-04-06'
##    base_dir: the full path of the data folder, e.g. "/mnt/y/lsp-analysis/LINCS-combinations/"
## output:
##    returns the directory of the well-level data for a barcode
##    if a date is given and no barcode, returns the directory of all data for the date
##    if no date or barcode is given, returns the base directory of all data
def get_data_dir(barcode=None, date=None, base_dir = "/mnt/y/lsp-analysis/LINCS-combinations/"):
    ### note: using unix folder conventions -- would need to re-write for Windows
    ### set for osx
    if not os.path.exists(base_dir):
        base_dir = "/Volumes/hits/lsp-analysis/LINCS-combinations/"
    if barcode is None and date is None:
        return(base_dir)
    if date is None:
        date = date_from_barcode(barcode)
    if barcode is None:
        plate_dir = ''
    else:
        plate_dir = barcode
    #folder_dict = define_folder_dict()
    if not 'folder_dict' in globals(): define_folder_dict('folder_dict')
    local_dir = folder_dict[date]
    full_dir = os.path.join(base_dir, local_dir, plate_dir)
    return(full_dir)

## Get the filename for well-level intensities for a given barcode and well
## input:
##    barcode: a plate barcode, e.g. '210406_combo_71'
##    well: a well of interest, e.g. 'D06'
## output:
##    full path/filename of the well-level data
def get_well_file(barcode, well):
    date = date_from_barcode(barcode)
    data_dir = get_data_dir(barcode)
    ### example file style
    f1 = barcode+".result."+well+"[test].csv"
    #f2 = barcode+".result."+well+"[test].csv"
    files = os.listdir(data_dir)
    if f1 in files:
        well_file = os.path.join(data_dir, f1)
    else:
        print("well csv not found!")
    return(well_file)

## Read the well-level data for a given well and barcode
## input:
##    barcode: a plate barcode, e.g. '210406_combo_71'
##    well: a well of interest, e.g. 'D06'
## output:
##    a pandas dataframe of dye intensities for individual cells
def read_well_data(barcode, well):
    ff = get_well_file(barcode, well)
    df = pd.read_csv(ff)
    return(df)

## Re-name columns of well-level dataframe for LDR, DNA, EDU, etc.
## input: original data frame read from csv
## output: data frame with re-names columns
def rename_df_columns(df, silent=True, hoechst_as_dna=False):
    col_dict = {}
    ### check for well name
    if 'Well Name' in df.columns:
        if not silent: print("'Well Name' column found -- re-naming as 'well'")
        col_dict['Well Name'] = 'well'
    else:
        print(df.columns)
        if not silent: print('Well Name column not found')
        return(df)
    ### check for LDRint
    if 'ldrint' in df.columns:
        if not silent: print("'ldrint' column found -- re-naming as 'ldr'")
        col_dict['ldrint'] = 'ldr'
    else:
        print(df.columns)
        if not silent: print('ldrint column not found')
        return(df)
    ### check for DNAcontent/Hoechst
    dna_col1 = 'Cell: DNAcontent (DD-bckgrnd)'
    dna_col2 = 'Cell: DNAcontent (DDD-bckgrnd)'
    hoechst1 = 'Cell: HoechstINT (DD-bckgrnd)'
    hoechst2 = 'Cell: HoechstINT (DDD-bckgrnd)'
    check_dna1 = dna_col1 in df.columns
    check_dna2 = dna_col2 in df.columns
    check_hoechst1 = hoechst1 in df.columns
    check_hoechst2 = hoechst2 in df.columns
    if not (check_dna1 or check_dna2 or check_hoechst1 or check_hoechst2):
        if not silent: print(df.columns)
        if not silent: print('DNA column not found')
    else:
        ### if hoechst_as_dna, use the HoechstINT column as DNA if available
        if hoechst_as_dna:
            if check_hoechst1 and check_hoechst2:
                print('Warning: Two HoechstINT columns -- using ' + "'"+hoechst2+"'")
                dna_col = hoechst2
            elif check_hoechst1: dna_col = hoechst1
            elif check_hoechst2: dna_col = hoechst2
            elif check_dna1 and check_dna2:
                print('Warning: Two DNAcontent columns -- using ' + "'"+dna_col2+"'")
                dna_col = dna_col2
            elif check_dna1: dna_col = dna_col1
            elif check_dna2: dna_col = dna_col2
        ### otherwise use DNAConent column as DNA
        else:
            if check_dna1 and check_dna2:
                print('Warning: Two DNAcontent columns -- using ' + "'"+dna_col2+"'")
                dna_col = dna_col2
            elif check_dna1: dna_col = dna_col1
            elif check_dna2: dna_col = dna_col2
            elif check_hoechst1 and check_hoechst2:
                print('Warning: Two HoechstINT columns -- using ' + "'"+hoechst2+"'")
                dna_col = hoechst2
            elif check_hoechst1: dna_col = hoechst1
            elif check_hoechst2: dna_col = hoechst2
        if not silent: print("'"+dna_col+"'"+" column found -- re-naming as 'dna'")
        col_dict[dna_col] = 'dna'
    ### check for Edu (raw)
    if 'Cell: EdUrawINT (DDD-bckgrnd)' in df.columns:
        if not silent: print("'Cell: EdUrawINT (DDD-bckgrnd)' column found -- re-naming as 'edu_raw'")
        col_dict['Cell: EdUrawINT (DDD-bckgrnd)'] = 'edu_raw'
    elif 'Cell: EdUrawINT (DD-bckgrnd)' in df.columns:
        if not silent: print("'Cell: EdUrawINT (DD-bckgrnd)' column found -- re-naming as 'edu_raw'")
        col_dict['Cell: EdUrawINT (DD-bckgrnd)'] = 'edu_raw'
    ### check for Edu (background)
    if 'Cell: EdUbackground (DDD-bckgrnd)' in df.columns:
        if not silent: print("'Cell: EdUbackground (DDD-bckgrnd)' column found -- re-naming as 'edu_bg'")
        col_dict['Cell: EdUbackground (DDD-bckgrnd)'] = 'edu_bg'
    elif 'Cell: EdUbackground (DD-bckgrnd)' in df.columns:
        if not silent: print("'Cell: EdUbackground (DD-bckgrnd)' column found -- re-naming as 'edu_bg'")
        col_dict['Cell: EdUbackground (DD-bckgrnd)'] = 'edu_bg'
    ### re-name data-frame columns
    df = df.rename(columns=col_dict)
    if 'edu_raw' in df.columns and 'edu_bg' in df.columns:
        df['edu'] = df.edu_raw - df.edu_bg
    return(df)

def read_and_rename_well_data(barcode, well, silent = False, hoechst_as_dna=False):
    df = read_well_data(barcode, well)
    df = rename_df_columns(df, silent = silent, hoechst_as_dna=hoechst_as_dna)
    return(df)


## Get all wells with data for a given plate
## input:
##    barcode: a plate barcode, e.g. '210406_combo_71'
## output:
##    A list of wells, e.g. ['C03', 'C04', ... , 'N22']
def get_all_wells(barcode):
    data_dir = get_data_dir(barcode)
    wells_with_data = []
    all_wells = get_well_names("all_wells")
    for well in all_wells:
        ### example file style
        f1 = barcode+".result."+well+"[test].csv"
        f1_full = os.path.join(data_dir, f1)
        check = os.path.exists( f1_full )
        if check:
            wells_with_data.append(well)
    wells_with_data.sort()
    return(wells_with_data)

### maybe not necessary for LDR intensity data?
#def read_all_wells(barcode):
#    wells = get_all_wells(barcode)
#    df_list = [read_well_data(barcode, well) for well in wells]
#    return(df_list)

def get_ldr_cutoff(barcode, well, peak_loc = 1.2, silent=False, hoechst_as_dna=False):
    #df = read_well_data(barcode, well)
    #df = rename_df_columns(df)
    df = read_and_rename_well_data(barcode, well, silent, hoechst_as_dna=hoechst_as_dna)
    ldr_gates, ldr_lims = dcf_int.get_ldrgates(ldrint = df['ldr'], peak_loc=peak_loc) ## 1.2 is default
    return(ldr_gates[1])

def get_ldr_cutoff_many(barcode, wells, peak_loc = 1.2, silent=True, hoechst_as_dna=False):
    ldrs = [get_ldr_cutoff(barcode, well, peak_loc = peak_loc, silent=silent, hoechst_as_dna=hoechst_as_dna) for well in wells]
    return(ldrs)

### note: file location hard-coded, only for wsl on my desktop right now
def load_well_metadata(name = 'meta', folder=None, file='single_timepoint_cleaned_from_raw_2023-06-08.parquet'):
    ### read parquet file w/ all metadata
    if folder is None:
        folder = "/mnt/c/Users/NC168/git/LINCS_combos/data/cleaned/"
    full_file = os.path.join(folder, file)
    df = pd.read_parquet(full_file)
    globals()[name] = df
    
def get_wells(barcode, cell_line):
    ### get only the wells for a certain cell line on a given barcode
    if not 'meta' in globals(): load_well_metadata()
    #query = " cell_line == 'SUM1315' & barcode == '201117_combo_33' "
    query = "cell_line == '"+cell_line+"' & barcode == '"+barcode+"'"
    meta_sub = meta.query(query)
    wells = list(meta_sub.well)
    wells.sort()
    return(wells)

def get_cell_lines_on_plate(barcode):
    if not 'meta' in globals(): load_well_metadata()
    query = "barcode == '"+barcode+"'"
    meta_sub = meta.query(query)
    cell_lines = meta_sub.cell_line.unique()
    return(cell_lines)

def get_ldr_cutoffs_plate(barcode, peak_loc = 1.2, silent = True):
    cell_lines = get_cell_lines_on_plate(barcode)
    df = get_ldr_cutoffs_cell_line_and_barcode(barcode, cell_lines, peak_loc=peak_loc, silent=silent)
    return(df)

def get_ldr_cutoffs_cell_line_and_barcode(barcode, cell_lines, peak_loc=1.2, silent = True):
    df_list = []
    for cell_line in cell_lines:
        if not silent: print(cell_line)
        wells = get_wells(barcode, cell_line)
        ldrs = get_ldr_cutoff_many(barcode, wells, peak_loc=peak_loc)
        d = {'well':wells, 'ldr_cutoff': ldrs, 'barcode':barcode, 'cell_line': cell_line}
        df_tmp = pd.DataFrame(data=d)
        df_list.append(df_tmp)
    df = pd.concat(df_list)
    return(df)

## x_lims: tuple of x limits for the plot
## y_lims: tuple of y limits for the plot
def plot_ldr(df, peak_loc = 1.2, scatter = True, silent = True, show_fig = True, 
             fig=None, outer=None, i=None, title = "", x_lims=None, y_lims=None, add_ldr_line = None):
    ldr_gates, ldr_lims = dcf_int.get_ldrgates(ldrint = df['ldr'], peak_loc=peak_loc)
    df = df.copy()
    ldr_cutoff = ldr_gates[1]
    #df = df.query("ldr > 0")
    #df['ldr'] = [x if x>0 else 10**(-10) for x in df.ldr]
    #### set negative ldr and dna values to the minimum positive values (just for plotting)
    df_pos1 = df.query("ldr > 0")
    min_ldr = np.min(df_pos1.ldr)
    df_pos2 = df.query("dna > 0")
    min_dna = np.min(df_pos2.dna)
    df['ldr'] = [x if x>0 else min_ldr for x in df.ldr]
    df['dna'] = [x if x>0 else min_dna for x in df.dna]
    if scatter:
        fig = mg.plot_ldr_dna_scatter(np.log10(df.dna), np.log10(df.ldr), ldr_cutoff, 
                                            dna_gates=None, plot_ldr_log10=True, is_ldrint=True,
                                           show_fig=show_fig, fig = fig, outer=outer, i=i,
                                     title=title, x_lims=x_lims, y_lims=y_lims, add_ldr_line = add_ldr_line)
    else:
        fig = mg.ldr_gating(np.log10(df.ldr), ldr_cutoff, nbins = 20)
    return(fig)

def plot_ldr_well(barcode, well, peak_loc = 1.2, scatter = True, silent = True, 
                  show_fig = True, fig=None, outer=None, i=None, title="", x_lims=None, 
                  y_lims=None, hoechst_as_dna=False, add_ldr_line=None):
    df = read_and_rename_well_data(barcode, well, silent, hoechst_as_dna=hoechst_as_dna)
    #df = df.query("ldr > 0")
    #df['ldr'] = [x if x>0 else 10**(-10) for x in df.ldr]
    #df_pos1 = df.query("ldr > 0")
    #min_ldr = np.min(df_pos1.ldr)
    #df_pos2 = df.query("dna > 0")
    #min_dna = np.min(df_pos2.dna)
    #df['ldr'] = [x if x>0 else min_ldr for x in df.ldr]
    #df['dna'] = [x if x>0 else min_dna for x in df.dna]
    fig = plot_ldr(df, peak_loc = peak_loc, scatter = scatter, silent = silent, 
                   show_fig = show_fig, fig = fig, outer=outer, i=i, title = title,x_lims=x_lims,
                   y_lims=y_lims, add_ldr_line = add_ldr_line)
    return(fig)

def plot_ldr_many(barcode, wells, peak_loc = 1.2, scatter = False, silent = True, hoechst_as_dna=False):
    for well in wells:
        df = read_and_rename_well_data(barcode, well, silent, hoechst_as_dna=hoechst_as_dna)
        plot_ldr(df, peak_loc = peak_loc, scatter = scatter)

def plot_ldr_pdf(barcode, wells, peak_loc = 1.2, figname = "test_ldr.pdf", scatter = True, 
                 silent = True, show_fig = True, hoechst_as_dna=False):
    pdf_pages = PdfPages(figname)
    fig_list = []
    for i in range(len(wells)):
        well = wells[i]
        well_meta = get_well_meta(barcode, well)
        cell_line = list(well_meta.cell_line)[0]
        trt1 = list(well_meta.agent1)[0]
        trt2 = list(well_meta.agent2)[0]
        conc1 = list(well_meta.concentration1_chr)[0]
        conc2 = list(well_meta.concentration2_chr)[0]
        df = read_and_rename_well_data(barcode, well, silent, hoechst_as_dna=hoechst_as_dna)
        ldr_gates, ldr_lims = dcf_int.get_ldrgates(ldrint = df['ldr'], peak_loc=peak_loc)
        ldr_cutoff = ldr_gates[1]
        df = df.query("ldr > 0")
        #df['ldr'] = [x if x>0 else 10**(-10) for x in df.ldr]
        if scatter:
            fig = mg.plot_ldr_dna_scatter(np.log10(df.dna), np.log10(df.ldr), ldr_cutoff, dna_gates=None, 
                                          plot_ldr_log10=True, is_ldrint=True, show_fig = show_fig)
        else:
            fig = mg.ldr_gating(np.log10(df.ldr), ldr_cutoff, nbins = 20)
        fig_title = str(trt1) + ": "+ str(conc1) + " uM, " + str(trt2) + ": " + str(conc2) + " uM"
        #print(fig_title)
        fig.suptitle(well + "\n" + fig_title, fontsize=12)
        fig_list.append(fig)
        plt.close()
        pdf_pages.savefig(fig)
    pdf_pages.close()
    return(fig_list)

def test_regate(barcode, cell_line, peak_loc = 1.2, figname = "test_figure", scatter = True, silent = True, test = True, 
                show_fig = False, hoechst_as_dna=False):
    path1 = os.path.join('temp_regating', 'csv')
    path2 = os.path.join('temp_regating', 'pdf')
    if not os.path.exists(path1):
        os.makedirs(path1)
    if not os.path.exists(path2):
        os.makedirs(path2)
    wells = get_wells(barcode, cell_line)
    df_list = []
    csv_file = os.path.join(path1, figname+'.csv')
    pdf_file = os.path.join(path2, figname+'.pdf')
    plot_list = plot_ldr_pdf(barcode, wells, peak_loc, figname=pdf_file, scatter=scatter, silent=silent, 
                             show_fig=show_fig, hoechst_as_dna=hoechst_as_dna)
    print('figures written to: ' + pdf_file)
    for well in wells:
        df = read_and_rename_well_data(barcode, well, silent, hoechst_as_dna=hoechst_as_dna)
        df_tmp = dcf_int.get_counts_df(df=df, barcode=barcode, well=well, peak_loc = peak_loc)
        df_list.append(df_tmp)
    df_out = pd.concat(df_list)
    df_out.to_csv(csv_file)
    return(df_out, plot_list)

def plot_wells_ldr(barcode, cell_line, peak_loc=1.2, scatter = True, silent=True,
                   figname = None, output_dir="default_gating", hoechst_as_dna=False):
    if figname is None: figname=barcode+'_'+cell_line+'_'+'peak_loc_'+str(peak_loc)
    if not 'meta' in globals(): load_well_metadata()
    wells = get_wells(barcode, cell_line)
    df_list = []
    df_full_list = []
    for well in wells:
        df = read_and_rename_well_data(barcode, well, silent, hoechst_as_dna=hoechst_as_dna)
        #df = df.query("ldr > 0")
        df_tmp = dcf_int.get_counts_df(df=df, barcode=barcode, well=well, peak_loc = peak_loc)
        df_list.append(df_tmp)
        df_full_list.append(df)
    df2 = pd.concat(df_list)
    df_full = pd.concat(df_full_list)
    df_pos1 = df_full.query("ldr > 0")
    y_log = np.log10(df_pos1.ldr)
    y_lims = (min(y_log), max(y_log))
    print(y_lims)
    df_pos2 = df_full.query("dna > 0")
    x_log = np.log10(df_pos2.dna)
    x_lims = (min(x_log)-0.2, max(x_log)+0.2)
    ### save counts data frame to csv
    csv1 = "all_wells_" + figname + ".csv"
    csv1_full = os.path.join(output_dir, csv1)
    df2.to_csv(csv1_full)
    ### plot wells that changed
    fig_list = []
    pdf = "all_wells_scatter_" + figname + ".pdf"
    pdf_full = os.path.join(output_dir, pdf)
    pdf_pages = PdfPages(pdf_full)
    nb_plots = len(df2.well)
    plots_per_page = 6
    for i in range(nb_plots):
        #print(i)
        if i % plots_per_page == 0:
            fig = plt.figure(figsize=(8.5, 11))
            outer = GridSpec(3, 2, wspace=0.2, hspace=0.5)
        well = wells[i]
        #print(well)
        df = read_and_rename_well_data(barcode, well, silent=True, hoechst_as_dna=hoechst_as_dna)
        ### get well metadata
        well_meta = get_well_meta(barcode, well)
        cell_line = list(well_meta.cell_line)[0]
        trt1 = list(well_meta.agent1)[0]
        trt2 = list(well_meta.agent2)[0]
        conc1 = list(well_meta.concentration1_chr)[0]
        conc2 = list(well_meta.concentration2_chr)[0]
        ### add title to figures
        fig_title = str(trt1) + ": "+ str(conc1) + " uM, " + str(trt2) + ": " + str(conc2) + " uM"
        fig_title = well+", peak_loc = "+str(peak_loc)+"\n"+fig_title
    
        i_page = i % plots_per_page
        ### make figures
        fig_tmp = plot_ldr_well(barcode, well, peak_loc = peak_loc, scatter = scatter, 
                                 silent = silent, show_fig = False, fig = fig, outer = outer, i = i_page,
                               title = fig_title, x_lims=x_lims, y_lims=y_lims, hoechst_as_dna=hoechst_as_dna)
        #plt.close()
        fig_list.append(fig_tmp)
        if (i + 1) % plots_per_page == 0 or (i + 1) == nb_plots:
               plt.tight_layout()
               pdf_pages.savefig()
               plt.close('all')
    pdf_pages.close()
    return([df2, fig_list])

def get_well_meta(barcode, well):
    query = "barcode == '"+ barcode+ "' & well == '"+ well + "'"
    df_sub = meta.query(query)
    return(df_sub)

def plot_ldr_cutoff_change(barcode, cell_line, peak_loc, scatter = True, silent=True, 
                           default_peak_loc = 1.2, figname = None,
                          #output_dir="/mnt/y/lsp-analysis/LINCS-combinations/re_gating/new_gating"):
                           output_dir="temp_regating",
                          hoechst_as_dna=False):
    if figname is None: figname=barcode+'_'+cell_line+'_'+'peak_loc_'+str(peak_loc)
    if not 'meta' in globals(): load_well_metadata()
    wells = get_wells(barcode, cell_line)
    df_list_orig = []
    df_list_new = []
    for well in wells:
        df = read_and_rename_well_data(barcode, well, silent, hoechst_as_dna=hoechst_as_dna)
        #df = df.query("ldr > 0")
        df_tmp_new = dcf_int.get_counts_df(df=df, barcode=barcode, well=well, peak_loc = peak_loc)
        df_list_new.append(df_tmp_new)
        df_tmp_orig = dcf_int.get_counts_df(df=df, barcode=barcode, well=well, peak_loc = default_peak_loc)
        df_list_orig.append(df_tmp_orig)
    df_orig = pd.concat(df_list_orig)
    df_new = pd.concat(df_list_new)
    ### add suffixes to measured columns in each data frame
    df_orig2 = df_orig.rename(columns={c: c+'_orig' for c in df_orig.columns if c not in ['barcode', 'well']})
    df_new2 = df_new.rename(columns={c: c+'_new' for c in df_new.columns if c not in ['barcode', 'well']})
    ### join the data frames
    df2 = df_orig2.merge(df_new2, on = ['barcode', 'well'], how = 'inner')
    meta_select = meta[['barcode', 'cell_line', 'well', 'agent1', 'concentration1_chr', 'agent2', 'concentration2_chr', 'timepoint']]
    df2 = df2.merge(meta_select, on = ['barcode', 'well'], how = 'left')
    cols = ['barcode', 'well', 'cell_count__dead_orig', 'cell_count__dead_new', 'cell_count_orig', 'cell_count_new', 
                  'ldr_cutoff_orig', 'ldr_cutoff_new', 'cell_line','agent1', 'concentration1_chr', 'agent2', 
                  'concentration2_chr', 'timepoint']
    last_cols = [ x for x in df2.columns if x not in cols ]
    cols.extend(last_cols)
    df2 = df2[cols]
    df2['label'] = df2.apply(lambda row: row.well if row.ldr_cutoff_orig != row.ldr_cutoff_new else "", axis=1)
    df2_sub = df2.query("ldr_cutoff_orig != ldr_cutoff_new")
    df2_sub.reset_index(drop=True, inplace=True)
    ### plot new vs. old cutoffs, live/dead counts
    gg1 = pw.load_ggplot(ggplot(df2, aes(x = "ldr_cutoff_orig", y = "ldr_cutoff_new")) +\
        geom_point(alpha = 0.5) +\
        geom_label(aes(label="label"), alpha = 0.5, nudge_x = 0.05, nudge_y = 0.05), figsize=(3,3))
    gg2 = pw.load_ggplot(ggplot(df2, aes(x = "cell_count__dead_orig", y = "cell_count__dead_new")) +\
        geom_point(alpha = 0.5) +\
        geom_label(aes(label="label"), alpha = 0.5, nudge_x = 20, nudge_y = 5), figsize=(3,3))
    gg3 = pw.load_ggplot(ggplot(df2, aes(x = "cell_count_orig", y = "cell_count_new")) +\
        geom_point(alpha = 0.5) +\
        geom_label(aes(label="label"), alpha = 0.5, nudge_x = 100, nudge_y = 100), figsize=(3,3))
    
    gg = (gg1|gg2|gg3)
    pdf1 = figname + str("_summary.pdf")
    pdf1_full = os.path.join(output_dir, pdf1)
    gg.savefig(pdf1_full)
                               
    ### plot wells that changed
    fig_list_orig = []
    fig_list_new = []
    pdf2 = figname + str("_wells_changed.pdf")
    pdf2_full = os.path.join(output_dir, pdf2)
    pdf_pages = PdfPages(pdf2_full)
    nb_rows = len(df2_sub.well)
    rows_per_page = 3
    for i in range(nb_rows):
        if i % rows_per_page == 0:
            fig = plt.figure(figsize=(8.5, 11))
            outer = GridSpec(3, 2, wspace=0.2, hspace=0.5)
        well = df2_sub.well[i]
        df = read_and_rename_well_data(barcode, well, silent=True, hoechst_as_dna=hoechst_as_dna)
    
        ### get well metadata
        well_meta = get_well_meta(barcode, well)
        cell_line = list(well_meta.cell_line)[0]
        trt1 = list(well_meta.agent1)[0]
        trt2 = list(well_meta.agent2)[0]
        conc1 = list(well_meta.concentration1_chr)[0]
        conc2 = list(well_meta.concentration2_chr)[0]
        ### add title to figures
        fig_title = str(trt1) + ": "+ str(conc1) + " uM, " + str(trt2) + ": " + str(conc2) + " uM"
        fig_title_orig = well+", peak_loc = "+str(default_peak_loc)+" (default)"+"\n"+fig_title
        fig_title_new = well+", peak_loc = "+str(peak_loc)+"\n"+fig_title
    
        i_page = i % rows_per_page
        ### make figures for new and old peak_loc values
        fig_orig = plot_ldr_well(barcode, well, peak_loc = default_peak_loc, scatter = scatter, 
                                 silent = silent, show_fig = False, fig = fig, outer = outer, i = 2*i_page,
                                title = fig_title_orig, hoechst_as_dna=hoechst_as_dna)
        #plt.close()
        fig_new = plot_ldr_well(barcode, well, peak_loc = peak_loc, scatter = scatter, 
                                 silent = silent, show_fig = False, fig = fig, outer = outer, i = 2*i_page+1,
                               title = fig_title_new, hoechst_as_dna=hoechst_as_dna)
        #plt.close()
        fig_list_orig.append(fig_orig)
        fig_list_new.append(fig_new)
        if (i + 1) % rows_per_page == 0 or (i + 1) == nb_rows:
               plt.tight_layout()
               pdf_pages.savefig()
               plt.close('all')
    pdf_pages.close()
    ### write data frames to csv files
    # write cell counts for all wells
    csv1 = figname + str("_all_wells.csv")
    csv1_full = os.path.join(output_dir, csv1)
    df2.to_csv(csv1_full)
    # write cell counts for only wells where counts changed
    csv2 = figname + str("_wells_changed.csv")
    csv2_full = os.path.join(output_dir, csv2)
    df2_sub.to_csv(csv2_full)
    
    return([df2, df2_sub, gg, fig_list_orig, fig_list_new])

### plot LDR cutoffs
def plot_flagged_wells_ldr(barcode, cell_line, well_df, figname = None, peak_loc = 1.2, output_dir="default_gating", write_pdf=True,
                          hoechst_as_dna=False):
    if figname is None: figname=barcode+'_'+cell_line+'_'+'peak_loc_'+str(peak_loc)
    query = "cell_line == '"+cell_line+"' & barcode == '"+barcode+"'"
    well_df = well_df.query(query)
    wells = get_wells(barcode, cell_line)
    df_list = []
    for well in wells:
        df = read_and_rename_well_data(barcode, well, silent=True,hoechst_as_dna=hoechst_as_dna)
        #df = df.query("ldr > 0")
        df_tmp = dcf_int.get_counts_df(df=df, barcode=barcode, well=well, peak_loc = peak_loc)
        df_list.append(df_tmp)
    df = pd.concat(df_list)
    #return(df)
    df['flagged'] = ["flagged" if x in list(well_df.well) else "not_flagged" for x in df.well]
    gg = ggplot(df, aes(x = 'flagged', y = 'ldr_cutoff')) + geom_boxplot() + geom_jitter()

    csv_file = "flagged_wells_" + figname + ".csv"
    csv_full = os.path.join(output_dir, csv_file)
    pdf = "boxplot_ldr_cutoff_" + figname + ".pdf"
    pdf_full = os.path.join(output_dir, pdf)
    
    if write_pdf: gg.save(pdf_full, format = "pdf", width = 2.5, height = 3)
    well_df.to_csv(csv_full)
    return(df, gg)
    
def plot_problem_plate(barcode, cell_line, peak_loc=1.2, df_wells=None, scatter=True, silent=True, output_dir="default_gating",
                      hoechst_as_dna=False):
    final_dir = os.path.join(output_dir, cell_line + "_" + barcode)
    if not os.path.exists(final_dir): os.makedirs(final_dir)
    ### plot ldr vs. dna scatterplots for all wells:
    plot_wells_ldr(barcode, cell_line, peak_loc=peak_loc, scatter = scatter, silent=silent,
                   figname = None, output_dir=final_dir, hoechst_as_dna=hoechst_as_dna)
    ### plot ldr cutoffs for flagged vs. unflagged wells
    if df_wells is not None:
        plot_flagged_wells_ldr(barcode, cell_line, df_wells, output_dir = final_dir, write_pdf=True, hoechst_as_dna=hoechst_as_dna)

def plot_all_problem_plates(peak_loc=1.2, scatter=True, silent=True, output_dir = "default_gating", hoechst_as_dna=False):
    df_plates, df_wells = get_plates_to_regate()
    n_plates = len(df_plates.barcode)
    print("plotting LDR for " + str(n_plates) + " cell lines/plates")
    for i in range(n_plates):
        barcode = list(df_plates.barcode)[i]
        cell_line = list(df_plates.cell_line)[i]
        print("Plate " + str(i) + ": "+ cell_line + " " + barcode)
        plot_problem_plate(barcode, cell_line, peak_loc=peak_loc, df_wells=df_wells, scatter=scatter,
                           silent=silent, output_dir=output_dir, hoechst_as_dna=hoechst_as_dna)
        

### Todo:
## 1) def plot_problem_plate(): 
        ## steps:
        ## 1) call plot_wells_ldr to create plots of all wells for a problem plate-- write to pdf
        ## 2) load problem well metadata, create box plot of LDR cutoffs for "flagged wells" vs. "unflagged wells" -- print to pdf
        ## 3) save both (plus a csv, already written from plot_wells_ldr function) to "default_gating/
                               
## 2) loop over all bad plates/cell lines and plot LDR gating with default options
    ## for each, look at box plots and well-level scatter plots and pick a cutoff in-between the "flagged" and "non-flagged" ldr cutoffs.
    ## put new cutoffs manually into a dict/dataframe

## 3) call plot_LDR_cutoff_change with the new LDR cutoffs to write new plots and data to files

## Define a dictionary with the local folder name for experiments on each date
def define_regating_df(name = 'regate_df'):
    #barcode = '211015_combo_176'
    #cell_line = 'SUM1315'
    data = [
        ### Time zero plates
        ## note: plate 62 gating is fine with default peak_loc=1.2 on all but one well -- not sure why it was bad before
        {'barcode': '210406_combo_62', 'cell_line': 'SUM1315', 'peak_loc': 1.2, 'hoechst_as_dna': False},
        {'barcode': '210423_combo_78', 'cell_line': 'SUM185PE', 'peak_loc': 2, 'hoechst_as_dna': False},
        ### End-time plates
        ### note: wells I06 and J07 -- almost 500 dead cells dead sub-g1, only ~50 LDR positive -- dna gating issue, not LDR gating?
        {'barcode': '210226_combo_51', 'cell_line': 'HCC1937', 'peak_loc': 1.2, 'hoechst_as_dna': True}, #using Hoechst column as dna fixes dead count
        ### note: only well I11 -- almost 500 dead cells dead sub-g1, only ~50 LDR positive -- dna gating issue, not LDR gating?
        ###  note: well I18 -- high subg1 dead cells
        {'barcode': '210226_combo_52', 'cell_line': 'HCC1937', 'peak_loc': 1.2, 'hoechst_as_dna': True}, #using Hoechst column as dna fixes dead count
        {'barcode': '210226_combo_53', 'cell_line': 'HCC1937', 'peak_loc': 1.2, 'hoechst_as_dna': True}, #using Hoechst column as dna fixes dead count
        {'barcode': '210226_combo_54', 'cell_line': 'HCC1937', 'peak_loc': 1.2, 'hoechst_as_dna': True}, #using Hoechst column as dna fixes dead count
        {'barcode': '210226_combo_55', 'cell_line': 'HCC1937', 'peak_loc': 1.2, 'hoechst_as_dna': True}, #using Hoechst column as dna fixes dead count
        {'barcode': '210226_combo_56', 'cell_line': 'HCC1937', 'peak_loc': 1.2, 'hoechst_as_dna': True}, #using Hoechst column as dna fixes dead count
        {'barcode': '210226_combo_57', 'cell_line': 'HCC1937', 'peak_loc': 1.2, 'hoechst_as_dna': True}, #using Hoechst column as dna fixes dead count
        {'barcode': '210302_combo_59', 'cell_line': 'HCC1937', 'peak_loc': 1.2, 'hoechst_as_dna': True}, #using Hoechst column as dna fixes dead count
        {'barcode': '210302_combo_60', 'cell_line': 'HCC1937', 'peak_loc': 1.2, 'hoechst_as_dna': True}, #using Hoechst column as dna fixes dead count
        {'barcode': '210302_combo_61', 'cell_line': 'HCC1937', 'peak_loc': 1.2, 'hoechst_as_dna': True}, #using Hoechst column as dna fixes dead count
        {'barcode': '210406_combo_69', 'cell_line': 'SUM1315', 'peak_loc': 1.2, 'hoechst_as_dna': False}, #regating w/ default values gives low dead count
        {'barcode': '210406_combo_70', 'cell_line': 'SUM1315', 'peak_loc': 1.2, 'hoechst_as_dna': False}, #regating w/ default values gives low dead count
        {'barcode': '210406_combo_71', 'cell_line': 'SUM1315', 'peak_loc': 1.2, 'hoechst_as_dna': False}, #regating w/ default values gives low dead count
        # plate 72 -- E11 is the only control well that looks bad after regating -- will be solved by trimmed mean
        ### notes: a few non-control wells look wrong -- E05, possibly E07, F05, F19, etc.
        ### notes: a wide range of LDR cutoffs -- from 2 to 4 -- good cutoff looks like around 3 to 3.25
        {'barcode': '210406_combo_72', 'cell_line': 'SUM1315', 'peak_loc': 1.2, 'hoechst_as_dna': False}, #regating w/ default values gives low dead count
        {'barcode': '210406_combo_73', 'cell_line': 'SUM1315', 'peak_loc': 1.2, 'hoechst_as_dna': False}, #regating w/ default values gives low dead count
        {'barcode': '210406_combo_74', 'cell_line': 'SUM1315', 'peak_loc': 1.2, 'hoechst_as_dna': False}, #regating w/ default values gives low dead count
        ### plate 75: E11 is the only control well that looks bad -- will be solved by trimmed mean
        {'barcode': '210406_combo_75', 'cell_line': 'SUM1315', 'peak_loc': 1.2, 'hoechst_as_dna': False}, #regating w/ default values gives low dead count
        {'barcode': '210406_combo_76', 'cell_line': 'SUM1315', 'peak_loc': 1.2, 'hoechst_as_dna': False}, #regating w/ default values gives low dead count
        {'barcode': '210406_combo_77', 'cell_line': 'SUM1315', 'peak_loc': 1.2, 'hoechst_as_dna': False}, #regating w/ default values gives low dead count
        {'barcode': '211005_combo_158', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211005_combo_160', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211005_combo_161', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211005_combo_162', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211005_combo_163', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211005_combo_164', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211005_combo_165', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211005_combo_166', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211015_combo_168', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211015_combo_169', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211015_combo_170', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211015_combo_171', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211015_combo_172', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211015_combo_173', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211015_combo_174', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211015_combo_175', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False},
        {'barcode': '211015_combo_176', 'cell_line': 'SUM1315', 'peak_loc': 2.75, 'hoechst_as_dna': False}
    ]
    df = pd.DataFrame(data)
    globals()[name] = df

def gate_well(barcode, well, peak_loc=1.2, silent=False, hoechst_as_dna=False):
    df = read_and_rename_well_data(barcode, well, silent, hoechst_as_dna=hoechst_as_dna)
    df_tmp = dcf_int.get_counts_df(df=df, barcode=barcode, well=well, peak_loc = peak_loc)
    return(df_tmp)

def regate_wells(silent=True):
    df_list2 = []
    for i in range(regate_df.shape[0]):
        print(i)
        barcode = regate_df['barcode'][i]
        cell_line = regate_df['cell_line'][i]
        peak_loc = regate_df['peak_loc'][i]
        hoechst_as_dna = regate_df['hoechst_as_dna'][i]
        if not 'meta' in globals(): load_well_metadata()
        wells = get_wells(barcode, cell_line)
        df_list = []
        for well in wells:
            #df = read_and_rename_well_data(barcode, well, silent, hoechst_as_dna=hoechst_as_dna)
            #df_tmp = dcf_int.get_counts_df(df=df, barcode=barcode, well=well, peak_loc = peak_loc)
            df_tmp = gate_well(barcode, well, peak_loc=peak_loc, silent=silent, hoechst_as_dna=hoechst_as_dna)
            df_list.append(df_tmp)
        df_out = pd.concat(df_list)
        df_list2.append(df_out)
    df_out2 = pd.concat(df_list2)
    return(df_out2)

def get_ldr_cutoffs_all(peak_loc = 1.2):
    if not 'folder_dict' in globals(): define_folder_dict('folder_dict')
    df_list_full = []
    for date in folder_dict.keys():
        print(date)
        plates = get_barcodes(date)
        df_list_date=[]
        for plate in plates:
            print(plate)
            df_tmp = get_ldr_cutoffs_plate(plate, peak_loc = peak_loc)
            df_list_date.append(df_tmp)
        df_date = pd.concat(df_list_date)
        df_list_full.append(df_date)
    df_full = pd.concat(df_list_full)
    return(df_full)

def get_ldr_cutoffs_fast(peak_loc=1.2):
    if not 'meta' in globals(): load_well_metadata()
    cutoffs = []
    for i in range(meta.shape[0]):
    #for i in range(50):
        if i % 1000 == 0: print(i)
        barcode = meta.barcode[i]
        well = meta.well[i]
        cutoff = get_ldr_cutoff(barcode, well, peak_loc = 1.2, silent=True)
        cutoffs.append(cutoff)
    return(cutoffs)

def get_ldr_cutoff_i(i, peak_loc=1.2):
    return(get_ldr_cutoff(meta.barcode[i], meta.well[i], peak_loc = 1.2, silent=True))

def get_ldr_cutoffs_parallel(peak_loc=1.2, nproc = 10, batch = 1000):
    if not 'meta' in globals(): load_well_metadata()
    cutoffs = []
    n_total = meta.shape[0]
    batches = np.ceil(meta.shape[0]/batch)
    for i in range(int(batches)):
        print(i)
        tic = time.time()
        start_batch = i*batch
        end_batch = min( (i+1)*batch, n_total)
        range_obj = range(start_batch, end_batch)
        pool = multiprocessing.Pool(nproc)
        cutoffs_batch = pool.map(get_ldr_cutoff_i, range_obj)
        #cutoffs_batch = pool.map(get_meta_i, range_obj)
        cutoffs.extend(cutoffs_batch)
        toc = time.time()
        print(str(toc-tic))
    return(cutoffs)

def get_counts_well(barcode, well, peak_loc=1.2, manual_ldr_cutoff=None, plot=True):
    df = read_and_rename_well_data(barcode, well)
    df_out = dcf_int.get_counts_df(df=df.copy(), barcode=barcode, well=well, 
                                   peak_loc = peak_loc, manual_ldr_cutoff=manual_ldr_cutoff)
    if plot: plot_ldr(df.copy(), peak_loc = peak_loc, add_ldr_line=manual_ldr_cutoff)
    return(df_out)

def kde_plot_wells(barcode, wells, title = "", add_legend=False, smoothing=1):
    fig, ax = plt.subplots()
    #ax.set_title(title, fontsize=12)
    for well in wells:
        df = read_and_rename_well_data(barcode, well, silent = True)
        ldrint = df['ldr'].copy()
        ldrint[ldrint < 0] = float('nan')
        logint = np.log10(ldrint)
        #x, y = sns.kdeplot(logint, ax=ax, color = "grey", alpha=0.5).get_lines()[0].get_data()
        if add_legend:
            meta_sub=get_well_meta(barcode=barcode, well=well)
            trt = list(meta_sub.agent1)[0] + " " + list(meta_sub.concentration1_chr)[0] + "; " + list(meta_sub.agent2)[0] + " " + list(meta_sub.concentration2_chr)[0]
            sns.kdeplot(logint, ax=ax, label = well + " " + trt, alpha=0.25, bw_adjust=smoothing).set_title(title)
            ax.legend()
        else:
            sns.kdeplot(logint, ax=ax, alpha=0.25).set_title(title)
        plt.close()
    return(fig)

def get_kde_plot_data_well(barcode, well,smoothing=1):
    fig, ax = plt.subplots()
    df = read_and_rename_well_data(barcode, well, silent = True)
    ldrint = df['ldr'].copy()
    ldrint[ldrint < 0] = float('nan')
    logint = np.log10(ldrint)
    x, y = sns.kdeplot(logint, ax=ax, color = "grey", alpha=0.5).get_lines()[0].get_data()
    plt.close()
    return(x,y)

def get_logldrint(barcode, well):
    df = read_and_rename_well_data(barcode, well, silent = True)
    ldrint = df['ldr'].copy()
    ldrint[ldrint < 0] = float('nan')
    logint = np.log10(ldrint)
    return(list(logint))

def kde_plot_cell_line(barcode, cell_line, n_wells=None, well_start=0, title = "", add_legend=False, smoothing=1):
    wells = get_wells(barcode, cell_line)
    print(wells)
    if n_wells is not None: wells = wells[well_start:(well_start+n_wells)]
    if title == "": title = cell_line + " " + barcode
    fig = kde_plot_wells(barcode, wells, title = title, add_legend=add_legend, smoothing=smoothing)
    return(fig)

def kde_plot_avg(barcode, cell_line, smoothing=1):
    wells = get_wells(barcode, cell_line)
    #if n_wells is not None: wells = wells[0:n_wells]
    logint_all = []
    #ax = axs[row,col]
    fig, ax = plt.subplots()
    for well in wells:
        logint = get_logldrint(barcode, well)
        logint_all.extend(logint)
    #return(logint_all)
    sns.kdeplot(logint_all, ax=ax, alpha=0.25, smoothing=smoothing).set_title(barcode + " " + cell_line)
    plt.close()
    return(fig)

def kde_plot_all_avg(cell_line, n_barcodes=None, smoothing=1):
    barcodes = get_all_plates_for_cell_line(cell_line)
    fig, ax = plt.subplots()
    if n_barcodes is not None: barcodes = barcodes[0:n_barcodes]
    print(len(barcodes))
    for i in range(len(barcodes)):
        #print(i)
        barcode = barcodes[i]
        wells = get_wells(barcode, cell_line)
        #if n_wells is not None: wells = wells[0:n_wells]
        logint_all = []
        for well in wells:
            logint = get_logldrint(barcode, well)
            logint_all.extend(logint)
        #return(logint_all)
        sns.kdeplot(logint_all, ax=ax, alpha=0.25, smoothing=smoothing).set_title(cell_line)
        plt.close()
    return(fig)

def kde_plot_plate(barcode, n_wells = None, output_dir="", filename="test_kde.pdf",
                   add_ldr_line=False, add_median_ldr_line=False, smoothing=1):
    cell_lines = get_cell_lines_on_plate(barcode)
    pdf_full = os.path.join(output_dir, filename)
    nb_plots = len(cell_lines)
    if nb_plots != 6: print(barcode + ": " + str(nb_plots) + " cell lines")
    if nb_plots == 0: return(None)
    ncols = 2
    nrows = int(np.ceil(nb_plots/2))
    nrows = 3
    fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(9, 4*nrows),
                        layout="constrained", sharex= "all")
    ldr_cutoffs = []
    for row in range(nrows):
        for col in range(ncols):
            i = row*ncols + col
            if i == nb_plots: break
            cell_line = cell_lines[i]
            wells = get_wells(barcode, cell_line)
            if n_wells is not None: wells = wells[0:n_wells]
            logint_all = []
            ax = axs[row,col]
            for well in wells:
                logint = get_logldrint(barcode, well)
                logint_all.extend(logint)
                sns.kdeplot(logint, ax=ax, alpha=0.25, smoothing=smoothing).set_title(cell_line)
                ldr_gates, ldr_lims = dcf_int.get_ldrgates(np.array([10**x for x in logint]))
                ldr_cutoff = ldr_gates[1]
                ldr_cutoffs.append(ldr_cutoff)
                if add_ldr_line:
                    ax.axvline(ldr_cutoff,ymin=0, ymax=0.1, color = "red", linestyle = "--")
            #sns.kdeplot(logint_all, ax=ax, color = "red").set_title(cell_line)
            xmin = -2
            xmax = 6
            x_lims = (xmin, xmax)
            plt.xlim(x_lims)
            x_ticks = np.arange(np.ceil(xmin), np.floor(xmax)+1)
            plt.xticks(x_ticks)
            ax.tick_params(labelbottom=True)
            if add_median_ldr_line:
                med_cutoff = np.median(ldr_cutoffs)
                #print(med_cutoff)
                ax.axvline(x=med_cutoff, ymin=0, ymax=1, color = "orange")
    fig.suptitle(barcode)
    plt.savefig(pdf_full)
    return(None)

def kde_plot_all_plates(smoothing=1):
    if not 'folder_dict' in globals(): define_folder_dict('folder_dict')
    for date in folder_dict.keys():
        print(date)
        plates = get_barcodes(date)
        plates.sort()
        for plate in plates:
            folder = "density_plots"
            if not os.path.exists(folder): os.makedirs(folder)
            file = plate + ".pdf"
            if os.path.exists(os.path.join(folder, file)):
                print(plate)
                print("pdf already written")
            else:
                kde_plot_plate(plate, output_dir = folder, filename = file, smoothing=smoothing)
                print(plate)
    return(None)

def kde_plot_all_plates_median(smoothing=1):
    if not 'folder_dict' in globals(): define_folder_dict('folder_dict')
    for date in folder_dict.keys():
        print(date)
        plates = get_barcodes(date)
        plates.sort()
        for plate in plates:
            folder = "density_plots_median_ldr"
            if not os.path.exists(folder): os.makedirs(folder)
            file = plate + "_median.pdf"
            if os.path.exists(os.path.join(folder, file)):
                print(plate)
                print("pdf already written")
            else:
                kde_plot_plate(plate, output_dir = folder, filename = file, add_median_ldr_line=True, smoothing=smoothing)
                print(plate)
    return(None)

def get_all_plates_for_cell_line(cell_line):
    if not 'meta' in globals(): load_well_metadata('meta')
    query = "cell_line == '"+cell_line+"'"
    meta_sub = meta.query(query)
    barcodes = meta_sub.barcode.unique()
    return(barcodes)

def kde_plot_all_plates_cell_line(cell_line, n_wells = None, n_barcodes=None, output_dir="", 
                                  filename="test_kde_cell_line.pdf", smoothing=1):
    barcodes = get_all_plates_for_cell_line(cell_line)
    barcodes.sort()
    pdf_full = os.path.join(output_dir, filename)
    ncols = 2
    nrows = 3
    pdf_pages = PdfPages(pdf_full)
    fig_list = []
    nb_plots = len(barcodes)
    if n_barcodes is not None: nb_plots = n_barcodes
    for i in range(nb_plots):
        barcode = barcodes[i]
        i_rel = (i % (ncols*nrows))
        col = i_rel % ncols
        row = i_rel // ncols
        print(str(row) + ", " + str(col))
        if i % (ncols*nrows) == 0:
            fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(9, 12),
                        layout="constrained", sharex= "all")
            fig.suptitle(cell_line)
        wells = get_wells(barcode, cell_line)
        if n_wells is not None: wells = wells[0:n_wells]
        ax = axs[row,col]
        for well in wells:
            logint = get_logldrint(barcode, well)
            #logint_all.append(logint)
            sns.kdeplot(logint, ax=ax, alpha=0.25, smoothing=smoothing).set_title(barcode)
        #sns.kdeplot(logint_all, ax=ax, color = "red").set_title(cell_line)
        xmin = -2
        xmax = 6
        x_lims = (xmin, xmax)
        plt.xlim(x_lims)
        x_ticks = np.arange(np.ceil(xmin), np.floor(xmax)+1)
        plt.xticks(x_ticks)
        ax.tick_params(labelbottom=True)
        plt.close()
        if (i + 1) % (ncols*nrows) == 0 or (i + 1) == nb_plots:
            plt.tight_layout()
            pdf_pages.savefig(fig)
            plt.close('all')
    pdf_pages.close()
    return(None)

def get_mixture_cutoff(model, n_iter=20, prob_tol=0.0001, update_tol=0.0001, silent=True):
    mean1 = model.distributions[0].means.item()
    mean2 = model.distributions[1].means.item()
    #print(mean1)
    #print(mean2)
    min = np.min([mean1, mean2])
    max = np.max([mean1, mean2])
    prob = 0
    cutoff = (min+max)/2
    error = 0.5-prob
    i=0
    diff = 1
    while (abs(error) > prob_tol or diff > update_tol) and i<n_iter:
        i=i+1
        if not silent:
            print(i)
            print(cutoff)
        if i>1:
            last_cutoff=cutoff
        else:
            last_cutoff=min
        prob_both=model.predict_proba([[cutoff]])
        prob1 = prob_both[0][0].item()
        prob2 = prob_both[0][1].item()
        if mean1<mean2:
            prob=prob1
        else:
            prob=prob2
        error = 0.5-prob
        if prob>0.5:
            min=cutoff
        else:
            max=cutoff
        cutoff=(min+max)/2
        diff=last_cutoff-cutoff
        if not silent:
            print(prob_both)
            print(prob)
    return(cutoff)

### inputs: x,y from sns.kdeplot output
def get_peaks_ldr(x, y, smoothing=1, first_peak_min=0.5,
                 min_prominence=0, min_peak_height=0.02, min_peak_distance=1,
                 single_peak_cutoff=3, silent=True):
    x=x.copy()
    y=y.copy()
    peak_locs, peak_props = find_peaks(y.copy(), height=0, prominence=0)
    peak_ldrs = x[peak_locs]
    peak_props_orig=peak_props.copy()
    peak_ldrs_orig=peak_ldrs.copy()
    ### get rid of invalid peaks
    keep_peaks = [peak_ldrs[i] > first_peak_min and 
                  peak_props['prominences'][i]>min_prominence and
                 peak_props['peak_heights'][i]>min_peak_height for i in range(len(peak_ldrs))]
    valid_peaks = np.sum(keep_peaks)
    indexes = [i for i in range(len(keep_peaks)) if keep_peaks[i]]
    peak_locs = [peak_locs[ind] for ind in indexes]
    peak_ldrs = [peak_ldrs[ind] for ind in indexes]
    for prop in peak_props:
        peak_props[prop] = np.array([ peak_props[prop][ind] for ind in indexes ])
    ### identify main peak and secondary peak
    peak_order = np.argsort(-np.array(peak_props['peak_heights'])) ### in order of peak height
    main_peak_index=peak_order[0]
    secondary_peak_index = None
    #### find the secondary peak
    if len(peak_locs)>1:
        ### loop over rest of peaks in order of height
        for i in range(1,len(peak_locs)):
            ### select next highest peak as long as it is sufficiently far from the first peak
            if peak_ldrs[main_peak_index] < single_peak_cutoff:
                check1 = peak_ldrs[peak_order[i]] - peak_ldrs[main_peak_index] > min_peak_distance
            else:
                check1 = peak_ldrs[main_peak_index] - peak_ldrs[peak_order[i]] > min_peak_distance
            if check1:
                secondary_peak_index = peak_order[i]
                break
                
    if secondary_peak_index is None:
        if not silent: print("search for shelf")
        ### search for a shelf if no secondary peak found
        shelf_dict = find_shelf(x.copy(), y.copy(), main_peak_ldr=peak_ldrs[main_peak_index], 
                                min_peak_height=min_peak_height,first_peak_min=first_peak_min)
        indexes = [main_peak_index]
        peak_locs = [peak_locs[ind] for ind in indexes]
        peak_ldrs = [peak_ldrs[ind] for ind in indexes]
        for prop in peak_props:
            peak_props[prop] = np.array([ peak_props[prop][ind] for ind in indexes ])
        ### add shelf if found
        if shelf_dict is not None:
            shelf_ldr = shelf_dict['ldr']
            shelf_height = shelf_dict['height']
            if shelf_ldr < peak_ldrs[0]:
                peak_ldrs = [shelf_ldr, peak_ldrs[0]]
                shelf = "peak1"
                peak_height1=shelf_height
                peak_height2=peak_props['peak_heights'][0]
            else:
                peak_ldrs = [peak_ldrs[0], shelf_ldr]
                shelf = "peak2"
                peak_height1=peak_props['peak_heights'][0]
                peak_height2=shelf_height
        else: ### single peak case
            shelf="none"
            if peak_ldrs[0]<single_peak_cutoff:
                peak_height1=peak_props['peak_heights'][0]
                peak_height2=float('nan')
            else:
                peak_height1=float('nan')
                peak_height2=peak_props['peak_heights'][0]
    else:
        shelf="none"
        ### add peaks in order of their ldr intensity
        if peak_ldrs[main_peak_index] < peak_ldrs[secondary_peak_index]:
            indexes = [main_peak_index, secondary_peak_index]
        else:
            indexes = [secondary_peak_index, main_peak_index]
        peak_locs = [peak_locs[ind] for ind in indexes]
        peak_ldrs = [peak_ldrs[ind] for ind in indexes]
        for prop in peak_props:
            peak_props[prop] = np.array([ peak_props[prop][ind] for ind in indexes ])
        peak_height1=peak_props['peak_heights'][0]
        peak_height2=peak_props['peak_heights'][1]
    if(len(peak_ldrs)==0):
        print("Warning! No valid peaks")
        shelf="none"
        out={'peak1': float('nan'), 'peak2': float('nan'), 'peak1_height':float('nan'), 'peak2_height':float('nan'),
             'shelf':shelf, 'peak_props':peak_props, 'peak_props_orig':peak_props_orig, 'peak_ldrs_orig': peak_ldrs_orig}
    if(len(peak_ldrs)==1):
        if(peak_ldrs[0] < single_peak_cutoff):
            if not silent: print("only live cell peak found")
            out={'peak1': peak_ldrs[0], 'peak2': float('nan'), 'peak1_height':peak_height1, 'peak2_height':peak_height2,
                 'shelf':shelf, 'peak_props':peak_props, 'peak_props_orig':peak_props_orig, 'peak_ldrs_orig': peak_ldrs_orig}
        else:
            if not silent: print("only dead cell peak found")
            out={'peak1': float('nan'), 'peak2': peak_ldrs[0],  'peak1_height':peak_height1, 'peak2_height':peak_height2,
                 'shelf':shelf, 'peak_props':peak_props, 'peak_props_orig':peak_props_orig, 'peak_ldrs_orig': peak_ldrs_orig}
    if(len(peak_ldrs)==2):
        out={'peak1': peak_ldrs[0], 'peak2': peak_ldrs[1],  'peak1_height':peak_height1, 'peak2_height':peak_height2,
             'shelf':shelf, 'peak_props':peak_props, 'peak_props_orig':peak_props_orig, 'peak_ldrs_orig': peak_ldrs_orig}
    if(len(peak_ldrs)>2):
        print("More than two peaks returned. This shouldn't happen -- must be a bug in the code")
        out={'peak1': peak_ldrs[0], 'peak2': peak_ldrs[1],  'peak1_height':peak_height1, 'peak2_height':peak_height2,
             'shelf':shelf, 'peak_props':peak_props, 'peak_props_orig':peak_props_orig, 'peak_ldrs_orig': peak_ldrs_orig}
    return(out)

def get_ldr_cutoff_mixture(logint, peak_ldrs, show=True, mean_tol=0.4, silent=False):
    #print("get_ldr_cutoff_mixture")
    logint = logint.copy()
    peak_ldrs=peak_ldrs.copy()
    ### mixture model
    X = np.array(logint).reshape(-1,1)
    X = torch.tensor(X).float()
    try:
        #print("model fitting")
        #print(peak_ldrs[0])
        #print(peak_ldrs[1])
        m1 = torch.tensor(peak_ldrs[0])
        m2 = torch.tensor(peak_ldrs[1])
        #m1.frozen=True
        #m2.frozen=True
        d1 = Normal(means=[m1], frozen = False)
        d2 = Normal(means=[m2], frozen = False)
        #d1.frozen=torch.tensor(True)
        #d2.frozen=torch.tensor(True)
        d3 = [d1,d2]
        priors = np.empty((len(X), 2))
        for i in range(len(priors)):
            if X[i][0] < peak_ldrs[0]+0.5:
                priors[i][0] = 1
                priors[i][1] = 0
            elif X[i][0] > peak_ldrs[1]-0.5:
                priors[i][0] = 0
                priors[i][1] = 1
            else:
                priors[i][0] = 0.5
                priors[i][1] = 0.5
        model = GeneralMixtureModel(d3, verbose=False, frozen=False, tol=0.001, max_iter=100, inertia=0.9).fit(X, priors=priors)
    except:
        if not silent: print("mixture model failed")
        return(float('nan'))
    try:
        ldr_cutoff = get_mixture_cutoff(model, silent=True)
    except:
        if not silent: print("get mixture cutoff failed")
        return(float('nan'))
    if show:
        try:
            x = np.arange(np.min(logint), np.max(logint), 0.1)
            y1 = model.distributions[0].probability(x.reshape(-1, 1))
            y2 = model.distributions[1].probability(x.reshape(-1, 1))
            y3 = model.probability(x.reshape(-1, 1))
            #fig, ax = plt.subplots()
            plt.figure(figsize=(6, 3))
            plt.hist(X[:,0], density=True, bins=30)
            plt.plot(x, y1, color = "green", label="Normal1")
            plt.axvline(peak_ldrs[0], color="green", label="live peak")
            plt.plot(x, y2, color = "red", label="Normal2")
            plt.axvline(peak_ldrs[1], color="red", label="dead peak")
            plt.plot(x, y3, color = "purple", label="Mixture")
            plt.axvline(ldr_cutoff, color="orange", label="LDR cutoff")
            plt.legend(loc=(1.05, 0.4))
            plt.tight_layout()
            plt.show()
        except:
            if not silent: print("plotting mixture model failed")
            return(float('nan'))
    mean1 = model.distributions[0].means.item()
    mean2 = model.distributions[1].means.item()
    check1 = abs(mean1-peak_ldrs[0]) < mean_tol
    check2 = abs(mean2-peak_ldrs[1]) < mean_tol
    check3 = ldr_cutoff > peak_ldrs[0] and ldr_cutoff < peak_ldrs[1]
    if(check1 and check2 and check3):
        return(ldr_cutoff)
    else:
        if not silent: print("mixture model fitting failed")
        return(float('nan'))

def get_ldr_cutoff_valley(x,y, peak_ldrs, silent=False):
    x=x.copy()
    y=y.copy()
    peak_ldrs=peak_ldrs.copy()
    ### find valley in between two most prominent peaks
    x_sub = [val for val in x if val < peak_ldrs[1] and val > peak_ldrs[0] ]
    y_sub = [val for val in y if val < peak_ldrs[1] and val > peak_ldrs[0] ]
    y_sub_neg = [-val for val in y_sub]
    valley_locs, valley_props = find_peaks(y_sub_neg, height=float('-Inf'), prominence=0)
    if len(valley_locs) > 0:
        ### get "peak" with maximum height -- all "peaks" will have negative height, so this will give the lowest valley
        valley_ldrs = [ y_sub[loc] for loc in valley_locs ]
        ldr_cutoff = np.max(valley_ldrs)
        return(ldr_cutoff)
    else:
        if not silent: print("valley method failed")
        return(float('nan'))

def get_ldrgates_new(ldrint, smoothing=1, show=True, first_peak_min=0.5,
                     min_prominence=0, min_peak_height=0.02, min_peak_distance=0.5,
                     single_peak_cutoff=3,
                     mixture_backup_method="valley", silent=True,
                    return_peaks_only=False):
    ldrint = ldrint[ldrint > 0]
    logint = np.log10(ldrint)
    logint = logint[ [not x for x in np.isnan(logint)] ]
    logint = logint[ [not x for x in np.isinf(logint)] ]
    logint = logint.copy()
    if show:
        fig, ax = plt.subplots()
        x, y = sns.kdeplot(logint, ax=ax, bw_adjust=smoothing).get_lines()[0].get_data()
        plt.close()
    else:
        x, y = sns.kdeplot(logint, bw_adjust=smoothing).get_lines()[0].get_data()
        plt.close()
    alive_dead_peaks = get_peaks_ldr(x.copy(), y.copy(), smoothing=smoothing, first_peak_min=first_peak_min,
                                    min_prominence=min_prominence, min_peak_height=min_peak_height,
                                    min_peak_distance=min_peak_distance, single_peak_cutoff=single_peak_cutoff,
                                    silent=silent)
    if return_peaks_only: return(alive_dead_peaks)
    peak_ldrs = [alive_dead_peaks['peak1'], alive_dead_peaks['peak2']]
    peak_ldrs = [x for x in peak_ldrs if not np.isnan(x)]
    method_used="none"
    if show:
        fig, ax = plt.subplots()
        x, y = sns.kdeplot(logint, ax=ax, bw_adjust=smoothing).get_lines()[0].get_data()
        #plt.show()
    else:
        x, y = sns.kdeplot(logint, bw_adjust=smoothing).get_lines()[0].get_data()
        plt.close()
    if len(peak_ldrs) == 1:
        if not silent: print("one peak. almost all cells alive?")
        ### assume almost all cells are alive
        if peak_ldrs[0] < single_peak_cutoff:
            if not silent: print("one peak. almost all cells alive?")
            ldr_cutoff = np.quantile(logint, 0.99)
            ldr_cutoff_mixture=ldr_cutoff
            ldr_cutoff_valley=ldr_cutoff
            ldr_cutoff_middle=ldr_cutoff
        else:
            if not silent: print("one peak. almost all cells dead?")
            ldr_cutoff = np.quantile(logint, 0.01)
            ldr_cutoff_mixture=ldr_cutoff
            ldr_cutoff_valley=ldr_cutoff
            ldr_cutoff_middle=ldr_cutoff
    elif len(peak_ldrs) > 1:
        if return_peaks_only: return({'peak_ldrs':peak_ldrs, 'peak_props':peak_props})
        ###### Note: write up this section as a new function: "get_ldr_cutoff"
        ldr_cutoff_mixture = get_ldr_cutoff_mixture(logint.copy(), peak_ldrs, show=show, silent=silent)
        ldr_cutoff_valley = get_ldr_cutoff_valley(x,y, peak_ldrs, silent=silent)
        ldr_cutoff_middle = (peak_ldrs[0] + peak_ldrs[1])/2
        if not np.isnan(ldr_cutoff_mixture):
            method_used = "mixture"
            ldr_cutoff = ldr_cutoff_mixture
        elif not np.isnan(ldr_cutoff_valley):
            method_used = "valley"
            ldr_cutoff = ldr_cutoff_valley
        else:
            method_used = "middle"
            ldr_cutoff = ldr_cutoff_middle
    out = alive_dead_peaks
    out['method_used'] = method_used
    out['ldr_cutoff'] = ldr_cutoff
    out['ldr_gates'] = np.array([-np.inf, ldr_cutoff])
    out['ldr_cutoff_mixture'] = ldr_cutoff_mixture
    out['ldr_cutoff_valley'] = ldr_cutoff_valley
    out['ldr_cutoff_middle'] = ldr_cutoff_middle
    #ldr_lims = np.array([x.min(), x.max()])
    #return(ldr_gates, ldr_lims)
    if not silent: print(out)
    return(out)

def get_ldrgates_new_well(barcode, well, smoothing=1.1, show=True, first_peak_min=0.5,
                          min_prominence=0, min_peak_height=0.02, min_peak_distance=0.5,
                          single_peak_cutoff=3,
                          silent=True, return_peaks_only=False):
    df = read_and_rename_well_data(barcode, well, silent=True)
    ldrint = df['ldr'].copy()
    results = get_ldrgates_new(ldrint, smoothing=smoothing, show=show, silent=silent,
                               min_prominence=min_prominence, min_peak_height=min_peak_height,
                               min_peak_distance=min_peak_distance, single_peak_cutoff=single_peak_cutoff,
                               return_peaks_only=return_peaks_only)
    return(results)


def get_peaks_cell_line(barcode, cell_line, method="mixture", smoothing=1, df_only=True, return_peaks_only=False):
    wells = get_wells(barcode, cell_line)
    peak_list=[]
    df_list = []
    for well in wells:
        print(well)
        pdict = get_ldrgates_new_well(barcode, well, show=True, return_peaks_only=return_peaks_only,
                                      smoothing=smoothing, silent=True)
        pdict['barcode']=barcode
        pdict['cell_line']=cell_line
        pdict['well']=well
        if return_peaks_only:
            cols = ['barcode', 'cell_line', 'well', 'peak1', 'peak2', 
                    'peak1_height', 'peak2_height','shelf']
        else:
            cols = ['barcode', 'cell_line', 'well', 'ldr_cutoff', 'peak1', 'peak2', 
                    'peak1_height', 'peak2_height','shelf','method_used', 'ldr_cutoff_mixture', 'ldr_cutoff_valley', 'ldr_cutoff_middle']
        dd = {k:v for (k,v) in pdict.items() if k in cols}
        df = pd.DataFrame(data=dd, index=[0])
        df = df[cols]
        df_list.append(df)
        peak_list.append(pdict['peak_props'])
    df_full = pd.concat(df_list)
    df_full = df_full.reset_index()
    if df_only:
        return(df_full)
    else:
        ### return the peak lists later
        return({'df':df_full, 'peak_props':peak_list})

def get_peaks_barcode(barcodes, smoothing=1, df_only=True, return_peaks_only=False):
    peak_list=[]
    df_list = []
    for barcode in barcodes:
        cell_lines = get_cell_lines_on_plate(barcode)
        for cell_line in cell_lines:
            #print(cell_line)
            wells = get_wells(barcode, cell_line)
            for well in wells:
                pdict = get_ldrgates_new_well(barcode, well, show=False, return_peaks_only=return_peaks_only,
                                      smoothing=smoothing, silent=True)
                pdict['barcode']=barcode
                pdict['cell_line']=cell_line
                pdict['well']=well
                if return_peaks_only:
                    cols = ['barcode', 'cell_line', 'well', 'peak1', 'peak2', 
                            'peak1_height', 'peak2_height','shelf']
                else:
                    cols = ['barcode', 'cell_line', 'well', 'ldr_cutoff', 'peak1', 'peak2', 
                            'peak1_height', 'peak2_height','shelf','method_used', 'ldr_cutoff_mixture', 'ldr_cutoff_valley', 'ldr_cutoff_middle']
                dd = {k:v for (k,v) in pdict.items() if k in cols}
                df = pd.DataFrame(data=dd, index=[0])
                df = df[cols]
                df_list.append(df)
                peak_list.append(pdict['peak_props'])
    df_full = pd.concat(df_list)
    df_full = df_full.reset_index()
    if df_only:
        return(df_full)
    else:
        ### return the peak lists later
        return({'df':df_full, 'peak_props':peak_list})

def get_all_peaks(smoothing=1):
    load_well_metadata('meta')
    barcodes = meta['barcode'].unique()
    out_file = "peak_locations_new.csv"
    if os.path.exists(out_file):
        df = pd.read_csv(out_file)
        done_barcodes = list(df['barcode'])
    barcodes = [bc for bc in barcodes if not (bc in done_barcodes)]
    barcodes = sorted(barcodes)
    for i in range(len(barcodes)):
        barcode = barcodes[i]
        print(barcode)
        res = get_peaks_barcode([barcode], smoothing=smoothing)
        if i==0 and not ('done_barcodes' in locals() or 'done_barcodes' in globals()):
            res.drop(columns=['index']).to_csv(out_file, header=True, index=False)
        else:
            res.drop(columns=['index']).to_csv(out_file, mode='a', header=False, index=False)
    return(res)

## steps
# 1) identify ends of left and right shelves
# 2) if main peak < 3 (default) assume it's the live peak, otherwise assume it's the dead peak
# 3) identify all x intervals that are:
#        1) inbetween those two shelves
#        2) at least min_distance=1 from first peak
#        3) to the right of the peak (if main peak<3), or to the left of the peak  (if main peak < 3)
# 4) identify the point closest to zero
# 3) report the point as the second peak as long as it's close enough to 0

### input x and y from the sns.kdeplot function
def find_side_shelves(x, y, tol=0.03):
    x = x.copy()
    y = y.copy()
    dx = np.diff(x)[0]
    grad = np.gradient(y, dx)  ## first derivative
    #grad2 = np.gradient(grad, dx)  ## second derivative
    for i in range(len(grad)):
        if abs(grad[i]) > tol:
            end_left_shelf = x[i]
            break
    for i in reversed(range(len(grad))):
        if abs(grad[i]) > tol:
            end_right_shelf = x[i]
            break
    return({'end_left':end_left_shelf, 'end_right':end_right_shelf})

def find_shelf_old(x, y, main_peak_ldr, min_peak_distance=1, single_peak_cutoff=3, tol=0.03):
    sides = find_side_shelves(x,y,tol)
    dx = np.diff(x)[0]
    grad = np.gradient(y, dx)  ## first derivative
    #grad2 = np.gradient(grad, dx)  ## second derivative
    ### remove side shelves
    ind = [list(x).index(val) for val in x if val > sides['end_left'] and val < sides['end_right']]
    x = [x[i] for i in ind]
    y = [y[i] for i in ind]
    grad = [grad[i] for i in ind]
    ### remove values close to the main peak (and to the wrong side of peak)
    if main_peak_ldr < single_peak_cutoff: ## assume it's alive peak
        ind = [list(x).index(val) for val in x if val > main_peak_ldr + min_peak_distance]
        x = [x[i] for i in ind]
        y = [y[i] for i in ind]
        grad = [grad[i] for i in ind]
    else: ### assume it's the dead peak
        ind = [list(x).index(val) for val in x if val < main_peak_ldr - min_peak_distance]
        x = [x[i] for i in ind]
        y = [y[i] for i in ind]
        grad = [grad[i] for i in ind]
    ind_shelf = np.argmin(np.abs(grad))
    #print(grad[ind_shelf])
    if np.abs(grad[ind_shelf]) < tol:
        #print('found')
        return(x[ind_shelf])
    else:
        #print("not found")
        return(None)

def find_shelf(x, y, main_peak_ldr, min_peak_distance=1, single_peak_cutoff=3, tol=0.03, 
               min_peak_height=0.02, first_peak_min=0.5, max_slope=0.3):
    x = x.copy()
    y = y.copy()
    sides = find_side_shelves(x,y,tol)
    dx = np.diff(x)[0]
    grad = np.gradient(y, dx)  ## first derivative
    #return(grad)
    #grad2 = np.gradient(grad, dx)  ## second derivative
    ### remove values close to the main peak (and to the wrong side of peak)
    if main_peak_ldr < single_peak_cutoff: ## assume it's alive peak
        #print('alive')
        #print(main_peak_ldr + min_peak_distance)
        #print(sides['end_right'])
        ind = [list(x).index(val) for val in x if val > main_peak_ldr + min_peak_distance and val < sides['end_right']]
        x = [x[i] for i in ind]
        y = [y[i] for i in ind]
        grad = [grad[i] for i in ind]
        ### find peaks close to zero
        peak_locs, peak_props = find_peaks(grad, height=float('-Inf'), prominence=0)
    else: ### assume it's the dead peak
        #print('dead')
        #print(main_peak_ldr - min_peak_distance)
        #print(sides['end_left'])
        ind = [list(x).index(val) for val in x if val < main_peak_ldr - min_peak_distance and val > sides['end_left']]
        x = [x[i] for i in ind]
        y = [y[i] for i in ind]
        grad = [grad[i] for i in ind]
        neg_grad = [-x for x in grad]
        ### find valleys close to zero
        peak_locs, peak_props = find_peaks(neg_grad, height=float('-Inf'), prominence=0)
    #return([x[i] for i in peak_locs])
    if len(peak_locs)>0:
        max_ind = np.argmin(abs(np.array(peak_props['peak_heights'])))
        ### choose the peak with height closed to zero
        shelf_loc = peak_locs[max_ind]
        ldr_shelf = x[shelf_loc] ### return x coordinate (ldr intensity)
        shelf_height = y[shelf_loc]
        shelf_slope = grad[shelf_loc]
        if ldr_shelf > first_peak_min and shelf_height>min_peak_height and abs(shelf_slope) < max_slope:
            out = {'ldr':ldr_shelf, 'height':shelf_height, 'slope': shelf_slope}
            #print(out)
            return(out)
        else:
            return(None)
    else:
        return(None)

def plot_density_derivative(barcode, well):
    logint = get_logldrint(barcode, well)
    fig, ax = plt.subplots()
    x, y = sns.kdeplot(logint, ax=ax, bw_adjust=1).get_lines()[0].get_data()
    plt.close()
    #sides = find_side_shelves(x,y,tol)
    dx = np.diff(x)[0]
    grad = np.gradient(y, dx)  ## first derivative
    gg1 = ggplot()+geom_line(aes(x=x, y=grad))
    print(gg1)
    out = {'x': x,'grad': grad}
    neg_grad = [-x for x in grad]
    peak_locs, peak_props = find_peaks(-grad, height=float('-Inf'), prominence=0)
    out = {'peak_locs': peak_locs, 'peak_props': peak_props}
    return(out)

load_well_metadata('meta')
define_regating_df('regate_df')
define_folder_dict('folder_dict')

In [None]:
res = get_all_peaks()

In [None]:
res['df_single']

In [None]:
kde_plot_all_plates()

In [None]:
kde_plot_all_plates_median()

In [None]:
kde_plot_all_plates_cell_line("SUM185PE", n_wells=None)

In [None]:
barcode = "210406_combo_62"
#cell_line="HS578T"
well = "D15" ### detected w/o smoothing, shelf w/ smoothing=1.1
well= "C04"
#well = "F21"
get_ldrgates_new_well(barcode, well, silent=False, smoothing=1)
logint=get_logldrint(barcode, well)
#kde_plot_wells(barcode, [well], smoothing=1)
#test = get_peaks_cell_line(barcode, cell_line)
#test1 = get_peaks_barcode([barcode], smoothing=1)
#test2 = get_peaks_barcode([barcode], smoothing=1.1)

In [None]:
logint = [x for x in logint if not math.isnan(x)]
logint = sorted(logint)
### x is the input sequence
### n is the window size (window size of 2 is two on either side of reference value)
def moving_sd(x,n):
    sd_list=[]
    i_list=[]
    for i in range(n, len(x)-n):
        vals=x[(i-n):(i+n+1)]
        #print(vals)
        sd = np.std(vals)
        sd_list.append(sd)
        i_list.append(i)
    df=pd.DataFrame(data={'i':i_list,'sd':sd_list})
    return(df)
df=moving_sd(logint, 10)
ggplot(df)+geom_point(aes(x='i',y='sd'))

x1,y1 = get_kde_plot_data_well(barcode, well,smoothing=1)
x,y=get_kde_plot_data_well(barcode, well,smoothing=1)
dx = np.diff(x1)[0]
grad = np.gradient(y1, dx)
grad2 = np.gradient(grad, dx)
gg1 =ggplot()+geom_line(aes(x=x1, y=y)) 
gg2 = ggplot()+geom_line(aes(x=x1, y=grad)) #+ coord_cartesian(ylim=[-0.05,0.05])
#ggplot()+geom_point(aes(x=x1, y=grad2))
print(gg1)
print(gg2)

In [None]:
find_shelf(x,y, main_peak_ldr=1,tol=0.03)

In [None]:
barcode = "210406_combo_62"
#cell_line="HS578T"
well = "D15" ### detected w/o smoothing, shelf w/ smoothing=1.1
#well= "C04"
#well = "F21"
barcode = "210406_combo_71"
well = "C06"
barcode="210806_combo_146"
well = "E12"

get_ldrgates_new_well(barcode, well, smoothing=1, return_peaks_only=False, show=True, silent=False)
plot_density_derivative(barcode, well)
#test = get_ldrgates_new_well(barcode, well, smoothing=1, return_peaks_only=True, show=False, silent=True)
#get_ldrgates_new_well(barcode, well, silent=False, smoothing=1, method = "valley")

In [None]:
barcode = "210302_combo_61"
well = "L09"
get_ldrgates_new_well(barcode, well, smoothing=1, return_peaks_only=False, show=True, silent=False)

In [None]:
barcode="201117_combo_19"
well="M06"
get_ldrgates_new_well(barcode, well, smoothing=1, method = "mixture", return_peaks_only=False, show=True, silent=False)

In [None]:
#test2['df_double'].sort_values('peak2')
df1 = pd.concat([test1['df_single'], test1['df_double']]).sort_values('well').reset_index()
df2 = pd.concat([test2['df_single'], test2['df_double']]).sort_values('well').reset_index()
df1['peak1_smooth'] = df2['peak1']
df1['peak2_smooth'] = df2['peak2']
df1 = df1.assign(peak1_diff = lambda x: x['peak1_smooth'] - x['peak1'],
                peak2_diff = lambda x: x['peak2_smooth'] - x['peak2'])
df_new = df1.query('peak2_smooth.isna() and peak2.notna()')
df_new

In [None]:
test1['df_single']

In [None]:
ggplot(df1, aes(x='cell_line', y = 'peak2_diff')) + geom_boxplot() + geom_point()

In [None]:
barcode="210226_combo_51"
#well="K08"
well = "K10"
#well = "K15"
#plot_ldr_well(barcode=barcode, well=well)
#get_ldrgates_new_well(barcode, well, smoothing=1, show=True, method="valley")
model = get_ldrgates_new_well(barcode, well, smoothing=1, show=True, method="mixture", silent=True)
test = get_ldrgates_new_well(barcode, well, show=False, return_peaks_only=True)

In [None]:
type([])

In [None]:
barcode="210226_combo_51"
#well="K08"
well = "K10" ## most alive
df_alive = read_and_rename_well_data(barcode, well)
well = "K15" ## most dead
df_dead = read_and_rename_well_data(barcode, well)

In [None]:
tmp

In [None]:
#barcode = "201117_combo_19"
#cell_line = "HCC70"
barcode = "210406_combo_62"
cell_line="HS578T"
barcode="201117_combo_12"
cell_line="HCC1806"
cell_line="HCC1187"
barcode = '210406_combo_62'
cell_line = 'SUM1315'
barcode = '210423_combo_78'
cell_line = 'SUM185PE'
barcode="210406_combo_71"
cell_line="HS578T"
tmp = get_peaks_cell_line(barcode, cell_line)
#tmp = get_peaks_barcode([barcode])
#tmp_valley = get_peaks_cell_line(barcode, cell_line, method="valley")
#tmp = get_peaks_barcode(barcodes)

In [None]:
tmp

In [None]:
tmp_valley.iloc[1:10]

In [None]:
print(tmp_orig.iloc[0:10])

In [None]:
barcode = "201117_combo_19"
cell_line = "HCC70"
#tmp = get_peaks_cell_line(barcode, cell_line)
tmp = get_peaks_barcode(barcode)

In [None]:
#tmp['df_double'].sort_values('peak2', ascending=False).iloc[1:10]
tmp['df_single']
#df_double_old = tmp['df_double']

In [None]:
barcode='201117_combo_19'
well='D12' ### misses second peak
#well='C19' ### gets second peak w/o min_prominence argument
#well='K13' ### tiny peak (prominence 0.015), but catches it
#well='C05'
#barcode='201117_combo_20'
#well ='K19' ### very tiny peak (prominence 0.0102), but catches it
#well = 'C03' ### misses second peak
#well = 'C07'
get_ldrgates_new_well(barcode, well, smoothing=1, show=True, method="mixture", silent=False, 
                      min_prominence=0, min_peak_height=0.01)

In [None]:
ggplot(tmp['df_double'], aes(x='cell_line', y='peak1')) + geom_boxplot() + geom_jitter()
#ggplot(tmp['df_single'], aes(x='cell_line', y='peak1')) + geom_boxplot() + geom_jitter()

In [None]:
#plot_wells_ldr(barcode="211029_combo_185", cell_line = "SUM1315")
#kde_plot_plate(plate, output_dir = "", filename = file, add_median_ldr_line=True)
kde_plot_cell_line(barcode="211029_combo_185", cell_line = "SUM1315")

In [None]:
#plot_wells_ldr(barcode="211029_combo_185", cell_line = "HCC1395")
plot_wells_ldr(barcode="210226_combo_51", cell_line = "HCC38")

In [None]:
barcode = "201117_combo_19"
cell_line = "HCC70"
kde_plot_cell_line(barcode, cell_line)
#plot_wells_ldr(barcode, cell_line)

In [None]:
barcode = "211005_combo_161"
cell_line = "SUM1315"
#plot_wells_ldr(barcode, cell_line)
kde_plot_cell_line(barcode, cell_line)

In [None]:
barcode = "211005_combo_162"
cell_line = "SUM149"
plot_wells_ldr(barcode, cell_line)
kde_plot_cell_line(barcode, cell_line)

In [None]:
barcode = "201117_combo_28"
cell_line = "SUM149"
kde_plot_cell_line(barcode, cell_line, n_wells=5, well_start=25, add_legend=True)

In [None]:
barcode = "201117_combo_13"
cell_line = "HCC70"
kde_plot_cell_line(barcode, cell_line, n_wells=5, well_start=35, add_legend=True)

In [None]:
barcode = "201117_combo_32"
cell_line = "HS578T"
kde_plot_cell_line(barcode, cell_line, n_wells=5, well_start=20, add_legend=True)

In [None]:
barcode = "201117_combo_32"
well = "C03"
plot_ldr_well(barcode, well)
well = "C05"
plot_ldr_well(barcode, well)
meta_sub = get_well_meta(barcode, well)

In [None]:
list(meta_sub.agent1)[0]

In [None]:
barcode = "201117_combo_13"
well = "M05"
plot_ldr_well(barcode, well)
well = "M07"
plot_ldr_well(barcode, well)
well = "N14"
plot_ldr_well(barcode, well)
test = "dsafasfd"

In [None]:
barcode = "201117_combo_28"
well = "G06"
plot_ldr_well(barcode, well)
well = "G09"
plot_ldr_well(barcode, well)
well = "G08"
plot_ldr_well(barcode, well)
well = "G04"
plot_ldr_well(barcode, well)
test = "dsafasfd"

In [None]:
barcode = "201117_combo_32"
kde_plot_plate(barcode, n_wells=5, filename="test_lines.pdf", add_median_ldr_line=True)

In [None]:
os.path.exists("density_plots/201117_combo_12.pdf")

In [None]:
kde_plot_all_avg(cell_line, n_barcodes=None)

In [None]:
cell_line="SUM185PE"
#test = get_all_plates_for_cell_line(cell_line)
kde_plot_all_avg(cell_line, n_barcodes=None)

In [None]:
len(test)

In [None]:
barcode = "201117_combo_14"
wells = ["C20", "D06"]
cell_line = "HCC1187"
#kde_plot_wells(barcode, wells, title = "testing")
#kde_plot_cell_line(barcode, cell_line, n_wells=3)

barcode = "210406_combo_62"
barcode = "210423_combo_78"
kde_plot_plate(barcode, n_wells = 5, filename="test_new4.pdf")


In [None]:
barcode = "201117_combo_14"
cell_line = "HCC1187"
kde_plot_avg(barcode, cell_line)

In [None]:
cutoffs = get_ldr_cutoffs_parallel()
meta_ldr = meta
meta_ldr['ldr_cutoff'] = cutoffs
#meta_ldr.to_csv("meta_with_ldr.csv")

In [None]:
### find possibly mis-gated wells
df = pd.read_csv('meta_with_ldr.csv')
df_tmp = df.query('cell_line == "BT20"')
df_join = df.groupby(['cell_line', 'barcode']).agg({'ldr_cutoff': ['median', np.std]})
df_join = df_join['ldr_cutoff']
df_join['high_cutoff'] = df_join['median'] + 3*df_join['std']
df_join['low_cutoff'] = df_join['median'] - 3*df_join['std']
df2 = df.merge(df_join, on = ['cell_line', 'barcode'], how = 'left')
df_high = df2.query('ldr_cutoff > high_cutoff')
df_low = df2.query('ldr_cutoff < low_cutoff')
df_high = df_high.reset_index()
df_low = df_low.reset_index()

In [None]:
#barcode = "211029_combo_180"
#cell_line = "SUM1315"
#peak_loc = 1.2
#well = "G06"

bad_well_df = pd.read_csv("mis_gated.csv")
bad_well_df = bad_well_df.merge(df_high, on = ['barcode', 'well', 'cell_line'], how = 'left')
bad_well_df
i=4
barcode = bad_well_df.barcode[i]
cell_line = bad_well_df.cell_line[i]
well = bad_well_df.well[i]
med_ldr = bad_well_df['median'][i]
df1 = get_counts_well(barcode, well, manual_ldr_cutoff=None, plot = True)
print(df1)
df2 = get_counts_well(barcode, well, manual_ldr_cutoff=3, plot = True)
print(df2)

In [None]:
df_low.iloc[50:55]

In [None]:
barcode = list(df_low.barcode)[53]
cell_line = list(df_low.cell_line)[53]
well = list(df_low.well)[53]
print(barcode)
print(cell_line)
print(well)
df_tmp = read_and_rename_well_data(barcode, well)
df_tmp

In [None]:
#df = df_high
df = df_low
#df.reset_index()
#fig_list = []
output_dir = ""
#pdf = "possible_misgated_high_3sigma.pdf"
pdf = "possible_misgated_low_3sigma.pdf"

pdf_full = os.path.join(output_dir, pdf)
pdf_pages = PdfPages(pdf_full)
nb_plots = len(df.well)
plots_per_page = 6
#print(nb_plots)
for i in range(nb_plots):
    #print(i)
    barcode = df.iloc[i].barcode
    cell_line = df.iloc[i].cell_line
    well = df.iloc[i].well
    median_ldr = df.iloc[i]['median']
    ldr_high = df.iloc[i].high_cutoff
    agent1 = df.iloc[i].agent1
    agent2 = df.iloc[i].agent2
    conc1 = df.iloc[i].concentration1_chr
    conc2 = df.iloc[i].concentration2_chr
    fig_title = well + " " + barcode + " " + cell_line + "\n" +str(agent1) + " " + str(conc1) + " uM " + str(agent2) + " " + str(conc2) + " uM " + "\n" +"Median LDR: " + str(round(median_ldr, 2)) +" LDR cutoff high: " + str(round(ldr_high, 2))
    if i % plots_per_page == 0:
        fig = plt.figure(figsize=(8.5, 11))
        outer = GridSpec(3, 2, wspace=0.2, hspace=0.5)
    i_page = i % plots_per_page
    ### make figures
    try:
        plot_ldr_well(barcode, well, add_ldr_line=median_ldr, fig=fig, outer=outer, i=i_page,
                 silent=True, show_fig=False, scatter=True, title=fig_title)
    except:
        message = "failed plotting: " + well + " " + barcode + " " + cell_line
        print(message)
        
    #fig_list.append(fig_tmp)
    if (i + 1) % plots_per_page == 0 or (i + 1) == nb_plots:
           plt.tight_layout()
           pdf_pages.savefig()
           plt.close('all')
pdf_pages.close()

In [None]:
df

In [None]:
#ggplot(meta_ldr, aes(x = 'date', y = 'ldr_cutoff')) + geom_boxplot() + coord_flip() + facet_wrap('cell_line', scales = "free_y") +  theme(figure_size = (10, 10))

In [None]:
#ggplot(meta_ldr.query('cell_line == "BT20"'), aes(x = 'barcode', y = 'ldr_cutoff')) + geom_boxplot() + geom_jitter() + coord_flip() + theme(figure_size = (5, 15))

In [None]:
plot_all_problem_plates()

In [None]:
df_new_gating = regate_wells()
df_new_gating.to_csv('regating_counts_07_10_2023.csv')

In [None]:
df1 = get_ldr_cutoffs_plate("211029_combo_180", peak_loc = 1.2)

In [None]:
ggplot(df1, aes(x = 'cell_line', y = 'ldr_cutoff')) + geom_boxplot() + geom_jitter()

In [None]:
barcode = "211029_combo_180"
cell_line = "SUM1315"
well = "G06"
#plot_wells_ldr(barcode, cell_line, output_dir = "")
#plot_ldr_well(barcode, well)

######## plot peaks w/ normal algorithm
df = read_and_rename_well_data(barcode, well, silent=True)
ldrint = df['ldr']
ldrint = ldrint[ldrint > 0]
logint = np.log10(ldrint)

import seaborn as sns
fig, ax = plt.subplots()
x, y = sns.kdeplot(logint, ax=ax).get_lines()[0].get_data()

peak_locs, _ = find_peaks(-y)
#print(peak_locs)
cc = x[peak_locs]
print(cc)

###### plot peaks w/ negative ldr values mapped to minimum positive ldr
df = read_and_rename_well_data(barcode, well, silent=True)

df_pos1 = df.query("ldr > 0")
min_ldr = np.min(df_pos1.ldr)
df_pos2 = df.query("dna > 0")
min_dna = np.min(df_pos2.dna)
df['ldr'] = [x if x>0 else min_ldr for x in df.ldr]
df['dna'] = [x if x>0 else min_dna for x in df.dna]

ldrint = df['ldr']
ldrint = ldrint[ldrint > 0]
logint = np.log10(ldrint)

import seaborn as sns
fig, ax = plt.subplots()
x, y = sns.kdeplot(logint, ax=ax).get_lines()[0].get_data()

peak_locs, _ = find_peaks(-y)
#print(peak_locs)
cc = x[peak_locs]
print(cc)

In [None]:
barcode = "211029_combo_180"
cell_line = "SUM1315"
peak_loc = 1.2
well = "G06"
well = "C19"
well = "C20"


######## plot peaks w/ smoothing
df = read_and_rename_well_data(barcode, well, silent=True)
ldrint = df['ldr']
ldrint = ldrint[ldrint > 0]
logint = np.log10(ldrint)

import seaborn as sns
fig, ax = plt.subplots()
x, y = sns.kdeplot(logint, ax=ax, bw_adjust=1).get_lines()[0].get_data()
peak_locs, test = find_peaks(-y, prominence=0, width=0)
#print(peak_locs)
cc = x[peak_locs]
try:
    ldr_cutoff = cc[cc > peak_loc][0]
except IndexError:
    ldr_cutoff = np.quantile(logint, 0.99)
plt.axvline(x=ldr_cutoff, ls = "--", color = "red")
print(cc)
print(test)
print(ldr_cutoff)

fig, ax = plt.subplots()
x, y = sns.kdeplot(logint, ax=ax, bw_adjust=1.5).get_lines()[0].get_data()
peak_locs, test = find_peaks(-y, prominence=0, width=0)
#print(peak_locs)
cc = x[peak_locs]
try:
    ldr_cutoff = cc[cc > peak_loc][0]
except IndexError:
    ldr_cutoff = np.quantile(logint, 0.99)
plt.axvline(x=ldr_cutoff, ls = "--", color = "red")
print(cc)
print(test)
print(ldr_cutoff)

fig, ax = plt.subplots()
x, y = sns.kdeplot(logint, ax=ax, bw_adjust=2).get_lines()[0].get_data()
peak_locs, test = find_peaks(-y, prominence=0, width=0)
#print(peak_locs)
cc = x[peak_locs]
try:
    ldr_cutoff = cc[cc > peak_loc][0]
except IndexError:
    ldr_cutoff = np.quantile(logint, 0.99)
plt.axvline(x=ldr_cutoff, ls = "--", color = "red")
print(cc)
print(test)
print(ldr_cutoff)

In [None]:
barcode = "211029_combo_180"
cell_line = "SUM1315"
peak_loc = 1.2
well = "C19"
df_test = read_and_rename_well_data(barcode, well)
test = plot_ldr(df_test, peak_loc = peak_loc)
dcf_int.get_counts_df(df=df_test, barcode=barcode, well=well, peak_loc = peak_loc)

In [None]:
barcode = "211029_combo_180"
cell_line = "SUM1315"
well = "G06"

###### plot peaks w/ negative ldr values mapped to minimum positive ldr
df = read_and_rename_well_data(barcode, well, silent=True)

df_pos1 = df.query("ldr > 0")
min_ldr = np.min(df_pos1.ldr)
df_pos2 = df.query("dna > 0")
min_dna = np.min(df_pos2.dna)
df['ldr'] = [x if x>0 else min_ldr for x in df.ldr]
df['dna'] = [x if x>0 else min_dna for x in df.dna]

ldrint = df['ldr']
ldrint = ldrint[ldrint > 0]
logint = np.log10(ldrint)

import seaborn as sns
fig, ax = plt.subplots()
x, y = sns.kdeplot(logint, ax=ax).get_lines()[0].get_data()

peak_locs, _ = find_peaks(-y, prominence=0.2)
#print(peak_locs)
cc = x[peak_locs]
print(cc)

In [None]:
barcode = "211029_combo_180"
well = "G06" ### example of bad gating
well = "C07" ### example of good separation, large peak

######## plot peaks w/ peak_prominence
df = read_and_rename_well_data(barcode, well, silent=True)
ldrint = df['ldr']
ldrint = ldrint[ldrint > 0]
logint = np.log10(ldrint)

import seaborn as sns
fig, ax = plt.subplots()
x, y = sns.kdeplot(logint, ax=ax).get_lines()[0].get_data()

peak_locs, test = find_peaks(-y, prominence=0.1, width = 30)
#print(peak_locs)
cc = x[peak_locs]
print(cc)

In [None]:
test

In [None]:
######## plot peaks w/ normal algorithm
df = read_and_rename_well_data(barcode, well, silent=True)
ldrint = df['ldr']
ldrint = ldrint[ldrint > 0]
logint = np.log10(ldrint)

import seaborn as sns
fig, ax = plt.subplots()
x, y = sns.kdeplot(logint, ax=ax).get_lines()[0].get_data()

peak_locs, _ = find_peaks(-y)
#print(peak_locs)
cc = x[peak_locs]
print(cc)

fig,ax = plt.subplots()
# plot the data
ax.plot(x,-y)

In [None]:
#### test viewing a few plates, looking for badly gated wells
barcode = "211029_combo_180"
cell_line = "BT20"
#plot_wells_ldr(barcode, cell_line, output_dir = "")

barcode = "211029_combo_182"
cell_line = "SUM159"
#plot_wells_ldr(barcode, cell_line, output_dir = "")


barcode = "211029_combo_180"
cell_line = "SUM1315"
plot_wells_ldr(barcode, cell_line, output_dir = "")

#well = "D12"
#plot_ldr_well(barcode, well)
#df = read_and_rename_well_data(barcode, well)
#df

In [None]:
for i in range(df.shape[0]):
    if df.ldr[i]<=0:
        print(i)

In [None]:
#meta_sub = meta.filter(meta.columns[0:9])
#df_new = df_new_gating.merge(meta_sub, on = ['barcode', 'well'], how = 'left')
#df_new.sort_values(by=['cell_count__dead'])

In [None]:
#df_new.query('agent1 == ""').sort_values(by=['cell_count__dead'])

In [None]:
#barcode = '210226_combo_51'
#cell_line = 'HCC1937'
#well = "I06"
#well = "J07"

#barcode = '210226_combo_52'
#cell_line = 'HCC1937'
#well = "I11"

#barcode = '210226_combo_53'
#cell_line = 'HCC1937'
#well = "I18"

barcode = '210226_combo_54'
cell_line = 'HCC1937'
well = "I03"
well = "I11"
well = "J04"
well = "J07"

barcode = '210226_combo_55'
cell_line = 'HCC1937'
well = "I15"

barcode = '210226_combo_56'
cell_line = 'HCC1937'
well = "I11"
#well = "I14"
#well = "I18"
#well = "I21"

barcode = '210226_combo_57'
cell_line = 'HCC1937'
well = "I03"
well = "I06"
well = "I11"
well = "J04"
well = "J07"

barcode = '210302_combo_59'
cell_line = 'HCC1937'
well = "I15"
#well = "J20"

barcode = '210302_combo_60'
cell_line = 'HCC1937'
well = "I11"
well = "I18"

barcode = '210302_combo_61'
cell_line = 'HCC1937'
well = "I06"
#well = "J07"

barcode = '210406_combo_69'
cell_line = 'SUM1315'
well = "E15"
well = "F20"

barcode = '210406_combo_70'
cell_line = 'SUM1315'
well = "E18"
#well = "E11"

barcode = '210406_combo_71'
cell_line = 'SUM1315'
#well = "E06"
well = "F07"

barcode = '210406_combo_72'
cell_line = 'SUM1315'
well = "E11"
#well = "E12"
#well = "E14"
#well = "E15"
#well = "F20"

barcode = '210406_combo_73'
cell_line = 'SUM1315'
well = "E11"
#well = "E14"
#well = "E18"
#well = "E21"
#well = "F05"

barcode = '210406_combo_74'
cell_line = 'SUM1315'
well = "E03"
well = "E06"
well = "E11"
well = "F04"
well = "F07"

barcode = '210406_combo_75'
cell_line = 'SUM1315'
well = "E11"
#well = "E12"
#well = "E14"
#well = "E15"
#well = "F20"

barcode = '210406_combo_76'
cell_line = 'SUM1315'
well = "E11"
well = "E14"
well = "E18"
well = "E21"
well = "F05"

barcode = '210406_combo_77'
cell_line = 'SUM1315'
well = "E03"
#well = "E06"
#well = "E11"
#well = "F04"
#well = "F07"

barcode = '211005_combo_158'
cell_line = 'SUM1315'
well = "G11"
well = "G12"
well = "G14"

barcode = '211005_combo_160'
cell_line = 'SUM1315'
well = "G11"
well = "G12"
well = "G14"

barcode = '211005_combo_164'
cell_line = 'SUM1315'
well = "G11"
#well = "G14"
#well = "H20"

barcode = '211005_combo_165'
cell_line = 'SUM1315'
well = "G11"
#well = "G14"
#well = "H20"

barcode = '211005_combo_166'
cell_line = 'SUM1315'
well = "G11"
well = "H04"
well = "H07"

### using DNAcontent column as dna -- same as before
df_test = read_and_rename_well_data(barcode, well)
### using HoechstINT column as dna
### note: using Hoechst gives lower dead count -- only 12 dead_subg1 vs. ~500 for DNAContent
#df_test = read_and_rename_well_data(barcode, well, hoechst_as_dna=True)
peak_loc = 1.2
test = plot_ldr(df_test, peak_loc = peak_loc)
dcf_int.get_counts_df(df=df_test, barcode=barcode, well=well, peak_loc = peak_loc)

In [None]:
### testing wells that legitimately do have high dead counts

barcode = "210423_combo_84"
well = "K21" ### Alp. + Tram -- cutoff at 1.5 ?? -- similar for K19
well = "K11" ### ctrl well -- low dead cells -- cutoff at ~3.25

### using DNAcontent column as dna -- same as before
df_test = read_and_rename_well_data(barcode, well)
### using HoechstINT column as dna
### note: using Hoechst gives lower dead count -- only 12 dead_subg1 vs. ~500 for DNAContent
#df_test = read_and_rename_well_data(barcode, well, hoechst_as_dna=True)
peak_loc = 1.2
test = plot_ldr(df_test, peak_loc = peak_loc)
dcf_int.get_counts_df(df=df_test, barcode=barcode, well=well, peak_loc = peak_loc)

In [None]:
barcode = '210406_combo_62'
cell_line = 'SUM1315'
well = "E03"

#barcode = "210423_combo_78"
#well = 'K04'
#cell_line = "SUM185PE"

#test = plot_problem_plate(barcode, cell_line)
plot_ldr_well(barcode, well, peak_loc = 1.2)
df_test = read_and_rename_well_data(barcode, well)
df_tmp = dcf_int.get_counts_df(df=df_test, barcode=barcode, well=well, peak_loc = 1.2)
#plot_wells_ldr(barcode, cell_line)
#df_tmp

dcf_int.get_ldrgates(df_test.ldr, peak_loc=1.2)[0]
#df_tmp