In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Literal

def process_df(df):

    new_df = pd.DataFrame()

    for sample in df['Sample'].unique():
        sample_df = df[df['Sample'] == sample]

        for rec_ev_id in sample_df['RecEv_ID'].unique():
            rec_ev_id_df = sample_df[sample_df['RecEv_ID'] == rec_ev_id]
            event = rec_ev_id_df['event'].iloc[0]

            #if the dataframe has only one row, add the row to the new dataframe
            if len(rec_ev_id_df) == 1:
            
                if rec_ev_id_df['Region'].iloc[0] < 0:
                    Region_Start = rec_ev_id_df['Region'].iloc[0]
                    Region_End = 0
                else:
                    Region_Start = 0
                    Region_End = rec_ev_id_df['Region'].iloc[0]

                new_row_single = {
                    'Sample': sample, 
                    'RecEv_ID': rec_ev_id, 
                    'Region_Start': Region_Start,
                    'Region_End': Region_End,
                    'Reference': rec_ev_id_df['Reference'].iloc[0],
                    'All_Positions: ': rec_ev_id_df['Region'].tolist(),
                    'Event': event # recent addition
                    }

                new_df = pd.concat([new_df, pd.DataFrame([new_row_single])], ignore_index=True)

            else:

                #order the dataframe by "Region" column
                rec_ev_id_df = rec_ev_id_df.sort_values(by='Region').copy().reset_index(drop=True)

                #check if the first and last values of the "Region" column are negative
                if rec_ev_id_df['Region'].iloc[0] < 0 and rec_ev_id_df['Region'].iloc[-1] < 0:
                    #add a row from the first value of the "Region" column to 0
                    new_row_start = {
                        'Sample': sample, 
                        'RecEv_ID': rec_ev_id, 
                        'Region_Start': rec_ev_id_df['Region'].iloc[0],
                        'Region_End': 0,
                        'Reference': rec_ev_id_df['Reference'].iloc[0],
                        'All_Positions: ': rec_ev_id_df['Region'].tolist(),
                        'Event': event # recent addition
                        }
                    new_df = pd.concat([new_df, pd.DataFrame([new_row_start])], ignore_index=True)

                #check if the first and last values of the "Region" column are positive
                elif rec_ev_id_df['Region'].iloc[0] > 0 and rec_ev_id_df['Region'].iloc[-1] > 0:
                    #add a row from 0 to the last value of the "Region" column
                    new_row_end = {
                        'Sample': sample, 
                        'RecEv_ID': rec_ev_id, 
                        'Region_Start': 0,
                        'Region_End': rec_ev_id_df['Region'].iloc[-1],
                        'Reference': rec_ev_id_df['Reference'].iloc[0],
                        'All_Positions: ': rec_ev_id_df['Region'].tolist(),
                        'Event': event # recent addition
                        }
                    new_df = pd.concat([new_df, pd.DataFrame([new_row_end])], ignore_index=True)

                else:
                    # go row by row and track the "Reference" and "Region" values
                    # if a "Reference" changes, add a row to the new dataframe where the
                    # "Region_End" is the middle position ("Region") between the switching "Reference" values

                    #initialize the "Region_Start" and "Region_End" variables
                    
                    if rec_ev_id_df['Region'].iloc[0] < 0:
                        Region_Start = rec_ev_id_df['Region'].iloc[0]
                    else:
                        Region_Start = 0

                    Region_End = 0
                    Reference = rec_ev_id_df['Reference'].iloc[0]
                    All_Positions = []
                    just_saved = False

                    for index, row in rec_ev_id_df.iterrows():
                        
                        if row['Reference'] != Reference:
                            Region_End = (Region_End + row['Region']) / 2
                            #add a row to the new dataframe
                            new_row = {
                                'Sample': sample, 
                                'RecEv_ID': rec_ev_id, 
                                'Region_Start': Region_Start,
                                'Region_End': Region_End,
                                'Reference': Reference,
                                'All_Positions: ': All_Positions,
                                'Event': event # recent addition
                                }
                            new_df = pd.concat([new_df, pd.DataFrame([new_row])], ignore_index=True)

                            #update the "Region_Start" and "Region_End" variables
                            Region_Start = Region_End
                            Reference = row['Reference']
                            All_Positions = [row['Region']]

                        else:
                            Region_End = row['Region']
                            All_Positions.append(row['Region'])

                    #add the last row to the new dataframe
                    new_row = {
                        'Sample': sample, 
                        'RecEv_ID': rec_ev_id, 
                        'Region_Start': Region_Start,
                        'Region_End': row['Region'],
                        'Reference': Reference,
                        'All_Positions: ': All_Positions,
                        'Event': event # recent addition
                        }
                    
                    new_df = pd.concat([new_df, pd.DataFrame([new_row])], ignore_index=True)

    return new_df

def add_genotype_column(df):
    from sample_genotypes_final import genotype_dict
    genotype_dict = genotype_dict["final_set"]
    df["Genotype"] = df["Sample"].map(genotype_dict)
    return df

def rank_tracks(df):

    df['Length'] = (df['Region_End'] - df['Region_Start']).abs()
    df['Total_Length'] = df.groupby(['Sample', 'RecEv_ID'])['Length'].transform('sum')

    # Sort in desired order (descending by Total_Length, ascending by Sample and RecEv_ID)
    df = df.sort_values(
        by=['Total_Length', 'Sample', 'RecEv_ID'],
        ascending=[False, True, True]
    )

    # Assign group-based rank with NO gaps (same group => same rank).
    # 'sort=False' tells pandas to respect the current row order rather than re-sorting.
    df['rank'] = df.groupby(
        ['Total_Length', 'Sample', 'RecEv_ID'],
        sort=False).ngroup() + 1

    # For plotting, let's use 'rank' as our row "index" in the vertical axis
    # (or name it something else like 'Index').
    df['Index'] = df['rank']
    # order the dataframe by "Index" column
    df = df.sort_values(by='Index').copy().reset_index(drop=True)

    return df

def plot_tracks(df, save_name="plot.png", show: bool = False, genotype_label: str = ""):
    # Draw a plot of "Region_Start" to "Region_End" intervals as stacked horizontal lines
    fig, ax = plt.subplots(figsize=(5, 5))  # adjust figure size as needed

    for _, row in df.iterrows():
        # Choose color based on 'Reference'
        if row['Reference'] == 'C->T':
            color = 'tab:red'
        elif row['Reference'] == 'G->A':
            color = 'tab:blue'
        else:
            color = 'tab:gray'
        
        # Plot a line from Region_Start to Region_End
        ax.plot(
            [row['Region_Start'], row['Region_End']],
            [row['Index'], row['Index']],  # negative to invert the order if desired
            color=color,
            linewidth=0.8,
            alpha=0.75
        )
        #also plot all positions on the line as dots
        for position in row['All_Positions: ']:
            ax.plot(
                position,
                row['Index'],
                color=color,
                marker='o',
                markersize=1.5,
                linewidth=0,
                alpha=0.75)

    ax.yaxis.set_visible(False)
    ax.yaxis.set_ticks([])
    ax.tick_params(axis='x', labelsize=16)
    ax.axvline(x=0, color='black', linestyle='--', linewidth=0.5)

    #add a legend to the plot
    ax.plot([], [], color='tab:red', label='C→T tracts')
    ax.plot([], [], color='tab:blue', label='G→A tracts')
    plt.legend(fontsize=11)

    ax.text(0.02, 0.98, genotype_label, transform=ax.transAxes, fontsize=15, verticalalignment='top', fontstyle='italic') #fontweight='bold',
    ax.set_xlabel('Distance (kb) from DSB axis', fontsize=14, fontweight='bold')
    
    plt.tight_layout()

    #save the plot as a .png file
    plt.savefig(save_name, dpi=300)   

    if show:
        plt.show()      
    plt.close()

def plot_magnitude_separate(
        magnitude_CT, 
        magnitude_GA, 
        save_name="plot_magnitude_both", 
        genotype_label: str = "",
        show: bool = False
        ):
    fig, ax = plt.subplots(figsize=(6, 4.5))

    ax.plot(magnitude_GA, label='G→A tracts', color='tab:blue')
    ax.fill_between(range(len(magnitude_GA)), magnitude_GA, color='tab:blue', alpha=0.1)

    ax.plot(magnitude_CT, label='C→T tracts', color='tab:red')
    ax.fill_between(range(len(magnitude_CT)), magnitude_CT, color='tab:red', alpha=0.1)

    ax.axvline(x=10000, color='black', linestyle='--', linewidth=0.5)
    ax.set_xticks([0, 5000, 10000, 15000, 20000])
    ax.set_xticklabels([-10, -5, 0, 5, 10])
    ax.yaxis.set_visible(False)
    ax.yaxis.set_ticks([])
    #ax.margins(0)
    ax.set_xlabel('Distance (kb) from DSB axis', fontsize=14, fontweight='bold')
    ax.tick_params(axis='x', labelsize=16)
    ax.legend(fontsize=12)

    if genotype_label != "":
        #put the genotype label on the plot upper left corner
        ax.text(0.02, 0.98, genotype_label, transform=ax.transAxes, fontsize=15, verticalalignment='top', fontstyle='italic') #fontweight='bold',
    plt.tight_layout()
    plt.savefig(f'{save_name}', dpi=300)
    if show:
        plt.show()
    plt.close()

def plot_magnitude_sum(
        magnitude_sum, 
        save_name="plot_magnitude_sum", 
        genotype_label: str = "",
        show: bool = False
        ):

    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(magnitude_sum, label='Sum', color='tab:green')
    #add a horizontal line at through the middle of the plot
    ax.axvline(x=10000, color='black', linestyle='--', linewidth=0.5)
    ax.set_xticks([0, 5000, 10000, 15000, 20001])
    ax.set_xticklabels([-10, -5, 0, 5, 10])
    ax.yaxis.set_visible(False)
    ax.yaxis.set_ticks([])
    #ax.margins(0)
    ax.set_xlabel('Distance (kb) from DSB axis', fontsize=14, fontweight='bold')
    ax.tick_params(axis='x', labelsize=16)
    ax.legend(fontsize=12)

    if genotype_label != "":
        ax.text(0.05, 0.98, genotype_label, transform=ax.transAxes, fontsize=15, verticalalignment='top', fontstyle='italic') #fontweight='bold',
    plt.tight_layout()
    plt.savefig(f'{save_name}', dpi=300)
    if show:
        plt.show()
    plt.close()

def build_df_magnitude(df, df_counts: pd.DataFrame, genotype_label: str, mutation_set: str) -> tuple:

    magnitude_CT = np.zeros(20001)
    magnitude_GA = np.zeros(20001)

    for index, row in df.iterrows():
        start = row['Region_Start']*1000
        end = row['Region_End']*1000

        start = int(start)+10000
        end = int(end)+10000

        if row['Reference'] == 'C->T':
            magnitude_CT[start:end+1] += 1

        elif row['Reference'] == 'G->A':
            magnitude_GA[start:end+1] += 1
        else:
            print("Error in the Reference column")


    #get the resection and BIR counts every 500 bp
    #global df_counts
    for i in range(0, 20001, 100):
        resection_count = magnitude_CT[i] + magnitude_GA[-1-i]
        BIR_count = magnitude_CT[-1-i] + magnitude_GA[i]
        position = i - 10000
        df_add = pd.DataFrame({'Position': position, 'Resection': [resection_count], 'BIR': [BIR_count], 'Genotype': [genotype_label], 'Mutation_Set': [mutation_set]})
        df_counts = pd.concat([df_counts, df_add], ignore_index=True)
        
    # print(f"Position: 0     RESECTION: {magnitude_CT[0] + magnitude_GA[20000]} BIR: {magnitude_CT[20000] + magnitude_GA[0]}")
    # print(f"Position: 5000  RESECTION: {magnitude_CT[6000] + magnitude_GA[14000]} BIR: {magnitude_CT[14000] + magnitude_GA[6000]}")
    # print(f"Position: 10000 RESECTION: {magnitude_CT[10000] + magnitude_GA[10000]} BIR: {magnitude_CT[10000] + magnitude_GA[10000]}")
    # print(f"Position: 15000 RESECTION: {magnitude_CT[14000] + magnitude_GA[6000]} BIR: {magnitude_CT[6000] + magnitude_GA[14000]}")
    # print(f"Position: 20000 RESECTION: {magnitude_CT[20000] + magnitude_GA[0]} BIR: {magnitude_CT[0] + magnitude_GA[20000]}")

    #normalize the magnitude arrays so that the sum of all values is 1 in each array
    magnitude_CT = magnitude_CT / np.sum(magnitude_CT)
    magnitude_GA = magnitude_GA / np.sum(magnitude_GA)

    #smooth the magnitude arrays
    magnitude_CT = np.convolve(magnitude_CT, np.ones(500)/500, mode='same')
    magnitude_GA = np.convolve(magnitude_GA, np.ones(500)/500, mode='same')

    #rotate the magnitude_GA around the y-axis and add the two magnitude arrays together
    magnitude_GA_rotated = np.flip(magnitude_GA)
    magnitude_sum = magnitude_CT + magnitude_GA_rotated

    return magnitude_CT, magnitude_GA, magnitude_sum, df_counts

def conduct_chi2_test_by_position(df_counts, print_results: bool = False, save_name: str = "chi2_test_results_100bp_inc.csv") -> pd.DataFrame:

    from scipy import stats

    control = "ung1∆"

    df_counts['is_significant'] = False

    test_results_df = pd.DataFrame(columns=['Genotype_Test', 'Genotype_Control', 'Position', 'Chi2', 'p', 'observed_test (Resection, BIR)', 'observed_control (Resection, BIR)'])
    for genotype in df_counts['Genotype'].unique():

        control_df = df_counts[df_counts['Genotype'] == control].copy()

        if genotype == control:
            continue

        #perform chi2 test for each position where the 2x2 table is the resection and BIR counts for the control and the other genotype
        for position in df_counts['Position'].unique():
            other_df = df_counts[(df_counts['Genotype'] == genotype) & (df_counts['Position'] == position)].copy()
            control_position = control_df[control_df['Position'] == position].copy()
            if len(other_df) > 0 and len(control_position) > 0:
                resection_control = control_position['Resection'].iloc[0]
                BIR_control = control_position['BIR'].iloc[0]
                resection_other = other_df['Resection'].iloc[0]
                BIR_other = other_df['BIR'].iloc[0]

                #if there are 0 values in the 2x2 table, just add nan values to the test_results_df
                if resection_control == 0 or BIR_control == 0 or resection_other == 0 or BIR_other == 0:
                    test_results_df = pd.concat([test_results_df, pd.DataFrame([[genotype, control, position, 'NA', 'NA', (resection_other, BIR_other), (resection_control, BIR_control)]], columns=['Genotype_Test', 'Genotype_Control', 'Position', 'Chi2', 'p', 'observed_test (Resection, BIR)', 'observed_control (Resection, BIR)'])], ignore_index=True)
                    if print_results:
                        print(f"Genotype: {genotype} Position: {position:6}, Observed: ({resection_control:5}, {BIR_control:5}) vs. ({resection_other:5}, {BIR_other:5}), Chi2: nan, p: nan")
                        df_counts.loc[(df_counts['Genotype'] == genotype) & (df_counts['Position'] == position), 'is_significant'] = False
                else:
                    chi2, p, dof, expected = stats.chi2_contingency([[resection_control, BIR_control], [resection_other, BIR_other]])
                    test_results_df = pd.concat([test_results_df, pd.DataFrame([[genotype, control, position, chi2, p, (resection_other, BIR_other), (resection_control, BIR_control)]], columns=['Genotype_Test', 'Genotype_Control', 'Position', 'Chi2', 'p', 'observed_test (Resection, BIR)', 'observed_control (Resection, BIR)'])], ignore_index=True)
                    if print_results:
                        print(f"Genotype: {genotype} Position: {position:6}, Observed: ({resection_control:5}, {BIR_control:5}) vs. ({resection_other:5}, {BIR_other:5}), Chi2: {chi2:.5}, p: {p:.3e}")
                    if p <= 0.05:
                        df_counts.loc[(df_counts['Genotype'] == genotype) & (df_counts['Position'] == position), 'is_significant'] = True

    #save the test_results_df to a csv file, but only for negative positions
    test_results_df_save = test_results_df[test_results_df['Position'] <= 0].copy()
    #multiply the position column by -1 to make the positions positive
    test_results_df_save['Position'] = test_results_df_save['Position'] * -1
    test_results_df_save.to_csv(save_name, sep='\t', index=False, encoding='utf-16')

    return df_counts

def normalize_arrays(array_set, normalization_genotype: str = 'ung1∆', log2_fc: bool = True) -> list:
    # from scipy.stats import ks_2samp

    normalized_arrays = []

    #normalize the arrays by dividing each position by the corresponding position in the normalization array
    normalize_by_array = [array[0] for array in array_set if array[1] == normalization_genotype][0]
    normalize_by_array = normalize_by_array / np.sum(normalize_by_array)

    for array in array_set:
        
        array_name = array[1]
        # normalize each array to its size/area (area is normalized to 1)
        normalized_array = array[0] / np.sum(array[0])
        
        # check by ks-test if the arrays are different from the normalization array
        # ks_stat, ks_pval = ks_2samp(normalize_by_array, normalized_array)
        # print(f"{array_name} KS-test p-value: {ks_pval}")

        if log2_fc:
            # normalize the arrays again by dividing each position by the corresponding position in the normalization array
            normalized_array = normalized_array / normalize_by_array

            # change the values to log2
            normalized_array = np.log2(normalized_array)

            #smooth the arrays
            normalized_array = np.convolve(normalized_array, np.ones(1000)/1000, mode='same')

        normalized_arrays.append([normalized_array, array_name])

    return normalized_arrays

def find_sig_intervals(df_genotype, array_range: tuple = (0, 20001)):
    sig_intervals = []
    start = None
    end = None
    for index, row in df_genotype.iterrows():
        if row['is_significant']:
            if start is None:
                start = row['Index']
                end = row['Index']
            else:
                end = row['Index']
        else:
            if start is not None:
                sig_intervals.append((start, end))
                start = None
                end = None
    if start is not None:
        sig_intervals.append((start, end))

    #find not significant intervals
    not_sig_intervals = []
    start = None
    end = None

    for index, row in df_genotype.iterrows():
        if not row['is_significant']:
            if start is None:
                start = row['Index']
                end = row['Index']
            else:
                end = row['Index']
        else:
            if start is not None:
                not_sig_intervals.append((start, end))
                start = None
                end = None
    if start is not None:
        not_sig_intervals.append((start, end))

    return sig_intervals, not_sig_intervals

def plot_resection_vs_bir(normalized_arrays, df_counts_test, sample_set, log2_fc, legend_dict):

    sns.set_context("poster")
    fig, ax = plt.subplots(figsize=(8.5, 6))

    #set pallette for the genotypes
    palette_dict = {
        'ung1∆': 'tab:blue',
        'ung1∆NAT': 'tab:green',
        'ung1∆exo1-nd': 'tab:red',
        'ung1∆pol32∆': 'tab:orange',
        'ung1∆exo1-ndpol32∆': 'tab:purple',
        'sgs1∆C': 'tab:brown',
        'ung1∆exo1-ndsgs1∆C': 'tab:pink'
    }

    # Figure1
    for array in normalized_arrays:
        if 'ung1∆NAT' in array[1]:
            continue

        genotype = array[1]
        genotype = legend_dict[genotype]

        df_sig = df_counts_test[df_counts_test['Genotype'] == genotype].copy()
        df_sig = df_sig[df_sig['Mutation_Set'] == sample_set].copy()
        sig_list, non_sig_list = find_sig_intervals(df_sig)

        #just add a label first
        ax.plot([], [], label=array[1], linestyle='-', color=palette_dict[genotype])

        if genotype == 'ung1∆':
            #plot a line at y0
            ax.plot(range(20001), np.zeros(20001), color='tab:blue', linestyle='--', linewidth=2)
            continue

        for sig_interval in sig_list:
            ax.plot(range(sig_interval[0], sig_interval[1]), array[0][sig_interval[0]:sig_interval[1]], linestyle='-', color=palette_dict[genotype])
        for non_sig_interval in non_sig_list:
            ax.plot(range(non_sig_interval[0], non_sig_interval[1]), array[0][non_sig_interval[0]:non_sig_interval[1]], linestyle=':', color=palette_dict[genotype], alpha=0.5, linewidth=2)


    #relabel the x-axis by subtracting the center position from each position
    ax.set_xticks([0, 5000, 10000, 15000, 20000])
    ax.set_xticklabels([-10, -5, 0, 5, 10])
    ax.axvline(x=10000, color='black', linestyle='--', linewidth=1)

    #add y-axis label
    if log2_fc == True:
        ax.set_ylabel('log2(FC) vs. $\\it{ung1∆}$')
    if log2_fc == False:
        ax.set_ylabel('Normalized Area')
        ax.yaxis.set_ticks([])

    #add shading behind the plot for the different regions
    ax.axvspan(0, 10000, color='tab:green', alpha=0.1)
    ax.axvspan(10000, 20000, color='tab:pink', alpha=0.1)

    #remove padding from the x-axis
    plt.margins(x=0)
    # set the x axis limits
    # plt.xlim(2000, 18000)

    handles, labels = ax.get_legend_handles_labels()
    labels = [legend_dict[label] for label in labels]
    ax.legend(handles, labels, prop={'style': 'italic', 'size': 12})
    plt.tight_layout()
    plt.savefig(f'plots/{sample_set}_area_normalized_FC_{log2_fc}', dpi=300)
    plt.show()

def plot_proportion_resection_or_bir(array_set, df_counts_test, sample_set, legend_dict, proportion_of: Literal['resection', 'BIR'] = "resection"):

    sns.set_context("poster")
    fig, ax = plt.subplots(figsize=(8.5, 6))

    #set pallette for the genotypes
    palette_dict = {
        'ung1∆': 'tab:blue',
        'ung1∆NAT': 'tab:green',
        'ung1∆exo1-nd': 'tab:red',
        'ung1∆pol32∆': 'tab:orange',
        'ung1∆exo1-ndpol32∆': 'tab:purple',
        'sgs1∆C': 'tab:brown',
        'ung1∆exo1-ndsgs1∆C': 'tab:pink'
        }
    for array_pack in array_set:

        array = array_pack[0]
        genotype = array_pack[1]

        if genotype =='ung1∆NAT':
            continue

        array_resection = array[:10001]
        array_resection = np.flip(array_resection)
        array_bir = array[10000:]

        if proportion_of == 'resection':
            array_to_plot = array_resection / (array_resection + array_bir)
        elif proportion_of == 'BIR':
            array_to_plot = array_bir / (array_resection + array_bir)
        else:
            print("Error in the proportion_of argument")
            return

        genotype = legend_dict[genotype]

        #just add a label first
        ax.plot([], [], label=genotype, linestyle='-', color=palette_dict[genotype])

        if genotype == 'ung1∆':
            ax.plot(array_to_plot, linestyle='--', color=palette_dict[genotype])
            continue

        ax.plot(array_to_plot, linestyle=':', color=palette_dict[genotype], alpha=0.5, linewidth=2)

        df_sig = df_counts_test[df_counts_test['Genotype'] == genotype].copy()
        df_sig = df_sig[df_sig['Mutation_Set'] == sample_set].copy()
        sig_list, non_sig_list = find_sig_intervals(df_sig)

        for sig_list_interval in sig_list:
            #remove 10 000 from both start and end of the interval
            sig_list_interval = (sig_list_interval[0]-10000, sig_list_interval[1]-10000)
            #if both start and end are negative, skip
            if sig_list_interval[0] < 0 and sig_list_interval[1] < 0:
                continue
            #if only start is negative, set it to 0
            if sig_list_interval[0] < 0:
                sig_list_interval = (0, sig_list_interval[1])
            
            #plot the interval as thicker line
            ax.plot(range(sig_list_interval[0], sig_list_interval[1]), array_to_plot[sig_list_interval[0]:sig_list_interval[1]], linestyle='-', color=palette_dict[genotype])

    #add y axis label
    ax.set_ylabel(f'Proportion of {proportion_of} tracts', fontsize=18, fontweight='bold')
    ax.set_xticks([0, 2000, 4000, 6000, 8000, 10000])
    ax.set_xticklabels([0, 2, 4, 6, 8, 10])
    ax.set_ylim(0, None)

    #remove padding from the x-axis
    plt.margins(x=0)

    #set x-axis label
    ax.set_xlabel('Distance (kb) from DSB axis', fontsize=18, fontweight='bold')
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles, labels, prop={'style': 'italic', 'size': 12})
    plt.tight_layout()
    plt.savefig(f'plots/{sample_set}_{proportion_of}_proportion', dpi=300)
    plt.show()

def transform_tracks(df_t: pd.DataFrame) -> pd.DataFrame:

    new_df = pd.DataFrame()

    for index, row in df_t.iterrows():
        #if region start and region end are passing the 0, split the region into two
        if row["Region_Start"] < 0 and row["Region_End"] > 0:
            new_row_1 = row.copy()
            new_row_1["Region_End"] = abs(new_row_1["Region_Start"])
            new_row_1["Region_Start"] = 0

            new_row_2 = row.copy()
            new_row_2["Region_Start"] = 0

            new_df = pd.concat([new_df, pd.DataFrame([new_row_1, new_row_2])], ignore_index=True)

        #if both region start and region end are negative, make them positive
        elif row["Region_Start"] <= 0 and row["Region_End"] <= 0:
            row["Region_Start"] = abs(row["Region_Start"])
            row["Region_End"] = abs(row["Region_End"])
            #Swap Region_Start and Region_End if Region_Start is larger than Region_End
            if row["Region_Start"] > row["Region_End"]:
                temp = row["Region_Start"]
                row["Region_Start"] = row["Region_End"]
                row["Region_End"] = temp
            new_df = pd.concat([new_df, pd.DataFrame([row])], ignore_index=True)   

        #if both region start and region end are positive, keep them as they are
        else:
            new_df = pd.concat([new_df, pd.DataFrame([row])], ignore_index=True)

    return new_df

rtg_list = [
    "JK_92_CA",
    "JK2_48_G",
    "JK2_55_C",
    "JK2_66_T",
    "JK2_97_C",
    "JK2_99_G",
    "JK2_100_",

    "JK3_E3_A",
    "JK3_E11_",

    "JK3_P1_C",
    "JK3_P3_C",
    "JK3_P4_C",
    "JK3_P6_G",
    "JK3_P7_G",
    "JK3_P8_G",
    "JK3_P10_",
    "JK3_P11_",
    "JK3_P13_",
    "JK3_P15_",
    "JK3_P17_",
    "JK3_P18_",
    "JK3_P19_",
    "JK3_P20_",
    "JK3_P21_",
    "JK3_P22_",
    "JK3_P23_",
    "JK3_P24_",
    "JK3_P25_",
    "JK3_P26_",
    "JK3_P27_",
    "JK3_P28_",
    "JK3_P32_",
    "JK3_P33_",
    "JK3_P34_",

    "JK5_41-3",
    "JK5_42-3",
    "JK5_43-3",
    "JK5_44-3",
    "JK5_45-3",

    "JMT41_CK", #this is a dissection of spo13 spores, real rtg?
    "JMT42_CK", #this is a dissection of spo13 spores, real rtg?
    "JMT43_CK", #this is a dissection of spo13 spores, real rtg?
    "JMT44_CK", #this is a dissection of spo13 spores, real rtg?
    "JMT45_CK", #this is a dissection of spo13 spores, real rtg?
    "JMT46_CK", #this is a dissection of spo13 spores, real rtg?
    "JMT47_CK", #this is a dissection of spo13 spores, real rtg?
    "JMT48_CK", #this is a dissection of spo13 spores, real rtg?

    "SD3_CKDN" ,#seems like RTG/non-sporulated
    "A4T2_CKD" ,#seems like RTG/non-sporulated
    "A4T3_CKD" ,#seems like RTG/non-sporulated
    "A4T4_CKD" ,#seems like RTG/non-sporulated

    "A4B1_CKD" ,#seems like RTG/non-sporulated
    "A4B2_CKD" ,#seems like RTG/non-sporulated
    "A4B3_CKD" ,#seems like RTG/non-sporulated
    "A4B4_CKD" ,#seems like RTG/non-sporulated
    
    "JMTD_4_C",

    "JT1_CKDN",
    "JT3_CKDN",
    "JT5_CKDN",
    "JT10_CKD",
    "JT29_CKD",

    "JT40_CKD",
    "JT45_CKD",
    "JT48_CKD",
    "JT52_CKD",
    "JT56_CKD",
    "JT58_CKD",
    "JT59_CKD",

    "Mal_JK15",
    "Mal_JK18",
    "Mal_JK20",
    "Mal_JK34",

    "Sgs-10_R",
    "Sgs-13_R",
    "Sgs-21_R",
    "Sgs-2_R1",

    "SgsEx-26",
    "SgsEx-36",
    "SgsEx-40",
    "SgsEx-5_"

    ]

In [None]:
sns.set_context("poster")

arrays = []
types = ['clustered', 'scattered'] #, 'all'
genotypes_raw =['ung1∆', 'ung1∆NAT', 'exo1-nd', 'pol32∆', 'exo1-ndpol32∆', 'sgs1∆C', 'exo1-ndsgs1∆C']
genotypes_final=['ung1∆', 'ung1∆NAT', 'ung1∆exo1-nd', 'ung1∆pol32∆', 'ung1∆exo1-ndpol32∆', 'ung1∆sgs1∆C', 'ung1∆exo1-ndsgs1∆C']

df_counts = pd.DataFrame(columns=['Position', 'Resection', 'BIR', 'Genotype'])

for dtype in types: 

    df_path = f'relative_mutations_GC_CO_{dtype}_new.csv'

    df = pd.read_csv(df_path, sep='\t', encoding='utf-16')

    df = df[(df['Region'] >= -10) & (df['Region'] <= 10)]

    df = process_df(df)
    df = add_genotype_column(df)

    #make sure the dataframe only contains the samples that are not in the rtg_list
    df = df[~df['Sample'].isin(rtg_list)]

    genotype_dict = dict(zip(genotypes_raw, genotypes_final))

    for genotype in df['Genotype'].unique():
        print(f"PLOTTING GENOTYPE: {genotype} {dtype}")
        genotype_df = df[df['Genotype'] == genotype].copy()

        # plot the dataframe as a stacked horizontal line plot and as a magnitude (count) plots, create tracks
        df_t = rank_tracks(genotype_df)
        plot_tracks(df_t, save_name=f"plots/{genotype}_{dtype}.png", genotype_label=genotype_dict[genotype])

        # build the magnitude arrays, add counts to the df_counts dataframe
        magnitude_CT, magnitude_GA, magnitude_sum, df_counts = build_df_magnitude(
            df=genotype_df,
            df_counts=df_counts,
            genotype_label=genotype_dict[genotype],
            mutation_set=dtype)
        arrays.append([magnitude_sum, f"{genotype}_{dtype}"])

        # plot the magnitude arrays
        plot_magnitude_separate(magnitude_CT, magnitude_GA, save_name=f"plots/{genotype}_{dtype}_magnitude.png", genotype_label=genotype_dict[genotype])
        plot_magnitude_sum(magnitude_sum, save_name=f"plots/{genotype}_{dtype}_magnitude_sum.png", genotype_label=genotype_dict[genotype])

        # new_df = transform_tracks(df_t)
        # new_dfs.append(new_df)

arrays_clustered = [array for array in arrays if 'clustered' in array[1]]
arrays_scattered = [array for array in arrays if 'scattered' in array[1]]
arrays_all = [array for array in arrays if 'all' in array[1]]


In [None]:
sample_set = 'scattered' # 'clustered', 'scattered', 'all'
log2_fc = True
array_set = arrays_clustered if sample_set == 'clustered' else arrays_scattered if sample_set == 'scattered' else arrays_all

#conduct chi2 test for the clustered mutation set and save the results to a csv file
df_counts_test = df_counts[df_counts['Mutation_Set'] == sample_set].copy()
df_counts_test = conduct_chi2_test_by_position(df_counts_test, print_results=False, save_name=f"chi2_test_results_{sample_set}_100bp_inc.csv")
df_counts_test['Index'] = df_counts_test['Position'] + 10000

#array_set_trimmed = [[array[0][2000:-2000], array[1]] for array in array_set]
normalized_arrays = normalize_arrays(array_set, normalization_genotype = f'ung1∆_{sample_set}', log2_fc=log2_fc)

#for the legend, change the raw genotypes to the final genotypes
genotypes_raw =['ung1∆', 'ung1∆NAT', 'exo1-nd', 'pol32∆', 'exo1-ndpol32∆', 'sgs1∆C', 'exo1-ndsgs1∆C']
genotypes_final=['ung1∆', 'ung1∆NAT', 'ung1∆exo1-nd', 'ung1∆pol32∆', 'ung1∆exo1-ndpol32∆', 'sgs1∆C', 'ung1∆exo1-ndsgs1∆C']
legend_dict = {genotypes_raw[i]+f'_{sample_set}': genotypes_final[i] for i in range(len(genotypes_raw))}

#make the plots larger
plot_resection_vs_bir(normalized_arrays, df_counts_test, sample_set, log2_fc, legend_dict)
plot_proportion_resection_or_bir(array_set, df_counts_test, sample_set, legend_dict, proportion_of='BIR')



Calculating average mutagenic ssDNA

In [None]:
arrays = []
types = ['clustered', 'scattered'] #, 'all'
genotypes_raw =['ung1∆', 'ung1∆NAT', 'exo1-nd', 'pol32∆', 'exo1-ndpol32∆', 'sgs1∆C', 'exo1-ndsgs1∆C']
genotypes_final=['ung1∆', 'ung1∆NAT', 'ung1∆exo1-nd', 'ung1∆pol32∆', 'ung1∆exo1-ndpol32∆', 'sgs1∆C', 'ung1∆exo1-ndsgs1∆C']

genotype_for_analysis = 'ung1∆'

df_counts = pd.DataFrame(columns=['Position', 'Resection', 'BIR', 'Genotype'])

for dtype in types: 

    print(f"PROCESSING: {dtype}")

    df_path = f'relative_mutations_GC_CO_{dtype}.csv'

    df = pd.read_csv(df_path, sep='\t', encoding='utf-16')

    df = df[(df['Region'] >= -7.5) & (df['Region'] <= 7.5)]

    df = process_df(df)

    df = add_genotype_column(df)

    df = df[df['Genotype'] == genotype_for_analysis].copy()

    #make sure the dataframe only contains the samples that are not in the rtg_list
    df = df[~df['Sample'].isin(rtg_list)]

    print(df)

    #count how many distinct samples are in the dataframe (Samples column)
    unique_samples = len(df['Sample'].unique()) 

    print(f"Number of unique samples: {unique_samples}")

    genotype_dict = dict(zip(genotypes_raw, genotypes_final))

    #create new column called Length that is the difference between Region_End and Region_Start
    df['Length'] = df['Region_End'] - df['Region_Start']

    total_length = df['Length'].sum()
    total_length_per_sample = total_length / unique_samples

    df_just_gc = df[df['Event'] == 'GC'].copy()
    df_just_co = df[df['Event'] == 'CO'].copy()

    total_length_gc = df_just_gc['Length'].sum()
    total_length_co = df_just_co['Length'].sum()

    print(f"Total length: {total_length}")
    print(f"Length per sample: {total_length_per_sample}")
    print(f"Total length GC: {total_length_gc}. Per sample: {total_length_gc / unique_samples}")
    print(f"Total length CO: {total_length_co}. Per sample: {total_length_co / unique_samples}")


    # for genotype in df['Genotype'].unique():
    #     print(f"PLOTTING GENOTYPE: {genotype} {dtype}")
        
    #     genotype_df = df[df['Genotype'] == genotype].copy()

    #     # plot the dataframe as a stacked horizontal line plot and as a magnitude (count) plots, create tracks
    #     df_t = rank_tracks(genotype_df)

