## Single Antibody DMS

In [None]:
#requirements

In [None]:
#Loading nessesary libraries
import os
import random
from Bio import SeqIO
from IPython.display import display, Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import dmslogo
import seaborn as sns

In [None]:
!pip install -U altair_viewer
!pip install altair
!pip install selenium
!pip install webdriver_manager

In [None]:
#initial loading and editing of counts table

In [None]:
import pandas as pd

file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'

# Read Excel into a DataFrame
df = pd.read_excel(file_path)

# Print first 5 rows with column titles
print(df.head())


In [None]:
#Reference

In [None]:
fasta_file = r'/Users/lucaschlotheuber/Desktop/ETH/RBD201_DMS1.fa'

# Read the FASTA file
for record in SeqIO.parse(fasta_file, "fasta"):
    wuhan_sequence = str(record.seq)
    break  # Assuming there's only one sequence in the FASTA file


In [None]:
#Loading the data
#Here we are using the modified DMS_file containing different identifiers such as immunization, conditions and immunization status
file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'

df_total = pd.read_excel(file_path, usecols=["DMS_RBD_AS_position", "Spike_AS_Position", "Count_of_Base", "Amino_Acid", "Type_of_Mutation", "Enrichment_Ratio", "barcode", "immunization", "condition", "Total_Reads"])

#in the original calculation SARS-CoV-2 Spike positions were annoted incorrectly (+5 AA)
df_total["Spike_AS_Position"] = df_total["Spike_AS_Position"] - 5

#Removing ~5000 datapoints with "inf" values in the Enrichment_Ratio column and discarding reads with less than 10.000 total reads
df_total = df_total.dropna(subset=['Enrichment_Ratio','Amino_Acid'])

df_total = df_total[df_total["Total_Reads"] > 1000]

immunization = "Wuhan_Sequence"
barcode = "Wuhan_Barcode"

data_wuhan = []
for position, amino_acid in enumerate(wuhan_sequence, start=1):
    data_wuhan.append({
        'DMS_RBD_AS_position': position,
        'Spike_AS_Position': position + 330,
        'Amino_Acid': amino_acid,
        'immunization': immunization,
        'barcode': barcode,
        'Enrichment_Ratio': 1,# Assuming an enrichment ratio of 1 for simplicity
    })

## Create a DataFrame
df_wuhan = pd.DataFrame(data_wuhan)

df_total = pd.concat([df_total, df_wuhan], ignore_index=True)


In [None]:
#Initial averaging and grouping

In [None]:
# Calculate the averages of Count_of_Base for the SYNOM and NON-SYNOM groups
average_counts = df_total.groupby('Type_of_Mutation')['Count_of_Base'].mean()
average_enrichment = df_total.groupby('Type_of_Mutation')['Enrichment_Ratio'].mean()

# Print the results
print("Average Count_of_Base for SYNOM group:", average_counts.get('SYNOM', 'N/A'))
print("Average Count_of_Base for NON-SYNOM group:", average_counts.get('NON-SYNOM', 'N/A'))

print("Average Enrichment_Ratio for SYNOM group:", average_enrichment.get('SYNOM', 'N/A'))
print("Average Enrichment_Ratio for NON-SYNOM group:", average_enrichment.get('NON-SYNOM', 'N/A'))

# Further subset into different immunization groups
immunization_groups = df_total['immunization'].unique()
for group in immunization_groups:
    if pd.notna(group):
        group_df = df_total[df_total['immunization'] == group]
        group_avg_counts = group_df.groupby('Type_of_Mutation')['Count_of_Base'].mean()
        print(f"\nAverage Count_of_Base for {group} - SYNOM group:", group_avg_counts.get('SYNOM', 'N/A'))
        print(f"Average Count_of_Base for {group} - NON-SYNOM group:", group_avg_counts.get('NON-SYNOM', 'N/A'))

# Calculate the standard deviations of Count_of_Base for the SYNOM and NON-SYNOM groups
std_devs = df_total.groupby('Type_of_Mutation')['Count_of_Base'].std()

# Print the results
print("Standard Deviation of Count_of_Base for SYNOM group:", std_devs.get('SYNOM', 'N/A'))
print("Standard Deviation of Count_of_Base for NON-SYNOM group:", std_devs.get('NON-SYNOM', 'N/A'))


The following code is generating line plots for each (grouped) conditions (Neutralizing Ab, Polyclonal Ab, "Library" and cells from mice immunized with either WT RBD or mutant RBD) - Single cell data is aggregated per condition.

In [None]:
#Raw data without separating in Binding and Escape fraction

In [None]:
#CODE FOR GENERATING LINEPLOTS WITH ENTIRE SEQUNCE, AND WITH HIGHLIGHTS
#Swap out barcode for barcode to do individual droplet analysis 

df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')] #Removes the first 33 positions due to bad read quality

df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

#The positions that have showed high enrichment ratios in the library droplets are discarded from the analysis (pos 33, 72, 81 and 151)
sites_to_show = map(
    str,
    #[(i+331) for i in range(30, 200) if i not in [33, 72, 81, 151]] +  
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] + #RBD-ACE2 interface according to article
    list(range(394,414)) + #R21 peptide sequnce with high affinity
    list(range(484, 505)) #R13 peptide sequence with high affinity
)
df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)


for immunization in df_filtered_agg['immunization'].unique():
    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')#.query("show_site")
    
    # Aggregate the data to ensure unique Spike_AS_Position values
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })
    
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)


    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(range(df_filtered_im['Spike_AS_Position'].min(), df_filtered['Spike_AS_Position'].max() + 1)).reset_index()
    
    # Merge the show_site column back into df_filtered
    df_filtered_im = df_filtered_im.merge(df_filtered_agg[['Spike_AS_Position', 'show_site']], on='Spike_AS_Position', how='left')
    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(20))  # Print first few rows to check if show_site is properly set

    fig, ax = dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Enrichment_Ratio",
        title=immunization,
        xlabel="Spike AA Position",
        ylabel="Enrichment Ratio",
        show_col="show_site"
    )
    

    
    # Save the figure
    #file_path = os.path.join(r"C:\Users\au649453\OneDrive - Aarhus universitet\PhD\Luca\DMS_plots\Enriched and targeted positions", f"{immunization}_lineplots.png")
    #plt.savefig(file_path, dpi = 300, bbox_inches = 'tight')
    #plt.close(fig)
    #plt.show()


In [None]:
#Initial heatmaps (rawdata)

In [None]:
import os
import pandas as pd
from Bio import SeqIO
import altair as alt

# Load FASTA sequence (Wuhan reference)
fasta_file = r'/Users/lucaschlotheuber/Desktop/ETH/RBD201_DMS1.fa'
for record in SeqIO.parse(fasta_file, "fasta"):
    wuhan_sequence = str(record.seq)
    break

# Load and clean the Excel data
file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'
df_total = pd.read_excel(file_path, usecols=[
    "DMS_RBD_AS_position", "Spike_AS_Position", "Count_of_Base",
    "Amino_Acid", "Type_of_Mutation", "Enrichment_Ratio",
    "barcode", "immunization", "condition", "Total_Reads"
])
df_total["Spike_AS_Position"] -= 5  # Adjust for 336 -> 331

# Clean up: remove NaNs, low reads, and stop codons
df_total = df_total.dropna(subset=['Enrichment_Ratio','Amino_Acid'])
df_total = df_total[df_total["Total_Reads"] > 1000]
df_total = df_total[df_total["Amino_Acid"] != '*']  # Exclude stop codons

# Add Wuhan reference line (assume enrichment ratio of 1)
immunization = "Wuhan_Sequence"
barcode = "Wuhan_Barcode"
data_wuhan = [{
    'DMS_RBD_AS_position': pos,
    'Spike_AS_Position': pos + 330,
    'Amino_Acid': aa,
    'immunization': immunization,
    'barcode': barcode,
    'Enrichment_Ratio': 1,
} for pos, aa in enumerate(wuhan_sequence, start=1) if aa != '*']
df_wuhan = pd.DataFrame(data_wuhan)
df_total = pd.concat([df_total, df_wuhan], ignore_index=True)

# Keep non-synonymous mutations and good regions only
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 364) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Mutation label (optional, not plotted here)
df_filtered['mutation'] = df_filtered['Amino_Acid'] + df_filtered['Spike_AS_Position'].astype(str)

# Aggregate for heatmap
# Aggregate and log2-transform enrichment ratios
df_heatmap = df_filtered.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization'], as_index=False
).agg({'Enrichment_Ratio': 'mean'})

# Filter out zeros or negative values before log2
df_heatmap = df_heatmap[df_heatmap['Enrichment_Ratio'] > 0]
df_heatmap['log2_enrichment'] = df_heatmap['Enrichment_Ratio'].apply(lambda x: round(np.log2(x), 3))


# Output directory
output_dir = "heatmap_output"
os.makedirs(output_dir, exist_ok=True)

# Plot heatmap for each immunization
for immun in df_heatmap['immunization'].unique():
    df_im = df_heatmap[df_heatmap['immunization'] == immun].copy()

    # Compute tick values every 5
    tick_vals = sorted(df_im['Spike_AS_Position'].unique())
    tick_vals = [x for x in tick_vals if x % 5 == 0]

    # Create Altair heatmap
    heatmap = alt.Chart(df_im).mark_rect().encode(
        x=alt.X('Spike_AS_Position:O', title='Spike AA Position',
                axis=alt.Axis(values=tick_vals)),
        y=alt.Y('Amino_Acid:N', title='Mutated AA'),
        color=alt.Color('log2_enrichment:Q',
                        scale=alt.Scale(scheme='redblue'),
                        title='log₂ Enrichment'),

        tooltip=[
            'Spike_AS_Position:O',
            'Amino_Acid:N',
            'Log2 Binding Ratio:Q'
        ]
    ).properties(
        title=f"Mutational Scanning Heatmap - {immun}",
        width=800,
        height=400
    ).configure_axis(
        labelFontSize=12,
        titleFontSize=14
    ).configure_title(
        fontSize=18,
        anchor='start'
    )

    # Save HTML file
    output_file = os.path.join(output_dir, f"{immun}_heatmap.html")
    heatmap.save(output_file)

    print(f"Saved heatmap for {immun} to {output_file}")


In [None]:
import os
import pandas as pd
import numpy as np
from Bio import SeqIO
import matplotlib.pyplot as plt
import seaborn as sns

# Load FASTA sequence (Wuhan reference)
fasta_file = r'/Users/lucaschlotheuber/Desktop/ETH/RBD201_DMS1.fa'
for record in SeqIO.parse(fasta_file, "fasta"):
    wuhan_sequence = str(record.seq)
    break

# Load and clean the Excel data
file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'
df_total = pd.read_excel(file_path, usecols=[
    "DMS_RBD_AS_position", "Spike_AS_Position", "Count_of_Base",
    "Amino_Acid", "Type_of_Mutation", "Enrichment_Ratio",
    "barcode", "immunization", "condition", "Total_Reads"
])
df_total["Spike_AS_Position"] -= 5  # Adjust 336 -> 331

# Clean up
df_total = df_total.dropna(subset=['Enrichment_Ratio', 'Amino_Acid'])
df_total = df_total[df_total["Total_Reads"] > 1000]
df_total = df_total[df_total["Amino_Acid"] != '*']  # Exclude stop codons

# Add Wuhan reference
immunization = "Wuhan_Sequence"
barcode = "Wuhan_Barcode"
data_wuhan = [{
    'DMS_RBD_AS_position': pos,
    'Spike_AS_Position': pos + 330,
    'Amino_Acid': aa,
    'immunization': immunization,
    'barcode': barcode,
    'Enrichment_Ratio': 1,
} for pos, aa in enumerate(wuhan_sequence, start=1) if aa != '*']
df_wuhan = pd.DataFrame(data_wuhan)
df_total = pd.concat([df_total, df_wuhan], ignore_index=True)

# Filter
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 364) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Aggregate
df_heatmap = df_filtered.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization'], as_index=False
).agg({'Enrichment_Ratio': 'mean'})

# Filter out zeros or negatives
df_heatmap = df_heatmap[df_heatmap['Enrichment_Ratio'] > 0]

# Invert values between 0 and 1 for log2 enrichment, leave others unchanged
df_heatmap['Enrichment_Ratio_transformed'] = df_heatmap['Enrichment_Ratio'].apply(
    lambda x: 1 / x if 0 < x < 1 else x
)

# Calculate log2 enrichment from transformed values
df_heatmap['log2_enrichment'] = df_heatmap['Enrichment_Ratio_transformed'].apply(
    lambda x: round(np.log2(x), 3)
)

# Split into log2 ranges
df_low = df_heatmap[df_heatmap['log2_enrichment'] <= 0].copy()
df_high = df_heatmap[df_heatmap['log2_enrichment'] > 0].copy()

# Output directory
output_dir = "heatmap_output_log2_split"
os.makedirs(output_dir, exist_ok=True)

# Function to generate and save log2-based heatmap as PNG
# Function to generate and save log2-based heatmap as PNG
def generate_heatmap(df_subset, enrichment_range_label):
    for immun in df_subset['immunization'].unique():
        df_im = df_subset[df_subset['immunization'] == immun].copy()

        # Create a pivot table for heatmap
        pivot_df = df_im.pivot(index="Amino_Acid", columns="Spike_AS_Position", values="log2_enrichment")

        # Create the heatmap using seaborn
        plt.figure(figsize=(10, 4))
        sns.heatmap(pivot_df, cmap='RdBu_r', vmin=-3, vmax=3, cbar_kws={'label': 'log2 Enrichment'},
                    annot=False, linewidths=0.5, linecolor='gray')

        # Set labels and title
        plt.title(f"log₂ {enrichment_range_label} Heatmap - {immun}", fontsize=18)
        plt.xlabel('Spike AA Position', fontsize=14)
        plt.ylabel('Mutated AA', fontsize=14)
        plt.yticks(fontsize=8)

        # Save as PNG
        output_file_png = os.path.join(output_dir, f"{immun}_heatmap_log2_{enrichment_range_label.replace(' ', '_')}.png")
        plt.tight_layout()
        plt.savefig(output_file_png)
        plt.close()
        print(f"Saved log2 {enrichment_range_label} heatmap for {immun} to {output_file_png}")


# Generate both sets
generate_heatmap(df_low, "0_to_1")
generate_heatmap(df_high, "1_to_max")


In [None]:
import os
import pandas as pd
import numpy as np
from Bio import SeqIO
import matplotlib.pyplot as plt
import seaborn as sns

# Load FASTA sequence (Wuhan reference)
fasta_file = r'/Users/lucaschlotheuber/Desktop/ETH/RBD201_DMS1.fa'
for record in SeqIO.parse(fasta_file, "fasta"):
    wuhan_sequence = str(record.seq)
    break

# Load and clean the Excel data
file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'
df_total = pd.read_excel(file_path, usecols=[
    "DMS_RBD_AS_position", "Spike_AS_Position", "Count_of_Base",
    "Amino_Acid", "Type_of_Mutation", "Enrichment_Ratio",
    "barcode", "immunization", "condition", "Total_Reads"
])
df_total["Spike_AS_Position"] -= 5  # Adjust 336 -> 331

# Clean up
df_total = df_total.dropna(subset=['Enrichment_Ratio', 'Amino_Acid'])
df_total = df_total[df_total["Total_Reads"] > 100]
df_total = df_total[df_total["Amino_Acid"] != '*']  # Exclude stop codons

# Add Wuhan reference
immunization = "Wuhan_Sequence"
barcode = "Wuhan_Barcode"
data_wuhan = [{
    'DMS_RBD_AS_position': pos,
    'Spike_AS_Position': pos + 330,
    'Amino_Acid': aa,
    'immunization': immunization,
    'barcode': barcode,
    'Enrichment_Ratio': 1,
} for pos, aa in enumerate(wuhan_sequence, start=1) if aa != '*']
df_wuhan = pd.DataFrame(data_wuhan)
df_total = pd.concat([df_total, df_wuhan], ignore_index=True)

# Filter
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 364)
]

# Aggregate
df_heatmap = df_filtered.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False
).agg({'Enrichment_Ratio': 'mean'})

# Filter out zeros or negatives
df_heatmap = df_heatmap[df_heatmap['Enrichment_Ratio'] > 0]

# Invert values between 0 and 1 for log2 enrichment, leave others unchanged
df_heatmap['Enrichment_Ratio_transformed'] = df_heatmap['Enrichment_Ratio'].apply(
    lambda x: 1 / x if 0 < x < 1 else x
)

# Calculate log2 enrichment from transformed values
df_heatmap['log2_enrichment'] = df_heatmap['Enrichment_Ratio_transformed'].apply(
    lambda x: round(np.log2(x), 3)
)

# Output directory
output_dir = "Heatmap_Combined"
os.makedirs(output_dir, exist_ok=True)

def generate_heatmap_per_immunization(df_subset):
    aa_order = ['R','K','H','D','E','Q','N','S','T','Y','W','F','A','I','L','M','V','G','P','C']

    for immun in df_subset['immunization'].unique():
        df_im = df_subset[df_subset['immunization'] == immun].copy()

        # Group across all barcodes
        df_grouped = df_im.groupby(
            ['Spike_AS_Position', 'Amino_Acid'], as_index=False
        ).agg({'log2_enrichment': 'mean'})

        pivot_df = df_grouped.pivot(index="Amino_Acid", columns="Spike_AS_Position", values="log2_enrichment")

        # Reindex rows to match AA order and drop rows with all NaNs
        pivot_df = pivot_df.reindex(aa_order).dropna(how='all')

        # Skip empty heatmaps
        if pivot_df.dropna(how='all').empty or pivot_df.isna().all().all():
            print(f"Skipped empty heatmap for immunization {immun}")
            continue

        plt.figure(figsize=(12, 4))
        sns.heatmap(
            pivot_df,
            cmap='RdBu_r',
            vmin=-6, vmax=6,
            cbar_kws={'label': 'log₂ Enrichment'},
            annot=False,
            linewidths=0,
            linecolor='gray'
        )

        plt.title(f"log₂ Enrichment Heatmap - {immun}", fontsize=18)
        plt.xlabel('Spike AA Position', fontsize=14)
        plt.ylabel('Mutated AA', fontsize=14)
        plt.yticks(fontsize=8)

        output_file_png = os.path.join(
            output_dir,
            f"immun_{immun}_heatmap_log2_combined.png"
        )
        plt.tight_layout()
        plt.savefig(output_file_png)
        plt.close()
        print(f"Saved combined log2 heatmap for immunization {immun} to {output_file_png}")

# Generate single heatmap per sample
generate_heatmap_per_immunization(df_heatmap)



In [None]:
#Total (df_total) is complete dataframe, while df_filtered referes to aggregates dataframes; aggregated dataframes 
#are pooled/ grouped using the mean, median or sum of a category (such as the Antibody binding ratio at a certain positon
#across all barcodes.


In [None]:
print(df_total.columns)

In [None]:
df_filtered_agg = df_filtered.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], 
    as_index=False
).agg({'Enrichment_Ratio': 'sum'})


In [None]:
print(df_filtered_agg.columns)


In [None]:
#Enrichment ratios are calculated from mutation frequencies and can be split into binding and escape fraction#
#Binding fraction includes all Er > 1 as the frequency in the droplet is higher compared to the library
#Escape or lost fraction includes all ER<1 as the frequency in the droplet is lower compared to the library

In [None]:
# In order to estimate the effect across a local epitope, rollowing windows can be applied. Too large smoothing will distort the data

In [None]:
import os
import pandas as pd
import numpy as np
from scipy import interpolate

# Define rolling window size for smoothing
ROLLING_WINDOW = 1  # Increased window for more smoothing (you can adjust this further)
ENRICHMENT_THRESHOLD = 50  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] + #RBD-ACE2 interface
    list(range(394,414)) + #R21 peptide sequence
    list(range(484, 505)) #R13 peptide sequence
)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

for immunization in df_filtered_agg['immunization'].unique():
    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })

    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Handle missing data by interpolating missing values (linear interpolation)
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Identify clusters of high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex for visualization
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered['Spike_AS_Position'].max() + 1)
    ).reset_index()

    # Ensure High_Enrichment is still boolean
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Plot with highlighted clusters
    fig, ax = dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title=f"{immunization} (Smoothed)",
        xlabel="Spike AA Position",
        ylabel="AB binding fraction \n Er",
        show_col="High_Enrichment"  # Highlight high enrichment clusters
    )

    ax.set_xlim(390, df_filtered_im['Spike_AS_Position'].max())

    # Get the current visible x-range (the range of 'Spike_AS_Position' currently shown on the plot)
    x_min, x_max = ax.get_xlim()

    # Filter data based on the visible x-range
    filtered_data = df_filtered_im[(df_filtered_im['Spike_AS_Position'] >= x_min) & (df_filtered_im['Spike_AS_Position'] <= x_max)]

    # Now, calculate y_min and y_max based only on the visible (filtered) data
    y_min = 0  # Always start y-axis at 0
    y_max = filtered_data['Smoothed_Enrichment'].max() +30  # Max value of the visible data

    # Set the y-axis limits based on the visible data
    ax.set_ylim(y_min, 350)

    # Optionally, print the limits to check
    print(f"Setting y-axis limits for visible data: min={y_min}, max={y_max}")


    # Save the plot as a PNG
    png_file_path = os.path.join(output_dir, f"{immunization}_plot.png")
    fig.savefig(png_file_path, dpi=300, bbox_inches="tight")

    print(f"Saved plot to {png_file_path}")

In [None]:
#Smoothing with AA region = 10 & high enrichers (> 50 fold)

In [None]:
import os
import pandas as pd
import numpy as np
from scipy import interpolate

# Define rolling window size for smoothing
ROLLING_WINDOW = 10  # Increased window for more smoothing (you can adjust this further)
ENRICHMENT_THRESHOLD = 50  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] + #RBD-ACE2 interface
    list(range(394,414)) + #R21 peptide sequence
    list(range(484, 505)) #R13 peptide sequence
)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

for immunization in df_filtered_agg['immunization'].unique():
    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })

    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Handle missing data by interpolating missing values (linear interpolation)
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Identify clusters of high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex for visualization
    # Reindex for visualization (ensures continuous x-axis)
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered['Spike_AS_Position'].max() + 1)
    ).reset_index()

    # Interpolate missing values to ensure continuous line
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')
    
    # Fill remaining NaNs with 0 to avoid gaps in the plot
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(0)
    
    # Ensure data is sorted before plotting
    df_filtered_im = df_filtered_im.sort_values(by='Spike_AS_Position')


    # Ensure High_Enrichment is still boolean
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Plot with highlighted clusters
    fig, ax = dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title=f"{immunization} (Smoothed)",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding",
        show_col="High_Enrichment"  # Highlight high enrichment clusters
    )

    ax.set_xlim(390, df_filtered_im['Spike_AS_Position'].max())

    # Get the current visible x-range (the range of 'Spike_AS_Position' currently shown on the plot)
    x_min, x_max = ax.get_xlim()

    # Filter data based on the visible x-range
    filtered_data = df_filtered_im[(df_filtered_im['Spike_AS_Position'] >= x_min) & (df_filtered_im['Spike_AS_Position'] <= x_max)]

    # Now, calculate y_min and y_max based only on the visible (filtered) data
    y_min = 0  # Always start y-axis at 0
    y_max = filtered_data['Smoothed_Enrichment'].max() +30  # Max value of the visible data

    # Set the y-axis limits based on the visible data
    ax.set_ylim(y_min, y_max)

    # Optionally, print the limits to check
    print(f"Setting y-axis limits for visible data: min={y_min}, max={y_max}")


    # Save the plot as a PNG
    png_file_path = os.path.join(output_dir, f"{immunization}_plot.png")
    fig.savefig(png_file_path, dpi=300, bbox_inches="tight")

    print(f"Saved plot to {png_file_path}")

    

## Analysis of Antibody-Repertoire Line plots include all droplets sequenced from a specific, sampled repertoire and their variation across droplets
#e.g. droplets generated with an anti-SARS-CoV-2 neutralizing clone
#Following graphs compare line profiles of the complete, sampled repertoire of different conditions

In [None]:
# Library_ctrl refers to a second library not used in binding experiments of these repertoires

In [None]:
#Filtering and aggregation can be included (optonal)
#Enrichment threshold to filter out escape fraction
#Sum aggregation to plot the total antibody ER per position (Polyreactive repertoire often show higher sum ER compared to
#repertoire with only speific mutations enriched
# > 331 ensures that only reads/nts with high quality (thred score) and inside the DMS library are included
# exclude synonomous mutations

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  
ENRICHMENT_THRESHOLD = 50  

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create figure and axis for combined plot
fig, ax = plt.subplots(figsize=(10, 6))

# Loop through immunizations and plot each one
for immunization in df_filtered_agg['immunization'].unique():
    print(f"Processing: {immunization}")

    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({'Enrichment_Ratio': 'sum'})

    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1
    ).mean()

    # Handle missing data by interpolating missing values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both'
    )

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Ensure data is sorted before plotting
    df_filtered_im = df_filtered_im.sort_values(by='Spike_AS_Position')

    # Plot on the shared axis
    ax.plot(df_filtered_im['Spike_AS_Position'], df_filtered_im['Smoothed_Enrichment'], label=immunization, linewidth=2)

# Formatting the plot
ax.set_xlim(390, df_filtered_im['Spike_AS_Position'].max())
ax.set_ylim(0, 450)
ax.set_xlabel("Spike AA Position")
ax.set_ylabel("Antibody Repertoire \n Binding")
ax.set_title("Smoothed Enrichment Across Immunizations")
ax.legend(title="Immunization")

# Save the combined plot
png_file_path = os.path.join(output_dir, "combined_plot.png")
plt.savefig(png_file_path, dpi=300, bbox_inches="tight")
plt.show()

print(f"Saved combined plot to {png_file_path}")


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import kruskal

# Define constants
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 50  # Unused in this version but can be included for threshold filtering
POSITION_FOR_TESTING = 450  # Position to run Kruskal-Wallis test

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Group by Position, Immunization, Barcode
df_per_pos = df_filtered.groupby(
    ['Spike_AS_Position', 'immunization', 'barcode'],
    as_index=False
)['Enrichment_Ratio'].sum()

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create figure and axis for combined plot
fig, ax = plt.subplots(figsize=(10, 6))

# For statistical test
position_test_data = {}

# Loop through immunizations and compute stats
for immunization in df_per_pos['immunization'].unique():
    print(f"Processing: {immunization}")
    
    df_im = df_per_pos[df_per_pos['immunization'] == immunization]

    # Group by Spike position: mean, std, count
    stats_df = df_im.groupby('Spike_AS_Position').agg(
        Mean_Enrichment=('Enrichment_Ratio', 'mean'),
        Std_Enrichment=('Enrichment_Ratio', 'std'),
        Count=('Enrichment_Ratio', 'count')
    ).reset_index()

    # Smooth the curves
    stats_df['Smoothed_Mean'] = stats_df['Mean_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1
    ).mean()

    stats_df['Smoothed_Std'] = stats_df['Std_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1
    ).mean()

    # Interpolate missing values
    stats_df['Smoothed_Mean'] = stats_df['Smoothed_Mean'].interpolate(method='linear', limit_direction='both')
    stats_df['Smoothed_Std'] = stats_df['Smoothed_Std'].interpolate(method='linear', limit_direction='both')

    # Save to CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_stats.csv")
    stats_df.to_csv(csv_file_path, index=False)

    # Plot line + shaded area
    ax.plot(stats_df['Spike_AS_Position'], stats_df['Smoothed_Mean'], label=immunization, linewidth=2)
    ax.fill_between(
        stats_df['Spike_AS_Position'],
        stats_df['Smoothed_Mean'] - stats_df['Smoothed_Std'],
        stats_df['Smoothed_Mean'] + stats_df['Smoothed_Std'],
        alpha=0.3
    )

    # Store data for statistical testing
    if POSITION_FOR_TESTING in stats_df['Spike_AS_Position'].values:
        val = df_im[df_im['Spike_AS_Position'] == POSITION_FOR_TESTING]['Enrichment_Ratio'].values
        if len(val) > 0:
            position_test_data[immunization] = val

# Formatting plot
ax.set_xlim(390, df_per_pos['Spike_AS_Position'].max())
ax.set_ylim(0, 220)
ax.set_xlabel("Spike AA Position")
ax.set_ylabel("Antibody Repertoire \n Binding")
ax.set_title("Smoothed Enrichment Across Immunizations")
ax.legend(title="Immunization")

# Save plot
png_file_path = os.path.join(output_dir, "combined_plot.png")
plt.savefig(png_file_path, dpi=300, bbox_inches="tight")
plt.show()
print(f"Saved combined plot to {png_file_path}")

# Optional: Kruskal-Wallis test at specific position
if len(position_test_data) > 1:
    stat, p = kruskal(*position_test_data.values())
    print(f"Kruskal-Wallis test at position {POSITION_FOR_TESTING}: H={stat:.3f}, p={p:.3e}")
else:
    print(f"Not enough data at position {POSITION_FOR_TESTING} for statistical testing.")


In [None]:
# to compare regions across antibody repertoires, clusters can be defined

In [None]:
from itertools import combinations
from collections import defaultdict

# Define cluster size
CLUSTER_SIZE = 20

# Create clusters starting from minimum position to maximum
min_pos = df_per_pos['Spike_AS_Position'].min()
max_pos = df_per_pos['Spike_AS_Position'].max()
clusters = [(start, start + CLUSTER_SIZE) for start in range(min_pos, max_pos, CLUSTER_SIZE)]

print("\n--- Pairwise Kruskal-Wallis Tests Across Immunizations in Each Cluster ---\n")

# Prepare output directory for cluster results
cluster_results_dir = os.path.join(output_dir, "cluster_stats")
os.makedirs(cluster_results_dir, exist_ok=True)

# Store results in a CSV
results_list = []

for start, end in clusters:
    cluster_name = f"{start}-{end}"
    cluster_data = df_per_pos[
        (df_per_pos['Spike_AS_Position'] >= start) &
        (df_per_pos['Spike_AS_Position'] < end)
    ]
    
    if cluster_data.empty:
        continue

    # Gather enrichment values by immunization for the cluster
    cluster_grouped = defaultdict(list)
    for immunization in cluster_data['immunization'].unique():
        values = cluster_data[cluster_data['immunization'] == immunization]['Enrichment_Ratio'].values
        if len(values) > 0:
            cluster_grouped[immunization] = values

    # Perform pairwise Kruskal-Wallis tests
    for (im1, val1), (im2, val2) in combinations(cluster_grouped.items(), 2):
        stat, p_val = kruskal(val1, val2)
        print(f"Cluster {cluster_name} | {im1} vs {im2}: H={stat:.3f}, p={p_val:.3e}")
        results_list.append({
            'Cluster': cluster_name,
            'Immunization_1': im1,
            'Immunization_2': im2,
            'H_statistic': stat,
            'p_value': p_val
        })

# Save all pairwise test results to CSV
results_df = pd.DataFrame(results_list)
results_df.to_csv(os.path.join(cluster_results_dir, "pairwise_kruskal_results.csv"), index=False)
print(f"\nAll pairwise test results saved to: {os.path.join(cluster_results_dir, 'pairwise_kruskal_results.csv')}")


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import kruskal

# Define constants
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 50  # Unused in this version but can be included for threshold filtering
POSITION_FOR_TESTING = 450  # Position to run Kruskal-Wallis test
CLUSTER_SIZE = 20
CONTROL_IMMUNIZATION = "Library_ctrl"

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Group by Position, Immunization, Barcode
df_per_pos = df_filtered.groupby(
    ['Spike_AS_Position', 'immunization', 'barcode'],
    as_index=False
)['Enrichment_Ratio'].sum()

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create figure and axis for combined plot
fig, ax = plt.subplots(figsize=(10, 6))

# For statistical test at specific position
position_test_data = {}

# Loop through immunizations and compute stats
for immunization in df_per_pos['immunization'].unique():
    print(f"Processing: {immunization}")
    
    df_im = df_per_pos[df_per_pos['immunization'] == immunization]

    # Group by Spike position: mean, std, count
    stats_df = df_im.groupby('Spike_AS_Position').agg(
        Mean_Enrichment=('Enrichment_Ratio', 'mean'),
        Std_Enrichment=('Enrichment_Ratio', 'std'),
        Count=('Enrichment_Ratio', 'count')
    ).reset_index()

    # Smooth the curves
    stats_df['Smoothed_Mean'] = stats_df['Mean_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1
    ).mean()

    stats_df['Smoothed_Std'] = stats_df['Std_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1
    ).mean()

    # Interpolate missing values
    stats_df['Smoothed_Mean'] = stats_df['Smoothed_Mean'].interpolate(method='linear', limit_direction='both')
    stats_df['Smoothed_Std'] = stats_df['Smoothed_Std'].interpolate(method='linear', limit_direction='both')

    # Save to CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_stats.csv")
    stats_df.to_csv(csv_file_path, index=False)

    # Plot line + shaded area
    ax.plot(stats_df['Spike_AS_Position'], stats_df['Smoothed_Mean'], label=immunization, linewidth=2)
    ax.fill_between(
        stats_df['Spike_AS_Position'],
        stats_df['Smoothed_Mean'] - stats_df['Smoothed_Std'],
        stats_df['Smoothed_Mean'] + stats_df['Smoothed_Std'],
        alpha=0.3
    )

    # Store data for statistical testing at specific position
    if POSITION_FOR_TESTING in stats_df['Spike_AS_Position'].values:
        val = df_im[df_im['Spike_AS_Position'] == POSITION_FOR_TESTING]['Enrichment_Ratio'].values
        if len(val) > 0:
            position_test_data[immunization] = val

# --- Add Cluster-Based Annotations and Significance Testing ---

def get_significance_stars(p):
    if p < 1e-5:
        return "****"
    elif p < 1e-4:
        return "***"
    elif p < 1e-3:
        return "**"
    elif p < 0.01:
        return "*"
    else:
        return "n.s."

# Define clusters
min_pos = 390
max_pos = df_per_pos['Spike_AS_Position'].max()
clusters = [(start, start + CLUSTER_SIZE) for start in range(min_pos, max_pos, CLUSTER_SIZE)]

for start, end in clusters:
    cluster_mid = (start + end) / 2
    ax.axvline(x=start, color='gray', linestyle='--', linewidth=0.5)

    # Extract cluster data
    cluster_data = df_per_pos[
        (df_per_pos['Spike_AS_Position'] >= start) &
        (df_per_pos['Spike_AS_Position'] < end)
    ]
    if cluster_data.empty:
        continue

    # Gather values by immunization
    immunization_values = {}
    for immunization in cluster_data['immunization'].unique():
        values = cluster_data[cluster_data['immunization'] == immunization]['Enrichment_Ratio'].values
        if len(values) > 0:
            immunization_values[immunization] = values

    # Compare each immunization to control
    if CONTROL_IMMUNIZATION in immunization_values:
        control_vals = immunization_values[CONTROL_IMMUNIZATION]
        for immunization, vals in immunization_values.items():
            if immunization == CONTROL_IMMUNIZATION:
                continue
            stat, p_val = kruskal(control_vals, vals)
            stars = get_significance_stars(p_val)

            # Determine annotation height
            cluster_max = cluster_data['Enrichment_Ratio'].max()
            y_pos = min(cluster_max + 10, ax.get_ylim()[1] - 5)

            # Annotate
            ax.text(
                cluster_mid, y_pos, stars,
                ha='center', va='bottom', fontsize=9,
                bbox=dict(boxstyle="round,pad=0.2", edgecolor='none', facecolor='white', alpha=0.6)
            )

# --- Final Plot Formatting and Save ---
ax.set_xlim(390, df_per_pos['Spike_AS_Position'].max())
ax.set_ylim(0, 220)
ax.set_xlabel("Spike AA Position")
ax.set_ylabel("AB Repertoire (all droplets) \n")
ax.set_title("Smoothed Enrichment Across Immunizations")
ax.legend(title="Immunization")

# Save plot
png_file_path = os.path.join(output_dir, "combined_plot.png")
plt.savefig(png_file_path, dpi=300, bbox_inches="tight")
plt.show()
print(f"Saved combined plot to {png_file_path}")

# Optional: Kruskal-Wallis test at specific position
if len(position_test_data) > 1:
    stat, p = kruskal(*position_test_data.values())
    print(f"Kruskal-Wallis test at position {POSITION_FOR_TESTING}: H={stat:.3f}, p={p:.3e}")
else:
    print(f"Not enough data at position {POSITION_FOR_TESTING} for statistical testing.")


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  
ENRICHMENT_THRESHOLD = 50  

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create figure and axis for combined plot
fig, ax = plt.subplots(figsize=(10, 6))

# Loop through immunizations and plot each one
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == "Library_ctrl":
        continue  # Skip this sample

    print(f"Processing: {immunization}")

    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({'Enrichment_Ratio': 'sum'})

    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1
    ).mean()

    # Handle missing data by interpolating missing values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both'
    )

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Ensure data is sorted before plotting
    df_filtered_im = df_filtered_im.sort_values(by='Spike_AS_Position')

    # Plot on the shared axis
    ax.plot(df_filtered_im['Spike_AS_Position'], df_filtered_im['Smoothed_Enrichment'], label=immunization, linewidth=2)


# Formatting the plot
ax.set_xlim(390, df_filtered_im['Spike_AS_Position'].max())
ax.set_ylim(0, 450)
ax.set_xlabel("Spike AA Position")
ax.set_ylabel("AB Repertoire (all droplets) \n ")
ax.set_title("Smoothed Enrichment Across Immunizations")
ax.legend(title="Immunization")

# Save the combined plot
png_file_path = os.path.join(output_dir, "combined_plot.png")
plt.savefig(png_file_path, dpi=300, bbox_inches="tight")
plt.show()

print(f"Saved combined plot to {png_file_path}")


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import kruskal

# Constants
ROLLING_WINDOW = 10
ENRICHMENT_THRESHOLD = 1
CONTROL_IMMUNIZATION = "Library_ctrl"
CLUSTER_SIZE = 20
MIN_POS = 390

# Label abbreviations
IMMUNIZATION_LABELS = {
    "Polyclonal_Ab": "P",
    "Neutralizing_Ab": "N",
    "wildtype_RBD": "WT",
    "Mutant_RBD": "MUT"
}

# Color map
color_map = {
    "Polyclonal_Ab": "red",
    "Neutralizing_Ab": "orange",
    "Mutant_RBD": "blue",
    "wildtype_RBD": "green"
}

# Filter and aggregate
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 364) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]
df_filtered_agg = df_filtered.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
)['Enrichment_Ratio'].sum()

# Output dir
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Initialize plot
fig, ax = plt.subplots(figsize=(10, 6))

# Store smoothed data
smoothed_data = {}

# Process each immunization
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == CONTROL_IMMUNIZATION:
        continue

    print(f"Processing: {immunization}")
    color = color_map.get(immunization, "black")

    df_im = df_filtered_agg[df_filtered_agg['immunization'] == immunization]
    df_im_grouped = df_im.groupby('Spike_AS_Position', as_index=False)['Enrichment_Ratio'].sum()
    df_im_grouped['Smoothed_Enrichment'] = df_im_grouped['Enrichment_Ratio'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean().interpolate(method='linear', limit_direction='both')
    df_im_grouped = df_im_grouped.sort_values(by='Spike_AS_Position')

    df_im_grouped.to_csv(os.path.join(output_dir, f"{immunization}_data.csv"), index=False)

    ax.plot(df_im_grouped['Spike_AS_Position'], df_im_grouped['Smoothed_Enrichment'],
            label=immunization, color=color, linewidth=2)

    smoothed_data[immunization] = df_im_grouped

# Control group
df_ctrl = df_filtered_agg[df_filtered_agg['immunization'] == CONTROL_IMMUNIZATION]
df_ctrl_grouped = df_ctrl.groupby('Spike_AS_Position', as_index=False)['Enrichment_Ratio'].sum()

# Cluster ranges
max_pos = df_filtered_agg['Spike_AS_Position'].max()
clusters = [(start, start + CLUSTER_SIZE) for start in range(MIN_POS, max_pos, CLUSTER_SIZE)]

def get_significance_stars(p):
    if p < 1e-5: return "****"
    elif p < 1e-4: return "***"
    elif p < 1e-3: return "**"
    elif p < 0.01: return "*"
    else: return "n.s."

# Statistical output
print("\n=== Statistical Comparison Summary ===")
print("Cluster\tGroup\tMean\tStd\tN\tvs Library Mean\tLibrary Std\tLibrary N\tp-value\tStars")

for start, end in clusters:
    cluster_mid = (start + end) / 2
    ax.axvline(x=start, color='gray', linestyle='--', linewidth=0.5)

    ctrl_vals = df_ctrl_grouped[
        (df_ctrl_grouped['Spike_AS_Position'] >= start) &
        (df_ctrl_grouped['Spike_AS_Position'] < end)
    ]['Enrichment_Ratio'].values

    if len(ctrl_vals) == 0:
        continue

    annotation_idx = 0
    for immunization in IMMUNIZATION_LABELS.keys():
        df_group = df_filtered_agg[
            (df_filtered_agg['immunization'] == immunization) &
            (df_filtered_agg['Spike_AS_Position'] >= start) &
            (df_filtered_agg['Spike_AS_Position'] < end)
        ]
        group_vals = df_group['Enrichment_Ratio'].values

        if len(group_vals) == 0:
            continue

        # Stats
        stat, p_val = kruskal(ctrl_vals, group_vals)
        stars = get_significance_stars(p_val)
        label = IMMUNIZATION_LABELS[immunization]
        y_pos = 390 - 10 * annotation_idx

        # Text annotation
        ax.text(cluster_mid, y_pos, f"{label}: {stars}",
                ha='center', va='bottom', fontsize=8,
                bbox=dict(boxstyle="round,pad=0.2", edgecolor='none', facecolor='white', alpha=0.6))

        # Print comparison info
        print(f"{start}-{end}\t{label}\t"
              f"{np.mean(group_vals):.1f}\t{np.std(group_vals):.1f}\t{len(group_vals)}\t"
              f"{np.mean(ctrl_vals):.1f}\t{np.std(ctrl_vals):.1f}\t{len(ctrl_vals)}\t"
              f"{p_val:.2e}\t{stars}")

        annotation_idx += 1

# Plot format
ax.set_xlim(390, max_pos)
ax.set_ylim(0, 600)
ax.set_xlabel("Spike AA Position")
ax.set_ylabel("Er Ratio for (total) Ab repertoire \n Binding fraction")
ax.set_title("Smoothed Enrichment Across Immunizations")

# Legend outside
ax.legend(title="Immunization", loc='upper left', bbox_to_anchor=(1.02, 1), borderaxespad=0.)

# Save
png_file_path = "/Users/lucaschlotheuber/Desktop/combined_plot_with_area_colored.png"
plt.savefig(png_file_path, dpi=300, bbox_inches="tight")
plt.show()

print(f"\n✅ Saved plot to: {png_file_path}")


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate

# Define rolling window size for smoothing
ROLLING_WINDOW = 10
ENRICHMENT_THRESHOLD = 1

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create figure and axis for combined plot
fig, ax = plt.subplots(figsize=(10, 6))

# Define color map for immunization types
color_map = {
    "Polyclonal_Ab": "red", 
    "Neutralizing_Ab": "orange", 
    "Mutant_RBD": "blue", 
    "wildtype_RBD": "green"
}

# Loop through immunizations and plot each one, excluding "Library_ctrl"
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == "Library_ctrl":
        continue  # Skip this sample

    # Get the color based on immunization type
    color = color_map.get(immunization, "black")  # Default to black if immunization type is unknown

    print(f"Processing: {immunization}")

    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({'Enrichment_Ratio': 'sum'})

    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1
    ).mean()

    # Handle missing data by interpolating missing values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both'
    )

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Ensure data is sorted before plotting
    df_filtered_im = df_filtered_im.sort_values(by='Spike_AS_Position')

    # Plot line and fill the area under the line with the assigned color
    # Replace non-positive values to avoid log2 issues
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].clip(lower=1e-5)
    
    # Apply log2 transform
    df_filtered_im['Log2_Smoothed_Enrichment'] = np.log2(df_filtered_im['Smoothed_Enrichment'])
    
    # Plot the log2-transformed enrichment
    ax.plot(df_filtered_im['Spike_AS_Position'], df_filtered_im['Log2_Smoothed_Enrichment'], label=immunization, color=color, linewidth=2)


# Formatting the plot
ax.set_xlim(390, df_filtered_im['Spike_AS_Position'].max())
ax.set_ylim(2.4, 9)
ax.set_xlabel("Spike AA Position")
ax.set_ylabel("Antibody Repertoire \n Binding")
ax.set_title("Smoothed Enrichment Across Immunizations")
ax.legend(title="Immunization")

# Save the combined plot
png_file_path = "/Users/lucaschlotheuber/Desktop/combined_plot_with_area_colored.png"
plt.savefig(png_file_path, dpi=300, bbox_inches="tight")
plt.show()


print(f"Saved combined plot to {png_file_path}")


In [None]:
#Inversion can help visualize escape fractions

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate

# Define rolling window size for smoothing
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 50  

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Inverting the Enrichment_Ratio to calculate escape (Escape = 1 / Enrichment_Ratio)
df_filtered_agg['Escape_Ratio'] = df_filtered_agg['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create figure and axis for combined plot
fig, ax = plt.subplots(figsize=(10, 6))

# Define color map for immunization types
color_map = {
    "Polyclonal_Ab": "red", 
    "Neutralizing_Ab": "orange", 
    "Mutant_RBD": "blue", 
    "wildtype_RBD": "green"
}

# Loop through immunizations and plot each one, excluding "Library_ctrl"
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == "Library_ctrl":
        continue  # Skip this sample

    # Get the color based on immunization type
    color = color_map.get(immunization, "black")  # Default to black if immunization type is unknown

    print(f"Processing: {immunization}")

    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate escape ratios at each position (this will use the Escape_Ratio instead of Enrichment_Ratio)
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({'Escape_Ratio': 'mean'})

    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Escape'] = df_filtered_im['Escape_Ratio'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1
    ).mean()

    # Handle missing data by interpolating missing values
    df_filtered_im['Smoothed_Escape'] = df_filtered_im['Smoothed_Escape'].interpolate(
        method='linear', limit_direction='both'
    )

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_escape_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Ensure data is sorted before plotting
    df_filtered_im = df_filtered_im.sort_values(by='Spike_AS_Position')

    # Plot line and fill the area under the line with the assigned color
    ax.plot(df_filtered_im['Spike_AS_Position'], df_filtered_im['Smoothed_Escape'], label=immunization, color=color, linewidth=2)
    ax.fill_between(df_filtered_im['Spike_AS_Position'], 0, df_filtered_im['Smoothed_Escape'], color=color, alpha=0.3)

# Formatting the plot
ax.set_xlim(390, df_filtered_im['Spike_AS_Position'].max())
ax.set_ylim(0, 15)  # Adjust this based on your data range for escape
ax.set_xlabel("Spike AA Position")
ax.set_ylabel("Antibody Escape")
ax.set_title("Smoothed Escape Across Immunizations")
ax.legend(title="Immunization")

# Save the combined plot
png_file_path = os.path.join(output_dir, "combined_escape_plot_with_area_colored.png")
plt.savefig(png_file_path, dpi=300, bbox_inches="tight")
plt.show()

print(f"Saved combined escape plot to {png_file_path}")


In [None]:
# Single cell/antibody plots can also be generated and compared

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate
from matplotlib.ticker import MultipleLocator
%matplotlib inline


# Define rolling window size for smoothing
ROLLING_WINDOW = 10  # Increased window for more smoothing (you can adjust this further)
ENRICHMENT_THRESHOLD = 0.8  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position and barcode
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface according to article
    list(range(394,414)) +  # R21 peptide sequence with high affinity
    list(range(484, 505))  # R13 peptide sequence with high affinity
)
print(f"Sites to show: {sites_to_show}")
print(df_filtered_agg.columns)  
df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

print(df_filtered_agg[['Spike_AS_Position', 'site_label', 'show_site']].head(20))

for immunization in df_filtered_agg['immunization'].unique():
    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')
    print(df_filtered_im.head())  # Check if data exists before plotting
    print(df_filtered_im.columns)
    
    fig, ax = plt.subplots(figsize=(6, 4))
    print("Unique barcodes:", df_filtered_im['barcode'].unique())  # Check if barcode values are correct


        # Aggregate enrichment ratios at each position (across all barcodes)
    df_grouped = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum'
    })

    # Apply rolling mean for smoothing
    df_grouped['Smoothed_Enrichment'] = df_grouped['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Handle missing data by interpolating
    df_grouped['Smoothed_Enrichment'] = df_grouped['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Plot the smoothed, summed enrichment
    ax.plot(df_grouped['Spike_AS_Position'], df_grouped['Smoothed_Enrichment'], label=f'{immunization} Total', color='black', lw=2)


# Formatting the shared plot
ax.set_title("Antibody Repertoire Binding Across Immunizations")
ax.set_xlabel("Spike AA Position")
ax.set_ylabel("Antibody Repertoire \n binding")
ax.legend()
ax.xaxis.set_minor_locator(MultipleLocator(5))
ax.set_xlim(350, df_filtered_agg['Spike_AS_Position'].max())
ax.set_ylim(0, 400)

# Save single combined plot
plt.tight_layout()
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop/", "Combined_Immunizations_Plot.png")
fig.savefig(plot_file_path, format='png')
plt.show()

In [None]:
# Per barcode plotting

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

# Enable inline plotting in Jupyter
%matplotlib inline

# Define rolling window size for smoothing
ROLLING_WINDOW = 10  # Adjust for more smoothing
ENRICHMENT_THRESHOLD = 50  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position and barcode
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394,414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "barcode_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Iterate over each barcode
for barcode in df_filtered_agg['barcode'].unique():
    print(f"Processing barcode: {barcode}")
    df_filtered_bc = df_filtered_agg[df_filtered_agg['barcode'] == barcode]
    
    fig, ax = plt.subplots(figsize=(6, 4))
    
    for immunization in df_filtered_bc['immunization'].unique():
        df_filtered_im = df_filtered_bc[df_filtered_bc['immunization'] == immunization]

        # Aggregate enrichment ratios at each position
        df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
            'Enrichment_Ratio': 'sum',
        })

        # Apply rolling mean for smoothing
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

        # Handle missing data by interpolating missing values
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

        # Plot each immunization as a separate line
        ax.plot(df_filtered_im['Spike_AS_Position'], df_filtered_im['Smoothed_Enrichment'], label=f'Immunization {immunization}', alpha=0.7)

    # Formatting plot
    ax.set_title(f"Barcode {barcode}")
    ax.set_xlabel("Spike AA Position")
    ax.set_ylabel("Antibody Repertoire \n Binding")

    ax.xaxis.set_minor_locator(MultipleLocator(5))
    ax.set_xlim(350, df_filtered_bc['Spike_AS_Position'].max())

    # Save plot
    plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop/immunization_csv_files", f"Barcode_{barcode}_plot.png")
    fig.savefig(plot_file_path, format='png')
    
    # Show the plot inline in Jupyter
    plt.show()

    plt.close(fig)


In [None]:
import os
import pandas as pd
import numpy as np
from scipy import interpolate

# Define rolling window size for smoothing
ROLLING_WINDOW = 5  # Increased window for more smoothing (you can adjust this further)
ENRICHMENT_THRESHOLD = 0.8  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] + #RBD-ACE2 interface
    list(range(394,414)) + #R21 peptide sequence
    list(range(484, 505)) #R13 peptide sequence
)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

for immunization in df_filtered_agg['immunization'].unique():
    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })

    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Handle missing data by interpolating missing values (linear interpolation)
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Identify clusters of high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex for visualization
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered['Spike_AS_Position'].max() + 1)
    ).reset_index()

    # Ensure High_Enrichment is still boolean
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Plot with highlighted clusters
    fig, ax = dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title=f"{immunization} (Smoothed)",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding",
        show_col="High_Enrichment"  # Highlight high enrichment clusters
    )

    ax.set_xlim(420, df_filtered_im['Spike_AS_Position'].max())

    # Get the current visible x-range (the range of 'Spike_AS_Position' currently shown on the plot)
    x_min, x_max = ax.get_xlim()

    # Filter data based on the visible x-range
    filtered_data = df_filtered_im[(df_filtered_im['Spike_AS_Position'] >= x_min) & (df_filtered_im['Spike_AS_Position'] <= x_max)]

    # Now, calculate y_min and y_max based only on the visible (filtered) data
    y_min = 0  # Always start y-axis at 0
    y_max = filtered_data['Smoothed_Enrichment'].max()  # Max value of the visible data

    # Set the y-axis limits based on the visible data
    ax.set_ylim(y_min, y_max)

    # Optionally, print the limits to check
    print(f"Setting y-axis limits for visible data: min={y_min}, max={y_max}")


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline


# Define rolling window size for smoothing
ROLLING_WINDOW = 20  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 0.8  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394,414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

for immunization in df_filtered_agg['immunization'].unique():
    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })

    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Handle missing data by interpolating missing values (linear interpolation)
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Identify clusters of high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex for visualization (after interpolation to avoid issues)
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()

    # Ensure High_Enrichment is still boolean
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Plot with highlighted clusters
    fig, ax = dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title=f"{immunization} (Smoothed)",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding",
        show_col="High_Enrichment"  # Highlight high enrichment clusters
    )

    ax.set_xlim(420, df_filtered_im['Spike_AS_Position'].max())

    # Get the current visible x-range (the range of 'Spike_AS_Position' currently shown on the plot)
    x_min, x_max = ax.get_xlim()

    # Filter data based on the visible x-range
    filtered_data = df_filtered_im[(df_filtered_im['Spike_AS_Position'] >= x_min) & (df_filtered_im['Spike_AS_Position'] <= x_max)]

    # Now, calculate y_min and y_max based only on the visible (filtered) data
    y_min = 0  # Always start y-axis at 0
    y_max = filtered_data['Smoothed_Enrichment'].max()  # Max value of the visible data

    # Set the y-axis limits based on the visible data
    ax.set_ylim(y_min, y_max)

    # Optionally, print the limits to check
    print(f"Setting y-axis limits for visible data: min={y_min}, max={y_max}")
    plt.show()


In [None]:
#unseparated, overlayed data

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

# Enable inline plotting in Jupyter
%matplotlib inline

# Define rolling window size for smoothing
ROLLING_WINDOW = 10  # Adjust for more smoothing
ENRICHMENT_THRESHOLD = 50  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position and barcode
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394,414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Separate plotting for overlay by immunization condition
for immunization in df_filtered_agg['immunization'].unique():
    print(f"Processing immunization: {immunization}")
    df_filtered_im = df_filtered_agg[df_filtered_agg['immunization'] == immunization]

    # Create a new figure for this immunization condition
    fig, ax = plt.subplots(figsize=(8, 6))

    # Plot each barcode as a separate line for this immunization condition
    for barcode in df_filtered_im['barcode'].unique():
        df_filtered_barcode = df_filtered_im[df_filtered_im['barcode'] == barcode]

        # Aggregate enrichment ratios at each position
        df_filtered_barcode = df_filtered_barcode.groupby('Spike_AS_Position', as_index=False).agg({
            'Enrichment_Ratio': 'sum',
        })

        # Apply rolling mean for smoothing
        df_filtered_barcode['Smoothed_Enrichment'] = df_filtered_barcode['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

        # Handle missing data by interpolating missing values
        df_filtered_barcode['Smoothed_Enrichment'] = df_filtered_barcode['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

        # Plot each barcode as a separate line
        ax.plot(df_filtered_barcode['Spike_AS_Position'], df_filtered_barcode['Smoothed_Enrichment'], 
                label=f'Barcode {barcode}', alpha=0.7)

    # Formatting plot
    ax.set_title(f"{immunization}")
    ax.set_xlabel("Spike AA Position")
    ax.set_ylabel("Antibody Repertoire \n Binding")
    ax.xaxis.set_minor_locator(MultipleLocator(5))
    ax.set_xlim(350, df_filtered_im['Spike_AS_Position'].max())

    # Automatically adjust Y-axis to fit the data range
    ax.relim()  # Recalculate limits based on the data
    ax.autoscale_view()  # Automatically adjust the view

    # Save plot
    plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop/immunization_csv_files", f"{immunization}_overlay_plot.png")
    fig.savefig(plot_file_path, format='png')
    
    # Show the plot inline in Jupyter
    plt.show()

    # Close the plot to free memory
    plt.close(fig)


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 20  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394,414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)

sites_to_show = list(sites_to_show)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(10, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'darkblue',
    'Mutant_RBD': '#004c4c'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':  # Skip 'Library_ctrl'
        continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    # Count unique barcodes and amino acids per position
    df_counts = df_filtered_im.groupby('Spike_AS_Position').agg({
        'barcode': pd.Series.nunique,
        'Amino_Acid': pd.Series.nunique
    }).rename(columns={'barcode': 'num_barcodes', 'Amino_Acid': 'num_amino_acids'})
    
    # Aggregate enrichment
    df_sum = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })
    
    # Merge and normalize
    df_filtered_im = df_sum.merge(df_counts, left_on='Spike_AS_Position', right_index=True)
    df_filtered_im['Enrichment_Ratio'] = (
        df_filtered_im['Enrichment_Ratio'] /
        df_filtered_im['num_barcodes']
    ) * df_filtered_im['num_amino_acids']

    if df_filtered_im['Enrichment_Ratio'].isna().any():
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')


    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Handle missing data by interpolating missing values (linear interpolation)
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Identify clusters of high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex for visualization (after interpolation to avoid issues)
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()

    print(f"Target sites to highlight: {list(sites_to_show)}")

    # Check the data type of Spike_AS_Position and sites_to_show for comparison
    print(f"Data type of 'Spike_AS_Position': {df_filtered_agg['Spike_AS_Position'].dtype}")
    print(f"Data type of 'sites_to_show': {type(sites_to_show)}")

    # Ensure High_Enrichment is still boolean
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    df_filtered_im = df_filtered_im.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )
    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))
    # Use dmslogo.line.draw_line but plot on the same axes with specified color
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",  # Remove individual titles for each plot
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding",
        show_col="show_site",  # Highlight high enrichment clusters
        ax=ax,  # Pass the same axes object for all plots
        linewidth=2.5,
        color=color_map.get(immunization, 'black')  # Get the color for the immunization (default to black)
    )
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)
    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    for _, site_data in highlight_sites.iterrows():
        # Use ax.hlines to draw horizontal lines only where there are sites to show
        ax.hlines(
            y=0,  # Set the y-position of the line to 0 (or a small value near the bottom of the plot)
            xmin=site_data['Spike_AS_Position'] - 0.5,  # Start of the line (slightly before the site)
            xmax=site_data['Spike_AS_Position'] + 0.5,  # End of the line (slightly after the site)
            color='black',  # Line color
            linestyle='-',  # Line style
            linewidth=10  # Line width
        )

# Set the y-axis limit to 200
ax.set_ylim(0, 400
           )

# After all lines are drawn, adjust plot settings
plt.title('Smoothed Enrichment Ratios for Different Immunizations', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Antibody Repertoire - Enrichment \n (PerCell Polyreactivity)', fontsize=16)

# Add the legend in the top right
# Create a grouped legend
handles, labels = ax.get_legend_handles_labels()

# Group the legend as required
group_1_labels = ['Polyclonal_Ab', 'Neutralizing_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles
labels = group_1_labels + group_2_labels

# Add the legend to the top right
plt.legend(handles, labels, title="Single Droplet Repertoire", loc='upper center', fontsize=11, frameon=False, handlelength=2, handleheight=1, title_fontsize=11, markerscale=8)

# Set the x-axis ticks explicitly to 20 ticks across the range, and label every other one
xticks = np.linspace(df_filtered_agg['Spike_AS_Position'].min(), df_filtered_agg['Spike_AS_Position'].max(), 20).astype(int)
ax.set_xticks(xticks)

# Set the labels to only show for every other tick
ax.set_xticklabels([str(x) if i % 2 == 0 else '' for i, x in enumerate(xticks)])


ax.yaxis.set_major_locator(MaxNLocator(integer=True, prune='lower', nbins=8))
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line3_plot.png")
fig.tight_layout() 
fig.savefig(plot_file_path, format='png')



In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 20
ENRICHMENT_THRESHOLD = 1

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +
    list(range(394,414)) +
    list(range(484, 505))
)
sites_to_show = list(sites_to_show)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

fig, ax = plt.subplots(figsize=(10, 6))

color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'darkblue',
    'Mutant_RBD': '#004c4c'
}

for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':
        continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    df_counts = df_filtered_im.groupby('Spike_AS_Position').agg({
        'barcode': pd.Series.nunique,
        'Amino_Acid': pd.Series.nunique
    }).rename(columns={'barcode': 'num_barcodes', 'Amino_Acid': 'num_amino_acids'})
    
    df_sum = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })
    
    df_filtered_im = df_sum.merge(df_counts, left_on='Spike_AS_Position', right_index=True)
    df_filtered_im['Enrichment_Ratio'] = (
        df_filtered_im['Enrichment_Ratio'] /
        df_filtered_im['num_barcodes']
    ) * df_filtered_im['num_amino_acids']

    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()

    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(f"Target sites to highlight: {list(sites_to_show)}")
    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding",
        show_col="show_site",
        ax=ax,
        linewidth=2.5,
        color=color_map.get(immunization, 'black')
    )
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    for _, site_data in highlight_sites.iterrows():
        ax.hlines(
            y=0,
            xmin=site_data['Spike_AS_Position'] - 0.5,
            xmax=site_data['Spike_AS_Position'] + 0.5,
            color='black',
            linestyle='-',
            linewidth=10
        )

ax.set_ylim(0, 400)
plt.title('Smoothed Enrichment Ratios for Different Immunizations', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Antibody Repertoire - Enrichment \n (PerCell Polyreactivity)', fontsize=16)

handles, labels = ax.get_legend_handles_labels()
group_1_labels = ['Polyclonal_Ab', 'Neutralizing_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

handles = group_1_handles + group_2_handles
labels = group_1_labels + group_2_labels

plt.legend(handles, labels, title="Single Droplet Repertoire", loc='upper center', fontsize=11, frameon=False, handlelength=2, handleheight=1, title_fontsize=11, markerscale=8)

# Full x-axis labels shown
xticks = np.linspace(df_filtered_agg['Spike_AS_Position'].min(), df_filtered_agg['Spike_AS_Position'].max(), 20).astype(int)
ax.set_xticks(xticks)
ax.set_xticklabels([str(x) for x in xticks])

ax.yaxis.set_major_locator(MaxNLocator(integer=True, prune='lower', nbins=8))
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line3_plot.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')


In [None]:
# If you want the min/max of the raw filtered Enrichment_Ratio:
print("Raw filtered enrichment min/max:",
      df_filtered['Enrichment_Ratio'].min(),
      df_filtered['Enrichment_Ratio'].max())

# Or, after you’ve aggregated by position (summing barcodes), use:
agg = df_filtered_agg.groupby('Spike_AS_Position')['Enrichment_Ratio'].sum()
print("Aggregated enrichment min/max:",
      agg.min(),
      agg.max())


### To analyse Binding Fraction (RBD Variants e.g. Y501 RBD increased in frequency after antibody binding in comparison to escape fraction (RBD variants disappearing from the pool of variants, escaped from binding), split Er values and plot separately

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Parameters
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 0

# Filter dataset
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Sites to highlight
sites_to_show = list(map(str,
    [455,456,472,473,484,485,486,490,496,499] +
    list(range(394,414)) + list(range(484,505))
))
df_filtered = df_filtered.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Precompute global max of the POSITION‐MEDIAN binding values (>1)
binding_meds = []
for imm in df_filtered['immunization'].unique():
    if imm == 'Library_ctrl': continue
    sub = df_filtered[df_filtered['immunization']==imm]
    med_series = sub[sub['Enrichment_Ratio']>1] \
                   .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
                   .median()
    if not med_series.empty:
        binding_meds.append(med_series.max())
global_binding_max = max(binding_meds) if binding_meds else 2.0

# Set up plot
fig, ax = plt.subplots(figsize=(10, 6))
color_map = {
    'Polyclonal_Ab':'darkorange',
    'Neutralizing_Ab':'red',
    'wildtype_RBD':'darkblue',
    'Mutant_RBD':'#004c4c'
}

for immunization in df_filtered['immunization'].unique():
    if immunization == 'Library_ctrl':
        continue

    sub = df_filtered[df_filtered['immunization']==immunization]

    # 1) median per position
    esc = sub[sub['Enrichment_Ratio']<=1] \
        .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
        .median()
    bind = sub[sub['Enrichment_Ratio']>1] \
        .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
        .median()

    if esc.empty and bind.empty:
        continue

    # 2) smooth
    esc = esc.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()
    bind = bind.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()

    # 3) transform with three branches
    def transform_val(y):
        if y < 1:
            return -(1 - y)          # 0 ≤ y < 1 → –(1–y)
        elif y == 1:
            return 0                 # y == 1 → 0
        else:
            return np.log10(y) / np.log10(2000)

    esc_t  = esc.apply(transform_val)
    bind_t = bind.apply(transform_val)

    # 4) determine x range
    positions = sorted(set(esc_t.index).union(bind_t.index))
    pos_min, pos_max = positions[0], positions[-1]
    idx = np.arange(pos_min, pos_max + 1)

    # 5) assemble df_plot
    df_plot = pd.DataFrame({'Spike_AS_Position': idx}).set_index('Spike_AS_Position')
    df_plot['Escape']  = esc_t.reindex(idx, fill_value=0)
    df_plot['Binding'] = bind_t.reindex(idx, fill_value=0)
    df_plot = df_plot.reset_index()

    # 6) save CSV
    df_plot.to_csv(os.path.join(output_dir, f"{immunization}_median_plot_data.csv"),
                   index=False)


    # 7) plot as dots
    ax.scatter(df_plot['Spike_AS_Position'], df_plot['Escape'],
               color=color_map.get(immunization), s=10, label=f"{immunization} escape", alpha=0.8)
    
    ax.scatter(df_plot['Spike_AS_Position'], df_plot['Binding'],
               color=color_map.get(immunization), s=10, label=f"{immunization} binding", marker='x', alpha=0.8)


    # 8) highlight
    for pos in df_plot.loc[
        df_plot['Spike_AS_Position'].astype(str).isin(sites_to_show),
        'Spike_AS_Position'
    ]:
        ax.hlines(y=-1, xmin=pos-0.5, xmax=pos+0.5,
                  color='black', linewidth=10)

# Final formatting
ax.set_ylim(-1, 0.5)
ax.axhline(0, color='gray', linestyle='--', linewidth=1)
plt.title('', fontsize=14)
ax.set_xlabel('Spike AA Position', fontsize=16)
ax.set_ylabel('Antibody Binding', fontsize=16)

handles, labels = ax.get_legend_handles_labels()
unique = dict(zip(labels, handles))  # deduplicate
ax.legend(unique.values(), unique.keys(),
          loc='center left', bbox_to_anchor=(1.01, 0.5),
          frameon=False, fontsize=10)
# X-ticks
xt = np.linspace(df_filtered['Spike_AS_Position'].min(),
                 df_filtered['Spike_AS_Position'].max(), 20).astype(int)
ax.set_xticks(xt)
ax.set_xticklabels([str(x) for x in xt])

ax.yaxis.set_major_locator(MaxNLocator(nbins=9, prune='both'))
fig.tight_layout()
fig.savefig(os.path.join(r"/Users/lucaschlotheuber/Desktop", "escape_binding_plot.png"),
            dpi=300)
plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Parameters
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 0

# Filter dataset
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Sites to highlight
sites_to_show = list(map(str,
    [455,456,472,473,484,485,486,490,496,499] +
    list(range(394,414)) + list(range(484,505))
))
df_filtered = df_filtered.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Define plotting function
def plot_by_subset(sub_df, title, filename):
    # Precompute global max of the POSITION‐MEDIAN binding values (>1)
    binding_meds = []
    for imm in sub_df['immunization'].unique():
        if imm == 'Library_ctrl': continue
        sub = sub_df[sub_df['immunization']==imm]
        med_series = sub[sub['Enrichment_Ratio']>1] \
                       .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
                       .median()
        if not med_series.empty:
            binding_meds.append(med_series.max())
    global_binding_max = max(binding_meds) if binding_meds else 2.0

    # Set up plot
    fig, ax = plt.subplots(figsize=(10, 6))
    color_map = {
        'Polyclonal_Ab':'darkorange',
        'Neutralizing_Ab':'red',
        'wildtype_RBD':'darkblue',
        'Mutant_RBD':'#004c4c'
    }

    for immunization in sub_df['immunization'].unique():
        if immunization == 'Library_ctrl':
            continue

        sub = sub_df[sub_df['immunization']==immunization]

        # 1) median per position
        esc = sub[sub['Enrichment_Ratio']<=1] \
            .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
            .median()
        bind = sub[sub['Enrichment_Ratio']>1] \
            .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
            .median()

        if esc.empty and bind.empty:
            continue

        # 2) smooth
        esc = esc.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()
        bind = bind.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()

        # 3) transform
        def transform_val(y):
            if y < 1:
                return -(1 - y)
            elif y == 1:
                return 0
            else:
                return np.log10(y) / np.log10(2000)

        esc_t  = esc.apply(transform_val)
        bind_t = bind.apply(transform_val)

        # 4) determine x range
        positions = sorted(set(esc_t.index).union(bind_t.index))
        pos_min, pos_max = positions[0], positions[-1]
        idx = np.arange(pos_min, pos_max + 1)

        # 5) assemble df_plot
        df_plot = pd.DataFrame({'Spike_AS_Position': idx}).set_index('Spike_AS_Position')
        df_plot['Escape']  = esc_t.reindex(idx, fill_value=0)
        df_plot['Binding'] = bind_t.reindex(idx, fill_value=0)
        df_plot = df_plot.reset_index()

        # 6) save CSV
        df_plot.to_csv(os.path.join(output_dir, f"{immunization}_{filename}_median_plot_data.csv"),
                       index=False)

        # 7) plot
        ax.scatter(df_plot['Spike_AS_Position'], df_plot['Escape'],
                   color=color_map.get(immunization), s=10, label=f"{immunization} escape", alpha=0.8)
        ax.scatter(df_plot['Spike_AS_Position'], df_plot['Binding'],
                   color=color_map.get(immunization), s=10, label=f"{immunization} binding", marker='x', alpha=0.8)

        # 8) highlight
        for pos in df_plot.loc[
            df_plot['Spike_AS_Position'].astype(str).isin(sites_to_show),
            'Spike_AS_Position'
        ]:
            ax.hlines(y=-1, xmin=pos-0.5, xmax=pos+0.5, color='black', linewidth=10)

    # Final formatting
    ax.set_ylim(0, 0.5)
    ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    plt.title(title, fontsize=14)
    ax.set_xlabel('Spike AA Position', fontsize=16)
    ax.set_ylabel('Median mAB binding ratio', fontsize=16)

    handles, labels = ax.get_legend_handles_labels()
    unique = dict(zip(labels, handles))
    ax.legend(unique.values(), unique.keys(),
              loc='center left', bbox_to_anchor=(1.01, 0.5),
              frameon=False, fontsize=10)

    xt = np.linspace(sub_df['Spike_AS_Position'].min(),
                     sub_df['Spike_AS_Position'].max(), 20).astype(int)
    ax.set_xticks(xt)
    ax.set_xticklabels([str(x) for x in xt])
    ax.yaxis.set_major_locator(MaxNLocator(nbins=9, prune='both'))
    fig.tight_layout()
    fig.savefig(os.path.join(r"/Users/lucaschlotheuber/Desktop", filename + ".png"), dpi=300)
    plt.show()


# --- Call plot function twice ---

# 1) Polyclonal vs Neutralizing
plot_by_subset(
    df_filtered[df_filtered['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    title="Calibration",
    filename="escape_binding_polyclonal_neutralizing"
)

# 2) Wildtype vs Mutant
plot_by_subset(
    df_filtered[df_filtered['immunization'].isin(['wildtype_RBD', 'Mutant_RBD'])],
    title="Immunization",
    filename="escape_binding_wildtype_mutant"
)


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Parameters
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 0

# Filter dataset
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Sites to highlight
sites_to_show = list(map(str,
    [455,456,472,473,484,485,486,490,496,499] +
    list(range(394,414)) + list(range(484,505))
))
df_filtered = df_filtered.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Define plotting function
def plot_by_subset(sub_df, title, filename):
    # Precompute global max of the POSITION‐MEDIAN binding values (>1)
    binding_meds = []
    for imm in sub_df['immunization'].unique():
        if imm == 'Library_ctrl': continue
        sub = sub_df[sub_df['immunization']==imm]
        med_series = sub[sub['Enrichment_Ratio']>1] \
                       .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
                       .median()
        if not med_series.empty:
            binding_meds.append(med_series.max())
    global_binding_max = max(binding_meds) if binding_meds else 2.0

    # Set up plot
    fig, ax = plt.subplots(figsize=(10, 6))
    color_map = {
        'Polyclonal_Ab':'darkorange',
        'Neutralizing_Ab':'red',
        'wildtype_RBD':'darkblue',
        'Mutant_RBD':'#004c4c'
    }

    for immunization in sub_df['immunization'].unique():
        if immunization == 'Library_ctrl':
            continue

        sub = sub_df[sub_df['immunization']==immunization]

        # 1) median per position
        esc = sub[sub['Enrichment_Ratio']<=1] \
            .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
            .median()
        bind = sub[sub['Enrichment_Ratio']>1] \
            .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
            .median()

        if esc.empty and bind.empty:
            continue

        # 2) smooth
        esc = esc.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()
        bind = bind.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()

        # 3) transform
        def transform_val(y):
            if y < 1:
                return -(1 - y)
            elif y == 1:
                return 0
            else:
                return np.log10(y) / np.log10(2000)

        esc_t  = esc.apply(transform_val)
        bind_t = bind.apply(transform_val)

        # ——— INSERTION: multiply by count of variants at each position ———
        counts = sub.groupby('Spike_AS_Position')['Enrichment_Ratio'].count()
        esc_t   = esc_t * counts.reindex(esc_t.index, fill_value=0)
        bind_t  = bind_t * counts.reindex(bind_t.index, fill_value=0)
        # ————————————————————————————————————————————————————————————————

        # 4) determine x range
        positions = sorted(set(esc_t.index).union(bind_t.index))
        pos_min, pos_max = positions[0], positions[-1]
        idx = np.arange(pos_min, pos_max + 1)

        # 5) assemble df_plot
        df_plot = pd.DataFrame({'Spike_AS_Position': idx}).set_index('Spike_AS_Position')
        df_plot['Escape']  = esc_t.reindex(idx, fill_value=0)
        df_plot['Binding'] = bind_t.reindex(idx, fill_value=0)
        df_plot = df_plot.reset_index()

        # 6) save CSV
        df_plot.to_csv(os.path.join(output_dir, f"{immunization}_{filename}_median_plot_data.csv"),
                       index=False)

        # 7) plot
        ax.scatter(df_plot['Spike_AS_Position'], df_plot['Escape'],
                   color=color_map.get(immunization), s=10, label=f"{immunization} escape", alpha=0.8)
        ax.scatter(df_plot['Spike_AS_Position'], df_plot['Binding'],
                   color=color_map.get(immunization), s=10, label=f"{immunization} binding", marker='x', alpha=0.8)

        # 8) highlight
        for pos in df_plot.loc[
            df_plot['Spike_AS_Position'].astype(str).isin(sites_to_show),
            'Spike_AS_Position'
        ]:
            ax.hlines(y=-1, xmin=pos-0.5, xmax=pos+0.5, color='black', linewidth=10)

    # Final formatting
    ax.set_ylim(-90, 25)
    ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    plt.title(title, fontsize=14)
    ax.set_xlabel('Spike AA Position', fontsize=16)
    ax.set_ylabel('Median mAB binding ratio', fontsize=16)

    handles, labels = ax.get_legend_handles_labels()
    unique = dict(zip(labels, handles))
    ax.legend(unique.values(), unique.keys(),
              loc='center left', bbox_to_anchor=(1.01, 0.5),
              frameon=False, fontsize=10)

    xt = np.linspace(sub_df['Spike_AS_Position'].min(),
                     sub_df['Spike_AS_Position'].max(), 20).astype(int)
    ax.set_xticks(xt)
    ax.set_xticklabels([str(x) for x in xt])
    ax.yaxis.set_major_locator(MaxNLocator(nbins=9, prune='both'))
    fig.tight_layout()
    fig.savefig(os.path.join(r"/Users/lucaschlotheuber/Desktop", filename + ".png"), dpi=300)
    plt.show()


# --- Call plot function twice ---

# 1) Polyclonal vs Neutralizing
plot_by_subset(
    df_filtered[df_filtered['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    title="Calibration",
    filename="escape_binding_polyclonal_neutralizing"
)

# 2) Wildtype vs Mutant
plot_by_subset(
    df_filtered[df_filtered['immunization'].isin(['wildtype_RBD', 'Mutant_RBD'])],
    title="Immunization",
    filename="escape_binding_wildtype_mutant"
)


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Parameters
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 0

# Filter dataset
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Sites to highlight
sites_to_show = list(map(str,
    [455,456,472,473,484,485,486,490,496,499] +
    list(range(394,414)) + list(range(484,505))
))
df_filtered = df_filtered.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Define plotting function
def plot_by_subset(sub_df, title, filename):
    # Precompute global max of the POSITION‐MEDIAN binding values (>1)
    binding_meds = []
    for imm in sub_df['immunization'].unique():
        if imm == 'Library_ctrl': continue
        sub = sub_df[sub_df['immunization']==imm]
        med_series = sub[sub['Enrichment_Ratio']>1] \
                       .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
                       .median()
        if not med_series.empty:
            binding_meds.append(med_series.max())
    global_binding_max = max(binding_meds) if binding_meds else 2.0

    # Set up plot
    fig, ax = plt.subplots(figsize=(10, 6))
    color_map = {
        'Polyclonal_Ab':'darkorange',
        'Neutralizing_Ab':'red',
        'wildtype_RBD':'darkblue',
        'Mutant_RBD':'#004c4c'
    }

    for immunization in sub_df['immunization'].unique():
        if immunization == 'Library_ctrl':
            continue

        sub = sub_df[sub_df['immunization']==immunization]

        # 1) median per position
        esc = sub[sub['Enrichment_Ratio']<=1] \
            .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
            .median()
        bind = sub[sub['Enrichment_Ratio']>1] \
            .groupby('Spike_AS_Position')['Enrichment_Ratio'] \
            .median()

        if esc.empty and bind.empty:
            continue

        # 2) smooth
        esc = esc.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()
        bind = bind.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()

        # 3) transform
        def transform_val(y):
            if y < 1:
                return -(1 - y)
            elif y == 1:
                return 0
            else:
                return np.log10(y) / np.log10(2000)

        esc_t  = esc.apply(transform_val)
        bind_t = bind.apply(transform_val)

        # ——— INSERTION: multiply by count of variants at each position ———
        counts = sub.groupby('Spike_AS_Position')['Enrichment_Ratio'].count()
        esc_t   = esc_t * counts.reindex(esc_t.index, fill_value=0)
        bind_t  = bind_t * counts.reindex(bind_t.index, fill_value=0)
        # ————————————————————————————————————————————————————————————————

        # 4) determine x range
        positions = sorted(set(esc_t.index).union(bind_t.index))
        pos_min, pos_max = positions[0], positions[-1]
        idx = np.arange(pos_min, pos_max + 1)

        # 5) assemble df_plot
        df_plot = pd.DataFrame({'Spike_AS_Position': idx}).set_index('Spike_AS_Position')
        df_plot['Escape']  = esc_t.reindex(idx, fill_value=0)
        df_plot['Binding'] = bind_t.reindex(idx, fill_value=0)
        df_plot = df_plot.reset_index()

        # 6) save CSV
        df_plot.to_csv(os.path.join(output_dir, f"{immunization}_{filename}_median_plot_data.csv"),
                       index=False)

        # 7) plot
        ax.scatter(df_plot['Spike_AS_Position'], df_plot['Escape'],
                   color=color_map.get(immunization), s=10, label=f"{immunization} escape", alpha=0.8)
        ax.scatter(df_plot['Spike_AS_Position'], df_plot['Binding'],
                   color=color_map.get(immunization), s=10, label=f"{immunization} binding", marker='x', alpha=0.8)

        # 8) highlight
        for pos in df_plot.loc[
            df_plot['Spike_AS_Position'].astype(str).isin(sites_to_show),
            'Spike_AS_Position'
        ]:
            ax.hlines(y=-1, xmin=pos-0.5, xmax=pos+0.5, color='black', linewidth=10)

    # Final formatting
    # Final formatting - Dynamic y-limits
    all_y_vals = pd.concat([df_plot['Escape'], df_plot['Binding']])
    y_min, y_max = all_y_vals.min(), all_y_vals.max()
    y_range = y_max - y_min
    buffer = y_range * 0.1 if y_range > 0 else 1.0
    ax.set_ylim(y_min - buffer, y_max + buffer)

    ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    plt.title(title, fontsize=14)
    ax.set_xlabel('Spike AA Position', fontsize=16)
    ax.set_ylabel('Median mAB binding ratio', fontsize=16)

    handles, labels = ax.get_legend_handles_labels()
    unique = dict(zip(labels, handles))
    ax.legend(unique.values(), unique.keys(),
              loc='center left', bbox_to_anchor=(1.01, 0.5),
              frameon=False, fontsize=10)

    xt = np.linspace(sub_df['Spike_AS_Position'].min(),
                     sub_df['Spike_AS_Position'].max(), 20).astype(int)
    ax.set_xticks(xt)
    ax.set_xticklabels([str(x) for x in xt])
    ax.yaxis.set_major_locator(MaxNLocator(nbins=9, prune='both'))
    fig.tight_layout()
    fig.savefig(os.path.join(r"/Users/lucaschlotheuber/Desktop", filename + ".png"), dpi=300)
    plt.show()


# --- Call plot function twice ---


for barcode in df_filtered['barcode'].unique():
    sub_df = df_filtered[df_filtered['barcode'] == barcode]
    plot_by_subset(
        sub_df,
        title=f"Escape and Binding - Barcode {barcode}",
        filename=f"escape_binding_barcode_{barcode}"
    )


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Parameters
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 0

# Filter dataset
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Sites to highlight
sites_to_show = list(map(str,
    [455,456,472,473,484,485,486,490,496,499] +
    list(range(394,414)) + list(range(484,505))
))
df_filtered = df_filtered.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Define plotting function
def plot_by_subset(sub_df, title, filename):
    fig, ax = plt.subplots(figsize=(10, 6))
    color_map = {
        'Polyclonal_Ab':'darkorange',
        'Neutralizing_Ab':'red',
        'wildtype_RBD':'darkblue',
        'Mutant_RBD':'#004c4c'
    }

    for immunization in sub_df['immunization'].unique():
        if immunization == 'Library_ctrl':
            continue

        sub = sub_df[sub_df['immunization']==immunization]

        # 1) Filter to non-synonymous only
        sub_ns = sub[sub['Type_of_Mutation'] == 'NON-SYNOM']

        if sub_ns.empty:
            continue

        # 2) Group: median & count
        grouped = sub_ns.groupby('Spike_AS_Position')['Enrichment_Ratio']
        median_er = grouped.median()
        breadth = grouped.count()

        # 3) Polyreactivity = median * count
        poly = median_er * breadth

        # 4) Smooth
        poly = poly.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()

        # 5) Prepare for plotting
        idx = np.arange(poly.index.min(), poly.index.max() + 1)
        df_plot = pd.DataFrame({'Spike_AS_Position': idx}).set_index('Spike_AS_Position')
        df_plot['Polyreactivity'] = poly.reindex(idx, fill_value=0)
        df_plot = df_plot.reset_index()

        # 6) Save CSV
        df_plot.to_csv(os.path.join(output_dir, f"{immunization}_{filename}_median_plot_data.csv"),
                       index=False)

        # 7) Plot
        ax.scatter(df_plot['Spike_AS_Position'], df_plot['Polyreactivity'],
                   color=color_map.get(immunization), s=10, label=f"{immunization} polyreactivity", alpha=0.8)

        # 8) Highlight sites
        for pos in df_plot.loc[
            df_plot['Spike_AS_Position'].astype(str).isin(sites_to_show),
            'Spike_AS_Position'
        ]:
            ax.hlines(y=0, xmin=pos-0.5, xmax=pos+0.5, color='black', linewidth=10)

    # Final formatting
    ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    plt.title(title, fontsize=14)
    ax.set_xlabel('Spike AA Position', fontsize=16)
    ax.set_ylabel('Polyreactivity Score\n(median × count)', fontsize=16)

    handles, labels = ax.get_legend_handles_labels()
    unique = dict(zip(labels, handles))
    ax.legend(unique.values(), unique.keys(),
              loc='center left', bbox_to_anchor=(1.01, 0.5),
              frameon=False, fontsize=10)

    xt = np.linspace(sub_df['Spike_AS_Position'].min(),
                     sub_df['Spike_AS_Position'].max(), 20).astype(int)
    ax.set_xticks(xt)
    ax.set_xticklabels([str(x) for x in xt])
    ax.yaxis.set_major_locator(MaxNLocator(nbins=9, prune='both'))
    fig.tight_layout()
    fig.savefig(os.path.join(r"/Users/lucaschlotheuber/Desktop", filename + ".png"), dpi=300)
    plt.show()


# --- Call plot function twice ---

# 1) Polyclonal vs Neutralizing
plot_by_subset(
    df_filtered[df_filtered['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    title="Calibration",
    filename="escape_binding_polyclonal_neutralizing"
)

# 2) Wildtype vs Mutant
plot_by_subset(
    df_filtered[df_filtered['immunization'].isin(['wildtype_RBD', 'Mutant_RBD'])],
    title="Immunization",
    filename="escape_binding_wildtype_mutant"
)


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Parameters
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 0

# Filter dataset
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Sites to highlight
sites_to_show = list(map(str,
    [455,456,472,473,484,485,486,490,496,499] +
    list(range(394,414)) + list(range(484,505))
))
df_filtered = df_filtered.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Define plotting function
def plot_by_subset(sub_df, title, filename):
    fig, ax = plt.subplots(figsize=(10, 6))
    color_map = {
        'Polyclonal_Ab':'darkorange',
        'Neutralizing_Ab':'red',
        'wildtype_RBD':'darkblue',
        'Mutant_RBD':'#004c4c'
    }

    for immunization in sub_df['immunization'].unique():
        if immunization == 'Library_ctrl':
            continue

        sub = sub_df[sub_df['immunization']==immunization]

        # 1) Filter to non-synonymous only
        sub_ns = sub[sub['Type_of_Mutation'] == 'NON-SYNOM']

        if sub_ns.empty:
            continue

        # 2) Group: median & count
        grouped = sub_ns.groupby('Spike_AS_Position')['Enrichment_Ratio']
        median_er = grouped.median()
        breadth = grouped.count()

        # 3) Smooth median and count separately
        smoothed_median = median_er.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()
        smoothed_count = breadth.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()
        
        # 4) Polyreactivity = smoothed_median * smoothed_count
        poly = smoothed_median * smoothed_count


        # 5) Prepare for plotting
        idx = np.arange(poly.index.min(), poly.index.max() + 1)
        df_plot = pd.DataFrame({'Spike_AS_Position': idx}).set_index('Spike_AS_Position')
        df_plot['Polyreactivity'] = poly.reindex(idx, fill_value=0)
        df_plot = df_plot.reset_index()

        # 6) Save CSV
        df_plot.to_csv(os.path.join(output_dir, f"{immunization}_{filename}_median_plot_data.csv"),
                       index=False)

        # 7) Plot
        ax.scatter(df_plot['Spike_AS_Position'], df_plot['Polyreactivity'],
                   color=color_map.get(immunization), s=10, label=f"{immunization} polyreactivity", alpha=0.8)

        # 8) Highlight sites
        for pos in df_plot.loc[
            df_plot['Spike_AS_Position'].astype(str).isin(sites_to_show),
            'Spike_AS_Position'
        ]:
            ax.hlines(y=0, xmin=pos-0.5, xmax=pos+0.5, color='black', linewidth=10)

    # Final formatting
    ax.axhline(0, color='gray', linestyle='--', linewidth=1)
    plt.title(title, fontsize=14)
    ax.set_xlabel('Spike AA Position', fontsize=16)
    ax.set_ylabel('Polyreactivity Score\n(median × count)', fontsize=16)

    handles, labels = ax.get_legend_handles_labels()
    unique = dict(zip(labels, handles))
    ax.legend(unique.values(), unique.keys(),
              loc='center left', bbox_to_anchor=(1.01, 0.5),
              frameon=False, fontsize=10)

    xt = np.linspace(sub_df['Spike_AS_Position'].min(),
                     sub_df['Spike_AS_Position'].max(), 20).astype(int)
    ax.set_xticks(xt)
    ax.set_xticklabels([str(x) for x in xt])
    ax.yaxis.set_major_locator(MaxNLocator(nbins=9, prune='both'))
    fig.tight_layout()
    fig.savefig(os.path.join(r"/Users/lucaschlotheuber/Desktop", filename + ".png"), dpi=300)
    plt.show()


# --- Call plot function twice ---

# 1) Polyclonal vs Neutralizing
plot_by_subset(
    df_filtered[df_filtered['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    title="Calibration",
    filename="escape_binding_polyclonal_neutralizing"
)

# 2) Wildtype vs Mutant
plot_by_subset(
    df_filtered[df_filtered['immunization'].isin(['wildtype_RBD', 'Mutant_RBD'])],
    title="Immunization",
    filename="escape_binding_wildtype_mutant"
)


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Parameters
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 0

# Filter dataset
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Sites to highlight
sites_to_show = list(map(str,
    [455,456,472,473,484,485,486,490,496,499] +
    list(range(394,414)) + list(range(484,505))
))
df_filtered = df_filtered.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Compute global max of median binding (>1)
binding_meds = []
for imm in df_filtered['immunization'].unique():
    if imm == 'Library_ctrl': continue
    med_series = (
        df_filtered[df_filtered['immunization']==imm]
        .loc[lambda d: d['Enrichment_Ratio'] > 1]
        .groupby('Spike_AS_Position')['Enrichment_Ratio']
        .median()
    )
    if not med_series.empty:
        binding_meds.append(med_series.max())
global_binding_max = max(binding_meds) if binding_meds else 2.0

# Prepare plotting
immunizations = [imm for imm in df_filtered['immunization'].unique() if imm!='Library_ctrl']
n_imms = len(immunizations)
bar_width = 0.8 / n_imms  # total bar width per group

fig, ax = plt.subplots(figsize=(12, 6))
color_map = {
    'Polyclonal_Ab':'darkorange',
    'Neutralizing_Ab':'red',
    'wildtype_RBD':'darkblue',
    'Mutant_RBD':'#004c4c'
}

# Loop per immunization with offsets
for i, immunization in enumerate(immunizations):
    sub = df_filtered[df_filtered['immunization']==immunization]

    # median per position
    esc = sub[sub['Enrichment_Ratio']<=1]\
        .groupby('Spike_AS_Position')['Enrichment_Ratio']\
        .median()\
        .rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()
    bind = sub[sub['Enrichment_Ratio']>1]\
        .groupby('Spike_AS_Position')['Enrichment_Ratio']\
        .median()\
        .rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()

    if esc.empty and bind.empty:
        continue

    # transform
    def transform_val(y):
        if y < 1:
            return -(1 - y)
        elif y == 1:
            return 0
        else:
            return (y - 1) / (global_binding_max - 1)
        # Raw medians
    esc_raw = sub[sub['Enrichment_Ratio'] <= 1]\
        .groupby('Spike_AS_Position')['Enrichment_Ratio']\
        .median()
    bind_raw = sub[sub['Enrichment_Ratio'] > 1]\
        .groupby('Spike_AS_Position')['Enrichment_Ratio']\
        .median()

    # Rolling smoothed
    esc_smoothed = esc_raw.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()
    bind_smoothed = bind_raw.rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()


    # x-range
    positions = sorted(set(esc_t.index).union(bind_t.index))
    pos_min, pos_max = positions[0], positions[-1]
    idx = np.arange(pos_min, pos_max + 1)

    # assemble df_plot
    df_plot = pd.DataFrame({'Spike_AS_Position': idx}).set_index('Spike_AS_Position')
    df_plot['Escape']  = esc_t.reindex(idx, fill_value=0)
    df_plot['Binding'] = bind_t.reindex(idx, fill_value=0)
    df_plot = df_plot.reset_index()

    # save CSV
    df_plot.to_csv(os.path.join(output_dir,
                      f"{immunization}_median_plot_data.csv"),
                   index=False)

    # bar positions offset
    offset = (i - n_imms/2) * bar_width + bar_width/2
    x_vals = df_plot['Spike_AS_Position'] + offset
    
    # bars
    ax.bar(x_vals,
           df_plot['Escape'],
           width=bar_width,
           color=color_map[immunization],
           alpha=0.8,
           label=f"{immunization} escape")
    ax.bar(x_vals,
           df_plot['Binding'],
           width=bar_width,
           color=color_map[immunization],
           alpha=0.4,
           label=f"{immunization} binding")

    # overlay lines connecting the ends of bars
    ax.plot(x_vals, df_plot['Escape'], color=color_map[immunization], linewidth=1.5)
    ax.plot(x_vals, df_plot['Binding'], color=color_map[immunization], linewidth=1.5, linestyle='--')


# formatting
ax.set_ylim(-1, 0.4)
ax.axhline(0, color='gray', linestyle='--', linewidth=1)
plt.title('Escape (solid) vs Binding (faded) Bar Plot', fontsize=14)
ax.set_xlabel('Spike AA Position', fontsize=16)
ax.set_ylabel('Transformed Enrichment', fontsize=16)

# legend outside right
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, title='Immunization',
          loc='center left', bbox_to_anchor=(1, 0.5),
          frameon=False, fontsize=10)

# x-ticks
xt = np.linspace(df_filtered['Spike_AS_Position'].min(),
                 df_filtered['Spike_AS_Position'].max(), 20).astype(int)
ax.set_xticks(xt)
ax.set_xticklabels([str(x) for x in xt])

ax.yaxis.set_major_locator(MaxNLocator(nbins=9, prune='both'))
fig.tight_layout()
fig.subplots_adjust(right=0.8)  # make room for legend
fig.savefig(os.path.join(
    r"/Users/lucaschlotheuber/Desktop", "escape_binding_barplot.png"),
    dpi=300
)
plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Parameters
ROLLING_WINDOW = 5
ENRICHMENT_THRESHOLD = 0

# Filter dataset
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Sites to highlight
sites_to_show = list(map(str,
    [455,456,472,473,484,485,486,490,496,499] +
    list(range(394,414)) + list(range(484,505))
))
df_filtered = df_filtered.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Compute global max of median binding (>1)
binding_meds = []
for imm in df_filtered['immunization'].unique():
    if imm == 'Library_ctrl': continue
    med_series = (
        df_filtered[df_filtered['immunization']==imm]
        .loc[lambda d: d['Enrichment_Ratio']>1]
        .groupby('Spike_AS_Position')['Enrichment_Ratio']
        .median()
    )
    if not med_series.empty:
        binding_meds.append(med_series.max())
global_binding_max = max(binding_meds) if binding_meds else 2.0

fig, ax = plt.subplots(figsize=(10, 6))

base_colors = {
    'Polyclonal_Ab':'darkorange',
    'Neutralizing_Ab':'red',
    'wildtype_RBD':'darkblue',
    'Mutant_RBD':'#004c4c'
}

for immunization in df_filtered['immunization'].unique():
    if immunization == 'Library_ctrl':
        continue

    sub = df_filtered[df_filtered['immunization']==immunization]
    esc = sub[sub['Enrichment_Ratio']<=1]\
        .groupby('Spike_AS_Position')['Enrichment_Ratio']\
        .median()\
        .rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()
    bind = sub[sub['Enrichment_Ratio']>1]\
        .groupby('Spike_AS_Position')['Enrichment_Ratio']\
        .median()\
        .rolling(ROLLING_WINDOW, center=True, min_periods=1).mean()

    # transform
    def transform_val(y):
        if y < 1:
            return -(1 - y)
        elif y == 1:
            return 0
        else:
            return (y - 1) / (global_binding_max - 1)

    esc_t  = esc.apply(transform_val)
    bind_t = bind.apply(transform_val)

    # x-range
    positions = sorted(set(esc_t.index).union(bind_t.index))
    idx = np.arange(positions[0], positions[-1]+1)

    df_plot = pd.DataFrame({'Spike_AS_Position': idx}).set_index('Spike_AS_Position')
    df_plot['Escape']  = esc_t.reindex(idx, fill_value=0)
    df_plot['Binding'] = bind_t.reindex(idx, fill_value=0)
    df_plot = df_plot.reset_index()

    # save CSV
    df_plot.to_csv(os.path.join(output_dir,
                    f"{immunization}_median_plot_data.csv"), index=False)

    # get colors
    base = colors.to_rgb(base_colors[immunization])
    dark_escape = tuple(c * 0.5 for c in base)  # darker tone
    light_binding = base

    # plot
    ax.plot(df_plot['Spike_AS_Position'], df_plot['Escape'],
            color=dark_escape, linestyle='-', linewidth=1,
            label=f"{immunization} escape")
    ax.plot(df_plot['Spike_AS_Position'], df_plot['Binding'],
            color=light_binding, linewidth=1,
            label=f"{immunization} binding")

    # highlight
    for pos in df_plot.loc[
        df_plot['Spike_AS_Position'].astype(str).isin(sites_to_show),
        'Spike_AS_Position'
    ]:
        ax.hlines(y=-1, xmin=pos-0.5, xmax=pos+0.5,
                  color='black', linewidth=10)

# formatting
ax.set_ylim(-1, 1)
ax.axhline(0, color='gray', linestyle='--', linewidth=1)
plt.title('Escape (darker) vs Binding per Immunization', fontsize=14)
ax.set_xlabel('Spike AA Position', fontsize=16)
ax.set_ylabel('Transformed Enrichment', fontsize=16)

handles, labels = ax.get_legend_handles_labels()
plt.legend(handles, labels, title='Immunization',
           loc='upper center', ncol=2, frameon=False, fontsize=10)

xt = np.linspace(df_filtered['Spike_AS_Position'].min(),
                 df_filtered['Spike_AS_Position'].max(), 20).astype(int)
ax.set_xticks(xt)
ax.set_xticklabels([str(x) for x in xt])

ax.yaxis.set_major_locator(MaxNLocator(nbins=9, prune='both'))
fig.tight_layout()
fig.savefig(os.path.join(r"/Users/lucaschlotheuber/Desktop", "escape_binding_plot.png"), dpi=300)
plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
from Bio import SeqIO
import altair as alt

# Load FASTA sequence (Wuhan reference)
fasta_file = r'/Users/lucaschlotheuber/Desktop/ETH/RBD201_DMS1.fa'
for record in SeqIO.parse(fasta_file, "fasta"):
    wuhan_sequence = str(record.seq)
    break

# Load and clean the Excel data
file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'
df_total = pd.read_excel(file_path, usecols=[
    "DMS_RBD_AS_position", "Spike_AS_Position", "Count_of_Base",
    "Amino_Acid", "Type_of_Mutation", "Enrichment_Ratio",
    "barcode", "immunization", "condition", "Total_Reads"
])
df_total["Spike_AS_Position"] -= 5  # Adjust 336 -> 331

# Clean up
df_total = df_total.dropna(subset=['Enrichment_Ratio', 'Amino_Acid'])
df_total = df_total[df_total["Total_Reads"] > 500]
df_total = df_total[df_total["Amino_Acid"] != '*']  # Exclude stop codons

# Add Wuhan reference
immunization = "Wuhan_Sequence"
barcode = "Wuhan_Barcode"
data_wuhan = [{
    'DMS_RBD_AS_position': pos,
    'Spike_AS_Position': pos + 330,
    'Amino_Acid': aa,
    'immunization': immunization,
    'barcode': barcode,
    'Enrichment_Ratio': 1,
} for pos, aa in enumerate(wuhan_sequence, start=1) if aa != '*']
df_wuhan = pd.DataFrame(data_wuhan)
df_total = pd.concat([df_total, df_wuhan], ignore_index=True)

# Filter
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 364) &
    (df_total['Type_of_Mutation'] == 'NON-SYNOM')
]

# Aggregate
df_heatmap = df_filtered.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization'], as_index=False
).agg({'Enrichment_Ratio': 'mean'})

# Filter out zeros or negatives
df_heatmap = df_heatmap[df_heatmap['Enrichment_Ratio'] > 0]

# Add log2 enrichment
# Invert values between 0 and 1 for log2 enrichment, leave others unchanged
df_heatmap['Enrichment_Ratio_transformed'] = df_heatmap['Enrichment_Ratio'].apply(
    lambda x: 1 / x if 0 < x < 1 else x
)

# Calculate log2 enrichment from transformed values
df_heatmap['log2_enrichment'] = df_heatmap['Enrichment_Ratio_transformed'].apply(
    lambda x: round(np.log2(x), 3)
)

# Split into log2 ranges
df_low = df_heatmap[df_heatmap['log2_enrichment'] <= 0].copy()   # log2 ≤ 0 → ratio ≤ 1
df_high = df_heatmap[df_heatmap['log2_enrichment'] > 0].copy()   # log2 > 0  → ratio > 1

# Output directory
output_dir = "heatmap_output_log2_split"
os.makedirs(output_dir, exist_ok=True)

# Function to generate and save log2-based heatmap
def generate_heatmap(df_subset, enrichment_range_label):
    for immun in df_subset['immunization'].unique():
        df_im = df_subset[df_subset['immunization'] == immun].copy()

        # Tick marks every 5
        tick_vals = sorted(df_im['Spike_AS_Position'].unique())
        tick_vals = [x for x in tick_vals if x % 5 == 0]

        heatmap = alt.Chart(df_im).mark_rect().encode(
            x=alt.X('Spike_AS_Position:O', title='Spike AA Position',
                    axis=alt.Axis(values=tick_vals)),
            y=alt.Y('Amino_Acid:N', title='Mutated AA'),
            color=alt.Color('Enrichment_Ratio:Q',
                scale=alt.Scale(scheme='redblue', domain=[0.1, 12]),
                title='Enrichment Ratio'),
            tooltip=[
                'Spike_AS_Position:O',
                'Amino_Acid:N',
                'log2_enrichment:Q'
            ]
        ).properties(
            title=f"log₂ {enrichment_range_label} Heatmap - {immun}",
            width=800,
            height=400
        ).configure_axis(
            labelFontSize=12,
            titleFontSize=14
        ).configure_title(
            fontSize=18,
            anchor='start'
        )

        output_file = os.path.join(output_dir, f"{immun}_heatmap_log2_{enrichment_range_label.replace(' ', '_')}.html")
        heatmap.save(output_file)
        print(f"Saved log2 {enrichment_range_label} heatmap for {immun} to {output_file}")

# Generate both sets
generate_heatmap(df_low, "0_to_1")
generate_heatmap(df_high, "1_to_max")


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

def transform_y(y, threshold=1, max_val=400):
    """
    Piecewise linear transform:
    - y in [0, threshold] maps to [0, 0.5]
    - y in [threshold, max_val] maps to [0.5, 1]
    """
    y = np.array(y)
    y_new = np.empty_like(y, dtype=float)
    mask_low = y <= threshold
    y_new[mask_low] = 0.5 * (y[mask_low] / threshold)
    mask_high = y > threshold
    y_new[mask_high] = 0.5 + 0.5 * ((y[mask_high] - threshold) / (max_val - threshold))
    return y_new

fig, ax = plt.subplots(figsize=(10, 6))

for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':
        continue

    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })

    # Fill missing enrichment values
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # No smoothing
    df_filtered_im['High_Enrichment'] = df_filtered_im['Enrichment_Ratio'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Transform enrichment for plotting
    df_filtered_im['Transformed_Enrichment'] = transform_y(df_filtered_im['Enrichment_Ratio'], threshold=1, max_val=400)

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()

    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Transformed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding",
        show_col="show_site",
        ax=ax,
        linewidth=1.5,
        color=color_map.get(immunization, 'black')
    )
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    for _, site_data in highlight_sites.iterrows():
        ax.hlines(
            y=0,
            xmin=site_data['Spike_AS_Position'] - 0.5,
            xmax=site_data['Spike_AS_Position'] + 0.5,
            color='black',
            linestyle='-',
            linewidth=10
        )

# Set transformed y-axis
ax.set_ylim(0, 1)
yticks_original = [0, 0.2, 0.5, 1, 50, 100, 200, 400]
yticks_transformed = transform_y(yticks_original, threshold=1, max_val=400)
ax.set_yticks(yticks_transformed)
ax.set_yticklabels([f"{v}%" if v <= 1 else f"{int(v)}" for v in yticks_original])

ax.axhline(y=transform_y(1, threshold=1, max_val=400), color='grey', linestyle='--', linewidth=1)

# Plot title, labels, legend
plt.title('Enrichment Ratios for Different Immunizations', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Antibody Repertoire - Enrichment \n (Sum Polyreactivity)', fontsize=16)

handles, labels = ax.get_legend_handles_labels()
group_1_labels = ['Polyclonal_Ab', 'Neutralizing_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']
group_1_handles = [handles[labels.index(label)] for label in group_1_labels if label in labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels if label in labels]
handles = group_1_handles + group_2_handles
labels = group_1_labels + group_2_labels
plt.legend(handles, labels, title="Single Droplet Repertoire", loc='upper center', fontsize=11, frameon=False, handlelength=2, handleheight=1, title_fontsize=11, markerscale=8)

xticks = np.linspace(df_filtered_agg['Spike_AS_Position'].min(), df_filtered_agg['Spike_AS_Position'].max(), 20).astype(int)
ax.set_xticks(xticks)
ax.set_xticklabels([str(x) if i % 2 == 0 else '' for i, x in enumerate(xticks)])

ax.yaxis.set_major_locator(MaxNLocator(integer=True, prune='lower', nbins=8))

all_vals = []
for imm in df_filtered_agg['immunization'].unique():
    if imm == "Library_ctrl": continue
    tmp = (
        df_filtered_agg.query(f'immunization == "{imm}"')
        .groupby("Spike_AS_Position", as_index=False)["Enrichment_Ratio"]
        .sum()
        .fillna(method="bfill").fillna(method="ffill")["Enrichment_Ratio"]
        .values
    )
    all_vals.append(tmp)
all_vals = np.concatenate(all_vals)
YMIN, YMAX = all_vals.min(), all_vals.max()

# 2) Recompute and reset all y-ticks and limits using the real range
#    so that 0→1→YMAX maps to 0→0.5→1
yticks_original = [YMIN, 0.2, 0.5, 1, 50, 100, 200, YMAX]
yticks_transformed = transform_y(yticks_original, threshold=1, max_val=YMAX)

ax.set_yticks(yticks_transformed)
ax.set_yticklabels(
    [f"{v:.1f}" if v <= 1 else f"{int(v)}" for v in yticks_original]
)

# 3) Set the exact transformed limits
ax.set_ylim(transform_y(YMIN, threshold=1, max_val=YMAX),
            transform_y(YMAX, threshold=1, max_val=YMAX))

# 4) Draw the "x-axis" at y=1 (transformed)
ax.axhline(
    y=transform_y(1, threshold=1, max_val=YMAX),
    color="grey", linestyle="--", linewidth=1
)

# 5) Redraw and save
fig.tight_layout()
fig.savefig(plot_file_path, format="png")
plt.show()

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator

# You need to have dmslogo installed and imported properly
import dmslogo.line

# Set pandas option
pd.set_option('future.no_silent_downcasting', True)

# Constants
ROLLING_WINDOW = 20  # smoothing window size
ENRICHMENT_THRESHOLD = 1  # threshold for high enrichment
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Assuming df_total is your initial DataFrame loaded somewhere before this code.

# Filter dataset for quality and mutation type
df_filtered = df_total[(df_total['Spike_AS_Position'] > 364) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position, amino acid, barcode, immunization
df_filtered_agg = df_filtered.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False).agg({'Enrichment_Ratio': 'mean'})

# Define target sites (positions to highlight)
sites_to_show = list(map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))    # R13 peptide sequence
))

# Annotate with site label and flag to show site
df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Create figure and axes for combined plot
fig, ax = plt.subplots(figsize=(10, 6))

# Color mapping for immunizations
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'darkblue',
    'Mutant_RBD': '#004c4c'
}

# Loop through immunizations and plot
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':
        continue  # skip control
    if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
        continue
    if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
        continue

    print(f"Processing immunization: {immunization}")

    # Filter for this immunization
    df_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratio by position (sum over amino acid/barcode)
    df_im = df_im.groupby('Spike_AS_Position', as_index=False).agg({'Enrichment_Ratio': 'mean'})

    # Fill any NaNs in Enrichment_Ratio
    df_im['Enrichment_Ratio'] = df_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Smooth using rolling mean
    df_im['Smoothed_Enrichment'] = df_im['Enrichment_Ratio'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Interpolate missing smoothed values
    df_im['Smoothed_Enrichment'] = df_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Show summary stats after smoothing
    print(f"  Smoothed_Enrichment range: min = {df_im['Smoothed_Enrichment'].min():.2f}, "
          f"max = {df_im['Smoothed_Enrichment'].max():.2f}, "
          f"mean = {df_im['Smoothed_Enrichment'].mean():.2f}, "
          f"95th percentile = {np.percentile(df_im['Smoothed_Enrichment'].dropna(), 95):.2f}")


    # Identify high enrichment positions
    df_im['High_Enrichment'] = df_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_im['High_Enrichment'] = df_im['High_Enrichment'].fillna(False).astype(bool)

    # Special filling for 'Neutralizing_Ab' if needed
    if immunization == 'Neutralizing_Ab':
        df_im['Smoothed_Enrichment'] = df_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV for this immunization
    csv_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_im.to_csv(csv_path, index=False)

    # Reindex to cover full range of positions for visualization
    full_range = range(df_im['Spike_AS_Position'].min(), df_im['Spike_AS_Position'].max() + 1)
    df_im = df_im.set_index('Spike_AS_Position').reindex(full_range).reset_index()

    # Re-assign show_site based on sites_to_show
    df_im = df_im.assign(
        show_site=lambda x: x['Spike_AS_Position'].astype(str).isin(sites_to_show)
    )
    df_im['High_Enrichment'] = df_im['High_Enrichment'].fillna(False).astype(bool)

    # Plot using dmslogo.line.draw_line on the shared axis
    dmslogo.line.draw_line(
        df_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",  # no individual title
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding",
        show_col="show_site",
        ax=ax,
        linewidth=2.5,
        color=color_map.get(immunization, 'black')
    )

    # Add empty plot to generate legend entry
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites on the x-axis with thick horizontal lines at y=0
    highlight_sites = df_im[df_im['show_site']]
    for _, site_row in highlight_sites.iterrows():
        ax.hlines(
            y=0,
            xmin=site_row['Spike_AS_Position'] - 0.5,
            xmax=site_row['Spike_AS_Position'] + 0.5,
            color='black',
            linestyle='-',
            linewidth=10
        )

# Set y limits
ax.set_ylim(0, 40)

# Titles and labels
plt.title('Smoothed Enrichment Ratios for Different Immunizations', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Antibody Repertoire - Enrichment \n (Average Polyreactivity)', fontsize=16)

# Legend - order and grouping
handles, labels = ax.get_legend_handles_labels()

#group_1_labels = ['Polyclonal_Ab', 'Neutralizing_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

#group_1_handles = [handles[labels.index(label)] for label in group_1_labels if label in labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels if label in labels]

#handles = group_1_handles + group_2_handles
#labels = group_1_labels + group_2_labels

handles =  group_2_handles
labels = group_2_labels

plt.legend(
    handles, labels,
    title="Single Droplet Repertoire",
    loc='upper center',
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    title_fontsize=11,
    markerscale=8
)

# Set x-axis ticks: 20 ticks across the range, label every other one
xticks = np.linspace(df_filtered_agg['Spike_AS_Position'].min(), df_filtered_agg['Spike_AS_Position'].max(), 20).astype(int)
ax.set_xticks(xticks)
ax.set_xticklabels([str(x) if i % 2 == 0 else '' for i, x in enumerate(xticks)])

# Set y-axis major locator for better integer ticks
ax.yaxis.set_major_locator(MaxNLocator(integer=True, prune='lower', nbins=8))

# Save figure
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", "combined_Immunization_Enrichment_plot.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')

# Show plot
plt.show()


In [None]:
print(f"Summary statistics for {immunization}:")
print(df_im[['Enrichment_Ratio', 'Smoothed_Enrichment']].describe())


In [None]:
# Combine all Smoothed_Enrichment data for histogram
all_smoothed = []

for immunization in df_filtered_agg['immunization'].unique():
    if immunization in ['Library_ctrl', 'Polyclonal_Ab', 'Neutralizing_Ab']:
        continue
    df_im = df_filtered_agg.query(f'immunization == "{immunization}"')
    df_im = df_im.groupby('Spike_AS_Position', as_index=False).agg({'Enrichment_Ratio': 'mean'})
    df_im['Smoothed_Enrichment'] = df_im['Enrichment_Ratio'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    all_smoothed.extend(df_im['Smoothed_Enrichment'].dropna().tolist())

# Plot histogram
plt.figure(figsize=(8, 4))
plt.hist(all_smoothed, bins=50, color='skyblue', edgecolor='black')
plt.title("Distribution of Smoothed Enrichment Values")
plt.xlabel("Smoothed Enrichment")
plt.ylabel("Frequency")
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 20  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & 
                       (df_total['Enrichment_Ratio'] > 1)]  # Only include Enrichment_Ratio > 1

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394,414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)

sites_to_show = list(sites_to_show)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(10, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':  # Skip 'Library_ctrl'
        continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    # Count number of unique barcodes in this immunization
    num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()
    
    # Aggregate and normalize by barcode count
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })
    df_filtered_im['Enrichment_Ratio'] /= num_barcodes  # Normalize
    if df_filtered_im['Enrichment_Ratio'].isna().any():
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')


    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Handle missing data by interpolating missing values (linear interpolation)
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Identify clusters of high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex for visualization (after interpolation to avoid issues)
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()

    print(f"Target sites to highlight: {list(sites_to_show)}")

    # Check the data type of Spike_AS_Position and sites_to_show for comparison
    print(f"Data type of 'Spike_AS_Position': {df_filtered_agg['Spike_AS_Position'].dtype}")
    print(f"Data type of 'sites_to_show': {type(sites_to_show)}")

    # Ensure High_Enrichment is still boolean
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    df_filtered_im = df_filtered_im.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )
    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))
    # Use dmslogo.line.draw_line but plot on the same axes with specified color
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",  # Remove individual titles for each plot
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding",
        show_col="show_site",  # Highlight high enrichment clusters
        ax=ax,  # Pass the same axes object for all plots
        linewidth=2.5,
        color=color_map.get(immunization, 'black')  # Get the color for the immunization (default to black)
    )
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)
    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    for _, site_data in highlight_sites.iterrows():
        # Use ax.hlines to draw horizontal lines only where there are sites to show
        ax.hlines(
            y=0,  # Set the y-position of the line to 0 (or a small value near the bottom of the plot)
            xmin=site_data['Spike_AS_Position'] - 0.5,  # Start of the line (slightly before the site)
            xmax=site_data['Spike_AS_Position'] + 0.5,  # End of the line (slightly after the site)
            color='black',  # Line color
            linestyle='-',  # Line style
            linewidth=10  # Line width
        )

# Set the y-axis limit to 200
ax.set_ylim(0, 25)

# After all lines are drawn, adjust plot settings
plt.title('Smoothed Enrichment Ratios for Different Immunizations', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Antibody Repertoire - Enrichment \n (Median per-cell Polyreactivity)', fontsize=16)

# Add the legend in the top right
# Create a grouped legend
handles, labels = ax.get_legend_handles_labels()

#group_1_labels = ['Polyclonal_Ab', 'Neutralizing_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

#group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
#handles = group_1_handles + group_2_handles
#labels = group_1_labels + group_2_labels

handles = group_2_handles
labels = group_2_labels

# Add the legend to the top right
plt.legend(handles, labels, title="Single Droplet Repertoire", loc='upper right', fontsize=11, frameon=False, handlelength=2, handleheight=1, title_fontsize=11, markerscale=8)

# Set the x-axis ticks explicitly to 20 ticks across the range, and label every other one
xticks = np.linspace(df_filtered_agg['Spike_AS_Position'].min(), df_filtered_agg['Spike_AS_Position'].max(), 20).astype(int)
ax.set_xticks(xticks)

# Set the labels to only show for every other tick
ax.set_xticklabels([str(x) if i % 2 == 0 else '' for i, x in enumerate(xticks)])


ax.yaxis.set_major_locator(MaxNLocator(integer=True, prune='lower', nbins=8))
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot.png")
fig.tight_layout() 
fig.savefig(plot_file_path, format='png')


# Display the combined plot
plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 1)]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})


# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue
    faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
    faint_alpha = 0.3
    faint_linewidth = 1.5

    #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
     #   continue
    #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
     #   continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Normalize by number of barcodes
    num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'mean'
    })
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Apply smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(),
              df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()

    # Safe log transform (ignore or remove zero/negative values)
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].apply(
        lambda x: np.log10(x) if x > 0 else np.nan
    )

    # Mark sites to highlight
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

    # Plot
    
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        show_col="show_site",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    for _, site_data in highlight_sites.iterrows():
        ax.hlines(
            y=0.76,  # set near bottom
            xmin=site_data['Spike_AS_Position'] - 0.5,
            xmax=site_data['Spike_AS_Position'] + 0.5,
            color='black',
            linestyle='-',
            linewidth=10
        )

# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=0.7,
            top=np.nanmax(df_filtered_im['Smoothed_Enrichment']) + 0.3)

plt.title('Smoothed Enrichment Ratios for Different Immunizations', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Single-cell Polyreactivity\n Log10 average binding ratio', fontsize=16)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']



plt.legend(handles, labels, title="Immunization",
           loc='upper right', fontsize=11, frameon=False,
           handlelength=2, handleheight=1, title_fontsize=11, markerscale=1)


# X-axis ticks
# Set x-ticks every 10 positions
xtick_start = int(np.floor(df_filtered_agg['Spike_AS_Position'].min() / 10.0) * 10)
xtick_end = int(np.ceil(df_filtered_agg['Spike_AS_Position'].max() / 10.0) * 10)
xticks = np.arange(xtick_start, xtick_end + 1, 10)
ax.set_xticks(xticks)
ax.set_xticklabels([str(x) for x in xticks], rotation=0)

ax.yaxis.set_major_locator(MaxNLocator(integer=True, prune='lower', nbins=8))

# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot_final.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')

plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, MultipleLocator
from matplotlib import font_manager
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Enrichment_Ratio'] >= 1) &
    (df_total['Amino_Acid'] != "*")
].copy()

# Define sites to show (as strings for consistent matching later)
sites_to_show = list(map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
))

# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered.groupby(
    ['immunization', 'Spike_AS_Position', 'barcode'], as_index=False
).agg({
    'Enrichment_Ratio': 'mean'
})

# Add useful columns for plotting
df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create figure and axis for plotting
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Immunizations to fade in plot
faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
faint_alpha = 0.3

# Loop through each immunization excluding 'Library_ctrl' and 'Neutralizing_Ab' (skip Neutralizing_Ab twice? Possibly intentional)
for immunization in df_filtered_agg['immunization'].unique():
    if immunization in ['Library_ctrl', 'Neutralizing_Ab']:
        continue

    print(immunization)
    df_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Group by position to get median enrichment ratio
    df_im_pos = df_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'mean'
    })

    # Fill NaNs forward/backward
    df_im_pos['Enrichment_Ratio'] = df_im_pos['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log2 transform, ignoring zero or negative values
    df_im_pos['Log2_Enrichment'] = df_im_pos['Enrichment_Ratio'].apply(lambda x: np.log2(x) if x > 0 else np.nan)
    print(df_im_pos[['Spike_AS_Position', 'Enrichment_Ratio']].head(20))

    # Smooth the log2 enrichment with rolling mean
    df_im_pos['Smoothed_Enrichment'] = df_im_pos['Log2_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_im_pos['Smoothed_Enrichment'] = df_im_pos['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')
    print(df_im_pos[['Spike_AS_Position', 'Smoothed_Enrichment']].head(20))

    # Flag positions with high enrichment
    df_im_pos['High_Enrichment'] = df_im_pos['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_im_pos['High_Enrichment'] = df_im_pos['High_Enrichment'].fillna(False).astype(bool)

    # Save CSV for this immunization
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_im_pos.to_csv(csv_file_path, index=False)

    # Reindex to fill missing positions (to ensure continuous x-axis)
    df_im_pos = df_im_pos.set_index('Spike_AS_Position').reindex(
        range(df_im_pos['Spike_AS_Position'].min(), df_im_pos['Spike_AS_Position'].max() + 1)
    ).reset_index()

    # Fill missing 'High_Enrichment' and assign 'show_site' again for reindexed dataframe
    df_im_pos['High_Enrichment'] = df_im_pos['High_Enrichment'].fillna(False).astype(bool)
    df_im_pos['show_site'] = df_im_pos['Spike_AS_Position'].astype(str).isin(sites_to_show)

    print(df_im_pos[['Spike_AS_Position', 'show_site']].head(10))

    # Calculate rolling min and max for shaded range plot
    df_im_pos['Smoothed_Enrichment_min'] = df_im_pos['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).min()
    df_im_pos['Smoothed_Enrichment_max'] = df_im_pos['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).max()

    # Plot shaded area
    ax.fill_between(
        df_im_pos['Spike_AS_Position'],
        df_im_pos['Smoothed_Enrichment_min'],
        df_im_pos['Smoothed_Enrichment_max'],
        color=color_map.get(immunization, 'black'),
        alpha=0.1
    )

    # Plot line (using dmslogo.line.draw_line - assuming dmslogo is imported and available)
    dmslogo.line.draw_line(
        df_im_pos,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        show_col="show_site",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )

    # Apply transparency if immunization in faint list
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)

    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Print smoothed enrichment every 20th position
    print(f"\nSmoothed Enrichment values for immunization '{immunization}':")
    positions_to_print = range(df_im_pos['Spike_AS_Position'].min(),
                               df_im_pos['Spike_AS_Position'].max() + 1, 20)
    for pos in positions_to_print:
        val_row = df_im_pos[df_im_pos['Spike_AS_Position'] == pos]
        if not val_row.empty:
            val = val_row['Smoothed_Enrichment'].values[0]
            print(f"Position {pos}: {val:.3f}")
        else:
            print(f"Position {pos}: (no data)")

    # Highlight specific sites
    highlight_sites = df_im_pos[df_im_pos['show_site']]
    for _, site_data in highlight_sites.iterrows():
        ax.hlines(
            y=0.2,  # set near bottom of plot
            xmin=site_data['Spike_AS_Position'] - 0.5,
            xmax=site_data['Spike_AS_Position'] + 0.5,
            color='black',
            linestyle='-',
            linewidth=10
        )

# Set y-axis limits
ax.set_ylim(bottom=0.7, top=np.nanmax(df_im_pos['Smoothed_Enrichment']+1.1))

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Polyreactivity\n Log2 AB Variant binding', fontsize=16)

from matplotlib.patches import Patch

# Create legend patch for epitopes
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')

handles, labels = ax.get_legend_handles_labels()
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels if label in labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels if label in labels]

handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels if label in label_map] + ['SARS-CoV-2 Spike antibody epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='upper left',
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks setup
major_locator = MultipleLocator(10)
minor_locator = MultipleLocator(2)

ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))
ax.xaxis.set_minor_locator(minor_locator)

ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
ax.yaxis.set_major_locator(MaxNLocator(prune='lower', nbins=15))
ax.yaxis.set_minor_locator(MultipleLocator(0.1))

# Save plot
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot_final.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')

plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, MultipleLocator
from matplotlib import font_manager
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[
    (df_total['Spike_AS_Position'] > 33 + 331) &
    (df_total['Enrichment_Ratio'] >= 1) &
    (df_total['Amino_Acid'] != "*")
].copy()

# Define sites to show (as strings for consistent matching later)
sites_to_show = list(map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
))

# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered.groupby(
    ['immunization', 'Spike_AS_Position', 'barcode'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Add useful columns for plotting
df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create figure and axis for plotting
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Immunizations to fade in plot
faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
faint_alpha = 0.3

# Loop through each immunization excluding 'Library_ctrl' and 'Neutralizing_Ab' (skip Neutralizing_Ab twice? Possibly intentional)
for immunization in df_filtered_agg['immunization'].unique():
    if immunization in ['Library_ctrl', 'Neutralizing_Ab']:
        continue

    print(immunization)
    df_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Group by position to get median enrichment ratio
    df_im_pos = df_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'mean'
    })

    # Fill NaNs forward/backward
    df_im_pos['Enrichment_Ratio'] = df_im_pos['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log2 transform, ignoring zero or negative values
    df_im_pos['Log2_Enrichment'] = df_im_pos['Enrichment_Ratio'].apply(lambda x: np.log2(x) if x > 0 else np.nan)
    print(df_im_pos[['Spike_AS_Position', 'Enrichment_Ratio']].head(20))

    # Smooth the log2 enrichment with rolling mean
    df_im_pos['Smoothed_Enrichment'] = df_im_pos['Log2_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_im_pos['Smoothed_Enrichment'] = df_im_pos['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')
    print(df_im_pos[['Spike_AS_Position', 'Smoothed_Enrichment']].head(20))

    # Flag positions with high enrichment
    df_im_pos['High_Enrichment'] = df_im_pos['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_im_pos['High_Enrichment'] = df_im_pos['High_Enrichment'].fillna(False).astype(bool)

    # Save CSV for this immunization
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_im_pos.to_csv(csv_file_path, index=False)

    # Reindex to fill missing positions (to ensure continuous x-axis)
    df_im_pos = df_im_pos.set_index('Spike_AS_Position').reindex(
        range(df_im_pos['Spike_AS_Position'].min(), df_im_pos['Spike_AS_Position'].max() + 1)
    ).reset_index()

    # Fill missing 'High_Enrichment' and assign 'show_site' again for reindexed dataframe
    df_im_pos['High_Enrichment'] = df_im_pos['High_Enrichment'].fillna(False).astype(bool)
    df_im_pos['show_site'] = df_im_pos['Spike_AS_Position'].astype(str).isin(sites_to_show)

    print(df_im_pos[['Spike_AS_Position', 'show_site']].head(10))

    # Calculate rolling min and max for shaded range plot
    df_im_pos['Smoothed_Enrichment_min'] = df_im_pos['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).min()
    df_im_pos['Smoothed_Enrichment_max'] = df_im_pos['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).max()

    # Plot shaded area
    ax.fill_between(
        df_im_pos['Spike_AS_Position'],
        df_im_pos['Smoothed_Enrichment_min'],
        df_im_pos['Smoothed_Enrichment_max'],
        color=color_map.get(immunization, 'black'),
        alpha=0.1
    )

    # Plot line (using dmslogo.line.draw_line - assuming dmslogo is imported and available)
    dmslogo.line.draw_line(
        df_im_pos,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        show_col="show_site",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )

    # Apply transparency if immunization in faint list
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)

    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Print smoothed enrichment every 20th position
    print(f"\nSmoothed Enrichment values for immunization '{immunization}':")
    positions_to_print = range(df_im_pos['Spike_AS_Position'].min(),
                               df_im_pos['Spike_AS_Position'].max() + 1, 20)
    for pos in positions_to_print:
        val_row = df_im_pos[df_im_pos['Spike_AS_Position'] == pos]
        if not val_row.empty:
            val = val_row['Smoothed_Enrichment'].values[0]
            print(f"Position {pos}: {val:.3f}")
        else:
            print(f"Position {pos}: (no data)")

    # Highlight specific sites
    highlight_sites = df_im_pos[df_im_pos['show_site']]
    for _, site_data in highlight_sites.iterrows():
        ax.hlines(
            y=0.2,  # set near bottom of plot
            xmin=site_data['Spike_AS_Position'] - 0.5,
            xmax=site_data['Spike_AS_Position'] + 0.5,
            color='black',
            linestyle='-',
            linewidth=10
        )

# Set y-axis limits
ax.set_ylim(bottom=0.7, top=np.nanmax(df_im_pos['Smoothed_Enrichment']+1.1))

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Polyreactivity\n Log2 AB Variant binding', fontsize=16)

from matplotlib.patches import Patch

# Create legend patch for epitopes
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')

handles, labels = ax.get_legend_handles_labels()
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels if label in labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels if label in labels]

handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels if label in label_map] + ['SARS-CoV-2 Spike antibody epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='upper left',
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks setup
major_locator = MultipleLocator(10)
minor_locator = MultipleLocator(2)

ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))
ax.xaxis.set_minor_locator(minor_locator)

ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
ax.yaxis.set_major_locator(MaxNLocator(prune='lower', nbins=15))
ax.yaxis.set_minor_locator(MultipleLocator(0.1))

# Save plot
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot_final.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')

plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))


pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 1)]

# Aggregate enrichment ratio by position
df_filtered_agg = df_total[df_total['Amino_Acid'] != "*"].copy()
df_filtered_agg['Enrichment_Ratio'] = df_aa_agg['Enrichment_Ratio'].clip(lower=1e-3)

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue
    faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
    faint_alpha = 0.3
    faint_linewidth = 1.5

    #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
     #   continue
    #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
     #   continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Normalize by number of barcodes
    num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()

    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'median'
    })
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log transform (ignore or remove zero/negative values)
    df_filtered_im['Log2_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(
        lambda x: np.log2(x) if x > 0 else np.nan
    )
    
    # Apply smoothing on the log2 transformed values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log2_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(),
              df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()



    # Mark sites to highlight
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

        # Rolling min and max for range area
    # Fixed (aligned min/max with smoothed mean)
    df_filtered_im['Smoothed_Enrichment_min'] = df_filtered_im['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).min()
    df_filtered_im['Smoothed_Enrichment_max'] = df_filtered_im['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).max()


    # Plot shaded area for range
    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        df_filtered_im['Smoothed_Enrichment_min'],
        df_filtered_im['Smoothed_Enrichment_max'],
        color=color_map.get(immunization, 'black'),
        alpha=0.1
    )

    # Plot
    
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        show_col="show_site",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    for _, site_data in highlight_sites.iterrows():
        ax.hlines(
            y=0.23,  # set near bottom
            xmin=site_data['Spike_AS_Position'] - 0.5,
            xmax=site_data['Spike_AS_Position'] + 0.5,
            color='black',
            linestyle='-',
            linewidth=10
        )

# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=2.8,
            top=np.nanmax(df_filtered_im['Smoothed_Enrichment']) + 0.5)

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Polyreactivity\n Log2 AB Variant binding', fontsize=16)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='upper left',
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
ax.yaxis.set_major_locator(MaxNLocator(prune='lower', nbins=15))  # no integer=True
ax.yaxis.set_minor_locator(MultipleLocator(0.1))  

# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot_final.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')

plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))


pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    immunization_data = {}
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue
    faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
    faint_alpha = 0.3
    faint_linewidth = 1.5

    #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
     #   continue
    #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
     #   continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Normalize by number of barcodes
    num_barcodes = barcode_counts.get(immunization, 0)
    #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()

    print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
    print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
    print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
    
    #df_filtered_im_agg = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
    #'Enrichment_Ratio': 'median'
    #})
   
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log transform (ignore or remove zero/negative values)
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(
        lambda x: np.log10(x) if x > 0 else np.nan
    )
    
    # Apply smoothing on the log2 transformed values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()

        #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(),
              df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()
    
    # Mark sites to highlight
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

        # Rolling min and max for range area
    # Fixed (aligned min/max with smoothed mean)
    df_filtered_im['Smoothed_Enrichment_min'] = df_filtered_im['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).min()
    df_filtered_im['Smoothed_Enrichment_max'] = df_filtered_im['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).max()


    # Plot shaded area for range
    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        df_filtered_im['Smoothed_Enrichment_min'],
        df_filtered_im['Smoothed_Enrichment_max'],
        color=color_map.get(immunization, 'black'),
        alpha=0.1
    )

    # Plot
    
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        show_col="show_site",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    #highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    #for _, site_data in highlight_sites.iterrows():
     #   ax.hlines(
     #       y=0,  # set near bottom
     #       xmin=site_data['Spike_AS_Position'] - 0.1,
     #       xmax=site_data['Spike_AS_Position'] + 0.1,
     #       color='black',
     #       linestyle='-',
     #       linewidth=6
     #   )
    print(f"\n[INFO] Immunization: {immunization}")
    print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
    print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
    print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
    print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")

    if immunization == 'wildtype_RBD':
        print("\n[DEBUG] Wildtype enrichment values (360–540):")
        print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
            ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
        ])


    immunization_data[immunization] = df_filtered_im.copy()
    
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.15,
            top=np.nanmax(df_filtered_im['Smoothed_Enrichment']) + 0.1)

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Log10 Antibody binding (Mean)', fontsize=18)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='upper left',
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.05))

# Only label 0.0, 0.3, 0.6
# Major ticks every 0.2 (labels show every 2nd 0.1 tick)
ax.yaxis.set_major_locator(MultipleLocator(0.05))
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.1f}"))

# Minor ticks every 0.1 or 0.05 (no labels)
ax.yaxis.set_minor_locator(MultipleLocator(0.05))

# Label every 2nd minor tick (every 0.04) similarly or skip labeling minor ticks entirely
# If you want to label every 2nd minor tick:


# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot_final_log10.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')

plt.show()


### Calculation and plotting with Median

In [None]:
#Plots for publication

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))

def symmetric_log10(x, epsilon=1e-8, max_cap=1e8):
    x_clipped = np.clip(np.abs(x), epsilon, max_cap)
    return np.sign(x) * np.log10(x_clipped)



pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    immunization_data = {}
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue
    faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
    faint_alpha = 0.3
    faint_linewidth = 1.5

    #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
     #   continue
    #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
     #   continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Normalize by number of barcodes
    num_barcodes = barcode_counts.get(immunization, 0)
    #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()

    print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
    print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
    print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
    
    #df_filtered_im_agg = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
    #'Enrichment_Ratio': 'median'
    #})
   
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log transform (ignore or remove zero/negative values)
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(symmetric_log10)

    
    # Apply smoothing on the log2 transformed values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()

        #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(),
              df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()
    
    # Mark sites to highlight
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

        # Rolling min and max for range area
    # Fixed (aligned min/max with smoothed mean)
    #df_filtered_im['Log10_Enrichment_min'] = df_filtered_im['Log10_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).min()
    #df_filtered_im['Log10_Enrichment_max'] = df_filtered_im['Log10_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).max()
    
    # Calculate rolling mean and rolling std dev for Log10_Enrichment
    df_filtered_im['Rolling_Mean'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Rolling_Std'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).std()
    
    # Define upper and lower bounds for shaded region
    lower_bound = df_filtered_im['Rolling_Mean'] - df_filtered_im['Rolling_Std']
    upper_bound = df_filtered_im['Rolling_Mean'] + df_filtered_im['Rolling_Std']

    # Plot shaded area for range
    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        #df_filtered_im['Log10_Enrichment_min'],
        #df_filtered_im['Log10_Enrichment_max'],
        lower_bound,
        upper_bound,
        color=color_map.get(immunization, 'black'),
        alpha=0.1
    )

    # Plot
    
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        #show_col="show_site",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    #highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    #for _, site_data in highlight_sites.iterrows():
     #   ax.hlines(
     #       y=0,  # set near bottom
     #       xmin=site_data['Spike_AS_Position'] - 0.1,
     #       xmax=site_data['Spike_AS_Position'] + 0.1,
     #       color='black',
     #       linestyle='-',
     #       linewidth=6
     #   )
    print(f"\n[INFO] Immunization: {immunization}")
    print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
    print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
    print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
    print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")

    if immunization == 'wildtype_RBD':
        print("\n[DEBUG] Wildtype enrichment values (360–510):")
        print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
            ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
        ])


    immunization_data[immunization] = df_filtered_im.copy()
    
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.3,
            top=0.1)

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Log10 Antibody binding (Median)', fontsize=18)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']

#plt.legend(
 #   handles, labels,
 #   title="Immunization",
 #   title_fontproperties=font_manager.FontProperties(weight='bold'),
 #   loc='lower left',
 #   fontsize=11,
 #   frameon=False,
  #  handlelength=2,
 #   handleheight=1,
  #  markerscale=1
#)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.05))
ax.set_xlim(right=510)

ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.1f}"))

# Minor ticks every 0.1 or 0.05 (no labels)
ax.yaxis.set_minor_locator(MultipleLocator(0.01))

# Label every 2nd minor tick (every 0.04) similarly or skip labeling minor ticks entirely
# If you want to label every 2nd minor tick:


# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot_final_log10.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')

plt.show()


In [None]:
#DMSpub Split by axis

In [None]:
# Variance (Shaded area) is calculated as the MINMAX of medians across the rolling circle (across 10 AA stretches).

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, FuncFormatter
from matplotlib import gridspec, font_manager
from matplotlib.patches import Patch
%matplotlib inline

# -----------------------------
# Utility function
# -----------------------------
def symmetric_log10(x, epsilon=1e-8, max_cap=1e8):
    x_clipped = np.clip(np.abs(x), epsilon, max_cap)
    return np.sign(x) * np.log10(x_clipped)

# -----------------------------
# Parameters
# -----------------------------
ROLLING_WINDOW = 10
ENRICHMENT_THRESHOLD = 0
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# -----------------------------
# Filter dataset
# -----------------------------
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Target sites
sites_to_show = list(map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +
    list(range(394, 414)) +
    list(range(484, 505))
))

# Count barcodes
barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

# Aggregate by immunization and position
df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str)
).groupby(['immunization', 'Spike_AS_Position'], as_index=False).agg({
    'Enrichment_Ratio': 'median'
})

# -----------------------------
# Color map
# -----------------------------
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}
faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
faint_alpha = 0.3

# -----------------------------
# Create figure with split axes
# -----------------------------
fig = plt.figure(figsize=(10, 6))
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], hspace=0.15)
ax_top = fig.add_subplot(gs[0])
ax_bottom = fig.add_subplot(gs[1], sharex=ax_top)
plt.setp(ax_top.get_xticklabels(), visible=False)

# Track global y-limits
top_y_max_global = -np.inf
bottom_y_min_global = np.inf

# -----------------------------
# Loop over immunizations
# -----------------------------
for immunization in df_filtered_agg['immunization'].unique():
    if immunization in ['Library_ctrl', 'Neutralizing_Ab']:
        continue

    df_filtered_im = df_filtered_agg[df_filtered_agg['immunization'] == immunization].copy()

    # Use bfill/ffill to avoid FutureWarning
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(symmetric_log10)

    # Split positive and negative BEFORE smoothing/interpolation
    df_top = df_filtered_im[df_filtered_im['Log10_Enrichment'] > 0].copy()
    df_bottom = df_filtered_im[df_filtered_im['Log10_Enrichment'] < 0].copy()
    split_dfs = {'top': df_top, 'bottom': df_bottom}

    # Reindex split DataFrames to full Spike_AS_Position range
    full_range = range(df_filtered_im['Spike_AS_Position'].min(),
                       df_filtered_im['Spike_AS_Position'].max() + 1)
    for key in split_dfs:
        df = split_dfs[key]
        if df.empty:
            continue

        # Reindex to full Spike_AS_Position range
        df = df.set_index('Spike_AS_Position').reindex(full_range)

        # Ensure numeric before interpolation
        df = df.apply(pd.to_numeric, errors='coerce')

        # Smooth enrichment
        df['Smoothed_Enrichment'] = df['Log10_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1
        ).mean().interpolate(method='linear', limit_direction='both')

        df = df.dropna(subset=['Smoothed_Enrichment'])

        # Rolling std for shading
        df['Rolling_Min'] = df['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1
        ).min()
        
        df['Rolling_Max'] = df['Log10_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1
        ).max()
        
        #df['Rolling_Std'] = df['Log10_Enrichment'].rolling(
         #   window=ROLLING_WINDOW, center=True, min_periods=1
        #).std().fillna(0)

        df = df.reset_index()
        split_dfs[key] = df
    # --- Update global y-limits using the smoothed values ---
    if not split_dfs['top'].empty:
        top_y_max_global = max(top_y_max_global, split_dfs['top']['Smoothed_Enrichment'].max())
    if not split_dfs['bottom'].empty:
        bottom_y_min_global = min(bottom_y_min_global, split_dfs['bottom']['Smoothed_Enrichment'].min())

 
    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # -----------------------------
    # Plot split
    # -----------------------------
    for key, df in split_dfs.items():
        if df.empty:
            continue

        ax_current = ax_top if key == 'top' else ax_bottom

                # Calculate rolling standard deviation for shading (±1 SD)
        df['Rolling_Std'] = df['Log10_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1
        ).std().fillna(0)

        # Shaded region with ±1 SD around smoothed enrichment
        ax_current.fill_between(
            df['Spike_AS_Position'],
            df['Smoothed_Enrichment'] - df['Rolling_Std'],
            df['Smoothed_Enrichment'] + df['Rolling_Std'],
            color=color_map.get(immunization, 'black'),
            alpha=0.1
        )

        # Shaded region
        #ax_current.fill_between(
        #    df['Spike_AS_Position'],
        #    df['Rolling_Min'],
        #    df['Rolling_Max'],
        #    color=color_map.get(immunization, 'black'),
        #    alpha=0.1
        #)
        #ax_current.fill_between(
        #    df['Spike_AS_Position'],
        #    df['Smoothed_Enrichment'] - df['Rolling_Std'],
        #    df['Smoothed_Enrichment'] + df['Rolling_Std'],
        #color=color_map.get(immunization, 'black'),
        #alpha=0.1
        #)

        # Smoothed line
        line_color = color_map.get(immunization, 'none')
        dmslogo.line.draw_line(
            df,
            x_col="Spike_AS_Position",
            height_col="Smoothed_Enrichment",
            ax=ax_current,
            linewidth=2,
            color=line_color
        )

        if immunization in faint_immunizations:
            ax_current.lines[-1].set_alpha(faint_alpha)

        ax_current.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

# -----------------------------
# Axis formatting
# -----------------------------
ax_top.set_ylabel(" ", fontsize=12)
ax_bottom.set_ylabel(" ", fontsize=12)
ax_bottom.set_xlabel("Spike AA Position", fontsize=12)
ax_bottom.axhline(y=0, color='black', linewidth=2)

ax_top.set_ylim(bottom=0, top=0.08)
ax_bottom.set_ylim(bottom=-0.28, top=0)
ax_bottom.set_xlim(right=510)

# X-axis ticks
ax_bottom.xaxis.set_major_locator(MultipleLocator(10))
ax_bottom.xaxis.set_minor_locator(MultipleLocator(2))
plt.setp(ax_bottom.xaxis.get_majorticklabels(), rotation=0, ha='center')

# Y-axis ticks
# Y-axis ticks with separate spacing
ax_top.yaxis.set_major_locator(MultipleLocator(0.05))   # Top axis major ticks
ax_top.yaxis.set_minor_locator(MultipleLocator(0.01))   # Top axis minor ticks

ax_bottom.yaxis.set_major_locator(MultipleLocator(0.1)) # Bottom axis major ticks
ax_bottom.yaxis.set_minor_locator(MultipleLocator(0.05)) # Bottom axis minor ticks

# Tick parameters
ax_top.tick_params(axis='y', which='major', length=7, width=1.2)
ax_top.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

ax_bottom.tick_params(axis='y', which='major', length=7, width=1.2)
ax_bottom.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

bottom_ticks = ax_bottom.get_yticks()
bottom_labels = ["" if t == 0 else str(round(t, 2)) for t in bottom_ticks]
ax_bottom.set_yticklabels(bottom_labels)


# -----------------------------
# Legend
# -----------------------------
#epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')
handles, labels = ax_top.get_legend_handles_labels()
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']
group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]
handles = group_1_handles + group_2_handles + [epitope_patch]
label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']
#ax_top.legend(handles=handles, labels=labels, fontsize=10, frameon=False)

# -----------------------------
# Save figure
# -----------------------------
plot_file_path = os.path.join("/Users/lucaschlotheuber/Desktop", "split_top_bottom_plot.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')
plt.show()


In [None]:
# Variance (Shaded area) is calculated as the CI95 interval from the line plot Log10

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, FuncFormatter
from matplotlib import gridspec, font_manager
from matplotlib.patches import Patch
%matplotlib inline

# -----------------------------
# Utility function
# -----------------------------
def symmetric_log10(x, epsilon=1e-8, max_cap=1e8):
    x_clipped = np.clip(np.abs(x), epsilon, max_cap)
    return np.sign(x) * np.log10(x_clipped)

# -----------------------------
# Parameters
# -----------------------------
ROLLING_WINDOW = 10
ENRICHMENT_THRESHOLD = 0
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# -----------------------------
# Filter dataset
# -----------------------------
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Target sites
sites_to_show = list(map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +
    list(range(394, 414)) +
    list(range(484, 505))
))

# Count barcodes
barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

# Aggregate by immunization and position
df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str)
).groupby(['immunization', 'Spike_AS_Position'], as_index=False).agg({
    'Enrichment_Ratio': 'median'
})

# -----------------------------
# Color map
# -----------------------------
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}
faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
faint_alpha = 0.6

# -----------------------------
# Create figure with split axes
# -----------------------------
fig = plt.figure(figsize=(10, 6))
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], hspace=0.15)
ax_top = fig.add_subplot(gs[0])
ax_bottom = fig.add_subplot(gs[1], sharex=ax_top)
plt.setp(ax_top.get_xticklabels(), visible=False)

# Track global y-limits
top_y_max_global = -np.inf
bottom_y_min_global = np.inf

# -----------------------------
# Loop over immunizations
# -----------------------------
for immunization in df_filtered_agg['immunization'].unique():
    if immunization in ['Library_ctrl', 'Neutralizing_Ab']:
        continue

    df_filtered_im = df_filtered_agg[df_filtered_agg['immunization'] == immunization].copy()

    # Use bfill/ffill to avoid FutureWarning
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(symmetric_log10)

    # Split positive and negative BEFORE smoothing/interpolation
    df_top = df_filtered_im[df_filtered_im['Log10_Enrichment'] > 0].copy()
    df_bottom = df_filtered_im[df_filtered_im['Log10_Enrichment'] < 0].copy()
    split_dfs = {'top': df_top, 'bottom': df_bottom}

    # Compute 95% CI across barcodes for shading
    df_im_raw = df_filtered[df_filtered['immunization'] == immunization].copy()
    df_im_raw['Log10_Enrichment'] = df_im_raw['Enrichment_Ratio'].apply(symmetric_log10)
    grouped = df_im_raw.groupby('Spike_AS_Position')['Log10_Enrichment']
    std_per_pos = grouped.std()
    count_per_pos = grouped.count()
    sem_per_pos = std_per_pos / np.sqrt(count_per_pos)
    dof = count_per_pos - 1
    dof[dof < 1] = 1
    t_critical = dof.apply(lambda df: stats.t.ppf(0.975, df))
    margin_of_error = t_critical * sem_per_pos

    # Smooth the margin across positions
    smoothed_margin = margin_of_error.rolling(window=ROLLING_WINDOW, center=True, min_periods=1).median()
    smoothed_margin = smoothed_margin.fillna(method='bfill').fillna(method='ffill')

    # Reindex split DataFrames to full Spike_AS_Position range
    full_range = range(df_filtered_im['Spike_AS_Position'].min(),
                       df_filtered_im['Spike_AS_Position'].max() + 1)
    for key in split_dfs:
        df = split_dfs[key]
        if df.empty:
            continue

        df = df.set_index('Spike_AS_Position').reindex(full_range)
        df = df.apply(pd.to_numeric, errors='coerce')

        # Smooth enrichment
        df['Smoothed_Enrichment'] = df['Log10_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1
        ).mean().interpolate(method='linear', limit_direction='both')

        # Add CI margin for shading
        df['Smoothed_CI_Margin'] = smoothed_margin.reindex(df.index).fillna(method='bfill').fillna(method='ffill')

        df = df.dropna(subset=['Smoothed_Enrichment'])
        df = df.reset_index()
        split_dfs[key] = df

    # --- Update global y-limits using the smoothed values ---
    if not split_dfs['top'].empty:
        top_y_max_global = max(top_y_max_global, split_dfs['top']['Smoothed_Enrichment'].max())
    if not split_dfs['bottom'].empty:
        bottom_y_min_global = min(bottom_y_min_global, split_dfs['bottom']['Smoothed_Enrichment'].min())

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # -----------------------------
    # Plot split
    # -----------------------------
    for key, df in split_dfs.items():
        if df.empty:
            continue

        ax_current = ax_top if key == 'top' else ax_bottom

        # --- 95% CI shading ---
        ax_current.fill_between(
            df['Spike_AS_Position'],
            df['Smoothed_Enrichment'] - df['Smoothed_CI_Margin'],
            df['Smoothed_Enrichment'] + df['Smoothed_CI_Margin'],
            color=color_map.get(immunization, 'black'),
            alpha=0.1
        )

        # Smoothed line
        line_color = color_map.get(immunization, 'none')
        dmslogo.line.draw_line(
            df,
            x_col="Spike_AS_Position",
            height_col="Smoothed_Enrichment",
            ax=ax_current,
            linewidth=2,
            color=line_color
        )

        if immunization in faint_immunizations:
            ax_current.lines[-1].set_alpha(faint_alpha)

        ax_current.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)


# -----------------------------
# Axis formatting
# -----------------------------
ax_top.set_ylabel(" ", fontsize=12)
ax_bottom.set_ylabel(" ", fontsize=12)
ax_top.set_xlabel("")
ax_bottom.axhline(y=0, color='black', linewidth=2)

ax_top.set_ylim(bottom=0, top=0.15)
ax_bottom.set_ylim(bottom=-0.3, top=0)
ax_bottom.set_xlim(right=510)

# X-axis ticks
ax_bottom.xaxis.set_major_locator(MultipleLocator(10))  # or 1, 2, whatever spacing you want
ax_bottom.set_xticklabels(
    [str(int(tick)) for tick in ax_bottom.get_xticks()],  # convert tick positions to strings
    rotation=0,
    fontsize=10
)
ax_bottom.tick_params(axis='x', labelbottom=True)
ax_bottom.xaxis.set_minor_locator(MultipleLocator(5))


# Y-axis ticks
# Y-axis ticks with separate spacing
ax_top.yaxis.set_major_locator(MultipleLocator(0.05))   # Top axis major ticks
ax_top.yaxis.set_minor_locator(MultipleLocator(0.01))   # Top axis minor ticks

ax_bottom.yaxis.set_major_locator(MultipleLocator(0.05)) # Bottom axis major ticks
ax_bottom.yaxis.set_minor_locator(MultipleLocator(0.01)) # Bottom axis minor ticks

# Tick parameters
ax_top.tick_params(axis='y', which='major', length=7, width=1.2)
ax_top.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

ax_bottom.tick_params(axis='y', which='major', length=7, width=1.2)
ax_bottom.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

bottom_ticks = ax_bottom.get_yticks()
bottom_labels = ["" if t == 0 else str(round(t, 2)) for t in bottom_ticks]
ax_bottom.set_yticklabels(bottom_labels)

# Hide top plot x-axis tick labels and ticks
ax_top.tick_params(axis='x', labelbottom=False, length=7)

# Bottom axis: show tick labels and tick marks
ax_bottom.tick_params(axis='x', labelbottom=True, length=7)
ax_bottom.set_xlabel("Spike AA Position", fontsize=15)
# Bottom axis: show tick marks and labels
# Bottom axis: show tick marks and labels
ax_bottom.tick_params(axis='x', labelbottom=True, length=7)  # turn labels on and keep tick marks
ax_bottom.set_xlabel("Spike AA Position", fontsize=15)
fig.canvas.draw()

# -----------------------------
# Legend
# -----------------------------
#epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')
handles, labels = ax_top.get_legend_handles_labels()
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']
group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]
handles = group_1_handles + group_2_handles + [epitope_patch]
label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']
#ax_top.legend(handles=handles, labels=labels, fontsize=10, frameon=False)

# -----------------------------
# Save figure
# -----------------------------
plot_file_path = os.path.join("/Users/lucaschlotheuber/Desktop", "AllABsplit_top_bottom_plot.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')
plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, FuncFormatter
from matplotlib import gridspec, font_manager
from matplotlib.patches import Patch
%matplotlib inline

# -----------------------------
# Utility function
# -----------------------------
def symmetric_log10(x, epsilon=1e-8, max_cap=1e8):
    x_clipped = np.clip(np.abs(x), epsilon, max_cap)
    return np.sign(x) * np.log10(x_clipped)

# -----------------------------
# Parameters
# -----------------------------
ROLLING_WINDOW = 15
ENRICHMENT_THRESHOLD = 1
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# -----------------------------
# Filter dataset
# -----------------------------
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Target sites
sites_to_show = list(map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +
    list(range(394, 414)) +
    list(range(484, 505))
))

# Count barcodes
barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

# Aggregate by immunization and position
df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str)
).groupby(['immunization', 'Spike_AS_Position'], as_index=False).agg({
    'Enrichment_Ratio': 'median'
})

# -----------------------------
# Color map
# -----------------------------
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}
faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
faint_alpha = 0.6

# -----------------------------
# Create figure with split axes
# -----------------------------
fig = plt.figure(figsize=(10, 6))
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], hspace=0.15)
ax_top = fig.add_subplot(gs[0])
ax_bottom = fig.add_subplot(gs[1], sharex=ax_top)
plt.setp(ax_top.get_xticklabels(), visible=False)

# Track global y-limits
top_y_max_global = -np.inf
bottom_y_min_global = np.inf

# -----------------------------
# Loop over immunizations
# -----------------------------
for immunization in df_filtered_agg['immunization'].unique():
    if immunization in ['Library_ctrl', 'Neutralizing_Ab']:
        continue

    df_filtered_im = df_filtered_agg[df_filtered_agg['immunization'] == immunization].copy()

    # Use bfill/ffill to avoid FutureWarning
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(symmetric_log10)

    # Split positive and negative BEFORE smoothing/interpolation
    df_top = df_filtered_im[df_filtered_im['Log10_Enrichment'] > 0].copy()
    df_bottom = df_filtered_im[df_filtered_im['Log10_Enrichment'] < 0].copy()
    split_dfs = {'top': df_top, 'bottom': df_bottom}

    # Compute 95% CI across barcodes for shading
    df_im_raw = df_filtered[df_filtered['immunization'] == immunization].copy()
    df_im_raw['Log10_Enrichment'] = df_im_raw['Enrichment_Ratio'].apply(symmetric_log10)
    grouped = df_im_raw.groupby('Spike_AS_Position')['Log10_Enrichment']
    std_per_pos = grouped.std()
    count_per_pos = grouped.count()
    sem_per_pos = std_per_pos / np.sqrt(count_per_pos)
    dof = count_per_pos - 1
    dof[dof < 1] = 1
    t_critical = dof.apply(lambda df: stats.t.ppf(0.975, df))
    margin_of_error = t_critical * sem_per_pos

    # Smooth the margin across positions
    smoothed_margin = margin_of_error.rolling(window=ROLLING_WINDOW, center=True, min_periods=1).median()
    smoothed_margin = smoothed_margin.fillna(method='bfill').fillna(method='ffill')

    # Reindex split DataFrames to full Spike_AS_Position range
    full_range = range(df_filtered_im['Spike_AS_Position'].min(),
                       df_filtered_im['Spike_AS_Position'].max() + 1)
    for key in split_dfs:
        df = split_dfs[key]
        if df.empty:
            continue

        df = df.set_index('Spike_AS_Position').reindex(full_range)
        df = df.apply(pd.to_numeric, errors='coerce')

        # Smooth enrichment
        df['Smoothed_Enrichment'] = df['Log10_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1
        ).mean().interpolate(method='linear', limit_direction='both')

        # Add CI margin for shading
        df['Smoothed_CI_Margin'] = smoothed_margin.reindex(df.index).fillna(method='bfill').fillna(method='ffill')

        df = df.dropna(subset=['Smoothed_Enrichment'])
        df = df.reset_index()
        split_dfs[key] = df

    # --- Update global y-limits using the smoothed values ---
    if not split_dfs['top'].empty:
        top_y_max_global = max(top_y_max_global, split_dfs['top']['Smoothed_Enrichment'].max())
    if not split_dfs['bottom'].empty:
        bottom_y_min_global = min(bottom_y_min_global, split_dfs['bottom']['Smoothed_Enrichment'].min())

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)


    # Plot split
    # -----------------------------
    for key, df in split_dfs.items():
        if df.empty:
            continue
    
        ax_current = ax_top if key == 'top' else ax_bottom
    
        # Compute rolling min and max for shading
        df['Rolling_Min'] = df['Log10_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1
        ).min()
        df['Rolling_Max'] = df['Log10_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1
        ).max()
    
        # --- Shaded min-max area ---
        ax_current.fill_between(
            df['Spike_AS_Position'],
            df['Rolling_Min'],
            df['Rolling_Max'],
            color=color_map.get(immunization, 'black'),
            alpha=0.1
        )
    
        # Smoothed line
        line_color = color_map.get(immunization, 'none')
        dmslogo.line.draw_line(
            df,
            x_col="Spike_AS_Position",
            height_col="Smoothed_Enrichment",
            ax=ax_current,
            linewidth=2,
            color=line_color
        )
    
        if immunization in faint_immunizations:
            ax_current.lines[-1].set_alpha(faint_alpha)
    
        ax_current.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

# -----------------------------
# Axis formatting
# -----------------------------
ax_top.set_ylabel(" ", fontsize=12)
ax_bottom.set_ylabel(" ", fontsize=12)
ax_bottom.set_xlabel("Spike AA Position", fontsize=12)
ax_bottom.axhline(y=0, color='black', linewidth=2)

ax_top.set_ylim(bottom=0, top=0.2)
ax_bottom.set_ylim(bottom=-0.4, top=0)
ax_bottom.set_xlim(right=510)

# X-axis ticks
ax_bottom.xaxis.set_major_locator(MultipleLocator(10))
ax_bottom.xaxis.set_minor_locator(MultipleLocator(2))
plt.setp(ax_bottom.xaxis.get_majorticklabels(), rotation=0, ha='center')

# Y-axis ticks
# Y-axis ticks with separate spacing
ax_top.yaxis.set_major_locator(MultipleLocator(0.05))   # Top axis major ticks
ax_top.yaxis.set_minor_locator(MultipleLocator(0.01))   # Top axis minor ticks

ax_bottom.yaxis.set_major_locator(MultipleLocator(0.05)) # Bottom axis major ticks
ax_bottom.yaxis.set_minor_locator(MultipleLocator(0.01)) # Bottom axis minor ticks

# Tick parameters
ax_top.tick_params(axis='y', which='major', length=7, width=1.2)
ax_top.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

ax_bottom.tick_params(axis='y', which='major', length=7, width=1.2)
ax_bottom.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

bottom_ticks = ax_bottom.get_yticks()
bottom_labels = ["" if t == 0 else str(round(t, 2)) for t in bottom_ticks]
ax_bottom.set_yticklabels(bottom_labels)

# -----------------------------
# Legend
# -----------------------------
#epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')
handles, labels = ax_top.get_legend_handles_labels()
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']
group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]
handles = group_1_handles + group_2_handles + [epitope_patch]
label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']
#ax_top.legend(handles=handles, labels=labels, fontsize=10, frameon=False)

# -----------------------------
# Save figure
# -----------------------------
plot_file_path = os.path.join("/Users/lucaschlotheuber/Desktop", "split_top_bottom_plot.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')
plt.show()


In [None]:
#Calculation and plotting with Mean

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))

def symmetric_log10(x, epsilon=1e-8, max_cap=1e8):
    x_clipped = np.clip(np.abs(x), epsilon, max_cap)
    return np.sign(x) * np.log10(x_clipped)

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position'], as_index=False
).agg({
    'Enrichment_Ratio': 'mean'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    immunization_data = {}
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue
    faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
    faint_alpha = 0.3
    faint_linewidth = 1.5

    #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
     #   continue
    #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
     #   continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Normalize by number of barcodes
    num_barcodes = barcode_counts.get(immunization, 0)
    #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()

    print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
    print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
    print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
    
    #df_filtered_im_agg = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
    #'Enrichment_Ratio': 'median'
    #})
   
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log transform (ignore or remove zero/negative values)
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(symmetric_log10)

    
    # Apply smoothing on the log2 transformed values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()

        #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(),
              df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()
    
    # Mark sites to highlight
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

        # Rolling min and max for range area
    # Fixed (aligned min/max with smoothed mean)
    #df_filtered_im['Log10_Enrichment_min'] = df_filtered_im['Log10_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).min()
    #df_filtered_im['Log10_Enrichment_max'] = df_filtered_im['Log10_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).max()
    
    # Calculate rolling mean and rolling std dev for Log10_Enrichment
    df_filtered_im['Rolling_Mean'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Rolling_Std'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).std()
    
    # Define upper and lower bounds for shaded region
    lower_bound = df_filtered_im['Rolling_Mean'] - df_filtered_im['Rolling_Std']
    upper_bound = df_filtered_im['Rolling_Mean'] + df_filtered_im['Rolling_Std']

    # Plot shaded area for range
    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        #df_filtered_im['Log10_Enrichment_min'],
        #df_filtered_im['Log10_Enrichment_max'],
        lower_bound,
        upper_bound,
        color=color_map.get(immunization, 'black'),
        alpha=0.1
    )

    # Plot
    
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        #show_col="show_site",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    #highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    #for _, site_data in highlight_sites.iterrows():
     #   ax.hlines(
     #       y=0,  # set near bottom
     #       xmin=site_data['Spike_AS_Position'] - 0.1,
     #       xmax=site_data['Spike_AS_Position'] + 0.1,
     #       color='black',
     #       linestyle='-',
     #       linewidth=6
     #   )
    print(f"\n[INFO] Immunization: {immunization}")
    print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
    print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
    print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
    print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")

    if immunization == 'wildtype_RBD':
        print("\n[DEBUG] Wildtype enrichment values (360–540):")
        print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
            ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
        ])


    immunization_data[immunization] = df_filtered_im.copy()
    
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.5,
            top=0.6)

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Log10 Antibody binding (Mean)', fontsize=18)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='lower left',
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.1))

ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.1f}"))

# Minor ticks every 0.1 or 0.05 (no labels)
ax.yaxis.set_minor_locator(MultipleLocator(0.05))

# Label every 2nd minor tick (every 0.04) similarly or skip labeling minor ticks entirely
# If you want to label every 2nd minor tick:


# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot_final_log10.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')

plt.show()


In [None]:
#mean with CI instead of SD as shaded range

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))

def symmetric_log10(x, epsilon=1e-8, max_cap=1e8):
    x_clipped = np.clip(np.abs(x), epsilon, max_cap)
    return np.sign(x) * np.log10(x_clipped)

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position'], as_index=False
).agg({
    'Enrichment_Ratio': 'mean'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    immunization_data = {}
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue
    faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
    faint_alpha = 0.3
    faint_linewidth = 1.5

    #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
     #   continue
    #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
     #   continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Normalize by number of barcodes
    num_barcodes = barcode_counts.get(immunization, 0)
    #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()

    print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
    print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
    print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
    
    #df_filtered_im_agg = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
    #'Enrichment_Ratio': 'median'
    #})
   
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log transform (ignore or remove zero/negative values)
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(symmetric_log10)

    
    # Apply smoothing on the log2 transformed values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()

        #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(),
              df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()
    
    # Mark sites to highlight
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

        # Rolling min and max for range area
    # Fixed (aligned min/max with smoothed mean)
    #df_filtered_im['Log10_Enrichment_min'] = df_filtered_im['Log10_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).min()
    #df_filtered_im['Log10_Enrichment_max'] = df_filtered_im['Log10_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).max()
    
    # Calculate rolling mean and rolling std dev for Log10_Enrichment
    rolling_mean = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    rolling_std = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).std()
    n = ROLLING_WINDOW  # Or count actual points per window if needed
    
    ci_upper = rolling_mean + 1.96 * (rolling_std / np.sqrt(n))
    ci_lower = rolling_mean - 1.96 * (rolling_std / np.sqrt(n))

    # Plot shaded area for range
    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        #df_filtered_im['Log10_Enrichment_min'],
        #df_filtered_im['Log10_Enrichment_max'],
        ci_lower,
        ci_upper,
        color=color_map.get(immunization, 'black'),
        alpha=0.1
    )

    # Plot
    
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        show_col="show_site",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    for _, site_data in highlight_sites.iterrows():
        ax.hlines(
            y=0,  # set near bottom
            xmin=site_data['Spike_AS_Position'] - 0.1,
            xmax=site_data['Spike_AS_Position'] + 0.1,
            color='orange',
            linestyle='-',
            linewidth=4
        )
    print(f"\n[INFO] Immunization: {immunization}")
    print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
    print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
    print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
    print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")

    if immunization == 'wildtype_RBD':
        print("\n[DEBUG] Wildtype enrichment values (360–540):")
        print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
            ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
        ])


    immunization_data[immunization] = df_filtered_im.copy()
    
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.4,
            top=0.8)

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Log10 Antibody binding (Mean)', fontsize=18)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='orange', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='lower left',
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.1))

ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.1f}"))

# Minor ticks every 0.1 or 0.05 (no labels)
ax.yaxis.set_minor_locator(MultipleLocator(0.05))
ax.grid(True, axis='y')   # Turn ON horizontal gridlines only
ax.grid(False, axis='x')  # Turn OFF vertical gridlines

# Label every 2nd minor tick (every 0.04) similarly or skip labeling minor ticks entirely
# If you want to label every 2nd minor tick:


# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot_final_log10.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')

plt.show()


In [None]:
#Mean with CI but rolled

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.scale import ScaleBase
from matplotlib.transforms import Transform
from matplotlib.ticker import FixedLocator, FuncFormatter
import matplotlib.scale as mscale


class PiecewiseLinearScale(mscale.ScaleBase):
    name = 'piecewise_linear'   # <--- Add this line

    def __init__(self, axis, **kwargs):
        super().__init__(axis)
        self.vmin_neg = kwargs.get('vmin_neg', -0.8)
        self.vmax_pos = kwargs.get('vmax_pos', 0.1)

    def get_transform(self):
        return self.PiecewiseLinearTransform(self.vmin_neg, self.vmax_pos)

    def set_default_locators_and_formatters(self, axis):
        # Set tick locations evenly spaced on both sides of zero
        neg_ticks = np.linspace(self.vmin_neg, 0, 5)
        pos_ticks = np.linspace(0, self.vmax_pos, 5)[1:]  # exclude zero duplicate
        axis.set_major_locator(FixedLocator(np.concatenate([neg_ticks, pos_ticks])))
        axis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.3f}"))

    class PiecewiseLinearTransform(Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, vmin_neg, vmax_pos):
            super().__init__()
            self.vmin_neg = vmin_neg
            self.vmax_pos = vmax_pos
            self.len_neg = abs(vmin_neg)
            self.len_pos = vmax_pos

        def transform_non_affine(self, y):
            y = np.array(y)
            res = np.empty_like(y, dtype=float)
            neg_mask = (y <= 0)
            pos_mask = (y > 0)

            # Map negative part linearly to [0, 0.5]
            res[neg_mask] = 0.5 * (y[neg_mask] - self.vmin_neg) / self.len_neg
            # Map positive part linearly to [0.5, 1]
            res[pos_mask] = 0.5 + 0.5 * y[pos_mask] / self.len_pos
            return res

        def inverted(self):
            return PiecewiseLinearScale.InvertedPiecewiseLinearTransform(self.vmin_neg, self.vmax_pos)

    class InvertedPiecewiseLinearTransform(Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, vmin_neg, vmax_pos):
            super().__init__()
            self.vmin_neg = vmin_neg
            self.vmax_pos = vmax_pos
            self.len_neg = abs(vmin_neg)
            self.len_pos = vmax_pos

        def transform_non_affine(self, y):
            y = np.array(y)
            res = np.empty_like(y, dtype=float)
            neg_mask = (y <= 0.5)
            pos_mask = (y > 0.5)

            res[neg_mask] = self.vmin_neg + (y[neg_mask] / 0.5) * self.len_neg
            res[pos_mask] = ((y[pos_mask] - 0.5) / 0.5) * self.len_pos
            return res

        def inverted(self):
            return PiecewiseLinearScale.PiecewiseLinearTransform(self.vmin_neg, self.vmax_pos)

# Register the custom scale with matplotlib
from matplotlib.scale import register_scale
register_scale(PiecewiseLinearScale)



In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))

def symmetric_log10(x, epsilon=1e-8, max_cap=1e8):
    x_clipped = np.clip(np.abs(x), epsilon, max_cap)
    return np.sign(x) * np.log10(x_clipped)

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 10  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 0  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    immunization_data = {}
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue
    faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
    faint_alpha = 0.3
    faint_linewidth = 1.5

    #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
     #   continue
    #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
     #   continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Normalize by number of barcodes
    num_barcodes = barcode_counts.get(immunization, 0)
    #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()

    print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
    print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
    print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
   
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log transform (ignore or remove zero/negative values)
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(symmetric_log10)

    
    #Apply smoothing on the log2 transformed values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()

        #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(),
              df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()
    
    # Mark sites to highlight
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

    # Interpolate final smoothed enrichment values to fill NaNs
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    n = 10  # rolling window for smoothing std dev
    
    # First, compute per-position std dev of Log10_Enrichment from raw replicate data for this immunization
    df_im_raw = df_filtered[df_filtered['immunization'] == immunization].copy()
    
    # Make sure to transform enrichment ratio to log10 scale safely
    df_im_raw['Log10_Enrichment'] = df_im_raw['Enrichment_Ratio'].apply(symmetric_log10)
    print(f"[DEBUG] Raw Log10 enrichment values (first 20 rows):\n{df_im_raw[['Spike_AS_Position', 'Log10_Enrichment']].head(20)}\n")

    # Compute std dev per position across replicates/barcodes
    std_per_pos = df_im_raw.groupby('Spike_AS_Position')['Log10_Enrichment'].std()
    print(f"[DEBUG] Raw std dev per position (first 20 values):\n{std_per_pos.head(20)}\n")

    
    # Smooth the std dev values across positions
    smoothed_std = std_per_pos.rolling(window=n, center=True, min_periods=1).median()
    print(f"[DEBUG] Smoothed std dev (first 20 values):\n{smoothed_std.head(20)}\n")

    
    # Fill NaNs at edges if any
    smoothed_std = smoothed_std.fillna(method='bfill').fillna(method='ffill')
    
    # Join this smoothed std dev back to df_filtered_im by position
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position')
    df_filtered_im['Smoothed_StdDev'] = smoothed_std
    df_filtered_im = df_filtered_im.reset_index()
    
    # Fill any remaining NaNs for the new column if needed
    df_filtered_im['Smoothed_StdDev'] = df_filtered_im['Smoothed_StdDev'].fillna(method='bfill').fillna(method='ffill')
    
    # Use ±1 standard deviation around the smoothed mean
    std_upper = df_filtered_im['Smoothed_Enrichment'] + df_filtered_im['Smoothed_StdDev']
    std_lower = df_filtered_im['Smoothed_Enrichment'] - df_filtered_im['Smoothed_StdDev']
    
    # Plot shaded standard deviation area
    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        std_lower,
        std_upper,
        color=color_map.get(immunization, 'black'),
        alpha=0.2,
        zorder=5
    )

    #df_filtered_im['Smoothed_Enrichment_min'] = df_filtered_im['Smoothed_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).min()
    #df_filtered_im['Smoothed_Enrichment_max'] = df_filtered_im['Smoothed_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).max()
    # First, create a rolling max and min on original (non-log) Enrichment_Ratio
     # Compute rolling 25th and 75th percentiles from the raw Enrichment_Ratio
    # Calculate min and max Enrichment_Ratio for each position

    
    # Use your already calculated smoothed mean as center for CI
    #ci_upper = df_filtered_im['Smoothed_Enrichment'] + 1.96 * (rolling_std / np.sqrt(n))
    #ci_lower = df_filtered_im['Smoothed_Enrichment'] - 1.96 * (rolling_std / np.sqrt(n))
    
    ## Plot shaded CI area
    #ax.fill_between(
    #    df_filtered_im['Spike_AS_Position'],
    #    ci_lower,
    #    ci_upper,
    #    color=color_map.get(immunization, 'black'),
    #    alpha=0.2,
    #    step='mid'
   # )

    # Plot
    
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        #show_col="show_site",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    highlight_mask = df_filtered_im['show_site'].apply(lambda x: 1.0 if x else np.nan) # 1.0 where True, 0.0 where False
    bar_height = 0.002  # adjust the height of the bars here

    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        -0.005,
        highlight_mask * bar_height,
        color='black',
        where=highlight_mask,
        alpha=0.5,
        step='mid'  # aligns fill to the positions nicely
    )

    print(f"\n[INFO] Immunization: {immunization}")
    print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
    print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
    print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
    print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")

    if immunization == 'wildtype_RBD':
        print("\n[DEBUG] Wildtype enrichment values (360–540):")
        print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
            ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
        ])


    immunization_data[immunization] = df_filtered_im.copy()
    
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.4,
            top=0.04)

plt.title('', fontsize=16)
plt.xlabel('Spike AA Position', fontsize=18)
plt.ylabel('Log10 AB binding (Mean) \n All antibodies', fontsize=16)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='orange', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']
fig.subplots_adjust(right=0.8)
plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='center left',          # Legend inside figure but on the left side of bbox_to_anchor
    bbox_to_anchor=(1, 0.5),    # Outside plot area on right, vertically centered
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.05))


# Minor ticks every 0.1 or 0.05 (no labels)
ax.yaxis.set_minor_locator(MultipleLocator(0.01))
ax.grid(False, axis='y')   # Turn ON horizontal gridlines only
ax.grid(False, axis='x')  # Turn OFF vertical gridlines

# Label every 2nd minor tick (every 0.04) similarly or skip labeling minor ticks entirely
# If you want to label every 2nd minor tick:  # wider figure to fit legend on right
# Add a single horizontal gridline at y=0.3
#ax.axhline(y=0.3, color='gray', linestyle='--', linewidth=1, alpha=0.7)
#ax.axhline(y=0.1, color='gray', linestyle='--', linewidth=1, alpha=0.7)

ax.set_xlim(right=510)

# Set limits to your observed data range or desired axis range
vmin_neg = -0.75  # from your code
vmax_pos = 0.6

# Set numeric limits for your data range
ax.set_ylim(vmin_neg, vmax_pos)

# Set your custom scale on the y-axis with these limits
ax.set_yscale('piecewise_linear', vmin_neg=vmin_neg, vmax_pos=vmax_pos)

# Save

plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_log10_Mean_AllAntibodies.png")
fig.set_size_inches(10, 4)
fig.tight_layout()
fig.savefig(plot_file_path, format='png')  # Save after resizing/layout
plt.show()


In [None]:
#For publication Median across all barcodes and 95% confidence interval variance!

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import scipy.stats as stats
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))

def symmetric_log10(x, epsilon=1e-8, max_cap=1e8):
    x_clipped = np.clip(np.abs(x), epsilon, max_cap)
    return np.sign(x) * np.log10(x_clipped)

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 10  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 0  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    immunization_data = {}
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue
    faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
    faint_alpha = 0.5
    faint_linewidth = 1.5

    #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
     #   continue
    #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
     #   continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Normalize by number of barcodes
    num_barcodes = barcode_counts.get(immunization, 0)
    #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()

    print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
    print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
    print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
   
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log transform (ignore or remove zero/negative values)
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(symmetric_log10)

    #Apply smoothing on the log2 transformed values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()

        #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(),
              df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()
    
    # Mark sites to highlight
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

    # Interpolate final smoothed enrichment values to fill NaNs
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')
    n = 10  # rolling window for smoothing
    
    # Filter raw data for immunization
    df_im_raw = df_filtered[df_filtered['immunization'] == immunization].copy()
    
    # Transform enrichment ratio safely
    df_im_raw['Log10_Enrichment'] = df_im_raw['Enrichment_Ratio'].apply(symmetric_log10)
    print(f"[DEBUG] Raw Log10 enrichment values (first 20 rows):\n{df_im_raw[['Spike_AS_Position', 'Log10_Enrichment']].head(20)}\n")
    
    # Calculate std dev and count per position
    grouped = df_im_raw.groupby('Spike_AS_Position')['Log10_Enrichment']
    std_per_pos = grouped.std()
    count_per_pos = grouped.count()
    
    print(f"[DEBUG] Raw std dev per position (first 20 values):\n{std_per_pos.head(20)}\n")
    print(f"[DEBUG] Counts per position (first 20 values):\n{count_per_pos.head(20)}\n")
    
    # Calculate standard error of the mean (SEM)
    sem_per_pos = std_per_pos / np.sqrt(count_per_pos)
    
    # Calculate t-critical value for 95% confidence interval (two-tailed)
    # Degrees of freedom = count - 1, minimum 1 to avoid div by zero
    dof = count_per_pos - 1
    dof[dof < 1] = 1
    t_critical = dof.apply(lambda df: stats.t.ppf(0.975, df))  # 0.975 for two-tailed 95%
    
    # Margin of error = t-critical * SEM
    margin_of_error = t_critical * sem_per_pos
    
    # Smooth the margin of error across positions
    smoothed_margin = margin_of_error.rolling(window=n, center=True, min_periods=1).median()
    print(f"[DEBUG] Smoothed margin of error (first 20 values):\n{smoothed_margin.head(20)}\n")
    
    # Fill NaNs at edges if any
    smoothed_margin = smoothed_margin.fillna(method='bfill').fillna(method='ffill')
    
    # Join smoothed margin of error back to df_filtered_im
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position')
    df_filtered_im['Smoothed_CI_Margin'] = smoothed_margin
    df_filtered_im = df_filtered_im.reset_index()
    
    # Fill any remaining NaNs
    df_filtered_im['Smoothed_CI_Margin'] = df_filtered_im['Smoothed_CI_Margin'].fillna(method='bfill').fillna(method='ffill')
    
    # Use ± margin of error around smoothed mean
    std_upper = df_filtered_im['Smoothed_Enrichment'] + df_filtered_im['Smoothed_CI_Margin']
    std_lower = df_filtered_im['Smoothed_Enrichment'] - df_filtered_im['Smoothed_CI_Margin']
    
    # Plot shaded 95% CI area
    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        std_lower,
        std_upper,
        color=color_map.get(immunization, 'black'),
        alpha=0.1,
        zorder=5
    )

    #df_filtered_im['Smoothed_Enrichment_min'] = df_filtered_im['Smoothed_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).min()
    #df_filtered_im['Smoothed_Enrichment_max'] = df_filtered_im['Smoothed_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).max()
    # First, create a rolling max and min on original (non-log) Enrichment_Ratio
     # Compute rolling 25th and 75th percentiles from the raw Enrichment_Ratio
    # Calculate min and max Enrichment_Ratio for each position

    
    # Use your already calculated smoothed mean as center for CI
    #ci_upper = df_filtered_im['Smoothed_Enrichment'] + 1.96 * (rolling_std / np.sqrt(n))
    #ci_lower = df_filtered_im['Smoothed_Enrichment'] - 1.96 * (rolling_std / np.sqrt(n))
    
    ## Plot shaded CI area
    #ax.fill_between(
    #    df_filtered_im['Spike_AS_Position'],
    #    ci_lower,
    #    ci_upper,
    #    color=color_map.get(immunization, 'black'),
    #    alpha=0.2,
    #    step='mid'
   # )

    # Plot
    line_color = color_map.get(immunization, 'none') 
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        #show_col="show_site",
        ax=ax,
        linewidth=2.5,
        color=line_color
    )
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    #highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    #highlight_mask = df_filtered_im['show_site'].apply(lambda x: 1.0 if x else np.nan) # 1.0 where True, 0.0 where False
    #bar_height = 0.002  # adjust the height of the bars here

    #ax.fill_between(
    #    df_filtered_im['Spike_AS_Position'],
    #    -0.005,
    #    highlight_mask * bar_height,
    #    color='black',
    #    where=highlight_mask,
    #    alpha=0.5,
    #    step='mid'  # aligns fill to the positions nicely
    #)

    print(f"\n[INFO] Immunization: {immunization}")
    print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
    print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
    print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
    print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")

    if immunization == 'wildtype_RBD':
        print("\n[DEBUG] Wildtype enrichment values (360–540):")
        print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
            ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
        ])


    immunization_data[immunization] = df_filtered_im.copy()
    
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.4,
            top=0.04)

plt.title('', fontsize=16)
plt.xlabel('Spike AA Position', fontsize=14)
plt.ylabel('Log10 AB binding (Median) \n All antibodies', fontsize=14)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='orange', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']
fig.subplots_adjust(right=0.8)


# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(15)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.2))


# Minor ticks every 0.1 or 0.05 (no labels)
ax.yaxis.set_minor_locator(MultipleLocator(0.1))
ax.grid(False, axis='y')   # Turn ON horizontal gridlines only
ax.grid(False, axis='x')  # Turn OFF vertical gridlines

# Label every 2nd minor tick (every 0.04) similarly or skip labeling minor ticks entirely
# If you want to label every 2nd minor tick:  # wider figure to fit legend on right
# Add a single horizontal gridline at y=0.3
#ax.axhline(y=0.3, color='gray', linestyle='--', linewidth=1, alpha=0.7)
#ax.axhline(y=0.1, color='gray', linestyle='--', linewidth=1, alpha=0.7)

ax.set_xlim(right=510)

# Set limits to your observed data range or desired axis range
vmin_neg = -1  # from your code
vmax_pos = 1

# Set numeric limits for your data range
ax.set_ylim(vmin_neg, vmax_pos)

# Set your custom scale on the y-axis with these limits
ax.set_yscale('piecewise_linear', vmin_neg=vmin_neg, vmax_pos=vmax_pos)
ax.spines['bottom'].set_visible(False)
ax.axhline(y=0, color="black", linewidth=1.2)

# Save

plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_log10_Median_AllAntibodies.png")
fig.set_size_inches(8, 5)
fig.tight_layout()
fig.savefig(plot_file_path, format='png')  # Save after resizing/layout
plt.show()


In [None]:
#same as before but instead of 95 CI range of smoothed values

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import scipy.stats as stats
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))

def symmetric_log10(x, epsilon=1e-8, max_cap=1e8):
    x_clipped = np.clip(np.abs(x), epsilon, max_cap)
    return np.sign(x) * np.log10(x_clipped)

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 10  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 0  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    immunization_data = {}
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue
    faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
    faint_alpha = 0.5
    faint_linewidth = 1.5

    #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
     #   continue
    #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
     #   continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Normalize by number of barcodes
    num_barcodes = barcode_counts.get(immunization, 0)
    #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()

    print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
    print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
    print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
   
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log transform (ignore or remove zero/negative values)
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(symmetric_log10)

    #Apply smoothing on the log2 transformed values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()

        #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(),
              df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()
    
    # Mark sites to highlight
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

    # Interpolate final smoothed enrichment values to fill NaNs
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')
    n = 10  # rolling window for smoothing
    
    # Filter raw data for immunization
    df_im_raw = df_filtered[df_filtered['immunization'] == immunization].copy()
    
    # Transform enrichment ratio safely
    df_im_raw['Log10_Enrichment'] = df_im_raw['Enrichment_Ratio'].apply(symmetric_log10)
    print(f"[DEBUG] Raw Log10 enrichment values (first 20 rows):\n{df_im_raw[['Spike_AS_Position', 'Log10_Enrichment']].head(20)}\n")
    
    # Calculate std dev and count per position
    grouped = df_im_raw.groupby('Spike_AS_Position')['Log10_Enrichment']
    std_per_pos = grouped.std()
    count_per_pos = grouped.count()
    
    print(f"[DEBUG] Raw std dev per position (first 20 values):\n{std_per_pos.head(20)}\n")
    print(f"[DEBUG] Counts per position (first 20 values):\n{count_per_pos.head(20)}\n")
    
    # Calculate standard error of the mean (SEM)
    sem_per_pos = std_per_pos / np.sqrt(count_per_pos)
    
    # Calculate t-critical value for 95% confidence interval (two-tailed)
    # Degrees of freedom = count - 1, minimum 1 to avoid div by zero
    dof = count_per_pos - 1
    dof[dof < 1] = 1
    t_critical = dof.apply(lambda df: stats.t.ppf(0.975, df))  # 0.975 for two-tailed 95%
    
    # Margin of error = t-critical * SEM
    margin_of_error = t_critical * sem_per_pos
    
    # Smooth the margin of error across positions
    smoothed_margin = margin_of_error.rolling(window=n, center=True, min_periods=1).median()
    print(f"[DEBUG] Smoothed margin of error (first 20 values):\n{smoothed_margin.head(20)}\n")
    
    # Fill NaNs at edges if any
    smoothed_margin = smoothed_margin.fillna(method='bfill').fillna(method='ffill')
    
    # Join smoothed margin of error back to df_filtered_im
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position')
    df_filtered_im['Smoothed_CI_Margin'] = smoothed_margin
    df_filtered_im = df_filtered_im.reset_index()
    
    # Fill any remaining NaNs
    df_filtered_im['Smoothed_CI_Margin'] = df_filtered_im['Smoothed_CI_Margin'].fillna(method='bfill').fillna(method='ffill')
    
    # Calculate rolling min and max of the smoothed enrichment values
    rolling_min = df_filtered_im['Smoothed_Enrichment'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).min()
    rolling_max = df_filtered_im['Smoothed_Enrichment'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).max()
    
    # Fill any NaNs at edges
    rolling_min = rolling_min.fillna(method='bfill').fillna(method='ffill')
    rolling_max = rolling_max.fillna(method='bfill').fillna(method='ffill')
    
    # Plot shaded range area using rolling min/max
    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        rolling_min,
        rolling_max,
        color=color_map.get(immunization, ),
        alpha=0.1,
        zorder=5
    )
    #df_filtered_im['Smoothed_Enrichment_min'] = df_filtered_im['Smoothed_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).min()
    #df_filtered_im['Smoothed_Enrichment_max'] = df_filtered_im['Smoothed_Enrichment'].rolling(
    #window=ROLLING_WINDOW, center=True, min_periods=1).max()
    # First, create a rolling max and min on original (non-log) Enrichment_Ratio
     # Compute rolling 25th and 75th percentiles from the raw Enrichment_Ratio
    # Calculate min and max Enrichment_Ratio for each position

    
    # Use your already calculated smoothed mean as center for CI
    #ci_upper = df_filtered_im['Smoothed_Enrichment'] + 1.96 * (rolling_std / np.sqrt(n))
    #ci_lower = df_filtered_im['Smoothed_Enrichment'] - 1.96 * (rolling_std / np.sqrt(n))
    
    ## Plot shaded CI area
    #ax.fill_between(
    #    df_filtered_im['Spike_AS_Position'],
    #    ci_lower,
    #    ci_upper,
    #    color=color_map.get(immunization, 'black'),
    #    alpha=0.2,
    #    step='mid'
   # )

    # Plot
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        #show_col="show_site",
        ax=ax,
        linewidth=2.5,
        color=color_map.get(immunization, )
    )
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    #highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    #highlight_mask = df_filtered_im['show_site'].apply(lambda x: 1.0 if x else np.nan) # 1.0 where True, 0.0 where False
    #bar_height = 0.002  # adjust the height of the bars here

    #ax.fill_between(
        #df_filtered_im['Spike_AS_Position'],
        #-0.005,
        #highlight_mask * bar_height,
        #color='black',
       # where=highlight_mask,
      #  alpha=0.5,
     #   step='mid'  # aligns fill to the positions nicely
    #)

    print(f"\n[INFO] Immunization: {immunization}")
    print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
    print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
    print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
    print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")

    if immunization == 'wildtype_RBD':
        print("\n[DEBUG] Wildtype enrichment values (360–540):")
        print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
            ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
        ])


    immunization_data[immunization] = df_filtered_im.copy()
    
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.4,
            top=0.04)

plt.title('', fontsize=16)
plt.xlabel('Spike AA Position', fontsize=14)
plt.ylabel('Log10 AB binding (Median) \n All antibodies', fontsize=14)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='orange', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']
fig.subplots_adjust(right=0.8)


# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.2))


# Minor ticks every 0.1 or 0.05 (no labels)
ax.yaxis.set_minor_locator(MultipleLocator(0.1))
ax.grid(False, axis='y')   # Turn ON horizontal gridlines only
ax.grid(False, axis='x')  # Turn OFF vertical gridlines

# Label every 2nd minor tick (every 0.04) similarly or skip labeling minor ticks entirely
# If you want to label every 2nd minor tick:  # wider figure to fit legend on right
# Add a single horizontal gridline at y=0.3
#ax.axhline(y=0.3, color='gray', linestyle='--', linewidth=1, alpha=0.7)
#ax.axhline(y=0.1, color='gray', linestyle='--', linewidth=1, alpha=0.7)

ax.set_xlim(right=510)

# Set limits to your observed data range or desired axis range
vmin_neg = -1  # from your code
vmax_pos = 1

# Set numeric limits for your data range
ax.set_ylim(vmin_neg, vmax_pos)

# Set your custom scale on the y-axis with these limits
#ax.set_yscale('piecewise_linear', vmin_neg=vmin_neg, vmax_pos=vmax_pos)
ax.spines['bottom'].set_visible(False)
ax.axhline(y=0, color="black", linewidth=1.2)

# Save

plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_log10_Median_AllAntibodies.png")
fig.set_size_inches(7, 4)
fig.tight_layout()
fig.savefig(plot_file_path, format='png')  # Save after resizing/layout
plt.show()


In [None]:
#Code for legend printing

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='center left',          # Legend inside figure but on the left side of bbox_to_anchor
    bbox_to_anchor=(1, 0.5),    # Outside plot area on right, vertically centered
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

## CODE FOR generating median +CI/SD plots from all individual droplets

In [None]:
#median and ER>1

In [None]:
# Single-antibody, put together

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))


pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 1)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position', 'barcode'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    immunization_data = {}
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue

    immunization_df = df_filtered[df_filtered['immunization'] == immunization]

    barcodes = list(immunization_df['barcode'].unique())[:5]
    print(f"[INFO] Plotting first 5 barcodes for immunization '{immunization}': {barcodes}")
    
    for barcode in barcodes:
        df_filtered_im = df_filtered_agg.query(
            f'immunization == "{immunization}" and barcode == "{barcode}"'
        ).copy()
    
        if df_filtered_im.empty:
            continue  # Skip empty
    
        # Now each barcode-immunization pair has one row per Spike_AS_Position
        df_filtered_im = df_filtered_im.drop_duplicates(subset='Spike_AS_Position')
        faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
        faint_alpha = 0.3
        faint_linewidth = 1.5
    
        #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
         #   continue
        #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
         #   continue

    
        # Normalize by number of barcodes
        num_barcodes = barcode_counts.get(immunization, 0)
        #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()
    
        print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
        print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
        print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
        
        #df_filtered_im_agg = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        #'Enrichment_Ratio': 'median'
        #})
       
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')
    
        # Safe log transform (ignore or remove zero/negative values)
        df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(
            lambda x: np.log10(x) if x > 0 else np.nan
        )
        
        # Apply smoothing on the log2 transformed values
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1).mean()
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
            method='linear', limit_direction='both')
    
        # Identify high enrichment
        df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
        df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    
        if immunization == 'Neutralizing_Ab':
            df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()
    
            #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')
    
        # Save CSV
        csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
        df_filtered_im.to_csv(csv_file_path, index=False)
    
        # Reindex
        # Ensure uniqueness by aggregating duplicates
        df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
            'Smoothed_Enrichment': 'median',
            'High_Enrichment': 'any',
            'Log10_Enrichment': 'median',
            'Enrichment_Ratio': 'median'
        })
        
        # Now reindex safely
        df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
            range(df_filtered_im['Spike_AS_Position'].min(),
                  df_filtered_im['Spike_AS_Position'].max() + 1)
        ).reset_index()
        
        # Mark sites to highlight
        df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
        df_filtered_im = df_filtered_im.assign(
            show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
        )
    
        print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))
    
        # Rolling min and max for range area
        # Fixed (aligned min/max with smoothed mean)
        df_filtered_im['Smoothed_Enrichment_min'] = df_filtered_im['Smoothed_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1).min()
        df_filtered_im['Smoothed_Enrichment_max'] = df_filtered_im['Smoothed_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1).max()
    
    
        # Plot shaded area for range
        ax.fill_between(
            df_filtered_im['Spike_AS_Position'],
            df_filtered_im['Smoothed_Enrichment_min'],
            df_filtered_im['Smoothed_Enrichment_max'],
            color=color_map.get(immunization, 'black'),
            alpha=0.1
        )
    
        # Plot
        
        dmslogo.line.draw_line(
            df_filtered_im,
            x_col="Spike_AS_Position",
            height_col="Smoothed_Enrichment",
            title="",
            xlabel="Spike AA Position",
            ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
            show_col="show_site",
            ax=ax,
            linewidth=2,
            color=color_map.get(immunization, 'black')
        )
        if immunization in faint_immunizations:
            ax.lines[-1].set_alpha(faint_alpha)
        ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)
    
        # Highlight sites
        #highlight_sites = df_filtered_im[df_filtered_im['show_site']]
        #for _, site_data in highlight_sites.iterrows():
         #   ax.hlines(
         #       y=0,  # set near bottom
         #       xmin=site_data['Spike_AS_Position'] - 0.1,
         #       xmax=site_data['Spike_AS_Position'] + 0.1,
         #       color='black',
         #       linestyle='-',
         #       linewidth=6
         #   )
        print(f"\n[INFO] Immunization: {immunization}")
        print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
        print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
        print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
        print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")
    
        if immunization == 'wildtype_RBD':
            print("\n[DEBUG] Wildtype enrichment values (360–540):")
            print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
                ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
            ])
    
    
        immunization_data[immunization] = df_filtered_im.copy()
    
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.05,
            top=np.nanmax(df_filtered_im['Smoothed_Enrichment']) + 0.8)

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Log10 Antibody binding (Mean)', fontsize=18)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='center left',          # Legend inside figure but on the left side of bbox_to_anchor
    bbox_to_anchor=(1, 0.5),    # Outside plot area on right, vertically centered
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.1))

# Only label 0.0, 0.3, 0.6
label_positions = [-0.1,0,0.1,0.2,0.4, 0.6, 0.8]
ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:.1f}" if any(np.isclose(y, lp, atol=1e-3) for lp in label_positions) else ""))

# Optional minor ticks every 0.05
ax.yaxis.set_minor_locator(MultipleLocator(0.05))

# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line3_plot_final_log10.png")
fig.set_size_inches(14, 6)
fig.tight_layout()
fig.savefig(plot_file_path, format='png')  # Save after resizing/layout
plt.show()

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))


pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 1)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position', 'barcode'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    immunization_data = {}
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue

    immunization_df = df_filtered[df_filtered['immunization'] == immunization]

    barcodes = np.random.choice(immunization_df['barcode'].unique(), size=min(3, immunization_df['barcode'].nunique()), replace=False)
    print(f"[INFO] Randomly selected 5 barcodes for immunization '{immunization}': {list(barcodes)}")

    
    for barcode in barcodes:
        df_filtered_im = df_filtered_agg.query(
            f'immunization == "{immunization}" and barcode == "{barcode}"'
        ).copy()
    
        if df_filtered_im.empty:
            continue  # Skip empty
    
        # Now each barcode-immunization pair has one row per Spike_AS_Position
        df_filtered_im = df_filtered_im.drop_duplicates(subset='Spike_AS_Position')
        faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
        faint_alpha = 0.3
        faint_linewidth = 1.5
    
        #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
         #   continue
        #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
         #   continue

    
        # Normalize by number of barcodes
        num_barcodes = barcode_counts.get(immunization, 0)
        #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()
    
        print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
        print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
        print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
        
        #df_filtered_im_agg = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        #'Enrichment_Ratio': 'median'
        #})
       
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')
    
        # Safe log transform (ignore or remove zero/negative values)
        df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(
            lambda x: np.log10(x) if x > 0 else np.nan
        )
        
        # Apply smoothing on the log2 transformed values
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1).mean()
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
            method='linear', limit_direction='both')
    
        # Identify high enrichment
        df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
        df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    
        if immunization == 'Neutralizing_Ab':
            df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()
    
            #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')
    
        # Save CSV
        csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
        df_filtered_im.to_csv(csv_file_path, index=False)
    
        # Reindex
        # Ensure uniqueness by aggregating duplicates
        df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
            'Smoothed_Enrichment': 'median',
            'High_Enrichment': 'any',
            'Log10_Enrichment': 'median',
            'Enrichment_Ratio': 'median'
        })
        
        # Now reindex safely
        df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
            range(df_filtered_im['Spike_AS_Position'].min(),
                  df_filtered_im['Spike_AS_Position'].max() + 1)
        ).reset_index()

        # Interpolate numeric columns linearly
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')
        
        # For other relevant numeric columns you use for plotting, interpolate or fill as needed:
        df_filtered_im['Log10_Enrichment'] = df_filtered_im['Log10_Enrichment'].interpolate(method='linear', limit_direction='both')
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].interpolate(method='linear', limit_direction='both')
        
        # For boolean columns, fill missing with False (or appropriate default)
        df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
        
        # If you want, you can also do forward/backward fill to be more conservative:
        # df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='ffill').fillna(method='bfill')

        
        # Mark sites to highlight
        df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
        df_filtered_im = df_filtered_im.assign(
            show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
        )
    
        print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))
    
        # Rolling min and max for range area
        # Fixed (aligned min/max with smoothed mean)
        df_filtered_im['Smoothed_Enrichment_min'] = df_filtered_im['Smoothed_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1).min()
        df_filtered_im['Smoothed_Enrichment_max'] = df_filtered_im['Smoothed_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1).max()
    
    
        # Plot shaded area for range
        ax.fill_between(
            df_filtered_im['Spike_AS_Position'],
            df_filtered_im['Smoothed_Enrichment_min'],
            df_filtered_im['Smoothed_Enrichment_max'],
            color=color_map.get(immunization, 'black'),
            alpha=0.1
        )
    
        # Plot
        
        dmslogo.line.draw_line(
            df_filtered_im,
            x_col="Spike_AS_Position",
            height_col="Smoothed_Enrichment",
            title="",
            xlabel="Spike AA Position",
            ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
            show_col="show_site",
            ax=ax,
            linewidth=2,
            color=color_map.get(immunization, 'black')
        )
        if immunization in faint_immunizations:
            ax.lines[-1].set_alpha(faint_alpha)
        ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)
    
        # Highlight sites
        #highlight_sites = df_filtered_im[df_filtered_im['show_site']]
        #for _, site_data in highlight_sites.iterrows():
         #   ax.hlines(
         #       y=0,  # set near bottom
         #       xmin=site_data['Spike_AS_Position'] - 0.1,
         #       xmax=site_data['Spike_AS_Position'] + 0.1,
         #       color='black',
         #       linestyle='-',
         #       linewidth=6
         #   )
        print(f"\n[INFO] Immunization: {immunization}")
        print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
        print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
        print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
        print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")
    
        if immunization == 'wildtype_RBD':
            print("\n[DEBUG] Wildtype enrichment values (360–540):")
            print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
                ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
            ])
    
    
        immunization_data[immunization] = df_filtered_im.copy()
    
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.05,
            top=np.nanmax(df_filtered_im['Smoothed_Enrichment']) + 0.8)

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Log10 Antibody binding (Mean)', fontsize=18)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='center left',          # Legend inside figure but on the left side of bbox_to_anchor
    bbox_to_anchor=(1, 0.5),    # Outside plot area on right, vertically centered
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2, labelsize=14)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out', labelsize=14)

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.1))

# Only label 0.0, 0.3, 0.6
label_positions = [-0.1,0,0.1,0.2,0.4, 0.6, 0.8]
ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:.1f}" if any(np.isclose(y, lp, atol=1e-3) for lp in label_positions) else ""))

# Optional minor ticks every 0.05
ax.yaxis.set_minor_locator(MultipleLocator(0.05))

# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line3_plot_final_log10.png")
fig.set_size_inches(14, 6)
fig.tight_layout()
fig.savefig(plot_file_path, format='png')  # Save after resizing/layout
plt.show()

In [None]:
#Median and ER >1 and With check for degree of interpolation

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))


pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 1)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position', 'barcode'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    immunization_data = {}
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue

    immunization_df = df_filtered[df_filtered['immunization'] == immunization]

    barcodes = np.random.choice(immunization_df['barcode'].unique(), size=min(3, immunization_df['barcode'].nunique()), replace=False)
    print(f"[INFO] Randomly selected 5 barcodes for immunization '{immunization}': {list(barcodes)}")

    
    for barcode in barcodes:
        df_filtered_im = df_filtered_agg.query(
            f'immunization == "{immunization}" and barcode == "{barcode}"'
        ).copy()
    
        if df_filtered_im.empty:
            continue  # Skip empty
    
        # Now each barcode-immunization pair has one row per Spike_AS_Position
        df_filtered_im = df_filtered_im.drop_duplicates(subset='Spike_AS_Position')
        faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
        faint_alpha = 0.3
        faint_linewidth = 1.5
    
        #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
         #   continue
        #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
         #   continue

    
        # Normalize by number of barcodes
        num_barcodes = barcode_counts.get(immunization, 0)
        #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()
    
        print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
        print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
        print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
        
        #df_filtered_im_agg = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        #'Enrichment_Ratio': 'median'
        #})
       
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')
    
        # Safe log transform (ignore or remove zero/negative values)
        df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(
            lambda x: np.log10(x) if x > 0 else np.nan
        )
        
        # Apply smoothing on the log2 transformed values
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1).mean()
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
            method='linear', limit_direction='both')
    
        # Identify high enrichment
        df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
        df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    
        if immunization == 'Neutralizing_Ab':
            df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()
    
            #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')
    
        # Save CSV
        csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
        df_filtered_im.to_csv(csv_file_path, index=False)
    
        # Reindex
        # Ensure uniqueness by aggregating duplicates
        df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
            'Smoothed_Enrichment': 'median',
            'High_Enrichment': 'any',
            'Log10_Enrichment': 'median',
            'Enrichment_Ratio': 'median'
        })

        # Reindex to get full range of positions expected
        full_pos_range = range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
        
        # Count number of actual data points (before interpolation/filling)
        actual_positions = df_filtered_im['Spike_AS_Position'].nunique()
        
        # Total positions in full range
        total_positions = len(full_pos_range)
        
        # Calculate coverage ratio
        coverage_ratio = actual_positions / total_positions
        
        # Define minimum coverage threshold, e.g. 50%
        MIN_COVERAGE = 0.5
        
        if coverage_ratio < MIN_COVERAGE:
            print(f"Skipping barcode {barcode} due to low coverage ({coverage_ratio:.2f})")
            continue  # Skip plotting this barcode

        
        # Now reindex safely
        df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
            range(df_filtered_im['Spike_AS_Position'].min(),
                  df_filtered_im['Spike_AS_Position'].max() + 1)
        ).reset_index()

        # Interpolate numeric columns linearly
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')
        
        # For other relevant numeric columns you use for plotting, interpolate or fill as needed:
        df_filtered_im['Log10_Enrichment'] = df_filtered_im['Log10_Enrichment'].interpolate(method='linear', limit_direction='both')
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].interpolate(method='linear', limit_direction='both')
        
        # For boolean columns, fill missing with False (or appropriate default)
        df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
        
        # If you want, you can also do forward/backward fill to be more conservative:
        # df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='ffill').fillna(method='bfill')

        
        # Mark sites to highlight
        df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
        df_filtered_im = df_filtered_im.assign(
            show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
        )
    
        print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))
    
        # Rolling min and max for range area
        # Fixed (aligned min/max with smoothed mean)
        df_filtered_im['Smoothed_Enrichment_min'] = df_filtered_im['Smoothed_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1).min()
        df_filtered_im['Smoothed_Enrichment_max'] = df_filtered_im['Smoothed_Enrichment'].rolling(
            window=ROLLING_WINDOW, center=True, min_periods=1).max()
    
    
        # Plot shaded area for range
        ax.fill_between(
            df_filtered_im['Spike_AS_Position'],
            df_filtered_im['Smoothed_Enrichment_min'],
            df_filtered_im['Smoothed_Enrichment_max'],
            color=color_map.get(immunization, 'black'),
            alpha=0.1
        )
    
        # Plot
        
        dmslogo.line.draw_line(
            df_filtered_im,
            x_col="Spike_AS_Position",
            height_col="Smoothed_Enrichment",
            title="",
            xlabel="Spike AA Position",
            ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
            show_col="show_site",
            ax=ax,
            linewidth=2,
            color=color_map.get(immunization, 'black')
        )
        if immunization in faint_immunizations:
            ax.lines[-1].set_alpha(faint_alpha)
        ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)
    
        # Highlight sites
        #highlight_sites = df_filtered_im[df_filtered_im['show_site']]
        #for _, site_data in highlight_sites.iterrows():
         #   ax.hlines(
         #       y=0,  # set near bottom
         #       xmin=site_data['Spike_AS_Position'] - 0.1,
         #       xmax=site_data['Spike_AS_Position'] + 0.1,
         #       color='black',
         #       linestyle='-',
         #       linewidth=6
         #   )
        print(f"\n[INFO] Immunization: {immunization}")
        print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
        print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
        print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
        print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")
    
        if immunization == 'wildtype_RBD':
            print("\n[DEBUG] Wildtype enrichment values (360–540):")
            print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
                ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
            ])
    
    
        immunization_data[immunization] = df_filtered_im.copy()
    
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.05,
            top=np.nanmax(df_filtered_im['Smoothed_Enrichment']) + 0.8)

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Log10 Antibody binding (Mean)', fontsize=18)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='center left',          # Legend inside figure but on the left side of bbox_to_anchor
    bbox_to_anchor=(1, 0.5),    # Outside plot area on right, vertically centered
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.1))

# Only label 0.0, 0.3, 0.6
label_positions = [-0.1,0,0.1,0.2,0.4, 0.6, 0.8]
ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:.1f}" if any(np.isclose(y, lp, atol=1e-3) for lp in label_positions) else ""))

# Optional minor ticks every 0.05
ax.yaxis.set_minor_locator(MultipleLocator(0.05))

# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line3_plot_final_log10.png")
fig.set_size_inches(14, 6)
fig.tight_layout()
fig.savefig(plot_file_path, format='png')  # Save after resizing/layout
plt.show()

# To create 2x2 tile with 4 subplots. 

### With interpolation, 3 random barcodes, smoothing

In [None]:
#with marker for which plots are interpolated

# Plotting specific replicates:

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
from matplotlib.patches import Patch

png_dir = r"/Users/lucaschlotheuber/Desktop/DMSnew"

filenames = [
    "replicate1_2x2_log10_median.png",
    "replicate2_2x2_log10_median.png",
    "replicate3_2x2_log10_median.png",
    "replicate4_2x2_log10_median.png"
]

images = []
for fname in filenames:
    path = os.path.join(png_dir, fname)
    if os.path.isfile(path):
        images.append(Image.open(path))
    else:
        print(f"Warning: File not found {path}")

fig, axes = plt.subplots(2, 2, figsize=(12, 6))
axes = axes.flatten()

for i, ax in enumerate(axes):
    if i < len(images):
        ax.imshow(images[i])
        ax.axis('off')
        ax.set_aspect('equal')  # or 'equal'

        ax.margins(0) 
    else:
        ax.axis('off')  # Hide empty plots

# Define legend handles and labels
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}

handles = [
    Patch(color=color_map['Polyclonal_Ab'], label=label_map['Polyclonal_Ab']),
    Patch(color=color_map['wildtype_RBD'], label=label_map['wildtype_RBD']),
    Patch(color=color_map['Mutant_RBD'], label=label_map['Mutant_RBD']),
    Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')
]

# Add legend outside the right side of the figure
fig.legend(
    handles=handles,
    loc='center left',
    bbox_to_anchor=(0.85, 0.5),
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)
fig.subplots_adjust(hspace=0.01)  # reduce space between rows (default ~0.2)

# Reduce vertical space and tighten layout, leave room on right for legend
plt.tight_layout(rect=[0, 0, 0.85, 1])  # leave 15% space on the right for legend

combined_path = os.path.join(png_dir, "combined_2x2_plot.png")
plt.savefig(combined_path, dpi=300, bbox_inches='tight')
plt.show()


### Plotting all replicates

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
from matplotlib.patches import Patch

png_dir = r"/Users/lucaschlotheuber/Desktop/DMSnew"

# Define groups: 5 groups of 4 replicates each
replicate_groups = [list(range(i, i+5)) for i in range(1, 21, 4)]

# Legend setup
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}

legend_handles = [
    Patch(color=color_map['Polyclonal_Ab'], label=label_map['Polyclonal_Ab']),
    Patch(color=color_map['wildtype_RBD'], label=label_map['wildtype_RBD']),
    Patch(color=color_map['Mutant_RBD'], label=label_map['Mutant_RBD']),
    Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')
]

# Loop through each group and make a 2x2 plot
for group_idx, replicate_range in enumerate(replicate_groups, start=1):
    images = []
    filenames = [f"replicate{r}_2x2_log10_median.png" for r in replicate_range]

    for fname in filenames:
        path = os.path.join(png_dir, fname)
        if os.path.isfile(path):
            images.append(Image.open(path))
        else:
            print(f"Warning: File not found: {path}")

    # Create 2x2 grid
    fig, axes = plt.subplots(2, 2, figsize=(12, 6))
    axes = axes.flatten()

    for i, ax in enumerate(axes):
        if i < len(images):
            ax.imshow(images[i])
            ax.axis('off')
            ax.set_aspect('equal')
            ax.margins(0)
        else:
            ax.axis('off')

    # Add legend
    fig.legend(
        handles=legend_handles,
        loc='center left',
        bbox_to_anchor=(0.85, 0.5),
        title="Immunization",
        title_fontproperties=font_manager.FontProperties(weight='bold'),
        fontsize=11,
        frameon=False,
        handlelength=2,
        handleheight=1,
        markerscale=1
    )

    fig.subplots_adjust(hspace=0.01)
    plt.tight_layout(rect=[0, 0, 0.85, 1])

    # Save
    combined_path = os.path.join(png_dir, f"combined_2x2_group{group_idx}.png")
    plt.savefig(combined_path, dpi=300, bbox_inches='tight')

    plt.show()
    print(f"[✓] Saved: {combined_path}")


### For Manuscript

In [None]:
import os
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
from matplotlib.patches import Patch

# Directory where replicate images are stored
png_dir = r"/Users/lucaschlotheuber/Desktop"

# Custom list of replicate numbers
replicates_to_plot = [2, 5, 6, 7]  # Custom 2x2

# Generate corresponding filenames
filenames = [f"replicate{r}_2x2_log10_median.png" for r in replicates_to_plot]

# Load images
images = []
for fname in filenames:
    path = os.path.join(png_dir, fname)
    if os.path.isfile(path):
        images.append(Image.open(path))
    else:
        print(f"Warning: File not found: {path}")

# Define legend
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}

legend_handles = [
    Patch(color=color_map['Polyclonal_Ab'], label=label_map['Polyclonal_Ab']),
    Patch(color=color_map['wildtype_RBD'], label=label_map['wildtype_RBD']),
    Patch(color=color_map['Mutant_RBD'], label=label_map['Mutant_RBD']),
    Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')
]

# Create 2x2 plot
fig, axes = plt.subplots(2, 2, figsize=(12, 6))
axes = axes.flatten()

for i, ax in enumerate(axes):
    if i < len(images):
        ax.imshow(images[i])
        ax.axis('off')
        ax.set_aspect('equal')
        ax.margins(0)
    else:
        ax.axis('off')

# Add legend to the right
fig.legend(
    handles=legend_handles,
    loc='center left',
    bbox_to_anchor=(0.85, 0.5),
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

fig.subplots_adjust(hspace=0.01)
plt.tight_layout(rect=[0, 0, 0.85, 1])

# Save figure
output_path = os.path.join(png_dir, "custom_2x2_replicates_11_12_18_8.png")
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"[✓] Saved: {output_path}")


### Above for publication

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))


pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 10  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 0  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position', 'barcode'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

NUM_REPLICATES = 30

for replicate_idx in range(1, NUM_REPLICATES + 1):
    print(f"Generating replicate {replicate_idx}")

    y_min_global = np.inf
    y_max_global = -np.inf
    
    # Set seed for reproducibility but different sample each replicate
    #np.random.seed(42 + replicate_idx)
    
    # Create new figure and axis per replicate
    fig, ax = plt.subplots(figsize=(8, 6))
    
    # Define color mapping
    color_map = {
        'Polyclonal_Ab': 'darkorange',
        'Neutralizing_Ab': 'red',
        'wildtype_RBD': 'green',
        'Mutant_RBD': 'darkblue'
    }
    
    
    # Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
    for immunization in df_filtered_agg['immunization'].unique():
        immunization_data = {}
        if immunization == 'Library_ctrl':
            continue
    
        if immunization == 'Neutralizing_Ab':
            continue
    
        immunization_df = df_filtered[df_filtered['immunization'] == immunization]
    
        barcodes = np.random.choice(immunization_df['barcode'].unique(), size=min(1, immunization_df['barcode'].nunique()), replace=False)
        print(f"[INFO] Randomly selected 5 barcodes for immunization '{immunization}': {list(barcodes)}")
    
        
        for barcode in barcodes:
            df_filtered_im = df_filtered_agg.query(
                f'immunization == "{immunization}" and barcode == "{barcode}"'
            ).copy()
        
            if df_filtered_im.empty:
                continue  # Skip empty
        
            # Now each barcode-immunization pair has one row per Spike_AS_Position
            df_filtered_im = df_filtered_im.drop_duplicates(subset='Spike_AS_Position')
            faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
            faint_alpha = 0.4
            faint_linewidth = 1.5
        
            #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
             #   continue
            #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
             #   continue
    
        
            # Normalize by number of barcodes
            num_barcodes = barcode_counts.get(immunization, 0)
            #num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()
        
            print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
            print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
            print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))
            
            #df_filtered_im_agg = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
            #'Enrichment_Ratio': 'median'
            #})
           
            df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')
        
            # Safe log transform (ignore or remove zero/negative values)
            df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(
                lambda x: np.log10(x) if x > 0 else np.nan
            )
            
            # Apply smoothing on the log2 transformed values
            df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
                window=ROLLING_WINDOW, center=True, min_periods=1).mean()
            df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
                method='linear', limit_direction='both')
        
            # Identify high enrichment
            df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
            df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
        
            if immunization == 'Neutralizing_Ab':
                df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].bfill().ffill()
        
                #df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')
        
            # Save CSV
            csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
            df_filtered_im.to_csv(csv_file_path, index=False)
        
            # Reindex
            # Ensure uniqueness by aggregating duplicates
            df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
                'Smoothed_Enrichment': 'median',
                'High_Enrichment': 'any',
                'Log10_Enrichment': 'median',
                'Enrichment_Ratio': 'median'
            })
    
            # Reindex to get full range of positions expected
            full_pos_range = range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
            
            # Count number of actual data points (before interpolation/filling)
            actual_positions = df_filtered_im['Spike_AS_Position'].nunique()
            
            # Total positions in full range
            total_positions = len(full_pos_range)
            
            # Calculate coverage ratio
            coverage_ratio = actual_positions / total_positions
            
            # Define minimum coverage threshold, e.g. 50%
            MIN_COVERAGE = 0.5
            
            if coverage_ratio < MIN_COVERAGE:
                print(f"Skipping barcode {barcode} due to low coverage ({coverage_ratio:.2f})")
                continue  # Skip plotting this barcode
    
            
            # Now reindex safely
            df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
                range(df_filtered_im['Spike_AS_Position'].min(),
                      df_filtered_im['Spike_AS_Position'].max() + 1)
            ).reset_index()
    
            # Interpolate numeric columns linearly
            df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')
            
            # For other relevant numeric columns you use for plotting, interpolate or fill as needed:
            df_filtered_im['Log10_Enrichment'] = df_filtered_im['Log10_Enrichment'].interpolate(method='linear', limit_direction='both')
            df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].interpolate(method='linear', limit_direction='both')
            
            # For boolean columns, fill missing with False (or appropriate default)
            df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
            
            # If you want, you can also do forward/backward fill to be more conservative:
            # df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='ffill').fillna(method='bfill')
    
            
            # Mark sites to highlight
            df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
            df_filtered_im = df_filtered_im.assign(
                show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
            )
        
            print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))
        
            n = 10  # rolling window for smoothing
            
            # Filter raw data for immunization
            df_im_raw = df_filtered[df_filtered['immunization'] == immunization].copy()
            
            # Transform enrichment ratio safely
            df_im_raw['Log10_Enrichment'] = df_im_raw['Enrichment_Ratio'].apply(symmetric_log10)
            print(f"[DEBUG] Raw Log10 enrichment values (first 20 rows):\n{df_im_raw[['Spike_AS_Position', 'Log10_Enrichment']].head(20)}\n")
            
            # Calculate std dev and count per position
            grouped = df_im_raw.groupby('Spike_AS_Position')['Log10_Enrichment']
            std_per_pos = grouped.std()
            count_per_pos = grouped.count()
            
            print(f"[DEBUG] Raw std dev per position (first 20 values):\n{std_per_pos.head(20)}\n")
            print(f"[DEBUG] Counts per position (first 20 values):\n{count_per_pos.head(20)}\n")
            
            # Calculate standard error of the mean (SEM)
            sem_per_pos = std_per_pos / np.sqrt(count_per_pos)
            
            # Calculate t-critical value for 95% confidence interval (two-tailed)
            # Degrees of freedom = count - 1, minimum 1 to avoid div by zero
            dof = count_per_pos - 1
            dof[dof < 1] = 1
            t_critical = dof.apply(lambda df: stats.t.ppf(0.975, df))  # 0.975 for two-tailed 95%
            
            # Margin of error = t-critical * SEM
            margin_of_error = t_critical * sem_per_pos
            
            # Smooth the margin of error across positions
            smoothed_margin = margin_of_error.rolling(window=n, center=True, min_periods=1).median()
            print(f"[DEBUG] Smoothed margin of error (first 20 values):\n{smoothed_margin.head(20)}\n")
            
            # Fill NaNs at edges if any
            smoothed_margin = smoothed_margin.fillna(method='bfill').fillna(method='ffill')
            
            # Join smoothed margin of error back to df_filtered_im
            df_filtered_im = df_filtered_im.set_index('Spike_AS_Position')
            df_filtered_im['Smoothed_CI_Margin'] = smoothed_margin
            df_filtered_im = df_filtered_im.reset_index()
            
            # Fill any remaining NaNs
            df_filtered_im['Smoothed_CI_Margin'] = df_filtered_im['Smoothed_CI_Margin'].fillna(method='bfill').fillna(method='ffill')
            
            # Use ± margin of error around smoothed mean
            std_upper = df_filtered_im['Smoothed_Enrichment'] + df_filtered_im['Smoothed_CI_Margin']
            std_lower = df_filtered_im['Smoothed_Enrichment'] - df_filtered_im['Smoothed_CI_Margin']
            
            # Plot shaded 95% CI area
            ax.fill_between(
                df_filtered_im['Spike_AS_Position'],
                std_lower,
                std_upper,
                color=color_map.get(immunization, 'black'),
                alpha=0.1,
                zorder=5
            )


            current_min = df_filtered_im['Smoothed_Enrichment'].min()
            current_max = df_filtered_im['Smoothed_Enrichment'].max()
            if current_min < y_min_global:
                y_min_global = current_min
            if current_max > y_max_global:
                y_max_global = current_max
        
            # Plot
            
            dmslogo.line.draw_line(
                df_filtered_im,
                x_col="Spike_AS_Position",
                height_col="Smoothed_Enrichment",
                title="",
                xlabel="Spike AA Position",
                ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
                #show_col="show_site",
                ax=ax,
                linewidth=3,
                color=color_map.get(immunization, 'black')
            )
            if immunization in faint_immunizations:
                ax.lines[-1].set_alpha(faint_alpha)
            ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)
        
            #Highlight sites
    
            print(f"\n[INFO] Immunization: {immunization}")
            print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
            print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
            print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
            print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")
        
            if immunization == 'wildtype_RBD':
                print("\n[DEBUG] Wildtype enrichment values (360–540):")
                print(df_filtered_im[df_filtered_im['Spike_AS_Position'].between(360, 540)][
                    ['Spike_AS_Position', 'Enrichment_Ratio', 'Log10_Enrichment', 'Smoothed_Enrichment']
                ])
        
        
            immunization_data[immunization] = df_filtered_im.copy()

    # Get y-axis limits
    y_min, y_max = ax.get_ylim()
    highlight_y = y_min + 0.05 * (y_max - y_min)  # 5% above bottom
    
    highlight_positions = [int(pos) for pos in sites_to_show]
    
    for pos in highlight_positions:
        ax.vlines(
            x=pos,
            ymin=-0.015,
            ymax=0,
            color='black',
            linewidth=3.5,
            alpha=0.9,
            zorder=10
        )
    # Y-axis limit (adjust if needed for log scale)
    #ax.set_ylim(bottom=-0.5,
    #            top=np.nanmax(df_filtered_im['Smoothed_Enrichment']) + 0.2)
    
    plt.title('', fontsize=14)
    plt.xlabel('Spike AA Position', fontsize=16)
    plt.ylabel('Log10 AB binding (Median)', fontsize=18)
    
    from matplotlib.patches import Patch
    
    # Create a square patch for the legend entry
    epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')
    
    
    handles, labels = ax.get_legend_handles_labels()
    # Legend (subset only to certain labels)
    group_1_labels = ['Polyclonal_Ab']
    group_2_labels = ['wildtype_RBD', 'Mutant_RBD']
    
    group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
    group_2_handles = [handles[labels.index(label)] for label in group_2_labels]
    
    # Combine handles and labels with a custom order
    handles = group_1_handles + group_2_handles + [epitope_patch]
    
    label_map = {
        'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
        'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
        'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
    }
    labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']

    # X-axis ticks
    # Set x-ticks every 10 positions
    # Major ticks every 10 positions
    major_locator = MultipleLocator(10)
    ax.xaxis.set_major_locator(major_locator)
    ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))
    
    # Minor ticks every 2 positions (sub-ticks without labels)
    minor_locator = MultipleLocator(2)
    ax.xaxis.set_minor_locator(minor_locator)
    
    # Enable minor ticks
    ax.tick_params(axis='x', which='major', length=7, width=1.2)
    ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')
    
    ax.tick_params(axis='y', which='major', length=7, width=1.2)
    ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')
    
    # Optionally rotate x labels
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
    # Tick marks every 0.1
    ax.yaxis.set_major_locator(MultipleLocator(0.05))
    
    # Only label 0.0, 0.3, 0.6
    label_positions = [-0.6,-0.5,-0.4,-0.3,-0.2,-0.1,0,0.1,0.2, 0.3, 0.4]
    ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:.1f}" if any(np.isclose(y, lp, atol=1e-3) for lp in label_positions) else ""))

    padding = 0.01
    ax.set_ylim(bottom=y_min_global - padding, top=y_max_global + padding)
    
    # Optional minor ticks every 0.05
    ax.yaxis.set_minor_locator(MultipleLocator(0.05))
    ax.grid(True, axis='y')   # Turn ON horizontal gridlines only
    ax.grid(False, axis='x')  # Turn OFF vertical gridlines
    ax.set_xlim(right=510)
    # Save
    plot_file_path = os.path.join(
        r"/Users/lucaschlotheuber/Desktop",f"replicate{replicate_idx}_2x2_log10_median.png"
    )
    fig.set_size_inches(8, 5) 
    fig.tight_layout()
    fig.savefig(plot_file_path, format='png')  # Save after resizing/layout
    plt.show()

## Individual Log10 Median antibody binding fraction and escape fraction split plots

In [None]:
#with split logic

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, MultipleLocator, FuncFormatter
from matplotlib import font_manager
import matplotlib.gridspec as gridspec
from scipy import stats
from matplotlib.patches import Patch
%matplotlib inline

print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 10  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 0  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +
    list(range(394, 414)) +
    list(range(484, 505))
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    #show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)

# Aggregate filtered data by position and immunization
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position', 'barcode'], as_index=False
).agg({'Enrichment_Ratio': 'median'})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

NUM_REPLICATES = 30

for replicate_idx in range(1, NUM_REPLICATES + 1):
    print(f"Generating replicate {replicate_idx}")

    y_min_global = np.inf
    y_max_global = -np.inf

    # Create stacked subplots
    fig = plt.figure(figsize=(7, 5))
    gs = gridspec.GridSpec(2, 1, height_ratios=[1, 1], hspace=0.15)
    ax_top = fig.add_subplot(gs[0])
    ax_bottom = fig.add_subplot(gs[1], sharex=ax_top)
    plt.setp(ax_top.get_xticklabels(), visible=False)

    color_map = {
        'Polyclonal_Ab': 'darkorange',
        'Neutralizing_Ab': 'red',
        'wildtype_RBD': 'green',
        'Mutant_RBD': 'darkblue'
    }
    
    top_y_min_global = np.inf
    top_y_max_global = -np.inf
    bottom_y_min_global = np.inf
    bottom_y_max_global = -np.inf


    for immunization in df_filtered_agg['immunization'].unique():
        immunization_data = {}
        if immunization in ['Library_ctrl']:
            continue

        immunization_df = df_filtered[df_filtered['immunization'] == immunization]

        barcodes = np.random.choice(
            immunization_df['barcode'].unique(),
            size=min(1, immunization_df['barcode'].nunique()),
            replace=False
        )
        print(f"[INFO] Randomly selected barcodes for '{immunization}': {list(barcodes)}")

        for barcode in barcodes:
            df_filtered_im = df_filtered_agg.query(
                f'immunization == "{immunization}" and barcode == "{barcode}"'
            ).copy()

            if df_filtered_im.empty:
                continue

            df_filtered_im = df_filtered_im.drop_duplicates(subset='Spike_AS_Position')
            faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
            faint_alpha = 0.4
            faint_linewidth = 1.5

            df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')
            df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(
                lambda x: np.log10(x) if x > 0 else np.nan
            )

            # Split into top/bottom
            df_top = df_filtered_im[df_filtered_im['Log10_Enrichment'] > 0].copy()
            df_bottom = df_filtered_im[df_filtered_im['Log10_Enrichment'] < 0].copy()
            split_dfs = {'top': df_top, 'bottom': df_bottom}



            for key, df in split_dfs.items():
                if df.empty:
                    continue

                # Smooth enrichment
                df['Smoothed_Enrichment'] = df['Log10_Enrichment'].rolling(
                    window=ROLLING_WINDOW, center=True, min_periods=1
                ).mean().interpolate(method='linear', limit_direction='both')

                # Flag high enrichment
                df['High_Enrichment'] = (df['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD).fillna(False).astype(bool)

                # Aggregate duplicates per Spike_AS_Position
                df_agg = df.groupby('Spike_AS_Position', as_index=False).agg({
                    'Smoothed_Enrichment': 'median',
                    'High_Enrichment': 'any',
                    'Log10_Enrichment': 'median',
                    'Enrichment_Ratio': 'median'
                })

                # Reindex to full Spike position range
                full_range = range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
                df_agg = df_agg.set_index('Spike_AS_Position').reindex(full_range).interpolate(method='linear', limit_direction='both').reset_index()

                split_dfs[key] = df_agg

                       # Update global min/max
            if not df_top.empty:
                top_y_min_global = min(top_y_min_global, df_top['Smoothed_Enrichment'].min())
                top_y_max_global = max(top_y_max_global, df_top['Smoothed_Enrichment'].max())
    
            if not df_bottom.empty:
                bottom_y_min_global = min(bottom_y_min_global, df_bottom['Smoothed_Enrichment'].min())
                bottom_y_max_global = max(bottom_y_max_global, df_bottom['Smoothed_Enrichment'].max())

            df_top = split_dfs['top']
            df_bottom = split_dfs['bottom']

            # Save CSV (original df)
            csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
            df_filtered_im.to_csv(csv_file_path, index=False)

            # Calculate confidence intervals
            df_im_raw = df_filtered[df_filtered['immunization'] == immunization].copy()
            df_im_raw['Log10_Enrichment'] = df_im_raw['Enrichment_Ratio'].apply(lambda x: np.log10(x) if x > 0 else np.nan)
            grouped = df_im_raw.groupby('Spike_AS_Position')['Log10_Enrichment']
            std_per_pos = grouped.std()
            count_per_pos = grouped.count()
            sem_per_pos = std_per_pos / np.sqrt(count_per_pos)
            dof = count_per_pos - 1
            dof[dof < 1] = 1
            t_critical = dof.apply(lambda df: stats.t.ppf(0.975, df))
            margin_of_error = t_critical * sem_per_pos
            smoothed_margin = margin_of_error.rolling(window=ROLLING_WINDOW, center=True, min_periods=1).median().fillna(method='bfill').fillna(method='ffill')

            df_filtered_im = df_filtered_im.set_index('Spike_AS_Position')
            df_filtered_im['Smoothed_CI_Margin'] = smoothed_margin
            df_filtered_im = df_filtered_im.reset_index().fillna(method='bfill').fillna(method='ffill')

            # Plot top/bottom separately
            for df, ax in zip([df_top, df_bottom], [ax_top, ax_bottom]):
                if df.empty:
                    continue

                ax.fill_between(
                    df['Spike_AS_Position'],
                    df['Smoothed_Enrichment'] - df_filtered_im['Smoothed_CI_Margin'],
                    df['Smoothed_Enrichment'] + df_filtered_im['Smoothed_CI_Margin'],
                    color=color_map.get(immunization, 'black'),
                    alpha=0.07,
                    zorder=5
                )

                # Use your existing plotting function
                dmslogo.line.draw_line(
                    df,
                    x_col='Spike_AS_Position',
                    height_col='Smoothed_Enrichment',
                    ax=ax,
                    linewidth=3.5,
                    color=color_map.get(immunization, 'black')
                )
            # Top axis limits
 
            # Top: arrow pointing up
            ax_top.set_ylabel(
                "Increased by IgG-binding\n(Enriched Fraction) →",
                fontsize=11,
                rotation=90,
                labelpad=15,  # adjust distance from axis
                ha='left',
                y=0.02
            )
            
            # Bottom: arrow pointing down
            ax_bottom.set_ylabel(
                "Decreased by IgG-binding\n← (Escape Fraction)",
                fontsize=11,
                rotation=90,
                labelpad=4,
                ha='left',
                y=0
            )
            ax_bottom.set_xlabel("Spike AA Position", fontsize=12)

            if immunization in faint_immunizations:
                ax_top.lines[-1].set_alpha(faint_alpha)
            ax_top.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

            immunization_data[immunization] = df_filtered_im.copy()


    # --- AFTER plotting all barcodes for all immunizations ---
    top_padding = 0.15
    bottom_padding = 0.15
    
    # Top always starts at 0
    if np.isfinite(top_y_max_global):
        ax_top.set_ylim(bottom=0, top=0.5)
    
    # Bottom always ends at 0
    if np.isfinite(bottom_y_min_global):
        ax_bottom.set_ylim(bottom=bottom_y_min_global - bottom_padding, top=0)


    # Highlight positions
    #highlight_positions = [int(pos) for pos in sites_to_show]
    #for ax in [ax_top, ax_bottom]:
     #  for pos in highlight_positions:
      #      ax.vlines(x=pos, ymin=-0.015, ymax=0, color='black', linewidth=3.5, alpha=0.9, zorder=10)
    ax_bottom.axhline(y=0, color='black', linewidth=2) 
    # X and Y axis formatting
    for ax in [ax_top]:
        ax.xaxis.set_major_locator(MultipleLocator(10))
        ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{int(x)}"))
        ax.tick_params(axis='y', which='major', length=7, width=1.2)
        ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')
        ax.yaxis.set_major_locator(MultipleLocator(0.1))
        ax.grid(True, axis='y')
        ax.grid(False, axis='x')

    for ax in [ax_bottom]:
        ax.xaxis.set_major_locator(MultipleLocator(10))
        ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{int(x)}"))
        ax.xaxis.set_minor_locator(MultipleLocator(2))
        ax.tick_params(axis='x', which='major', length=7, width=1.2)
        ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')
        ax.tick_params(axis='y', which='major', length=7, width=1.2)
        ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')
        ax.yaxis.set_major_locator(MultipleLocator(0.1))
        ax.yaxis.set_minor_locator(MultipleLocator(0.05))
        ax.grid(True, axis='y')
        ax.grid(False, axis='x')

    plt.setp(ax_top.xaxis.get_majorticklabels(), visible=False)
    plt.setp(ax_bottom.xaxis.get_majorticklabels(), rotation=0, ha='center')
    # Only show y=0 tick label on bottom axis
    for tick in ax_top.get_yticklabels():
        if tick.get_text() == '0':
            tick.set_visible(False)

    top_ticks = ax_top.get_yticks()
    # Hide only 0 for top axis
    top_tick_labels = ["" if t == 0 else str(round(t, 2)) for t in top_ticks]
    ax_top.set_yticklabels(top_tick_labels)
        
    ax_bottom.set_xlim(right=510)
    ax_top.set_xlabel("")   # Remove any top axis label
    ax_bottom.set_xlabel("Spike AA Position", fontsize=12)  # Only bottom
    plt.setp(ax_top.get_xticklabels(), visible=False)

    # Legend
    #epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')
    #handles, labels = ax_top.get_legend_handles_labels()
    #group_1_labels = ['Polyclonal_Ab']
    #group_2_labels = ['wildtype_RBD', 'Mutant_RBD']
    #group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
    #group_2_handles = [handles[labels.index(label)] for label in group_2_labels]
    #handles = group_1_handles + group_2_handles + [epitope_patch]
    #label_map = {
    #    'Polyclonal_Ab': 'antiRBD pAB',
    #    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    #    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
    #}
    #labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']
    #ax_top.legend(handles=handles, labels=labels, fontsize=10)

    # Save
    plot_file_path = os.path.join(
        r"/Users/lucaschlotheuber/Desktop/DMSnew", f"replicate{replicate_idx}_2x2_log10_median.png"
    )
    fig.subplots_adjust(left=0.16) 
    fig.tight_layout()
    fig.savefig(plot_file_path, format='png')
    plt.show()


## Repeat for 5 other random barcodes/dropets:

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
from itertools import chain
%matplotlib inline
from matplotlib.ticker import MultipleLocator, FuncFormatter

import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))


pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 0  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

# Aggregate enrichment ratio by position
df_filtered = df_filtered[df_filtered['Amino_Acid'] != "*"].copy()

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

barcode_counts = df_filtered.groupby('immunization')['barcode'].nunique().to_dict()

df_filtered_agg = df_filtered.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)

fig, axs = plt.subplots(2, 2, figsize=(15, 12))
axs = axs.flatten()


# Aggregate filtered data by position and immunization before plotting
df_filtered_agg = df_filtered_agg.groupby(
    ['immunization', 'Spike_AS_Position', 'barcode'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'
})

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}


# Filter immunizations once
immunization_data = {}
immunizations = [imm for imm in df_filtered_agg['immunization'].unique()
                 if imm not in ('Library_ctrl', 'Neutralizing_Ab')][:4]



for i, immunization in enumerate(immunizations[:4]):  # Limit to 4 immunizations for 2x2 grid
    ax = axs[i]
    immunization_df = df_filtered_agg[df_filtered_agg['immunization'] == immunization]

    # Get barcodes for this immunization
    barcodes = immunization_df['barcode'].unique()
    if len(barcodes) == 0:
        print(f"[WARN] No barcodes for {immunization}, skipping.")
        continue

    # Pick one random barcode per immunization
    barcode = np.random.choice(barcodes)
    print(f"[INFO] Randomly selected barcode for immunization '{immunization}': {barcode}")

    df_filtered_im = immunization_df[immunization_df['barcode'] == barcode].copy()
    if df_filtered_im.empty:
        print(f"[WARN] No data for barcode {barcode} of immunization {immunization}")
        continue


    df_filtered_im = df_filtered_im.drop_duplicates(subset='Spike_AS_Position')

    # Normalize by number of barcodes (for potential later use)
    num_barcodes = barcode_counts.get(immunization, 0)

    print(f"Spike_AS_Position sample before aggregation:\n{df_filtered_im['Spike_AS_Position'].head()}")
    print(f"\nRows for immunization {immunization}: {df_filtered_im.shape[0]}")
    print(df_filtered_im[['Spike_AS_Position', 'Enrichment_Ratio']].head(10))

    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log transform
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(
        lambda x: np.log10(x) if x > 0 else np.nan
    )

    # Smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both'
    )

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Aggregate to one value per position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Smoothed_Enrichment': 'median',
        'High_Enrichment': 'any',
        'Log10_Enrichment': 'median',
        'Enrichment_Ratio': 'median'
    })

    # Coverage check
    full_pos_range = range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
    actual_positions = df_filtered_im['Spike_AS_Position'].nunique()
    total_positions = len(full_pos_range)
    coverage_ratio = actual_positions / total_positions

    if coverage_ratio < 0.5:
        print(f"Skipping barcode {barcode} due to low coverage ({coverage_ratio:.2f})")
        continue

    # Reindex and interpolate
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(full_pos_range).reset_index()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Log10_Enrichment'].interpolate(method='linear', limit_direction='both')
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].interpolate(method='linear', limit_direction='both')
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Mark sites to highlight
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    # Rolling min/max for shading
    df_filtered_im['Smoothed_Enrichment_min'] = df_filtered_im['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).min()
    df_filtered_im['Smoothed_Enrichment_max'] = df_filtered_im['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).max()

    # Shaded background
    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        df_filtered_im['Smoothed_Enrichment_min'],
        df_filtered_im['Smoothed_Enrichment_max'],
        color=color_map.get(immunization, 'black'),
        alpha=0.1
    )

    # Plot line
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )

    if immunization in {'Polyclonal_Ab', 'Neutralizing_Ab'}:
        ax.lines[-1].set_alpha(0.3)

    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)
    ax.set_title(f"{immunization} - Barcode {barcode}", fontsize=12)
    ax.set_xlabel('Spike AA Position')
    ax.set_ylabel('Log10 Antibody binding')
    ax.set_xlim(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max())

    # Optional debug info
    print(f"\n[INFO] Immunization: {immunization}")
    print(f"→ Positions: {df_filtered_im['Spike_AS_Position'].min()} to {df_filtered_im['Spike_AS_Position'].max()}")
    print(f"→ Mean Enrichment: {df_filtered_im['Enrichment_Ratio'].mean():.3f}")
    print(f"→ Mean Log10 Enrichment: {df_filtered_im['Log10_Enrichment'].mean():.3f}")
    print(f"→ Mean Smoothed Log10 Enrichment: {df_filtered_im['Smoothed_Enrichment'].mean():.3f}")

    immunization_data[immunization] = df_filtered_im.copy()

# Get y-axis limits
y_min, y_max = ax.get_ylim()
highlight_y = y_min + 0.05 * (y_max - y_min)  # 5% above bottom

highlight_positions = [int(pos) for pos in sites_to_show]

ax.hlines(
    y=highlight_y,
    xmin=min(highlight_positions),
    xmax=max(highlight_positions),
    color='black',
    linewidth=6,  # thick line
    alpha=0.8,
    zorder=10
)
# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.5,
            top=0.5)

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Log10 Antibody binding (Mean)', fontsize=18)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='black', edgecolor='black', label='SARS-CoV-2 Spike antibody epitopes')

all_handles_labels = [ax.get_legend_handles_labels() for ax in axs]
all_handles, all_labels = zip(*all_handles_labels)

# Flatten lists
handles = list(chain.from_iterable(all_handles))
labels = list(chain.from_iterable(all_labels))


print("[DEBUG] Available legend labels:", labels)
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['SARS-CoV-2 Spike antibody epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='center left',          # Legend inside figure but on the left side of bbox_to_anchor
    bbox_to_anchor=(1, 0.5),    # Outside plot area on right, vertically centered
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.1))

# Only label 0.0, 0.3, 0.6
label_positions = [-0.6,-0.4,-0.2,0,0.2,0.4, 0.6, 0.8]
ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:.1f}" if any(np.isclose(y, lp, atol=1e-3) for lp in label_positions) else ""))

# Optional minor ticks every 0.05
ax.yaxis.set_minor_locator(MultipleLocator(0.05))

# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line3_plot_final_log10.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')  # Save after resizing/layout
plt.show()

### Statistics to confirm binding ratios and transformation: Tables with raw-data, log10, Gaussian-smoothed

In [None]:
summary_rows = []

target_positions = list(range(360, 541, 15))

for immunization, df_im in immunization_data.items():
    print(f"\n[INFO] Generating table values for: {immunization}")
    
    for target_pos in target_positions:
        row_match = df_im[df_im['Spike_AS_Position'] == target_pos]
        if row_match.empty:
            print(f"  → Skipping position {target_pos} (not in data)")
            continue

        row = row_match.iloc[0]
        summary_rows.append({
            'Immunization': immunization,
            'Target_Spike_Position': target_pos,
            'Actual_Position_Used': row['Spike_AS_Position'],
            'Enrichment_Ratio': round(row['Enrichment_Ratio'], 3),
            'Log10_Enrichment': round(row['Log10_Enrichment'], 3),
            'Smoothed_Enrichment': round(row['Smoothed_Enrichment'], 3)
        })


In [None]:
summary_df = pd.DataFrame(summary_rows)


In [None]:
summary_rows = []
target_positions = list(range(360, 541, 15))

for immunization, df_im in immunization_data.items():
    print(f"\n[INFO] Generating table values for: {immunization}")
    
    for target_pos in target_positions:
        row_match = df_im[df_im['Spike_AS_Position'] == target_pos]
        if row_match.empty:
            print(f"  → Skipping position {target_pos} (not in data)")
            continue

        row = row_match.iloc[0]
        summary_rows.append({
            'Immunization': immunization,
            'Target_Spike_Position': target_pos,
            'Actual_Position_Used': int(row['Spike_AS_Position']),
            'Enrichment_Ratio': round(row['Enrichment_Ratio'], 3),
            'Log10_Enrichment': round(row['Log10_Enrichment'], 3),
            'Smoothed_Enrichment': round(row['Smoothed_Enrichment'], 3)
        })

summary_df = pd.DataFrame(summary_rows)

print("\n✅ Final Summary Table:")
print(summary_df.head(10))


In [None]:
print("[DEBUG] immunization_data keys:", list(immunization_data.keys()))


In [None]:
from IPython.display import display
import pandas as pd
import numpy as np

pd.set_option('display.max_rows', None)            # Show all rows
pd.set_option('display.max_columns', None)         # Show all columns
pd.set_option('display.max_colwidth', None)        # Show full content in each cell (no truncation)
pd.set_option('display.width', None)               # Auto detect width of console for wrapping
pd.set_option('display.expand_frame_repr', False)  # Prevent wrapping to multiple lines

# Define which positions you want in the table
target_positions = list(range(360, 541, 1))

summary_rows = []

def shorten_raw_values(vals):
    if isinstance(vals, list):
        if len(vals) > 100:
            return str(vals[:100])[:-1] + ", ...]"
        else:
            return str(vals)
    return vals

for immunization, df_im in immunization_data.items():
    print(f"[INFO] Adding table rows for: {immunization}")
    
    for pos in target_positions:
        row_match = df_im[df_im['Spike_AS_Position'] == pos]
        if row_match.empty:
            continue

        row = row_match.iloc[0]
        
        # Raw enrichment values: pull from original (unsmoothed) df_total
        raw_vals = df_total[
            (df_total['immunization'] == immunization) &
            (df_total['Spike_AS_Position'] == pos) &
            (df_total['Amino_Acid'] != "*")
        ]['Enrichment_Ratio'].clip(lower=1e-3).tolist()

        summary_rows.append({
            'Immunization': immunization,
            'Target_Spike_Position': pos,
            'Actual_Position_Used': int(row['Spike_AS_Position']),
            'Raw_Enrichment_Short': shorten_raw_values(raw_vals),
            'Mean_Enrichment': round(np.median(raw_vals), 3),
            'Log10_Enrichment': round(np.log10(np.median(raw_vals)), 3) if np.median(raw_vals) > 0 else np.nan,
            'Smoothed_Enrichment': round(row['Smoothed_Enrichment'], 3)
        })

# Final DataFrame
summary_df = pd.DataFrame(summary_rows)

# Style with column widths
styled_summary = summary_df.style.set_table_styles([
    {"selector": "th", "props": [("text-align", "center")]},
    {"selector": "td", "props": [("text-align", "center")]},
    {"selector": "thead th.col0", "props": [("min-width", "90px")]},   # Immunization
    {"selector": "thead th.col1", "props": [("min-width", "60px")]},   # Target_Spike_Position
    {"selector": "thead th.col2", "props": [("min-width", "60px")]},   # Actual_Position_Used
    {"selector": "thead th.col3", "props": [("min-width", "450px")]},  # Raw_Enrichment_Short
    {"selector": "thead th.col4", "props": [("min-width", "80px")]},   # Mean_Enrichment
    {"selector": "thead th.col5", "props": [("min-width", "100px")]},  # Log10_Enrichment
    {"selector": "thead th.col6", "props": [("min-width", "130px")]}   # Smoothed_Enrichment
])

print("\n✅ Final Summary Table:")
display(styled_summary)


In [None]:
from IPython.display import display
import pandas as pd
import numpy as np

pd.set_option('display.max_rows', None)            # Show all rows
pd.set_option('display.max_columns', None)         # Show all columns
pd.set_option('display.max_colwidth', None)        # Show full content in each cell (no truncation)
pd.set_option('display.width', None)               # Auto detect width of console for wrapping
pd.set_option('display.expand_frame_repr', False)  # Prevent wrapping to multiple lines

# Define which positions you want in the table
target_positions = list(range(360, 541, 1))

summary_rows = []

def shorten_raw_values(vals):
    if isinstance(vals, list):
        if len(vals) > 100:
            return str(vals[:100])[:-1] + ", ...]"
        else:
            return str(vals)
    return vals

for immunization, df_im in immunization_data.items():
    print(f"[INFO] Adding table rows for: {immunization}")
    
    for pos in target_positions:
        row_match = df_im[df_im['Spike_AS_Position'] == pos]
        if row_match.empty:
            continue

        row = row_match.iloc[0]
        
        # Raw enrichment values: pull from original (unsmoothed) df_total
        raw_vals = df_total[
            (df_total['immunization'] == immunization) &
            (df_total['Spike_AS_Position'] == pos) &
            (df_total['Amino_Acid'] != "*")
        ]['Enrichment_Ratio'].clip(lower=1e-3).tolist()

        summary_rows.append({
            'Immunization': immunization,
            'Target_Spike_Position': pos,
            'Actual_Position_Used': int(row['Spike_AS_Position']),
            'Raw_Enrichment_Short': shorten_raw_values(raw_vals),
            'Median_Enrichment': round(np.median(raw_vals), 3),
            'Log10_Enrichment': round(np.log10(np.median(raw_vals)), 3) if np.median(raw_vals) > 0 else np.nan,
            'Smoothed_Enrichment': round(row['Smoothed_Enrichment'], 3)
        })

# Final DataFrame
summary_df = pd.DataFrame(summary_rows)

# Style with column widths
styled_summary = summary_df.style.set_table_styles([
    {"selector": "th", "props": [("text-align", "center")]},
    {"selector": "td", "props": [("text-align", "center")]},
    {"selector": "thead th.col0", "props": [("min-width", "90px")]},   # Immunization
    {"selector": "thead th.col1", "props": [("min-width", "60px")]},   # Target_Spike_Position
    {"selector": "thead th.col2", "props": [("min-width", "60px")]},   # Actual_Position_Used
    {"selector": "thead th.col3", "props": [("min-width", "450px")]},  # Raw_Enrichment_Short
    {"selector": "thead th.col4", "props": [("min-width", "80px")]},   # Mean_Enrichment
    {"selector": "thead th.col5", "props": [("min-width", "100px")]},  # Log10_Enrichment
    {"selector": "thead th.col6", "props": [("min-width", "130px")]}   # Smoothed_Enrichment
])

print("\n✅ Final Summary Table:")
display(styled_summary)


### Control table with raw data and transformed data (all values)

In [None]:
import numpy as np
import pandas as pd

# Define region
region_start = 500
region_end = 510
region_name = f"{region_start}-{region_end}"

summary_rows_barcode = []

# Filter data for region, non-stop codons, positive enrichment
df_region = df_total[
    (df_total['Spike_AS_Position'] >= region_start) &
    (df_total['Spike_AS_Position'] <= region_end) &
    (df_total['Enrichment_Ratio'] > 0) &
    (df_total['Amino_Acid'] != "*")
].copy()

# Get unique immunizations and barcodes
immunizations = df_region['immunization'].unique()

for immunization in immunizations:
    df_im = df_region[df_region['immunization'] == immunization]
    barcodes = df_im['barcode'].unique()

    # Calculate mean enrichment per barcode in region
    for barcode in barcodes:
        df_bc = df_im[df_im['barcode'] == barcode]

        # Mean enrichment (raw) over region for barcode
        median_enrich = df_bc['Enrichment_Ratio'].median()

        # Log10 mean enrichment
        log10_median_enrich = np.log10(median_enrich) if median_enrich > 0 else np.nan

        summary_rows_barcode.append({
            'Immunization': immunization,
            'Barcode': barcode,
            'Region': region_name,
            'Median_Enrichment': round(median_enrich, 3),
            'Log10_Median_Enrichment': round(log10_median_enrich, 3)
        })

# Create DataFrame for barcode means
df_barcode_summary = pd.DataFrame(summary_rows_barcode)

# Calculate median over barcodes per immunization
median_summary = df_barcode_summary.groupby('Immunization').agg({
    'Median_Enrichment': 'median',
    'Log10_Median_Enrichment': 'median'
}).rename(columns={
    'Median_Enrichment': 'Median_Median_Enrichment_across_Barcodes',
    'Log10_Median_Enrichment': 'Median_Log10_Median_Enrichment_across_Barcodes'
}).reset_index()

print(f"\n✅ Per-barcode mean enrichment stats for region {region_name}:")
print(df_barcode_summary.head(100))

print(f"\n✅ Median mean enrichment across barcodes per immunization for region {region_name}:")
print(median_summary)

# Optionally, merge or display both tables side by side


In [None]:
import numpy as np
import pandas as pd

# Define region
region_start = 500
region_end = 510
region_name = f"{region_start}-{region_end}"

summary_rows_barcode_region = []
summary_rows_barcode_position = []

# Filter data for region, non-stop codons, positive enrichment
df_region = df_total[
    (df_total['Spike_AS_Position'] >= region_start) &
    (df_total['Spike_AS_Position'] <= region_end) &
    (df_total['Enrichment_Ratio'] > 0) &
    (df_total['Amino_Acid'] != "*")
].copy()

# Get unique immunizations and barcodes
immunizations = df_region['immunization'].unique()

for immunization in immunizations:
    df_im = df_region[df_region['immunization'] == immunization]
    barcodes = df_im['barcode'].unique()

    for barcode in barcodes:
        df_bc = df_im[df_im['barcode'] == barcode]

        # Median enrichment over the entire region per barcode
        median_enrich_region = df_bc['Enrichment_Ratio'].median()
        log10_median_enrich_region = np.log10(median_enrich_region) if median_enrich_region > 0 else np.nan

        summary_rows_barcode_region.append({
            'Immunization': immunization,
            'Barcode': barcode,
            'Region': region_name,
            'Median_Enrichment': round(median_enrich_region, 3),
            'Log10_Median_Enrichment': round(log10_median_enrich_region, 3)
        })

        # Median enrichment per position for this barcode
        for position in range(region_start, region_end + 1):
            df_bc_pos = df_bc[df_bc['Spike_AS_Position'] == position]

            if not df_bc_pos.empty:
                median_enrich_pos = df_bc_pos['Enrichment_Ratio'].median()
                log10_median_enrich_pos = np.log10(median_enrich_pos) if median_enrich_pos > 0 else np.nan
            else:
                median_enrich_pos = np.nan
                log10_median_enrich_pos = np.nan

            summary_rows_barcode_position.append({
                'Immunization': immunization,
                'Barcode': barcode,
                'Position': position,
                'Median_Enrichment': round(median_enrich_pos, 3) if not np.isnan(median_enrich_pos) else np.nan,
                'Log10_Median_Enrichment': round(log10_median_enrich_pos, 3) if not np.isnan(log10_median_enrich_pos) else np.nan
            })

# Create DataFrames
df_barcode_summary_region = pd.DataFrame(summary_rows_barcode_region)
df_barcode_summary_position = pd.DataFrame(summary_rows_barcode_position)

# Sort the position summary by Immunization then by Position ascending
df_barcode_summary_position = df_barcode_summary_position.sort_values(by=['Immunization', 'Position']).reset_index(drop=True)

# Final summary: median of medians
median_summary_region = df_barcode_summary_region.groupby('Immunization').agg({
    'Median_Enrichment': 'median',
    'Log10_Median_Enrichment': 'median'
}).rename(columns={
    'Median_Enrichment': 'Median_Enrichment_across_Barcodes',
    'Log10_Median_Enrichment': 'Median_Log10_Enrichment_across_Barcodes'
}).reset_index()

median_summary_position = df_barcode_summary_position.groupby(['Immunization', 'Position']).agg({
    'Median_Enrichment': 'median',
    'Log10_Median_Enrichment': 'median'
}).rename(columns={
    'Median_Enrichment': 'Median_Enrichment_across_Barcodes',
    'Log10_Median_Enrichment': 'Median_Log10_Enrichment_across_Barcodes'
}).reset_index()

# Output
print(f"\n✅ Per-barcode **median** enrichment stats for region {region_name}:")
print(df_barcode_summary_region.head(1000))

print(f"\n✅ Median enrichment across barcodes per immunization for region {region_name}:")
print(median_summary_region)

print(f"\n✅ Per-barcode **median** enrichment stats per position for region {region_name} (sorted by position):")
print(df_barcode_summary_position.head(1000))

print(f"\n✅ Median enrichment across barcodes per immunization per position for region {region_name}:")
print(median_summary_position)


In [None]:
print("From df_filtered_agg:", df_filtered_agg['immunization'].unique())


In [None]:
summary_rows = []

print("Available immunizations in df_total:")
print(df_total['immunization'].dropna().unique())

# Define fixed target positions (every 15th position from 360 to 540)
target_positions = list(range(360, 541, 15))

valid_immunizations = df_total['immunization'].dropna().unique()

for immunization in valid_immunizations:
    
    if immunization in ['Library_ctrl', 'Neutralizing_Ab']:
        continue

    print(f"\n[INFO] Processing immunization: '{immunization}'")
    df_im = df_total[df_total['immunization'] == immunization].copy()
    print(f"  → Raw rows: {len(df_im)}")

    df_im = df_im[df_im['Amino_Acid'] != "*"]
    print(f"  → After filtering '*': {len(df_im)}")

    df_im['Enrichment_Ratio'] = df_im['Enrichment_Ratio'].clip(lower=1e-3)

    if df_im.empty:
        print(f"[WARN] No data for {immunization} after filtering.")
        continue

    grouped = df_im.groupby('Spike_AS_Position')['Enrichment_Ratio']
    mean_enrichment = grouped.mean()
    raw_enrichment = grouped.apply(list)

    log10_enrichment = mean_enrichment.apply(lambda x: np.log10(x) if x > 0 else np.nan)
    smoothed = log10_enrichment.rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    summary_data = pd.DataFrame({
        'Mean_Enrichment': mean_enrichment,
        'Raw_Enrichment_Values': raw_enrichment,
        'Log10_Enrichment': log10_enrichment,
        'Smoothed_Enrichment': smoothed
    }).reset_index()

    print(f"  → Positions in summary_data: {summary_data['Spike_AS_Position'].min()} to {summary_data['Spike_AS_Position'].max()}")

    for target_pos in target_positions:
        if summary_data.empty:
            continue

        # Find closest actual position
        diffs = (summary_data['Spike_AS_Position'] - target_pos).abs()
        min_diff = diffs.min()
        closest_idx = diffs.idxmin()
        closest_row = summary_data.loc[closest_idx]

        # Debug print for each appended row
        print(f"    Target pos {target_pos} → Closest actual pos {closest_row['Spike_AS_Position']} (diff {min_diff})")

        row = {
            'Immunization': immunization,
            'Target_Spike_Position': target_pos,
            'Actual_Position_Used': int(closest_row['Spike_AS_Position']),
            'Raw_Enrichment_Values': closest_row['Raw_Enrichment_Values'],
            'Mean_Enrichment': round(closest_row['Mean_Enrichment'], 3),
            'Log10_Enrichment': round(closest_row['Log10_Enrichment'], 3),
            'Smoothed_Enrichment': round(closest_row['Smoothed_Enrichment'], 3)
        }
        summary_rows.append(row)

summary_df = pd.DataFrame(summary_rows)

print("\nImmunizations included in summary_df:", summary_df['Immunization'].unique())
print(summary_df.head(10))


In [None]:
import pandas as pd
pd.set_option('display.max_rows', None)            # Show all rows
pd.set_option('display.max_columns', None)         # Show all columns
pd.set_option('display.max_colwidth', None)        # Show full content in each cell (no truncation)
pd.set_option('display.width', None)                # Auto detect width of console for wrapping
pd.set_option('display.expand_frame_repr', False)  # Do not wrap to multiple lines
from IPython.display import display
display(summary_df)

### Specific check for WT values

In [None]:
def shorten_raw_values(lst):
    if isinstance(lst, list):
        if len(lst) > 3:
            return str(lst[:3])[:-1] + ", ...]"
        else:
            return str(lst)
    return lst

summary_df['Raw_Enrichment_Short'] = summary_df['Raw_Enrichment_Values'].apply(shorten_raw_values)

# Now display with the shortened column instead of the full list
display(summary_df.drop(columns=['Raw_Enrichment_Values']))


## Line plot version with Double-aggregation: Median and Mean (not used)

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from matplotlib import font_manager
from matplotlib.ticker import MultipleLocator, FuncFormatter

%matplotlib inline
import os
print("Current working dir:", os.getcwd())
print("Files in output_dir:", os.listdir(output_dir))


pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 15  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 0  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33 + 331) &
                       (df_total['Enrichment_Ratio'] > 0)]

# Aggregate enrichment ratio by position
df_aa_agg = df_total[df_total['Amino_Acid'] != "*"].copy()
df_aa_agg['Enrichment_Ratio'] = df_aa_agg['Enrichment_Ratio'].clip(lower=1e-3)
df_aa_agg = df_aa_agg.groupby(['Spike_AS_Position', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'median'
})


# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
sites_to_show = list(sites_to_show)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(8, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':
        continue

    if immunization == 'Neutralizing_Ab':
        continue
    faint_immunizations = {'Polyclonal_Ab', 'Neutralizing_Ab'}
    faint_alpha = 0.3
    faint_linewidth = 1.5

    #if immunization == 'Polyclonal_Ab':  # Skip 'Library_ctrl'
     #   continue
    #if immunization == 'Neutralizing_Ab':  # Skip 'Library_ctrl'
     #   continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Normalize by number of barcodes
    num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'mean'
    })
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Safe log transform (ignore or remove zero/negative values)
    df_filtered_im['Log10_Enrichment'] = df_filtered_im['Enrichment_Ratio'].apply(
        lambda x: np.log10(x) if x > 0 else np.nan
    )
    
    # Apply smoothing on the log2 transformed values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log10_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).mean()
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(
        method='linear', limit_direction='both')

    # Identify high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')
    # Print what's being plotted
    print(f"\n[Every 5th Position Smoothed Enrichment for {immunization}]")
    subset = df_filtered_im[df_filtered_im['Spike_AS_Position'] % 5 == 0][['Spike_AS_Position', 'Smoothed_Enrichment']]
    print(subset.to_string(index=False))

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(),
              df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()



    # Mark sites to highlight
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)
    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

        # Rolling min and max for range area
    # Fixed (aligned min/max with smoothed mean)
    df_filtered_im['Smoothed_Enrichment_min'] = df_filtered_im['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).min()
    df_filtered_im['Smoothed_Enrichment_max'] = df_filtered_im['Smoothed_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1).max()


    # Plot shaded area for range
    ax.fill_between(
        df_filtered_im['Spike_AS_Position'],
        df_filtered_im['Smoothed_Enrichment_min'],
        df_filtered_im['Smoothed_Enrichment_max'],
        color=color_map.get(immunization, 'black'),
        alpha=0.1
    )

    # Plot
    
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₁₀ enrichment)",
        show_col="show_site",
        ax=ax,
        linewidth=2,
        color=color_map.get(immunization, 'black')
    )
    if immunization in faint_immunizations:
        ax.lines[-1].set_alpha(faint_alpha)
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    #highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    #for _, site_data in highlight_sites.iterrows():
    #    ax.hlines(
    #        y=0.1,  # set near bottom
    #        xmin=site_data['Spike_AS_Position'] - 0.5,
    #        xmax=site_data['Spike_AS_Position'] + 0.5,
    #        color='black',
    #        linestyle='-',
    #        linewidth=10
    #    )

# Y-axis limit (adjust if needed for log scale)
ax.set_ylim(bottom=-0.05,
            top=np.nanmax(df_filtered_im['Smoothed_Enrichment']) + 0.2)

plt.title('', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Log10 AB binding (Mean)', fontsize=18)

from matplotlib.patches import Patch

# Create a square patch for the legend entry
epitope_patch = Patch(facecolor='orange', edgecolor='orange', label='SARS-CoV-2 Spike antibody epitopes')


handles, labels = ax.get_legend_handles_labels()
# Legend (subset only to certain labels)
group_1_labels = ['Polyclonal_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles + [epitope_patch]

label_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'wildtype_RBD': 'ASCs from SARS-CoV-2 Wuhan Imm.',
    'Mutant_RBD': 'ASCs from SARS-CoV-2 B.1.135 Imm.'
}
labels = [label_map[label] for label in group_1_labels + group_2_labels] + ['Selected Antibody RBD epitopes']

plt.legend(
    handles, labels,
    title="Immunization",
    title_fontproperties=font_manager.FontProperties(weight='bold'),
    loc='upper left',
    fontsize=11,
    frameon=False,
    handlelength=2,
    handleheight=1,
    markerscale=1
)

# X-axis ticks
# Set x-ticks every 10 positions
# Major ticks every 10 positions
major_locator = MultipleLocator(10)
ax.xaxis.set_major_locator(major_locator)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f"{int(x)}"))

# Minor ticks every 2 positions (sub-ticks without labels)
minor_locator = MultipleLocator(2)
ax.xaxis.set_minor_locator(minor_locator)

# Enable minor ticks
ax.tick_params(axis='x', which='major', length=7, width=1.2)
ax.tick_params(axis='x', which='minor', length=4, width=0.8, direction='out')

ax.tick_params(axis='y', which='major', length=7, width=1.2)
ax.tick_params(axis='y', which='minor', length=4, width=0.8, direction='out')

# Optionally rotate x labels
plt.setp(ax.xaxis.get_majorticklabels(), rotation=0, ha='center')
# Tick marks every 0.1
ax.yaxis.set_major_locator(MultipleLocator(0.1))

# Only label 0.0, 0.3, 0.6
label_positions = [0.0, 0.3, 0.6]
ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:.1f}" if any(np.isclose(y, lp, atol=1e-3) for lp in label_positions) else ""))

# Optional minor ticks every 0.05
ax.yaxis.set_minor_locator(MultipleLocator(0.05))

# Save
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot_final_log10.png")
fig.tight_layout()
fig.savefig(plot_file_path, format='png')

plt.show()


## Statistics for Manuscript: Comparing Antibody binding ratios (median of individual droplets) across Spike protein regions

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "380-400": (380, 400),
    "400-410": (400, 410),
    "420-440": (420, 440),
    "490-510": (490, 510),
    "510-520": (510, 520),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---

positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 1
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 1)]

# Update immunizations array to exclude 'Un-enrich. Libr'
immunizations = [imm for imm in df_total2['immunization'].unique() if imm != 'Un-enrich. Libr']
print("Immunizations found in data (excluding 'Un-enrich. Libr'):", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values_log = {}  # for plotting (log10 means)
    region_values_raw = {}  # for stats (raw means)
    
    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        # Filter to positive means (>=1) to safely log-transform (for plotting)
        barcode_stats = barcode_stats[barcode_stats['Mean_Enrichment'] > 0].copy()
        
        # Store raw means for stats
        region_values_raw[imm] = barcode_stats['Mean_Enrichment']
        
        # Log10 transform means for plotting
        barcode_stats['Mean_Enrichment_log10'] = np.log10(barcode_stats['Mean_Enrichment'])
        region_values_log[imm] = barcode_stats['Mean_Enrichment_log10']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment_log10'],  # log10 for plotting
                'Std': row['Std_Enrichment'],  # Std still original scale
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    # Perform statistical tests on raw enrichment means (not log10)
    for imm1, imm2 in combinations(region_values_raw.keys(), 2):
        vals1 = region_values_raw[imm1].dropna()
        vals2 = region_values_raw[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats_log10.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats_log10.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Sort immunizations for plotting (exclude 'Un-enrich. Libr')
immunizations_sorted = ['Neutralizing_Ab', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=7, order=immunizations_sorted)
    ax.set_title(region_name, fontsize=16)

    ax.set_ylabel('log10 (Barcode Mean Enrichment ER)', fontsize=14)
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 2.25
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(0, ymax)
    y_offset = ymax - 2*ystep
    ax.tick_params(axis='y', labelsize=12)  # increase y-axis tick label size

    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if not star:
            continue
        if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
            continue
        x1 = immunizations_sorted.index(row['Group1'])
        x2 = immunizations_sorted.index(row['Group2'])
        x_center = (x1 + x2) / 2
        ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
        ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=16, fontweight='bold')

        y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_log10.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats_log10.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats_log10.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


## Manuscript Statistics plot, comparison of Raw-ER values across immunizations (Color) and across Spike Region (separate box-plots). Significance test: MW-u test

In [None]:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "360-370": (360, 370),
    "400-410": (400, 410),
    "410-420": (410, 420),
    #"420-430": (420, 430),
    #"430-440": (430, 440),
    #"440-450": (440, 450),
    #"450-460": (450, 460),
    #"460-470": (460, 470),
    #"470-480": (470, 480),
    #"480-490": (480, 490),
    "490-500": (490, 500),
    "500-510": (500, 510),
    #"510-520": (510, 520),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---

positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 1
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 0)]

# Update immunizations array to exclude 'Un-enrich. Libr'
immunizations = [imm for imm in df_total2['immunization'].unique() if imm != 'Un-enrich. Libr']
print("Immunizations found in data (excluding 'Un-enrich. Libr'):", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}


color_map = {
    'Polyclonal_Ab': (255/255, 140/255, 0/255, 0.4),    # darkorange
    'Neutralizing_Ab': (255/255, 0/255, 0/255, 0.4),    # red
    'wildtype_RBD': (0/255, 128/255, 0/255, 0.4),       # green
    'Mutant_RBD': (30/255, 144/255, 255/255, 0.4)       # dodgerblue
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values_log = {}  # for plotting (log10 means)
    region_values_raw = {}  # for stats (raw means)
    
    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        # Filter to positive means (>=1) to safely log-transform (for plotting)
        barcode_stats = barcode_stats[barcode_stats['Mean_Enrichment'] > 0].copy()
        
        # Store raw means for stats
        region_values_raw[imm] = barcode_stats['Mean_Enrichment']
        
        # Log10 transform means for plotting
        barcode_stats['Mean_Enrichment_log10'] = np.log10(barcode_stats['Mean_Enrichment'])
        region_values_log[imm] = barcode_stats['Mean_Enrichment_log10']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment_log10'],  # log10 for plotting
                'Std': row['Std_Enrichment'],  # Std still original scale
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    # Perform statistical tests on raw enrichment means (not log10)
    for imm1, imm2 in combinations(region_values_raw.keys(), 2):
        vals1 = region_values_raw[imm1].dropna()
        vals2 = region_values_raw[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats_log10.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats_log10.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Sort immunizations for plotting (exclude 'Un-enrich. Libr')
immunizations_sorted = ['Neutralizing_Ab', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']
palette = {imm: color_map[imm] for imm in immunizations_sorted}

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False,
                order=immunizations_sorted, palette=palette)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=7, order=immunizations_sorted)
    ax.set_title(region_name, fontsize=16)

    ax.set_ylabel('Log10 AB binding (Mean)', fontsize=14)
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 2.5
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(-0.4, ymax)
    y_offset = ymax - 4*ystep
    ax.tick_params(axis='y', labelsize=12)  # increase y-axis tick label size

    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if not star:
            continue
        if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
            continue
        x1 = immunizations_sorted.index(row['Group1'])
        x2 = immunizations_sorted.index(row['Group2'])
        x_center = (x1 + x2) / 2
        ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
        ax.text(x_center, y_offset+ystep*0.01, star, ha='center', va='bottom', fontsize=22, fontweight='bold')

        y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_log10.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats_log10.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats_log10.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "360-370": (360, 370),
    "400-410": (400, 410),
    "410-420": (410, 420),
    #"420-430": (420, 430),
    #"430-440": (430, 440),
    #"440-450": (440, 450),
    #"450-460": (450, 460),
    #"460-470": (460, 470),
    #"470-480": (470, 480),
    #"480-490": (480, 490),
    "490-500": (490, 500),
    "500-510": (500, 510),
    #"510-520": (510, 520),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---

positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 1
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 0)]

# Update immunizations array to exclude 'Un-enrich. Libr'
#immunizations = [imm for imm in df_total2['immunization'].unique()
               #  if imm not in ['Un-enrich. Libr', 'Neutralizing_Ab']]
immunizations = [imm for imm in df_total2['immunization'].unique() if imm != 'Un-enrich. Libr']
print("Immunizations found in data (excluding 'Un-enrich. Libr'):", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}


color_map = {
    'Polyclonal_Ab': (255/255, 140/255, 0/255, 0.4),    # darkorange
    'Neutralizing_Ab': (255/255, 0/255, 0/255, 0.4),    # red
    'wildtype_RBD': (0/255, 128/255, 0/255, 0.4),       # green
    'Mutant_RBD': (30/255, 144/255, 255/255, 0.4)       # dodgerblue
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values_log = {}  # for plotting (log10 means)
    region_values_raw = {}  # for stats (raw means)
    
    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        # Filter to positive means (>=1) to safely log-transform (for plotting)
        barcode_stats = barcode_stats[barcode_stats['Mean_Enrichment'] > 0].copy()
        
        # Store raw means for stats
        region_values_raw[imm] = barcode_stats['Mean_Enrichment']
        
        # Log10 transform means for plotting
        barcode_stats['Mean_Enrichment_log10'] = np.log10(barcode_stats['Mean_Enrichment'])
        region_values_log[imm] = barcode_stats['Mean_Enrichment_log10']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment_log10'],  # log10 for plotting
                'Std': row['Std_Enrichment'],  # Std still original scale
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    # Perform statistical tests on raw enrichment means (not log10)
    for imm1, imm2 in combinations(region_values_raw.keys(), 2):
        vals1 = region_values_raw[imm1].dropna()
        vals2 = region_values_raw[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats_log10.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats_log10.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Sort immunizations for plotting (exclude 'Un-enrich. Libr')
immunizations_sorted = ['Neutralizing_Ab', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']
palette = {imm: color_map[imm] for imm in immunizations_sorted}

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False,
                order=immunizations_sorted, palette=palette)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=7, order=immunizations_sorted)
    ax.set_title(region_name, fontsize=16)

    ax.set_ylabel('Log10 AB binding (Mean)', fontsize=14)
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 2.5
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(-0.4, ymax)
    y_offset = ymax - 4*ystep
    ax.tick_params(axis='y', labelsize=12)  # increase y-axis tick label size

    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if not star:
            continue
        if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
            continue
        x1 = immunizations_sorted.index(row['Group1'])
        x2 = immunizations_sorted.index(row['Group2'])
        x_center = (x1 + x2) / 2
        ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
        ax.text(x_center, y_offset+ystep*0.01, star, ha='center', va='bottom', fontsize=22, fontweight='bold')

        y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_log10_noNeut.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats_log10.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats_log10.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
# Same as before but with median

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    #"360-370": (360, 370),
    "370-380": (370, 380),
    #"380-390": (380, 390),
    "390-400": (390, 400),
    "400-410": (400, 410),
    "410-420": (410, 420),
    #"420-430": (420, 430),
    "430-440": (430, 440),
    "440-450": (440, 450),
    #"450-460": (450, 460),
    #"480-490": (480, 490),
    "490-500": (490, 500),
    #"500-510": (500, 510),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---
positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 0
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 0)]

# Update immunizations array to exclude 'Un-enrich. Libr'
immunizations = [imm for imm in df_total2['immunization'].unique() if imm != 'Un-enrich. Libr']
print("Immunizations found in data (excluding 'Un-enrich. Libr'):", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

color_map = {
    'Polyclonal_Ab': (255/255, 140/255, 0/255, 0.4),    # darkorange
    'Neutralizing_Ab': (255/255, 0/255, 0/255, 0.4),    # red
    'wildtype_RBD': (0/255, 128/255, 0/255, 0.4),       # green
    'Mutant_RBD': (30/255, 144/255, 255/255, 0.4)       # dodgerblue
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values_log = {}  # for plotting (log10 means)
    region_values_raw = {}  # for stats (raw means)
    
    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        # Step 1: Median per position per barcode
        median_per_pos = df_region.groupby(['barcode', 'Spike_AS_Position'])['Enrichment_Ratio'].median().reset_index()

        # Step 2: Keep only positive values for log transformation
        median_per_pos = median_per_pos[median_per_pos['Enrichment_Ratio'] > 0].copy()

        # Step 3: Log10 transform
        median_per_pos['Enrichment_Ratio_log10'] = np.log10(median_per_pos['Enrichment_Ratio'])

        # Step 4: Average log10 medians across the region for each barcode
        barcode_stats = median_per_pos.groupby('barcode').agg(
            Mean_Enrichment_log10=('Enrichment_Ratio_log10', 'mean'),
            Std_Enrichment_log10=('Enrichment_Ratio_log10', 'std'),
            n=('Enrichment_Ratio_log10', 'count')
        ).reset_index()

        # Also store raw scale averages of medians for stats
        barcode_stats['Mean_Enrichment'] = median_per_pos.groupby('barcode')['Enrichment_Ratio'].mean().values

        # Store for later
        region_values_raw[imm] = barcode_stats['Mean_Enrichment']
        region_values_log[imm] = barcode_stats['Mean_Enrichment_log10']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment_log10'],  # already log10
                'Std': row['Std_Enrichment_log10'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    # Perform statistical tests on raw enrichment means (not log10)
    for imm1, imm2 in combinations(region_values_raw.keys(), 2):
        vals1 = region_values_raw[imm1].dropna()
        vals2 = region_values_raw[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats_log10.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats_log10.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Sort immunizations for plotting (exclude 'Un-enrich. Libr')
immunizations_sorted = ['Neutralizing_Ab', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']
palette = {imm: color_map[imm] for imm in immunizations_sorted}

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False,
                order=immunizations_sorted, palette=palette)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=7, order=immunizations_sorted)
    ax.set_title(region_name, fontsize=16)

    ax.set_ylabel('Log10 AB binding (Median)', fontsize=14)
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 0.5
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(-0.65, ymax)
    y_offset = ymax - 4*ystep-0.04
    ax.tick_params(axis='y', labelsize=12)

    for _, row in region_results.iterrows():
        #if row['Group1'] == 'Neutralizing_Ab' or row['Group2'] == 'Neutralizing_Ab':
         #   continue
        star = pval_to_stars(row['p_value'])
        if not star:
            continue
        if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
            continue
        x1 = immunizations_sorted.index(row['Group1'])
        x2 = immunizations_sorted.index(row['Group2'])
        x_center = (x1 + x2) / 2
        ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.15, y_offset+ystep*0.15, y_offset], lw=1, c='k')
        ax.text(x_center, y_offset+ystep*0.01, star, ha='center', va='bottom', fontsize=18, fontweight='bold')

        y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_log10_noNeut.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats_log10.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats_log10.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
#KS Test

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import ks_2samp
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    #"360-370": (360, 370),
    #"370-380": (370, 380),
    "370-400": (380, 390),
    "400-430": (390, 400),
    "430-460": (400, 410),
    "460-490": (410, 420),
    "490-520": (420, 430),
    "520-550": (430, 440),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---
positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 0
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 0)]

# Update immunizations array to exclude 'Un-enrich. Libr'
immunizations = [imm for imm in df_total2['immunization'].unique() if imm != 'Un-enrich. Libr']
print("Immunizations found in data (excluding 'Un-enrich. Libr'):", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

color_map = {
    'Polyclonal_Ab': (255/255, 140/255, 0/255, 0.4),    # darkorange
    'Neutralizing_Ab': (255/255, 0/255, 0/255, 0.4),    # red
    'wildtype_RBD': (0/255, 128/255, 0/255, 0.4),       # green
    'Mutant_RBD': (30/255, 144/255, 255/255, 0.4)       # dodgerblue
}

plot_data = []
results = []


for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values_log = {}  # for plotting (log10 means)
    region_values_raw = {}  # for stats (raw means)
    
    for imm in immunizations:
        print(f"{imm}: total barcodes = {df_total2[df_total2['immunization'] == imm]['barcode'].nunique()}")
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        # Step 1: Median per position per barcode
        median_per_pos = df_region.groupby(['barcode', 'Spike_AS_Position'])['Enrichment_Ratio'].mean().reset_index()

        # Step 2: Keep only positive values for log transformation
        median_per_pos = median_per_pos[median_per_pos['Enrichment_Ratio'] > 0].copy()

        # Step 3: Log10 transform
        median_per_pos['Enrichment_Ratio_log10'] = np.log10(median_per_pos['Enrichment_Ratio'])

        # Step 4: Average log10 medians across the region for each barcode
        barcode_stats = median_per_pos.groupby('barcode').agg(
            Mean_Enrichment_log10=('Enrichment_Ratio_log10', 'mean'),
            Std_Enrichment_log10=('Enrichment_Ratio_log10', 'std'),
            n=('Enrichment_Ratio_log10', 'count')
        ).reset_index()

        # Also store raw scale averages of medians for stats
        barcode_stats['Mean_Enrichment'] = median_per_pos.groupby('barcode')['Enrichment_Ratio'].mean().values

        # Store for later
        region_values_raw[imm] = barcode_stats['Mean_Enrichment']
        region_values_log[imm] = barcode_stats['Mean_Enrichment_log10']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment_log10'],  # already log10
                'Std': row['Std_Enrichment_log10'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    # Perform statistical tests on raw enrichment means (not log10)
    for imm1, imm2 in combinations(region_values_raw.keys(), 2):
        vals1 = region_values_raw[imm1].dropna()
        vals2 = region_values_raw[imm2].dropna()
    
        print(f"  KS test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping KS test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = ks_2samp(vals1, vals2)
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'KS_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats_KS.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats_KS.csv'")

# Now that all regions are processed and plot_data is complete:
plot_df = pd.DataFrame(plot_data)

print(f"\nPrepared plot DataFrame with {len(plot_df)} rows")
print("\nFinal mean log10 enrichment values per Region and Immunization (per barcode):\n")

for region_name in sorted(plot_df['Region'].unique()):
    print(f"Region: {region_name}")
    sub_region = plot_df[plot_df['Region'] == region_name]
    
    for imm in immunizations:
        sub_imm = sub_region[sub_region['Immunization'] == imm]
        if sub_imm.empty:
            print(f"  {imm}: No data")
            continue
        
        print(f"  Immunization: {imm} (n={len(sub_imm)})")
        for idx, row in sub_imm.iterrows():
            print(f"    Barcode: {row['barcode']}, Mean log10 enrichment: {row['Enrichment']:.4f}")
        
        median_val = sub_imm['Enrichment'].median()
        mean_val = sub_imm['Enrichment'].mean()
        print(f"    Summary: median={median_val:.4f}, mean={mean_val:.4f}")
    print()

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    else:
        return ''   # skip * and nonsignificant


#def pval_to_stars(p):
 #   if p < 0.001:
  #      return '***'
   # elif p < 0.01:
    #    return '**'
    #elif p < 0.05:
    #    return '*'
    #else:
    #    return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)
#fig, axes = plt.subplots(2, 3, figsize=(24, 10), sharey=True)
axes = axes.flatten()  

# Sort immunizations for plotting (exclude 'Un-enrich. Libr')
immunizations_sorted = ['Neutralizing_Ab', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']
palette = {imm: color_map[imm] for imm in immunizations_sorted}

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False,
                order=immunizations_sorted, palette=palette)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=7, order=immunizations_sorted)
    ax.set_title(region_name, fontsize=16)
    ax.set_ylabel("Log10 AB binding (median)\n$\\mathbf{\\Leftarrow}$ Enrichment $\\mathbf{\\Rightarrow}$", 
                  rotation=90, labelpad=30, va='center', fontsize=18)
    ax.yaxis.set_label_coords(-0.3, 0.4)
    #ax.set_ylabel('Log10 AB binding (Median)', fontsize=14)
    #ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = sub_df['Enrichment'].max() 
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(-1, 4)
    y_offset = (ymax - 4*ystep-0.02)+ 1.3
    
    ax.tick_params(axis='y', labelsize=12)

    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if not star:
            continue
        if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
            continue
        x1 = immunizations_sorted.index(row['Group1'])
        x2 = immunizations_sorted.index(row['Group2'])
        x_center = (x1 + x2) / 2
        ax.plot([x1, x1, x2, x2],
                [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset],
                lw=1.5, c='k')
        ax.text(x_center, y_offset+ystep*0.01, star,
                ha='center', va='bottom',
                fontsize=22, fontweight='bold')
        y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_log10_KS_NeutF.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats_log10.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats_log10.csv'")

print("\nPairwise KS Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, KS_stat = {row['KS_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "360-370": (360, 370),
    "400-410": (400, 410),
    "410-420": (410, 420),
    #"420-430": (420, 430),
    #"430-440": (430, 440),
    #"440-450": (440, 450),
    #"450-460": (450, 460),
    #"460-470": (460, 470),
    #"470-480": (470, 480),
    #"480-490": (480, 490),
    "490-500": (490, 500),
    "500-510": (500, 510),
    #"510-520": (510, 520),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---

positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 1
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 0)]

# Update immunizations array to exclude 'Un-enrich. Libr'
#immunizations = [imm for imm in df_total2['immunization'].unique()
               #  if imm not in ['Un-enrich. Libr', 'Neutralizing_Ab']]
immunizations = [imm for imm in df_total2['immunization'].unique() if imm != 'Un-enrich. Libr']
print("Immunizations found in data (excluding 'Un-enrich. Libr'):", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}


color_map = {
    'Polyclonal_Ab': (255/255, 140/255, 0/255, 0.4),    # darkorange
    'Neutralizing_Ab': (255/255, 0/255, 0/255, 0.4),    # red
    'wildtype_RBD': (0/255, 128/255, 0/255, 0.4),       # green
    'Mutant_RBD': (30/255, 144/255, 255/255, 0.4)       # dodgerblue
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values_log = {}  # for plotting (log10 means)
    region_values_raw = {}  # for stats (raw means)
    
    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['median', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'median': 'Median_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        # Filter to positive means (>=1) to safely log-transform (for plotting)
        barcode_stats = barcode_stats[barcode_stats['Median_Enrichment'] > 0].copy()
        
        # Store raw means for stats
        region_values_raw[imm] = barcode_stats['Median_Enrichment']
        
        # Log10 transform means for plotting
        barcode_stats['Median_Enrichment_log10'] = np.log10(barcode_stats['Median_Enrichment'])
        region_values_log[imm] = barcode_stats['Median_Enrichment_log10']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Median_Enrichment_log10'],  # log10 for plotting
                'Std': row['Std_Enrichment'],  # Std still original scale
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    # Perform statistical tests on raw enrichment means (not log10)
    for imm1, imm2 in combinations(region_values_raw.keys(), 2):
        vals1 = region_values_raw[imm1].dropna()
        vals2 = region_values_raw[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats_log10.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats_log10.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Sort immunizations for plotting (exclude 'Un-enrich. Libr')
immunizations_sorted = ['Neutralizing_Ab', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']
palette = {imm: color_map[imm] for imm in immunizations_sorted}




for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False,
                order=immunizations_sorted, palette=palette)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=7, order=immunizations_sorted)
    ax.set_title(region_name, fontsize=16)

    ax.set_ylabel('Log10 AB binding (Median)', fontsize=14)
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 0.4
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(-0.3, ymax)
    y_offset = ymax - 4*ystep
    ax.tick_params(axis='y', labelsize=12)  # increase y-axis tick label size

    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if not star:
            continue
        if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
            continue
        x1 = immunizations_sorted.index(row['Group1'])
        x2 = immunizations_sorted.index(row['Group2'])
        x_center = (x1 + x2) / 2
        ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
        ax.text(x_center, y_offset+ystep*0.01, star, ha='center', va='bottom', fontsize=22, fontweight='bold')

        y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_log10_noNeut.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats_log10.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats_log10.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
#Above for publication

### Pre-analysis of statistics of ER values

In [None]:

import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

import os
print("Current working dir:", os.getcwd())


# Define regions of interest
regions = {
    "380-390": (380, 390),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-510": (470, 510),
}

# Immunization groups
immunizations = ['Polyclonal_Ab', 'Neutralizing_Ab', 'wildtype_RBD', 'Mutant_RBD']

# Display titles
title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135'
}

# Load datasets
data_dict = {}
for imm in immunizations:
    csv_path = os.path.join(output_dir, f"{imm}_data.csv")
    df = pd.read_csv(csv_path)
    data_dict[imm] = df

# Collect data for plotting
plot_data = []
results = []

print(data_dict['Polyclonal_Ab'].columns)

for region_name, (start, end) in regions.items():
    region_values = {}
    
    for imm, df in data_dict.items():
        df_region = df[(df['Spike_AS_Position'] >= start) & (df['Spike_AS_Position'] <= end)]
        vals = df_region['Enrichment_Ratio'].dropna()
        region_values[imm] = vals
        
        for val in vals:
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': val
            })
    
    # Pairwise statistical testing
    for (imm1, imm2) in combinations(region_values.keys(), 2):
        vals1, vals2 = region_values[imm1], region_values[imm2]
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval
        })

# Save stats
results_df = pd.DataFrame(results)
results_df = results_df.sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)

# Convert plot data to DataFrame
plot_df = pd.DataFrame(plot_data)

# Helper for p-value to stars
def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

# Plot
sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]
    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=3)

    ax.set_title(region_name)
    ax.set_ylabel('Enrichment', fontsize=14)
    ax.set_xlabel('')

    # Relabel x-ticks
    new_labels = [title_map.get(label.get_text(), label.get_text()) for label in ax.get_xticklabels()]
    ax.set_xticklabels(new_labels, rotation=0)

    # Add significance annotations
    region_results = results_df[results_df['Region'] == region_name]
    # Set fixed y-axis max and star annotation range
    ymax = 90
    star_start = 50  # Start placing stars from here
    star_step = 3     # Distance between significance bars
    ax.set_ylim(0, ymax)
    y_offset = star_start
    #ymax = sub_df['Enrichment'].max()
    #ymax = 85
    #ax.set_ylim(0, ymax)
    #ystep = (ymax * 0.1) if ymax > 0 else 1
    #y_offset = ymax - ystep
    for i, row in region_results.iterrows():
        group1, group2 = row['Group1'], row['Group2']
        star = pval_to_stars(row['p_value'])
        if star:
            x1 = immunizations.index(group1)
            x2 = immunizations.index(group2)
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=12)
            y_offset += ystep  # increment to avoid overlap

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison.png", dpi=300)
plt.show()

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

# For safety: set output_dir if not already set
output_dir = "output"  # <- Update if needed

print("Current working dir:", os.getcwd())

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Define regions of interest
regions = {
    "380-390": (380, 390),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-510": (470, 510),
}

# Immunization groups
immunizations = ['Polyclonal_Ab', 'Neutralizing_Ab', 'wildtype_RBD', 'Mutant_RBD']

# Display titles
title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135'
}

# Load and transform datasets
data_dict = {}
for imm in immunizations:
    csv_path = os.path.join(output_dir, f"{imm}_data.csv")
    df = pd.read_csv(csv_path)
    
    # Remove non-positive enrichment values before log10 transform
    df = df[df['Enrichment_Ratio'] > 0].copy()
    df['Log10_Enrichment'] = np.log10(df['Enrichment_Ratio'])
    data_dict[imm] = df

# Collect data for plotting
plot_data = []
results = []

for region_name, (start, end) in regions.items():
    region_values = {}

    for imm, df in data_dict.items():
        df_region = df[(df['Spike_AS_Position'] >= start) & (df['Spike_AS_Position'] <= end)]
        vals = df_region['Log10_Enrichment'].dropna()
        region_values[imm] = vals

        for val in vals:
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Log10_Enrichment': val
            })

    # Pairwise statistical testing on log10 data
    for (imm1, imm2) in combinations(region_values.keys(), 2):
        vals1, vals2 = region_values[imm1], region_values[imm2]
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval
        })

# Save statistics
results_df = pd.DataFrame(results)
results_df = results_df.sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_log10_stats.csv", index=False)

# Convert plot data to DataFrame
plot_df = pd.DataFrame(plot_data)

# Helper for p-value to stars
def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

# Plot
sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]
    sns.boxplot(data=sub_df, x='Immunization', y='Log10_Enrichment', ax=ax, showfliers=False)
    sns.swarmplot(data=sub_df, x='Immunization', y='Log10_Enrichment', ax=ax, color=".25", size=3)

    ax.set_title(region_name)
    ax.set_ylabel('log₁₀(Enrichment)', fontsize=14)
    ax.set_xlabel('')

    # Relabel x-ticks
    new_labels = [title_map.get(label.get_text(), label.get_text()) for label in ax.get_xticklabels()]
    ax.set_xticklabels(new_labels, rotation=0)

    # Add significance annotations
    region_results = results_df[results_df['Region'] == region_name]
    ymax = sub_df['Log10_Enrichment'].max() + 0.5
    y_offset = ymax - 0.1
    y_step = 0.2

    ax.set_ylim(bottom=sub_df['Log10_Enrichment'].min() - 0.5, top=ymax + 0.2)

    for i, row in region_results.iterrows():
        group1, group2 = row['Group1'], row['Group2']
        star = pval_to_stars(row['p_value'])
        if star:
            x1 = immunizations.index(group1)
            x2 = immunizations.index(group2)
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset + 0.05, y_offset + 0.05, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset + 0.07, star, ha='center', va='bottom', fontsize=12)
            y_offset += y_step

plt.tight_layout()
plt.savefig("regionwise_log10_enrichment_comparison.png", dpi=300)
plt.show()


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-510": (470, 510),
}
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 1)]
immunizations = df_total2['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)

print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Filter out NaN immunizations before sorting
# Define the plotting order explicitly (excluding 'wildtype_RBD')
immunizations_sorted = sorted([imm for imm in immunizations if pd.notna(imm)])
immunizations_sorted = ['Library_ctrl', 'Neutralizing_Ab', 'Polyclonal_Ab','wildtype_RBD', 'Mutant_RBD']


for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name)
    ax.set_ylabel('Mean Enrichment per Barcode')
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 40
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(0, 35)
    y_offset = ymax + ystep
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=12)
            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode.png", dpi=300)
plt.show()


plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

# Save per-barcode enrichment data
plot_df.to_csv("per_barcode_enrichment_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")


In [None]:
#p values mann whitney across regions, all enrichment values

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-510": (470, 510),
}

immunizations = df_total['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total[df_total['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)

print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Filter out NaN immunizations before sorting
# Define the plotting order explicitly (excluding 'wildtype_RBD')
immunizations_sorted = sorted([imm for imm in immunizations if pd.notna(imm)])
immunizations_sorted = ['Library_ctrl', 'Polyclonal_Ab','wildtype_RBD', 'Mutant_RBD']


for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name)
    ax.set_ylabel('Mean Enrichment per Barcode')
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 15
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(0, 22)
    y_offset = ymax + ystep
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=12)
            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_Above1.png", dpi=300)
plt.show()


plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

# Save per-barcode enrichment data
plot_df.to_csv("per_barcode_enrichment_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "380-390": (380, 390),
    "390-400": (390, 400),
    "400-410": (400, 410),
    "420-430": (420, 430),
    "430-440": (430, 440),
    "440-450": (440, 450),
    "450-460": (450, 460),
    "460-470": (460, 470),
    "470-480": (470, 480),
    "480-490": (480, 490),
    "490-500": (490, 500),
    "500-510": (500, 510),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---

positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 1
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 1)]

# Update immunizations array to exclude 'Un-enrich. Libr'
immunizations = [imm for imm in df_total2['immunization'].unique() if imm != 'Un-enrich. Libr']
print("Immunizations found in data (excluding 'Un-enrich. Libr'):", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    #'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        # Remove zero or negative means before log transform (should be >=1 anyway)
        barcode_stats = barcode_stats[barcode_stats['Mean_Enrichment'] > 0].copy()
        barcode_stats['Mean_Enrichment'] = np.log10(barcode_stats['Mean_Enrichment'])

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],  # Note: Std is still on original scale
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats_log10.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats_log10.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Sort immunizations for plotting (exclude 'Un-enrich. Libr')
immunizations_sorted = ['Library_ctrl', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)
    ax.set_title(region_name, fontsize=16)

    ax.set_ylabel('log10 (Mean Enrichment per Barcode)', fontsize=14)
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 5
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(0, ymax)
    y_offset = ymax - ystep
    ax.tick_params(axis='y', labelsize=12)  # increase y-axis tick label size

    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if not star:
            continue
        if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
            continue
        x1 = immunizations_sorted.index(row['Group1'])
        x2 = immunizations_sorted.index(row['Group2'])
        x_center = (x1 + x2) / 2
        ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
        ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=16, fontweight='bold')

        y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_log10.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats_log10.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats_log10.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
#p values mann whitney across regions, >1 enrichment values

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-510": (470, 510),
}
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 1)]
immunizations = df_total2['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)

print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Filter out NaN immunizations before sorting
# Define the plotting order explicitly (excluding 'wildtype_RBD')
immunizations_sorted = sorted([imm for imm in immunizations if pd.notna(imm)])
immunizations_sorted = ['Library_ctrl', 'Polyclonal_Ab','wildtype_RBD', 'Mutant_RBD']


for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name)
    ax.set_ylabel('Mean Enrichment per Barcode')
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 15
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(0, 22)
    y_offset = ymax + ystep
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=12)
            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_Above1.png", dpi=300)
plt.show()


plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

# Save per-barcode enrichment data
plot_df.to_csv("per_barcode_enrichment_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


## Mean ER per barcode, larger Spike regions

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-510": (470, 510),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---

positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 1
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 1)]

# Update immunizations array to include the new sample
immunizations = df_total2['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    #'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
    #'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Make sure to include 'Un-enrich. Libr' in the sorted immunizations order for plotting
immunizations_sorted = ['Un-enrich. Libr','Library_ctrl', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)
    ax.set_title(region_name, fontsize=16)

    ax.set_ylabel('Mean Enrichment per Barcode', fontsize=14)
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 15
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(0, 22)
    y_offset = ymax + ystep
    ax.tick_params(axis='y', labelsize=12)  # increase y-axis tick label size

    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if not star or 'Un-enrich. Libr' in [row['Group1'], row['Group2']]:
            continue
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=16, fontweight='bold')

            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_Above1_Lib.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-490": (470, 490),
    "490-510": (490, 510),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---

positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 1
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 1)]

# Update immunizations array to include the new sample
immunizations = df_total2['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\n2',
    'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Make sure to include 'Un-enrich. Libr' in the sorted immunizations order for plotting
immunizations_sorted = ['Un-enrich. Libr','Library_ctrl', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)
    ax.set_title(region_name, fontsize=16)

    ax.set_ylabel('Mean Enrichment per Barcode', fontsize=14)
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 15
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(0, 26)
    y_offset = ymax + ystep
    ax.tick_params(axis='y', labelsize=12)  # increase y-axis tick label size

    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if not star or 'Un-enrich. Libr' in [row['Group1'], row['Group2']]:
            continue
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=18, fontweight='bold')

            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_ER_per_barcode_Above1_Lib.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:. 4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
### Log10 Comparison

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined
print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-490": (470, 490),
    "490-510": (490, 510),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---
positions = sorted({pos for start, end in regions.values() for pos in range(start, end + 1)})

new_sample_rows = [{
    'barcode': f"UnEnrichLib_{pos}",
    'immunization': 'Un-enrich. Libr',
    'Spike_AS_Position': pos,
    'Enrichment_Ratio': 1.0
} for pos in positions]

new_sample_df = pd.DataFrame(new_sample_rows)
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# --- Filter and log-transform ---
df_total2 = df_total[df_total['Enrichment_Ratio'] >= 1].copy()
df_total2['Log10_Enrichment'] = np.log10(df_total2['Enrichment_Ratio'])

immunizations = df_total2['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\n2',
    'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        grouped = df_region.groupby('barcode')

        barcode_stats = grouped['Log10_Enrichment'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Log10_Enrichment', 'std': 'Std_Log10_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Log10_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Log10_Enrichment': row['Mean_Log10_Enrichment'],
                'Std': row['Std_Log10_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
        if len(vals1) == 0 or len(vals2) == 0:
            continue
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_log10_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_log10_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)
immunizations_sorted = ['Un-enrich. Libr','Library_ctrl', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Log10_Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Log10_Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name, fontsize=16)
    ax.set_ylabel('log10 (Mean Enrichment per Barcode)', fontsize=14)
    ax.set_xlabel('')
    ax.set_xticklabels([title_map.get(label, label) for label in immunizations_sorted], rotation=0)
    ax.tick_params(axis='y', labelsize=12)
    ax.set_ylim(0, 1.2)  # adjust as needed based on data range

    region_results = results_df[results_df['Region'] == region_name]
    y_offset = 0.4
    ystep = 0.1
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if not star or 'Un-enrich. Libr' in [row['Group1'], row['Group2']]:
            continue
        if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
            continue
        x1 = immunizations_sorted.index(row['Group1'])
        x2 = immunizations_sorted.index(row['Group2'])
        x_center = (x1 + x2) / 2
        ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.2, y_offset+ystep*0.2, y_offset], lw=1.5, c='k')
        ax.text(x_center, y_offset+ystep*0.3, star, ha='center', va='bottom', fontsize=18, fontweight='bold')
        y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_log10_ER_per_barcode.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_log10_enrichment_stats.csv", index=False)
print("Saved per-barcode log10 enrichment stats to 'per_barcode_log10_enrichment_stats.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:. 4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined
print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-490": (470, 490),
    "490-510": (490, 510),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---
positions = sorted({pos for start, end in regions.values() for pos in range(start, end + 1)})

new_sample_rows = [{
    'barcode': f"UnEnrichLib_{pos}",
    'immunization': 'Un-enrich. Libr',
    'Spike_AS_Position': pos,
    'Enrichment_Ratio': 1.0
} for pos in positions]

new_sample_df = pd.DataFrame(new_sample_rows)
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# --- Filter and log-transform ---
df_total2 = df_total[df_total['Enrichment_Ratio'] >= 0].copy()
df_total2['Log10_Enrichment'] = np.log10(df_total2['Enrichment_Ratio'])

immunizations = df_total2['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\n2',
    'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        grouped = df_region.groupby('barcode')

        barcode_stats = grouped['Log10_Enrichment'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Log10_Enrichment', 'std': 'Std_Log10_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Log10_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Log10_Enrichment': row['Mean_Log10_Enrichment'],
                'Std': row['Std_Log10_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
        if len(vals1) == 0 or len(vals2) == 0:
            continue
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_log10_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_log10_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)
immunizations_sorted = ['Un-enrich. Libr','Library_ctrl', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Log10_Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Log10_Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name, fontsize=16)
    ax.set_ylabel('log10 (Mean Enrichment per Barcode)', fontsize=14)
    ax.set_xlabel('')
    ax.set_xticklabels([title_map.get(label, label) for label in immunizations_sorted], rotation=0)
    ax.tick_params(axis='y', labelsize=12)
    ax.set_ylim(-0.6, 0.8)  # adjust as needed based on data range

    region_results = results_df[results_df['Region'] == region_name]
    y_offset = 0.4
    ystep = 0.12
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if not star or 'Un-enrich. Libr' in [row['Group1'], row['Group2']]:
            continue
        if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
            continue
        x1 = immunizations_sorted.index(row['Group1'])
        x2 = immunizations_sorted.index(row['Group2'])
        x_center = (x1 + x2) / 2
        ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.2, y_offset+ystep*0.2, y_offset], lw=1.5, c='k')
        ax.text(x_center, y_offset+ystep*0.3, star, ha='center', va='bottom', fontsize=18, fontweight='bold')
        y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_log10_ALLER_per_barcode.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_log10_enrichment_stats.csv", index=False)
print("Saved per-barcode log10 enrichment stats to 'per_barcode_log10_enrichment_stats.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:. 4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-490": (470, 490),
    "490-510": (490, 510),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---

positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 1
df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 1)]

# Update immunizations array to include the new sample
immunizations = df_total2['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl',
    'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Make sure to include 'Un-enrich. Libr' in the sorted immunizations order for plotting
immunizations_sorted = ['Un-enrich. Libr','Library_ctrl', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name)
    ax.set_ylabel('Mean Enrichment per Barcode')
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 15
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(0, 26)
    y_offset = ymax + ystep
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=12)
            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_Above1_Lib.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu
import numpy as np

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

# --- Define regions ---
regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-490": (470, 490),
    "490-510": (490, 510),
}

# --- Add new sample "Un-enrich. Libr" with enrichment=1 at all region positions ---

positions = []
for start, end in regions.values():
    positions.extend(range(start, end + 1))
positions = sorted(set(positions))

new_sample_rows = []
for pos in positions:
    new_sample_rows.append({
        'barcode': f"UnEnrichLib_{pos}",
        'immunization': 'Un-enrich. Libr',
        'Spike_AS_Position': pos,
        'Enrichment_Ratio': 1.0
    })

new_sample_df = pd.DataFrame(new_sample_rows)

# Append new sample rows to df_total
df_total = pd.concat([df_total, new_sample_df], ignore_index=True)
print(f"Added {len(new_sample_df)} new rows for 'Un-enrich. Libr' sample with Enrichment_Ratio=1.")

# Filter df_total to keep Enrichment_Ratio >= 1
#df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 1)]

# Update immunizations array to include the new sample
immunizations = df_total2['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl',
    'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total[df_total['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Make sure to include 'Un-enrich. Libr' in the sorted immunizations order for plotting
immunizations_sorted = ['Un-enrich. Libr','Library_ctrl', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']

for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name)
    ax.set_ylabel('Mean Enrichment per Barcode')
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 12
    ystep = (ymax * 0.15) if ymax > 0 else 1
    ax.set_ylim(0, 27)
    y_offset = ymax + ystep
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=12)
            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode_ALL_Lib.png", dpi=300)
plt.show()

plot_df.to_csv("per_barcode_enrichment_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")


In [None]:
#p values mann whitney across regions, <1 enrichment values

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

# Assume df_total is your original DataFrame with all data combined

print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-510": (470, 510),
}
df_total2 = df_total[(df_total['Enrichment_Ratio'] <= 1)]
immunizations = df_total2['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)

print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Filter out NaN immunizations before sorting
# Define the plotting order explicitly (excluding 'wildtype_RBD')
immunizations_sorted = sorted([imm for imm in immunizations if pd.notna(imm)])
immunizations_sorted = ['Library_ctrl', 'Neutralizing_Ab', 'Polyclonal_Ab','wildtype_RBD', 'Mutant_RBD']


for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name)
    ax.set_ylabel('Mean Enrichment per Barcode')
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = 1.1
    ystep = (ymax * 0.1) if ymax > 0 else 1
    ax.set_ylim(0, 2)
    y_offset = ymax + ystep
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=12)
            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcodeAllRatios.png", dpi=300)
plt.show()


plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

# Save per-barcode enrichment data
plot_df.to_csv("per_barcode_enrichment_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")

print("\nPairwise Mann-Whitney U Test Results (sorted by p-value):")
for _, row in results_df.iterrows():
    star = pval_to_stars(row['p_value'])
    print(f"Region: {row['Region']}, {row['Group1']} vs {row['Group2']}, "
          f"p = {row['p_value']:.4g} {star}, U = {row['U_stat']:.2f}, "
          f"n1 = {row['n1']}, n2 = {row['n2']}")



## p values mann whitney across regions, >1 enrichment values and < 1: AB Escape

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

# Assume df_total is your original DataFrame with all data combined
print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-510": (470, 510),
}

immunizations = df_total['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

plot_data = []
results = []

df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 0) & (df_total['Enrichment_Ratio'] <= 1)]
# Use df_filtered instead of df_total in the rest of the code


for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)

print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Filter out NaN immunizations before sorting
# Define the plotting order explicitly (excluding 'wildtype_RBD')
immunizations_sorted = sorted([imm for imm in immunizations if pd.notna(imm)])
immunizations_sorted = ['Library_ctrl', 'Neutralizing_Ab', 'Polyclonal_Ab','wildtype_RBD', 'Mutant_RBD']


for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name)
    ax.set_ylabel('Mean Enrichment per Barcode')
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = sub_df['Enrichment'].max()
    ystep = (ymax * 0.1) if ymax > 0 else 1
    y_offset = ymax + ystep
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=12)
            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode.png", dpi=300)
plt.show()


plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

# Save per-barcode enrichment data
plot_df.to_csv("per_barcode_Escape_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

# Assume df_total is your original DataFrame with all data combined
print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-510": (470, 510),
}

immunizations = df_total['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

plot_data = []
results = []

df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 0) & (df_total['Enrichment_Ratio'] <= 1)]
# Use df_filtered instead of df_total in the rest of the code


for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)

print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Filter out NaN immunizations before sorting
# Define the plotting order explicitly (excluding 'wildtype_RBD')
immunizations_sorted = sorted([imm for imm in immunizations if pd.notna(imm)])
immunizations_sorted = ['Library_ctrl', 'Neutralizing_Ab', 'Polyclonal_Ab','wildtype_RBD', 'Mutant_RBD']


for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name)
    ax.set_ylabel('Mean Enrichment per Barcode')
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = sub_df['Enrichment'].max()
    ystep = (ymax * 0.1) if ymax > 0 else 1
    y_offset = ymax + ystep
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=12)
            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode.png", dpi=300)
plt.show()


plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

# Save per-barcode enrichment data
plot_df.to_csv("per_barcode_Escape_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

# Assume df_total is your original DataFrame with all data combined
print("Initial df_total columns:", df_total.columns.tolist())

assert 'barcode' in df_total.columns, "'barcode' column missing in df_total!"
assert 'immunization' in df_total.columns, "'immunization' column missing in df_total!"

regions = {
    "390-420": (390, 420),
    "424-450": (424, 450),
    "450-470": (450, 470),
    "470-510": (470, 510),
}

immunizations = df_total['immunization'].unique()
print("Immunizations found in data:", immunizations)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\nCtrl'
}

plot_data = []
results = []

df_total2 = df_total[(df_total['Enrichment_Ratio'] >= 1)]
# Use df_filtered instead of df_total in the rest of the code


for region_name, (start, end) in regions.items():
    print(f"\nProcessing region {region_name} from {start} to {end}")
    
    region_values = {}

    for imm in immunizations:
        df_imm = df_total2[df_total2['immunization'] == imm]
        print(f"  {imm}: {len(df_imm)} rows before region filtering")
        
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        print(f"  {imm}: {len(df_region)} rows after region filtering")

        if 'barcode' not in df_region.columns:
            raise ValueError(f"'barcode' column missing after filtering for immunization {imm} in region {region_name}!")

        grouped = df_region.groupby('barcode')
        print(f"  {imm}: {len(grouped)} unique barcodes in region")

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
    
        print(f"  Mann-Whitney test between {imm1} (n={len(vals1)}) and {imm2} (n={len(vals2)})")
    
        if len(vals1) == 0 or len(vals2) == 0:
            print(f"   Skipping Mann-Whitney test for {imm1} vs {imm2} due to zero data size.")
            continue
    
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
    
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

results_df = pd.DataFrame(results).sort_values(by='p_value')
results_df.to_csv("pairwise_enrichment_stats.csv", index=False)
print("\nSaved statistical test results to 'pairwise_enrichment_stats.csv'")

plot_df = pd.DataFrame(plot_data)

print(f"Prepared plot DataFrame with {len(plot_df)} rows")

def pval_to_stars(p):
    if p < 0.001:
        return '***'
    elif p < 0.01:
        return '**'
    elif p < 0.05:
        return '*'
    else:
        return ''

sns.set(style="whitegrid")
fig, axes = plt.subplots(1, len(regions), figsize=(16, 5), sharey=True)

# Filter out NaN immunizations before sorting
# Define the plotting order explicitly (excluding 'wildtype_RBD')
immunizations_sorted = sorted([imm for imm in immunizations if pd.notna(imm)])
immunizations_sorted = ['Library_ctrl', 'Neutralizing_Ab', 'Polyclonal_Ab','wildtype_RBD', 'Mutant_RBD']


for ax, (region_name, _) in zip(axes, regions.items()):
    sub_df = plot_df[plot_df['Region'] == region_name]

    sns.boxplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, showfliers=False, order=immunizations_sorted)
    sns.swarmplot(data=sub_df, x='Immunization', y='Enrichment', ax=ax, color=".25", size=4, order=immunizations_sorted)

    ax.set_title(region_name)
    ax.set_ylabel('Mean Enrichment per Barcode')
    ax.set_xlabel('')

    new_labels = [title_map.get(label, label) for label in immunizations_sorted]
    ax.set_xticklabels(new_labels, rotation=0)

    region_results = results_df[results_df['Region'] == region_name]
    ymax = sub_df['Enrichment'].max()
    ystep = (ymax * 0.1) if ymax > 0 else 1
    y_offset = ymax + ystep
    for _, row in region_results.iterrows():
        star = pval_to_stars(row['p_value'])
        if star:
            if row['Group1'] not in immunizations_sorted or row['Group2'] not in immunizations_sorted:
                continue
            x1 = immunizations_sorted.index(row['Group1'])
            x2 = immunizations_sorted.index(row['Group2'])
            x_center = (x1 + x2) / 2
            ax.plot([x1, x1, x2, x2], [y_offset, y_offset+ystep*0.1, y_offset+ystep*0.1, y_offset], lw=1.5, c='k')
            ax.text(x_center, y_offset+ystep*0.15, star, ha='center', va='bottom', fontsize=12)
            y_offset += ystep

plt.tight_layout()
plt.savefig("regionwise_enrichment_comparison_per_barcode.png", dpi=300)
plt.show()


plot_df = pd.DataFrame(plot_data)
print(f"Prepared plot DataFrame with {len(plot_df)} rows")

# Save per-barcode enrichment data
plot_df.to_csv("per_barcode_Escape_stats.csv", index=False)
print("Saved per-barcode enrichment stats to 'per_barcode_enrichment_stats.csv'")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

# Your enrichment function is NOT used here (no transformation)
# We filter values 0 < enrichment <= 1 only

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    region_values = {}

    for imm in immunizations:
        df_imm = df_total[df_total['immunization'] == imm]
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        grouped = df_region.groupby('barcode')

        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        # Filter values between 0 and 1 (excluding 0)
        barcode_stats = barcode_stats[(barcode_stats['Mean_Enrichment'] > 0) & (barcode_stats['Mean_Enrichment'] <= 1)]

        region_values[imm] = barcode_stats['Mean_Enrichment']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Mean_Enrichment'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    # Mann-Whitney U test between immunization groups
    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()
        if len(vals1) == 0 or len(vals2) == 0:
            continue
        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

# Convert plot data to DataFrame
plot_df = pd.DataFrame(plot_data)
results_df = pd.DataFrame(results).sort_values(by='p_value')

# Save data if you want
plot_df.to_csv("block1_filtered_no_transform_plot_data.csv", index=False)
results_df.to_csv("block1_filtered_no_transform_stats.csv", index=False)

# Plotting violin plot
plt.figure(figsize=(12, 8))
sns.violinplot(data=plot_df, x='Region', y='Enrichment', hue='Immunization', cut=0, inner='quartile')
plt.title('Enrichment Ratios (0 < Enrichment ≤ 1) by Region and Immunization (No Transformation)')
plt.ylabel('Mean Enrichment Ratio')
plt.xlabel('Region')
plt.legend(title='Immunization')
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import combinations
from scipy.stats import mannwhitneyu

def enrichment_height(enrichment):
    if enrichment >= 1:
        return np.log2(enrichment)
    else:
        return -np.log2(1 / enrichment)

plot_data = []
results = []

for region_name, (start, end) in regions.items():
    region_values = {}

    for imm in immunizations:
        df_imm = df_total[df_total['immunization'] == imm]
        df_region = df_imm[(df_imm['Spike_AS_Position'] >= start) & (df_imm['Spike_AS_Position'] <= end)]
        grouped = df_region.groupby('barcode')
        barcode_stats = grouped['Enrichment_Ratio'].agg(['mean', 'std', 'count']).reset_index()
        barcode_stats.rename(columns={'mean': 'Mean_Enrichment', 'std': 'Std_Enrichment', 'count': 'n'}, inplace=True)

        barcode_stats['Transformed'] = barcode_stats['Mean_Enrichment'].apply(enrichment_height)

        region_values[imm] = barcode_stats['Transformed']

        for _, row in barcode_stats.iterrows():
            plot_data.append({
                'Region': region_name,
                'Immunization': imm,
                'Enrichment': row['Transformed'],
                'Std': row['Std_Enrichment'],
                'n_sites': row['n'],
                'barcode': row['barcode']
            })

    for imm1, imm2 in combinations(region_values.keys(), 2):
        vals1 = region_values[imm1].dropna()
        vals2 = region_values[imm2].dropna()

        if len(vals1) == 0 or len(vals2) == 0:
            continue

        stat, pval = mannwhitneyu(vals1, vals2, alternative='two-sided')
        results.append({
            'Region': region_name,
            'Group1': imm1,
            'Group2': imm2,
            'Mean1': vals1.mean(),
            'Mean2': vals2.mean(),
            'Median1': vals1.median(),
            'Median2': vals2.median(),
            'U_stat': stat,
            'p_value': pval,
            'n1': len(vals1),
            'n2': len(vals2)
        })

plot_df = pd.DataFrame(plot_data)
results_df = pd.DataFrame(results).sort_values(by='p_value')
plot_df.to_csv("block2_transformed_all.csv", index=False)
results_df.to_csv("block2_stats_all.csv", index=False)

immunizations_sorted = ['Library_ctrl', 'Neutralizing_Ab', 'Polyclonal_Ab', 'wildtype_RBD', 'Mutant_RBD']

sns.set(style="whitegrid")
plt.figure(figsize=(12, 8))

ax = sns.boxplot(data=plot_df, x='Region', y='Enrichment', hue='Immunization',
                 order=sorted(plot_df['Region'].unique()),
                 hue_order=immunizations_sorted,
                 showfliers=False)
sns.swarmplot(data=plot_df, x='Region', y='Enrichment', hue='Immunization',
              order=sorted(plot_df['Region'].unique()),
              hue_order=immunizations_sorted,
              dodge=True, color=".25", ax=ax)

plt.title('Transformed Enrichment by Region and Immunization')
plt.ylabel('Transformed Enrichment (log2 scale)')
plt.xlabel('Region')

# Fix legend duplicates
handles, labels = ax.get_legend_handles_labels()
n = len(immunizations_sorted)
plt.legend(handles[:n], labels[:n], title='Immunization', loc='best')

# Add p-value annotations above boxplots for each region and pair
regions_unique = sorted(plot_df['Region'].unique())
n_regions = len(regions_unique)
n_hues = len(immunizations_sorted)

positions = {}
for i, region in enumerate(regions_unique):
    for j, imm in enumerate(immunizations_sorted):
        positions[(region, imm)] = i - 0.2 + (j * 0.4 / (n_hues - 1)) if n_hues > 1 else i

for _, row in results_df.iterrows():
    region = row['Region']
    imm1 = row['Group1']
    imm2 = row['Group2']
    pval = row['p_value']

    if (region, imm1) not in positions or (region, imm2) not in positions:
        continue  # skip if missing

    x1 = positions[(region, imm1)]
    x2 = positions[(region, imm2)]
    y_max = plot_df[(plot_df['Region'] == region) & 
                    (plot_df['Immunization'].isin([imm1, imm2]))]['Enrichment'].max()

    y, h, col = y_max + 0.3, 0.1, 'k'

    ax.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, c=col)
    if pval < 0.001:
        p_text = '***'
    elif pval < 0.01:
        p_text = '**'
    elif pval < 0.05:
        p_text = '*'
    else:
        p_text = f"ns\n({pval:.2f})"

    ax.text((x1 + x2) * .5, y + h, p_text, ha='center', va='bottom', color=col, fontsize=12)

plt.tight_layout()
plt.show()


In [None]:
#plots

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 10  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & 
                       (df_total['Enrichment_Ratio'] > 1)]  # Only include Enrichment_Ratio > 1

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394,414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)

sites_to_show = list(sites_to_show)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(10, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':  # Skip 'Library_ctrl'
        continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    # Count number of unique barcodes in this immunization
    num_barcodes = df_filtered_agg[df_filtered_agg['immunization'] == immunization]['barcode'].nunique()
    
    # Aggregate and normalize by barcode count
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })
    df_filtered_im['Enrichment_Ratio'] /= num_barcodes  # Normalize
    if df_filtered_im['Enrichment_Ratio'].isna().any():
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')
    # Avoid log of 0 or negative numbers
    df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].replace(0, np.nan)
    df_filtered_im['Log_Enrichment'] = np.log10(df_filtered_im['Enrichment_Ratio'])


    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Log_Enrichment'].rolling(
        window=ROLLING_WINDOW, center=True, min_periods=1
    ).mean()

    # Handle missing data by interpolating missing values (linear interpolation)
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Identify clusters of high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex for visualization (after interpolation to avoid issues)
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()

    print(f"Target sites to highlight: {list(sites_to_show)}")

    # Check the data type of Spike_AS_Position and sites_to_show for comparison
    print(f"Data type of 'Spike_AS_Position': {df_filtered_agg['Spike_AS_Position'].dtype}")
    print(f"Data type of 'sites_to_show': {type(sites_to_show)}")

    # Ensure High_Enrichment is still boolean
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    df_filtered_im = df_filtered_im.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )
    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))
    # Use dmslogo.line.draw_line but plot on the same axes with specified color
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",  # Remove individual titles for each plot
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding",
        show_col="show_site",  # Highlight high enrichment clusters
        ax=ax,  # Pass the same axes object for all plots
        linewidth=2,
        color=color_map.get(immunization, 'black')  # Get the color for the immunization (default to black)
    )
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)
    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    for _, site_data in highlight_sites.iterrows():
        # Use ax.hlines to draw horizontal lines only where there are sites to show
        ax.hlines(
            y=0,  # Set the y-position of the line to 0 (or a small value near the bottom of the plot)
            xmin=site_data['Spike_AS_Position'] - 0.5,  # Start of the line (slightly before the site)
            xmax=site_data['Spike_AS_Position'] + 0.5,  # End of the line (slightly after the site)
            color='black',  # Line color
            linestyle='-',  # Line style
            linewidth=10  # Line width
        )

# Set the y-axis limit to 200
# Set y-axis limit and ticks at 0.1 intervals
ax.set_ylim(0.4, 1.35)
ax.yaxis.set_major_locator(MaxNLocator(integer=False, prune='lower', nbins=20))  # allows decimals
ax.set_yticks(np.arange(0.4, 1.35, 0.1))

# After all lines are drawn, adjust plot settings
plt.title('Smoothed Enrichment Ratios for Different Immunizations', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Antibody Repertoire - Enrichment \n (Median per-cell Polyreactivity)', fontsize=16)

# Add the legend in the top right
# Create a grouped legend
handles, labels = ax.get_legend_handles_labels()

group_1_labels = ['Polyclonal_Ab', 'Neutralizing_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles
labels = group_1_labels + group_2_labels

#handles = group_2_handles
#labels = group_2_labels

# Add the legend to the top right
plt.legend(handles, labels, title="Single Droplet Repertoire", loc='upper right', fontsize=11, frameon=False, handlelength=2, handleheight=1, title_fontsize=11, markerscale=8)

# Set the x-axis ticks explicitly to 20 ticks across the range, and label every other one
xticks = np.linspace(df_filtered_agg['Spike_AS_Position'].min(), df_filtered_agg['Spike_AS_Position'].max(), 10).astype(int)
ax.set_xticks(xticks)

# Set the labels to only show for every other tick
ax.set_xticklabels([str(x) if i % 2 == 0 else '' for i, x in enumerate(xticks)])


ax.yaxis.set_major_locator(MaxNLocator(integer=True, prune='lower', nbins=8))
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot.png")
fig.tight_layout() 
fig.savefig(plot_file_path, format='png')


# Display the combined plot
plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 20  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 1  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & 
                       (df_total['Enrichment_Ratio'] > 1)]  # Only include Enrichment_Ratio > 1

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394,414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)

sites_to_show = list(sites_to_show)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(10, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'darkblue',
    'Mutant_RBD': '#004c4c'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':  # Skip 'Library_ctrl'
        continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'median',
    })
    if df_filtered_im['Enrichment_Ratio'].isna().any():
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')


    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Handle missing data by interpolating missing values (linear interpolation)
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Identify clusters of high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex for visualization (after interpolation to avoid issues)
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()

    print(f"Target sites to highlight: {list(sites_to_show)}")

    # Check the data type of Spike_AS_Position and sites_to_show for comparison
    print(f"Data type of 'Spike_AS_Position': {df_filtered_agg['Spike_AS_Position'].dtype}")
    print(f"Data type of 'sites_to_show': {type(sites_to_show)}")

    # Ensure High_Enrichment is still boolean
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    df_filtered_im = df_filtered_im.assign(
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )
    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))
    # Use dmslogo.line.draw_line but plot on the same axes with specified color
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Smoothed_Enrichment",
        title="",  # Remove individual titles for each plot
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding",
        show_col="show_site",  # Highlight high enrichment clusters
        ax=ax,  # Pass the same axes object for all plots
        linewidth=2.5,
        color=color_map.get(immunization, 'black')  # Get the color for the immunization (default to black)
    )
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)
    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    for _, site_data in highlight_sites.iterrows():
        # Use ax.hlines to draw horizontal lines only where there are sites to show
        ax.hlines(
            y=0,  # Set the y-position of the line to 0 (or a small value near the bottom of the plot)
            xmin=site_data['Spike_AS_Position'] - 0.5,  # Start of the line (slightly before the site)
            xmax=site_data['Spike_AS_Position'] + 0.5,  # End of the line (slightly after the site)
            color='black',  # Line color
            linestyle='-',  # Line style
            linewidth=10  # Line width
        )

# Set the y-axis limit to 200
ax.set_ylim(0, 9)

# After all lines are drawn, adjust plot settings
plt.title('Smoothed Enrichment Ratios for Different Immunizations', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Single-cell Polyreactivity)', fontsize=16)

# Add the legend in the top right
# Create a grouped legend
handles, labels = ax.get_legend_handles_labels()

# Group the legend as required
group_1_labels = ['Polyclonal_Ab', 'Neutralizing_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels]

# Combine handles and labels with a custom order
handles = group_1_handles + group_2_handles
labels = group_1_labels + group_2_labels

# Add the legend to the top right
plt.legend(handles, labels, title="Single Droplet Repertoire", loc='upper right', fontsize=11, frameon=False, handlelength=2, handleheight=1, title_fontsize=11, markerscale=8)

# Set the x-axis ticks explicitly to 20 ticks across the range, and label every other one
xticks = np.linspace(df_filtered_agg['Spike_AS_Position'].min(), df_filtered_agg['Spike_AS_Position'].max(), 20).astype(int)
ax.set_xticks(xticks)

# Set the labels to only show for every other tick
ax.set_xticklabels([str(x) if i % 2 == 0 else '' for i, x in enumerate(xticks)])


ax.yaxis.set_major_locator(MaxNLocator(integer=True, prune='lower', nbins=8))
plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop", f"{immunization}_Line2_plot.png")
fig.tight_layout() 
fig.savefig(plot_file_path, format='png')


# Display the combined plot
plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
%matplotlib inline

pd.set_option('future.no_silent_downcasting', True)

# Define rolling window size for smoothing
ROLLING_WINDOW = 20  # Adjust this window size to refine the smoothing
ENRICHMENT_THRESHOLD = 0.8  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394,414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)

sites_to_show = list(sites_to_show)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Create a single figure
fig, ax = plt.subplots(figsize=(10, 6))

# Define color mapping
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'darkblue',
    'Mutant_RBD': '#004c4c'
}

# Loop through each immunization and plot on the same axes, excluding 'Library_ctrl'
for immunization in df_filtered_agg['immunization'].unique():
    if immunization == 'Library_ctrl':  # Skip 'Library_ctrl'
        continue

    print(immunization)
    df_filtered_im = df_filtered_agg.query(f'immunization == "{immunization}"')

    # Aggregate enrichment ratios at each position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })
    if df_filtered_im['Enrichment_Ratio'].isna().any():
        df_filtered_im['Enrichment_Ratio'] = df_filtered_im['Enrichment_Ratio'].fillna(method='bfill').fillna(method='ffill')

    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Interpolate missing values
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Log2 transformation after smoothing
    df_filtered_im['Log2_Enrichment'] = np.log2(df_filtered_im['Smoothed_Enrichment'].replace(0, np.nan))
    df_filtered_im['Log2_Enrichment'] = df_filtered_im['Log2_Enrichment'].fillna(0)

    # Identify high enrichment sites (based on raw smoothed enrichment)
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    if immunization == 'Neutralizing_Ab':
        df_filtered_im['Log2_Enrichment'] = df_filtered_im['Log2_Enrichment'].fillna(method='bfill').fillna(method='ffill')

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex for visualization
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()

    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    df_filtered_im = df_filtered_im.assign(
        show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
    )

    print(df_filtered_im[['Spike_AS_Position', 'show_site']].head(10))

    # Plot log2 enrichment
    dmslogo.line.draw_line(
        df_filtered_im,
        x_col="Spike_AS_Position",
        height_col="Log2_Enrichment",
        title="",
        xlabel="Spike AA Position",
        ylabel="Antibody Repertoire \n binding (log₂ Enrichment)",
        show_col="show_site",
        ax=ax,
        linewidth=2.5,
        color=color_map.get(immunization, 'black')
    )
    ax.plot([], [], color=color_map.get(immunization, 'black'), label=immunization)

    # Highlight sites
    highlight_sites = df_filtered_im[df_filtered_im['show_site']]
    for _, site_data in highlight_sites.iterrows():
        ax.hlines(
            y=0,
            xmin=site_data['Spike_AS_Position'] - 0.5,
            xmax=site_data['Spike_AS_Position'] + 0.5,
            color='black',
            linestyle='-',
            linewidth=10
        )

# Set y-axis limit based on log2-transformed data
ax.set_ylim(bottom=0)
ax.set_ylim(2.5, 11)

# Final plot settings
plt.title('Smoothed Log₂ Enrichment Ratios for Different Immunizations', fontsize=14)
plt.xlabel('Spike AA Position', fontsize=16)
plt.ylabel('Antibody Repertoire - log₂ Enrichment', fontsize=16)

# Legend
handles, labels = ax.get_legend_handles_labels()
group_1_labels = ['Polyclonal_Ab', 'Neutralizing_Ab']
group_2_labels = ['wildtype_RBD', 'Mutant_RBD']

group_1_handles = [handles[labels.index(label)] for label in group_1_labels if label in labels]
group_2_handles = [handles[labels.index(label)] for label in group_2_labels if label in labels]

handles = group_1_handles + group_2_handles
labels = group_1_labels + group_2_labels

plt.legend(handles, labels, title="Single Droplet Repertoire", loc='upper center', fontsize=11, frameon=False, handlelength=2, handleheight=1, title_fontsize=11, markerscale=8)

plt.tight_layout()
plt.show()


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Define rolling window size for smoothing
ROLLING_WINDOW = 20
ENRICHMENT_THRESHOLD = 0.8

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  
    list(range(394,414)) +  
    list(range(484, 505))  
)

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Color mapping based on immunization name
color_map = {
    'Polyclonal_Ab': '#d2691e',       # dark orange / terracotta
    'Neutralizing_Ab': '#ff0000',     # red
    'wildtype_RBD': '#0000ff',        # blue
    'mutant': '#00ced1',              # dark turquoise
}

# Prepare figure
fig, ax = plt.subplots(figsize=(12, 6))

for immunization in df_filtered_agg['immunization'].unique():
    print(immunization)
    df_filtered_im = df_filtered_agg.query('immunization == @immunization')

    # Aggregate enrichment ratios at each position
    df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio': 'sum',
    })

    # Apply rolling mean for smoothing
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

    # Handle missing data
    df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

    # Identify clusters of high enrichment
    df_filtered_im['High_Enrichment'] = df_filtered_im['Smoothed_Enrichment'] > ENRICHMENT_THRESHOLD
    df_filtered_im['High_Enrichment'] = df_filtered_im['High_Enrichment'].fillna(False).astype(bool)

    # Save CSV
    csv_file_path = os.path.join(output_dir, f"{immunization}_data.csv")
    df_filtered_im.to_csv(csv_file_path, index=False)

    # Reindex for visualization
    df_filtered_im = df_filtered_im.set_index('Spike_AS_Position').reindex(
        range(df_filtered_im['Spike_AS_Position'].min(), df_filtered_im['Spike_AS_Position'].max() + 1)
    ).reset_index()

    # Plot line on same axis
    label = immunization.replace("_", " ")
    color = color_map.get(immunization, "gray")
    ax.plot(df_filtered_im['Spike_AS_Position'], df_filtered_im['Smoothed_Enrichment'], label=label, color=color)

# Final formatting
ax.set_xlim(420, df_filtered_im['Spike_AS_Position'].max())
ax.set_ylim(0, df_filtered_agg['Enrichment_Ratio'].max())
ax.set_xlabel("Spike AA Position")
ax.set_ylabel("Antibody Repertoire \n binding")
ax.set_title("Smoothed Enrichment Across Spike Positions by Immunization")
ax.legend(title="Immunization", fontsize=10)
plt.tight_layout()
plt.show()


In [None]:
print(df_filtered_agg.columns)

Code block for generating logoplots of enriched positions
Color codes are assigned to each barcode, and the colors are shown in a separate plot.
This can be used to show how each barcode/single cell is contributing to the grouped enrichments. 

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

# Enable inline plotting in Jupyter
%matplotlib inline

# Define rolling window size for smoothing
ROLLING_WINDOW = 8  # Adjust for more smoothing
ENRICHMENT_THRESHOLD = 50  # Adjust based on data distribution

# Filter dataset to remove low-quality reads and select only non-synonymous mutations
df_filtered = df_total[(df_total['Spike_AS_Position'] > 33+331) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

# Aggregate enrichment ratio by position and barcode
df_filtered_agg = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Define target sites
sites_to_show = map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394,414)) +  # R21 peptide sequence
    list(range(484, 505))  # R13 peptide sequence
)
df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Output directory
output_dir = "barcode_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Iterate over each barcode
for barcode in df_filtered_agg['barcode'].unique():
    print(f"Processing barcode: {barcode}")
    df_filtered_bc = df_filtered_agg[df_filtered_agg['barcode'] == barcode]
    
    fig, ax = plt.subplots(figsize=(6, 4))
    
    for immunization in df_filtered_bc['immunization'].unique():
        df_filtered_im = df_filtered_bc[df_filtered_bc['immunization'] == immunization]

        # Aggregate enrichment ratios at each position
        df_filtered_im = df_filtered_im.groupby('Spike_AS_Position', as_index=False).agg({
            'Enrichment_Ratio': 'sum',
        })

        # Apply rolling mean for smoothing
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Enrichment_Ratio'].rolling(window=ROLLING_WINDOW, center=True, min_periods=1).mean()

        # Handle missing data by interpolating missing values
        df_filtered_im['Smoothed_Enrichment'] = df_filtered_im['Smoothed_Enrichment'].interpolate(method='linear', limit_direction='both')

        # Plot each immunization as a separate line
        ax.plot(df_filtered_im['Spike_AS_Position'], df_filtered_im['Smoothed_Enrichment'], label=f'Immunization {immunization}', alpha=0.7)

    # Formatting plot
    ax.set_title(f"Barcode {barcode}")
    ax.set_xlabel("Spike AA Position")
    ax.set_ylabel("Antibody Repertoire \n Binding")

    ax.xaxis.set_minor_locator(MultipleLocator(5))
    ax.set_xlim(350, df_filtered_bc['Spike_AS_Position'].max())

    # Adjust legend to be outside the plot
    ax.legend(bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0.)

    # Save plot
    plot_file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop/immunization_csv_files", f"Barcode_{barcode}_plot.png")
    fig.savefig(plot_file_path, format='png', bbox_inches='tight')
    
    # Show the plot inline in Jupyter
    plt.show()

    plt.close(fig)


In [None]:
import matplotlib.pyplot as plt

plt.ion()
#Code block for generating logoplots of enriched positions
#Change immunization to barcode to do individual droplet analysis

df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode','immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'mean'
})

sites_to_show = map(
    str,
    [417,484,501]
    #[(i+336) for i in range(107, 114) if i not in [33, 72, 81, 151]] +  
    #[455, 456, 472, 473, 484, 485, 486, 490, 496, 499] + # RBD-ACE2 interface according to article
    #list(range(394,414)) + # R21 peptide sequence with high affinity
    #list(range(484, 506)) # R13 peptide sequence with high affinity
)
df_logo_agg = df_logo_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[~df_logo_agg['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]


#The query can be changed to filter for specific barcodes or removed to get all barcodes
#.query('immunization == "wildtype_RBD"')
for barcode in df_logo_agg['barcode'].unique():
    print(barcode)
    fig, ax = dmslogo.draw_logo(
        df_logo_agg.query(f'barcode == "{barcode}"').query("show_site"),
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Enrichment_Ratio",
        title=barcode + ' logoplot',
        addbreaks=True
    )
    
    ax.set_ylabel("Antigen Binding")  # Set the y-axis label to Binding Ratio
    ax.set_xlabel("SARS-Cov-2 Spike AA Position")  # Se
    # Save the figure
    file_path = os.path.join(r"/Users/lucaschlotheuber/Desktop/immunization_csv_files", f"{barcode}_logoplots.png", )
    plt.savefig(file_path, dpi = 300, bbox_inches = 'tight')
    #plt.close(fig)




In [None]:
import matplotlib.pyplot as plt

plt.ion()

# Aggregate data
df_logo_agg = df_total.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio': 'sum'})

# Filter for Enrichment_Ratio > 1
df_logo_agg = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]

# Define sites to show
sites_to_show = map(str, [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505])

# Add site labels and determine which sites to show
df_logo_agg = df_logo_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Ensure amino acids are uppercase and exclude specific characters
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[~df_logo_agg['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]

# Generate logo plots for each unique barcode
for barcode in df_logo_agg['barcode'].unique():
    print(barcode)
    fig, ax = dmslogo.draw_logo(
        df_logo_agg.query(f'barcode == "{barcode}"').query("show_site"),
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Enrichment_Ratio",
        title=barcode + ' logoplot',
        addbreaks=True
    )
    
    ax.set_ylabel("Antigen Binding")  # Set the y-axis label to Binding Ratio
    ax.set_xlabel("SARS-Cov-2 Spike AA Position")  # Set the x-axis label
    # Save the figure
    file_path = os.path.join(
        r"/Users/lucaschlotheuber/Desktop/immunization_csv_files",
        f"{barcode}_logoplots.png"
    )
    plt.savefig(file_path, dpi=300, bbox_inches='tight')


In [None]:
import matplotlib.pyplot as plt
import os

plt.ion()

# Aggregate data
df_logo_agg = df_total.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio': 'sum'})

# Filter for Enrichment_Ratio > 1
df_logo_agg = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]

# Define sites to show
sites_to_show = map(str, [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505])

# Add site labels and determine which sites to show
df_logo_agg = df_logo_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Ensure amino acids are uppercase and exclude specific characters
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[~df_logo_agg['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]

# Generate logo plots for each unique barcode and immunization
for barcode in df_logo_agg['barcode'].unique():
    print(f"Processing barcode: {barcode}")
    
    # Get the immunization for the current barcode (assuming it's consistent for the barcode)
    immunization = df_logo_agg.query(f'barcode == "{barcode}"')['immunization'].iloc[0]
    
    # Create the plot for the specific barcode and immunization
    fig, ax = dmslogo.draw_logo(
        df_logo_agg.query(f'barcode == "{barcode}"').query("show_site"),
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Enrichment_Ratio",
        title=f'{barcode} Logoplot - Immunization: {immunization}',  # Add immunization to title
        addbreaks=True
    )
    
    # Set axis labels
    ax.set_ylabel("Antigen Binding")  # Set the y-axis label to Binding Ratio
    ax.set_xlabel("SARS-Cov-2 Spike AA Position")  # Set the x-axis label
    
    # Save the figure
    file_path = os.path.join(
        r"/Users/lucaschlotheuber/Desktop/immunization_csv_files",
        f"{barcode}_logoplots.png"
    )
    plt.savefig(file_path, dpi=300, bbox_inches='tight')
    print(f"Saved logoplot for {barcode} to {file_path}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

plt.ion()

# Filter out rows where Enrichment_Ratio is zero and Enrichment_Ratio < 1
df_escape = df_total[(df_total['Enrichment_Ratio'] < 1) & (df_total['Enrichment_Ratio'] != 0)]

# Apply the inverse to the Enrichment_Ratio
df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

# Apply log2 transformation
df_escape['Enrichment_Ratio_log2'] = df_escape['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

# Aggregate the data by position, amino acid, barcode, and immunization
df_escape_agg = df_escape.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio_log2': 'mean'})  # Add Enrichment_Ratio_log2 here

# Define the sites to show
sites_to_show = map(str, [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505])  # RBD-ACE2 interface according to article

df_escape_agg = df_escape_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Add the following before the "barcode.unique" to select various conditions
# .query('immunization == "Mutant_RBD"')

for barcode in df_escape_agg['barcode'].unique():
    print(barcode)
    
    # Filter based on barcode and selected sites
    df_filtered = df_escape_agg.query(f'barcode == "{barcode}"').query("show_site")
    
    if not df_filtered.empty:
        # Exclude stop codons before plotting
        df_filtered = df_filtered[~df_filtered['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]
        
        # Create the plot
        fig, ax = dmslogo.draw_logo(
            df_filtered,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log2",
            title=barcode + ' logoplot',
            addbreaks=True,
        )
        
        ax.set_ylabel("Antibody Escape")  # Set the y-axis label
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  

        # Save the figure
        file_path = os.path.join(
            r"/Users/lucaschlotheuber/Desktop/immunization_escape",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with tight bounding box
        
        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memory
        plt.close(fig)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

plt.ion()

# Filter out rows where Enrichment_Ratio is zero and Enrichment_Ratio < 1
df_escape = df_total[(df_total['Enrichment_Ratio'] < 1) & (df_total['Enrichment_Ratio'] != 0)]

# Apply the inverse to the Enrichment_Ratio
df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

# Apply log2 transformation
df_escape['Enrichment_Ratio_log2'] = df_escape['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

# Aggregate the data by position, amino acid, barcode, and immunization
df_escape_agg = df_escape.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio_log2': 'sum'})  # Add Enrichment_Ratio_log2 here

# Define the sites to show
sites_to_show = map(str, [417, 484,  501])  # RBD-ACE2 interface according to article

df_escape_agg = df_escape_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Add the following before the "barcode.unique" to select various conditions
# .query('immunization == "Mutant_RBD"')

for barcode in df_escape_agg['barcode'].unique():
    print(barcode)
    
    # Filter based on barcode and selected sites
    df_filtered = df_escape_agg.query(f'barcode == "{barcode}"').query("show_site")
    
    if not df_filtered.empty:
        # Exclude stop codons before plotting
        df_filtered = df_filtered[~df_filtered['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]
        
        # Create the plot
        fig, ax = dmslogo.draw_logo(
            df_filtered,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log2",
            title=barcode + ' logoplot',
            addbreaks=True,
        )
        
        ax.set_ylabel("Antibody Escape")  # Set the y-axis label
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  

        # Save the figure
        file_path = os.path.join(
            r"/Users/lucaschlotheuber/Desktop/immunization_escape",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight', pad_inches=0.8)   # Save with tight bounding box
        
        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memory
        plt.close(fig)


In [None]:
wuhan_strain_aa = {
    417: 'K', 439: 'N', 440: 'N', 452: 'L', 476: 'G', 477: 'S', 484: 'E',
    493: 'Q', 501: 'N', 502: 'G', 505: 'L'
}

In [None]:
import matplotlib.pyplot as plt

plt.ion()

# Define the dictionary for the specific amino acids at certain positions
wuhan_strain_aa = {
    417: 'K', 439: 'N', 440: 'N', 452: 'L', 476: 'G', 477: 'S', 484: 'E',
    493: 'Q', 501: 'N', 502: 'G', 505: 'L'
}



# Filter for Enrichment_Ratio > 1
df_logo_agg = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]


# Aggregate data
df_logo_agg = df_total.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio': 'sum'})


# Define sites to show (as strings for consistency)
sites_to_show = map(str, [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505])

# Add site labels and determine which sites to show
df_logo_agg = df_logo_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Ensure amino acids are uppercase and exclude specific characters
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[~df_logo_agg['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]

# Debugging: Check the filtered positions and amino acids
print("Filtered data with relevant sites and amino acids:")
print(df_logo_agg[['Spike_AS_Position', 'Amino_Acid']].drop_duplicates())

# Generate logo plots for each unique barcode
for barcode in df_logo_agg['barcode'].unique():
    print(f"Generating plot for barcode: {barcode}")
    filtered_data = df_logo_agg.query(f'barcode == "{barcode}"').query("show_site")

    # Debugging: Check positions for the current barcode
    print(f"Positions and amino acids for {barcode}:")
    print(filtered_data[['Spike_AS_Position', 'Amino_Acid']])

    # Check if there is any data to plot
    if not filtered_data.empty:
        fig, ax = dmslogo.draw_logo(
            filtered_data,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio",
            title=barcode + ' logoplot',
            addbreaks=True
        )

        ax.set_ylabel("Antigen Binding")  # Set the y-axis label to Binding Ratio
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  # Set the x-axis label

        # Save the figure
        file_path = os.path.join(
            r"/Users/lucaschlotheuber/Desktop/immunization_csv_files2",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with tight bounding box

        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memorya
        plt.close(fig)
    else:
        print(f"No data to plot for barcode: {barcode}")


In [None]:
import pandas as pd

# Ensure all positions are included in every barcode
all_positions = pd.DataFrame({'Spike_AS_Position': list(map(str, wuhan_strain_aa.keys()))})

for barcode in df_logo_agg['barcode'].unique():
    print(f"Generating plot for barcode: {barcode}")
    filtered_data = df_logo_agg.query(f'barcode == "{barcode}"').query("show_site")

    # Ensure all positions exist in the dataset, even if empty
    filtered_data = all_positions.merge(filtered_data, on="Spike_AS_Position", how="left")

    # Fill missing values for plotting
    filtered_data['Amino_Acid'].fillna("", inplace=True)  # Empty amino acid for missing positions
    filtered_data['Enrichment_Ratio'].fillna(0, inplace=True)  # Set enrichment to 0

    # Debugging: Check positions for the current barcode
    print(f"Positions and amino acids for {barcode}:")
    print(filtered_data[['Spike_AS_Position', 'Amino_Acid', 'Enrichment_Ratio']])

    # Check if there is any data to plot (there should be at least empty positions)
    if not filtered_data.empty:
        fig, ax = dmslogo.draw_logo(
            filtered_data,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio",
            title=barcode + ' logoplot',
            addbreaks=True
        )

        ax.set_ylabel("Antigen Binding")  # Set the y-axis label to Binding Ratio
        ax.set_xlabel("SARS-CoV-2 Spike AA Position")  # Set the x-axis label

        # Save the figure
        file_path = os.path.join(
            r"/Users/lucaschlotheuber/Desktop/immunization_csv_files",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with tight bounding box

        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        plt.close(fig)
    else:
        print(f"No data to plot for barcode: {barcode}")


In [None]:
import matplotlib.pyplot as plt

plt.ion()

# Define the dictionary for the specific amino acids at certain positions
wuhan_strain_aa = {
    417: 'K', 439: 'N', 440: 'N', 452: 'L', 476: 'G', 477: 'S', 484: 'E',
    493: 'Q', 501: 'N', 502: 'G', 505: 'L'
}

# Aggregate data
df_logo_agg = df_total.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio': 'mean'})

# Filter for Enrichment_Ratio > 1
df_logo_agg = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]

# Define sites to show (as strings for consistency)
sites_to_show = map(str, [417, 484, 501])

# Add site labels and determine which sites to show
df_logo_agg = df_logo_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Ensure amino acids are uppercase and exclude specific characters
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[~df_logo_agg['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]

# Debugging: Check the filtered positions and amino acids
print("Filtered data with relevant sites and amino acids:")
print(df_logo_agg[['Spike_AS_Position', 'Amino_Acid']].drop_duplicates())

# Generate logo plots for each unique barcode
for barcode in df_logo_agg['barcode'].unique():
    print(f"Generating plot for barcode: {barcode}")
    filtered_data = df_logo_agg.query(f'barcode == "{barcode}"').query("show_site")

    # Debugging: Check positions for the current barcode
    print(f"Positions and amino acids for {barcode}:")
    print(filtered_data[['Spike_AS_Position', 'Amino_Acid']])

    # Check if there is any data to plot
    if not filtered_data.empty:
        fig, ax = dmslogo.draw_logo(
            filtered_data,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio",
            title=barcode + ' logoplot',
            addbreaks=True
        )

        ax.set_ylabel("Antigen Binding")  # Set the y-axis label to Binding Ratio
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  # Set the x-axis label

        # Save the figure
        file_path = os.path.join(
            r"/Users/lucaschlotheuber/Desktop/immunization_csv_files",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with tight bounding box

        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memorya
        plt.close(fig)
    else:
        print(f"No data to plot for barcode: {barcode}")


In [None]:
import matplotlib.pyplot as plt

plt.ion()

# Define the dictionary for the specific amino acids at certain positions
wuhan_strain_aa = {
    417: 'K', 439: 'N', 440: 'N', 452: 'L', 476: 'G', 477: 'S', 484: 'E',
    493: 'Q', 501: 'N', 502: 'G', 505: 'L'
}

# Filter out rows where Enrichment_Ratio is zero and Enrichment_Ratio < 1
df_escape = df_total[(df_total['Enrichment_Ratio'] < 1) & (df_total['Enrichment_Ratio'] != 0)]

# Apply the inverse to the Enrichment_Ratio
df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

# Aggregate the data by position, amino acid, barcode, and immunization
df_escape_agg = df_escape.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio_inverted': 'sum'})

# Apply log2 transformation
df_escape_agg['Enrichment_Ratio_log2'] = df_escape_agg['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

# Filter out rows where the amino acid is in wuhan_strain_aa at the corresponding positions
df_escape_agg = df_escape_agg[~df_escape_agg.apply(
    lambda row: row['Spike_AS_Position'] in wuhan_strain_aa and row['Amino_Acid'] == wuhan_strain_aa[row['Spike_AS_Position']],
    axis=1
)]

# Define the sites to show (RBD-ACE2 interface)
sites_to_show = map(str, [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505])

df_escape_agg = df_escape_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Loop through each barcode to generate the plots
for barcode in df_escape_agg['barcode'].unique():
    print(barcode)
    
    # Filter based on barcode and selected sites
    df_filtered = df_escape_agg.query(f'barcode == "{barcode}"').query("show_site")
    
    if not df_filtered.empty:
        # Exclude stop codons before plotting
        df_filtered = df_filtered[~df_filtered['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]
        
        # Create the plot
        fig, ax = dmslogo.draw_logo(
            df_filtered,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log2",
            title=barcode + ' logoplot',
            addbreaks=True,
        )
        
        ax.set_ylabel("Antibody Escape")  # Set the y-axis label
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  

        # Save the figure
        file_path = os.path.join(
            r"/Users/lucaschlotheuber/Desktop/immunization_escape",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with tight bounding box
        
        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memory
        plt.close(fig)


In [None]:
#Asess if some are esacped and enrichedd

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

plt.ion()

# Define the dictionary for the specific amino acids at certain positions
wuhan_strain_aa = {
    417: 'K', 439: 'N', 440: 'N', 452: 'L', 476: 'G', 477: 'S', 484: 'E',
    493: 'Q', 501: 'N', 502: 'G', 505: 'L'
}

# Filter out rows where Enrichment_Ratio is zero and Enrichment_Ratio < 1
df_escape = df_total[(df_total['Enrichment_Ratio'] < 1) & (df_total['Enrichment_Ratio'] != 0)]

# Apply the inverse to the Enrichment_Ratio
df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

# Aggregate the data by position, amino acid, barcode, and immunization
df_escape_agg = df_escape.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio_inverted': 'sum'})

# Apply log2 transformation
df_escape_agg['Enrichment_Ratio_log2'] = df_escape_agg['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

# Filter out rows where the amino acid is in wuhan_strain_aa at the corresponding positions
df_escape_agg = df_escape_agg[~df_escape_agg.apply(
    lambda row: row['Spike_AS_Position'] in wuhan_strain_aa and row['Amino_Acid'] == wuhan_strain_aa[row['Spike_AS_Position']],
    axis=1
)]

# Define the sites to show (RBD-ACE2 interface)
sites_to_show = map(str, [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505])

df_escape_agg = df_escape_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Loop through each barcode to generate the plots
for barcode in df_escape_agg['barcode'].unique():
    print(f"Processing barcode: {barcode}")
    
    # Get the immunization for the current barcode
    immunization = df_escape_agg.query(f'barcode == "{barcode}"')['immunization'].iloc[0]
    
    # Filter based on barcode and selected sites
    df_filtered = df_escape_agg.query(f'barcode == "{barcode}"').query("show_site")
    
    if not df_filtered.empty:
        # Exclude stop codons before plotting
        df_filtered = df_filtered[~df_filtered['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]
        
        # Create the plot
        fig, ax = dmslogo.draw_logo(
            df_filtered,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log2",
            title=f'{barcode} Logoplot - Immunization: {immunization}',  # Add immunization to title
            addbreaks=True,
        )
        
        ax.set_ylabel("Antibody Escape")  # Set the y-axis label
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  

        # Save the figure
        file_path = os.path.join(
            r"/Users/lucaschlotheuber/Desktop/ETH/immunization_escape",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with tight bounding box
        
        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memory
        plt.close(fig)
        print(f"Saved logoplot for {barcode} - Immunization: {immunization} to {file_path}")


In [None]:
# Function to identify and print the conflicting cases (where both Enrichment_Ratio > 1 and < 1 for the same barcode, position, and amino acid)
def find_conflicting_cases(df):
    # Group by barcode, position, and amino acid
    grouped = df.groupby(['barcode', 'Spike_AS_Position', 'Amino_Acid'])
    
    # List to store conflicting cases
    conflicting_cases = []
    
    # Loop through each group to identify conflicts
    for (barcode, position, aa), group in grouped:
        if (group['Enrichment_Ratio'] > 1).any() and (group['Enrichment_Ratio'] < 1).any():
            conflicting_cases.append((barcode, position, aa, group))
    
    return conflicting_cases

# Find and print the conflicting cases
conflicting_cases = find_conflicting_cases(df_escape)

# Print the conflicting cases for inspection
for barcode, position, aa, group in conflicting_cases:
    print(f"Conflicting cases for barcode {barcode}, position {position}, amino acid {aa}:")
    print(group[['barcode', 'Spike_AS_Position', 'Amino_Acid', 'Enrichment_Ratio']])
    print("\n---\n")


In [None]:
import matplotlib.pyplot as plt

plt.ion()

# Define the dictionary for the specific amino acids at certain positions
wuhan_strain_aa = {
    417: 'K', 439: 'N', 440: 'N', 452: 'L', 476: 'G', 477: 'S', 484: 'E',
    493: 'Q', 501: 'N', 502: 'G', 505: 'L'
}

# Filter out rows where Enrichment_Ratio is zero and Enrichment_Ratio < 1
df_escape = df_total[(df_total['Enrichment_Ratio'] < 1) & (df_total['Enrichment_Ratio'] != 0)]

# Step 1: Identify positions with the same amino acid where both Enrichment_Ratio > 1 and < 1 exist per barcode.
def filter_positions_per_barcode(df):
    # Group by barcode, position, and amino acid
    grouped = df.groupby(['barcode', 'Spike_AS_Position', 'Amino_Acid'])
    
    # Identify positions where there are both Enrichment_Ratio > 1 and < 1 for the same barcode
    positions_to_filter = grouped.filter(
        lambda x: any(x['Enrichment_Ratio'] < 1) and any(x['Enrichment_Ratio'] > 1)
    )['Spike_AS_Position'].unique()
    
    return positions_to_filter

# Step 2: For each barcode, filter out rows with conflicting enrichment ratios at the same position and amino acid
positions_to_filter = filter_positions_per_barcode(df_escape)
df_filtered = df_escape[~df_escape['Spike_AS_Position'].isin(positions_to_filter)]

# Apply the inverse to the Enrichment_Ratio
df_filtered['Enrichment_Ratio_inverted'] = df_filtered['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

# Aggregate the data by position, amino acid, barcode, and immunization
df_filtered_agg = df_filtered.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio_inverted': 'mean'})

# Apply log2 transformation
df_filtered_agg['Enrichment_Ratio_log2'] = df_filtered_agg['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

# Filter out rows where the amino acid is in wuhan_strain_aa at the corresponding positions
df_filtered_agg = df_filtered_agg[~df_filtered_agg.apply(
    lambda row: row['Spike_AS_Position'] in wuhan_strain_aa and row['Amino_Acid'] == wuhan_strain_aa[row['Spike_AS_Position']],
    axis=1
)]

# Define the sites to show (RBD-ACE2 interface)
sites_to_show = map(str, [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505])

df_filtered_agg = df_filtered_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Loop through each barcode to generate the plots
for barcode in df_filtered_agg['barcode'].unique():
    print(barcode)
    
    # Filter based on barcode and selected sites
    df_filtered_barcode = df_filtered_agg.query(f'barcode == "{barcode}"').query("show_site")
    
    if not df_filtered_barcode.empty:
        # Exclude stop codons before plotting
        df_filtered_barcode = df_filtered_barcode[~df_filtered_barcode['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]
        
        # Create the plot
        fig, ax = dmslogo.draw_logo(
            df_filtered_barcode,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log2",
            title=barcode + ' logoplot',
            addbreaks=True,
        )
        
        ax.set_ylabel("Antibody Escape")  # Set the y-axis label
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  

        # Save the figure
        file_path = os.path.join(
            r"/Users/lucaschlotheuber/Desktop/immunization_escape",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with tight bounding box
        
        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memory
        plt.close(fig)


In [None]:
#volcano plot

In [None]:
# Filter for enriched and escaped data
df_enriched = df_total[df_total['Enrichment_Ratio'] > 1]
df_escaped = df_total[df_total['Enrichment_Ratio'] < 1]

# Group by position and barcode (or any other relevant grouping)
df_enriched_grouped = df_enriched.groupby(['Spike_AS_Position', 'barcode'], as_index=False).agg({'Enrichment_Ratio': 'sum'})
df_escaped_grouped = df_escaped.groupby(['Spike_AS_Position', 'barcode'], as_index=False).agg({'Enrichment_Ratio': 'sum'})

# Identify positions that appear in both enriched and escaped subsets
# We can find common positions between the two grouped datasets
common_positions = set(df_enriched_grouped['Spike_AS_Position']).intersection(df_escaped_grouped['Spike_AS_Position'])

# Filter the data to include only rows where the position is in both enriched and escaped
df_enriched_escaped = df_total[df_total['Spike_AS_Position'].isin(common_positions)]

# Now, you can inspect the data where positions are both enriched and escaped
# You may want to further investigate by grouping by position and barcode
df_enriched_escaped_grouped = df_enriched_escaped.groupby(
    ['Spike_AS_Position', 'barcode'], as_index=False
).agg({'Enrichment_Ratio': 'sum'})

# Show results
print(df_enriched_escaped_grouped)

# Plot or analyze further, depending on your goal


In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Assuming you have a dataframe named df_combined with at least these columns:
# 'Spike_AS_Position', 'Amino_Acid', 'Sample_Count', 'Max_Enrichment'

# Also assuming you have the ancestral (Wuhan strain) amino acid reference:
wuhan_strain_aa = {
    484: 'E', 417: 'K', 501: 'N', 346: 'R', 452: 'L',
    # Add other positions as needed
}

# Create E484K-style mutation labels
df_combined['mutation_label'] = df_combined.apply(
    lambda row: f"{wuhan_strain_aa.get(row['Spike_AS_Position'], '?')}"
                f"{row['Spike_AS_Position']}"
                f"{row['Amino_Acid']}",
    axis=1
)

# Define a threshold for labeling (you can change this!)
label_threshold_enrichment = 10
label_threshold_count = 5

# Mark which rows to label
df_combined['label_this'] = (
    (df_combined['Max_Enrichment'] > label_threshold_enrichment) &
    (df_combined['Sample_Count'] > label_threshold_count)
)

# Set up the plot
fig, ax = plt.subplots(figsize=(10, 6))

# Scatter plot
sc = ax.scatter(
    df_combined['Max_Enrichment'],
    df_combined['Sample_Count'],
    c='grey',
    alpha=0.7,
    edgecolors='k',
    s=40
)

# Highlight points to be labeled
highlighted = df_combined[df_combined['label_this']]
ax.scatter(
    highlighted['Max_Enrichment'],
    highlighted['Sample_Count'],
    c='red',
    edgecolors='black',
    s=60,
    label='Labeled mutations'
)

# Add text labels for selected mutations
for _, row in highlighted.iterrows():
    ax.text(
        row['Max_Enrichment'],
        row['Sample_Count'],
        row['mutation_label'],
        fontsize=9,
        ha='center',
        va='bottom'
    )

# Formatting
ax.set_xlabel('Enrichment Score', fontsize=12)
ax.set_ylabel('Sample Count', fontsize=12)
ax.set_title('Volcano Plot of Spike Mutations', fontsize=14)
ax.grid(True)
ax.legend()

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text

# Combine enriched and escaped data
df_combined_volcano = pd.concat([
    df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1],
    df_escape.assign(Enrichment_Ratio=df_escape['Enrichment_Ratio'])
])

df_combined_volcano = df_combined_volcano[df_combined_volcano['Enrichment_Ratio'] > 0]
df_combined_volcano['log2_Enrichment'] = np.log2(df_combined_volcano['Enrichment_Ratio'])

# Count appearances across barcodes
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

# Get mean log2 enrichment
df_max_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .mean()
    .reset_index(name='log2_Enrichment')
)

# Merge
df_volcano = pd.merge(df_counts, df_max_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Assign mutation label like E484K
df_volcano['mutation'] = df_volcano['Spike_AS_Position'].astype(str)
df_volcano['mutation_label'] = df_volcano['mutation'].str.zfill(3)
df_volcano['site_label'] = df_volcano['Amino_Acid'] + df_volcano['mutation_label']

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Get top N enriched and escape
top_n = 5

top_enriched = df_filtered[df_filtered['log2_Enrichment'] > 0].nlargest(top_n, ['Sample_Count', 'log2_Enrichment'])
top_escape = df_filtered[df_filtered['log2_Enrichment'] < 0].nsmallest(top_n, ['Sample_Count', 'log2_Enrichment'])

# Plot
plt.figure(figsize=(12, 6))
ax = sns.scatterplot(
    data=df_volcano,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    alpha=0.8,
    edgecolor='black'
)

plt.axvline(0, linestyle='--', color='gray')
plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")
plt.legend(title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Annotate top points
texts = []
for _, row in pd.concat([top_enriched, top_escape]).iterrows():
    offset_x = 1 if row['log2_Enrichment'] > 0 else -1
    offset_y = 2
    texts.append(
        plt.text(
            row['log2_Enrichment'] + offset_x,
            row['Sample_Count'] + offset_y,
            row['site_label'],
            fontsize=10,
            weight='bold',
            color='black'
        )
    )
    # Draw line to the actual point
    plt.plot(
        [row['log2_Enrichment'], row['log2_Enrichment'] + offset_x],
        [row['Sample_Count'], row['Sample_Count'] + offset_y],
        linestyle='--',
        color='gray',
        linewidth=0.8
    )

adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))

plt.tight_layout()
plt.show()


In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)

# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['mutation'] = df_volcano['Spike_AS_Position'].astype(str)
df_volcano['mutation_label'] = df_volcano['mutation'].str.zfill(3)
df_volcano['site_label'] = df_volcano['Amino_Acid'] + df_volcano['mutation_label']

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Get top N enriched and escape
top_n = 5

top_enriched = df_filtered[df_filtered['log2_Enrichment'] > 0].nlargest(top_n, ['Sample_Count', 'log2_Enrichment'])
top_escape = df_filtered[df_filtered['log2_Enrichment'] < 0].nsmallest(top_n, ['Sample_Count', 'log2_Enrichment'])

# Plot
plt.figure(figsize=(12, 6))
ax = sns.scatterplot(
    data=df_volcano,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    alpha=0.8,
    edgecolor='black'
)

plt.axvline(0, linestyle='--', color='gray')
plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")
plt.legend(title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Annotate top points
texts = []
for _, row in pd.concat([top_enriched, top_escape]).iterrows():
    offset_x = 1 if row['log2_Enrichment'] > 0 else -1
    offset_y = 2
    texts.append(
        plt.text(
            row['log2_Enrichment'] + offset_x,
            row['Sample_Count'] + offset_y,
            row['site_label'],
            fontsize=10,
            weight='bold',
            color='black'
        )
    )
    # Draw line to the actual point
    plt.plot(
        [row['log2_Enrichment'], row['log2_Enrichment'] + offset_x],
        [row['Sample_Count'], row['Sample_Count'] + offset_y],
        linestyle='--',
        color='gray',
        linewidth=0.8
    )

adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text

# Combine enriched and escaped data
df_combined_volcano = pd.concat([
    df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1],
    df_escape.assign(Enrichment_Ratio=df_escape['Enrichment_Ratio'])
])

df_combined_volcano = df_combined_volcano[df_combined_volcano['Enrichment_Ratio'] > 0]
df_combined_volcano['log2_Enrichment'] = np.log2(df_combined_volcano['Enrichment_Ratio'])

# Retain all columns from the original data frame
original_columns = df_combined_volcano.columns.tolist()

# Count appearances across barcodes
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

# Get mean log2 enrichment
df_max_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .mean()
    .reset_index(name='log2_Enrichment')
)

# Merge counts and enrichment data
df_volcano = pd.merge(df_counts, df_max_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['mutation'] = df_volcano['Spike_AS_Position'].astype(str)
df_volcano['mutation_label'] = df_volcano['mutation'].str.zfill(3)
df_volcano['site_label'] = df_volcano['Amino_Acid'] + df_volcano['mutation_label']

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Get top N enriched and escape
top_n = 5

top_enriched = df_filtered[df_filtered['log2_Enrichment'] > 0].nlargest(top_n, ['Sample_Count', 'log2_Enrichment'])
top_escape = df_filtered[df_filtered['log2_Enrichment'] < 0].nsmallest(top_n, ['Sample_Count', 'log2_Enrichment'])

# Plot
plt.figure(figsize=(12, 6))
ax = sns.scatterplot(
    data=df_volcano,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    alpha=0.8,
    edgecolor='black'
)

plt.axvline(0, linestyle='--', color='gray')
plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")
plt.legend(title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Annotate top points
texts = []
for _, row in pd.concat([top_enriched, top_escape]).iterrows():
    offset_x = 1 if row['log2_Enrichment'] > 0 else -1
    offset_y = 2
    texts.append(
        plt.text(
            row['log2_Enrichment'] + offset_x,
            row['Sample_Count'] + offset_y,
            row['site_label'],
            fontsize=10,
            weight='bold',
            color='black'
        )
    )
    # Draw line to the actual point
    plt.plot(
        [row['log2_Enrichment'], row['log2_Enrichment'] + offset_x],
        [row['Sample_Count'], row['Sample_Count'] + offset_y],
        linestyle='--',
        color='gray',
        linewidth=0.8
    )

adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))

plt.tight_layout()
plt.show()


In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)

# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['mutation'] = df_volcano['Spike_AS_Position'].astype(str)
df_volcano['mutation_label'] = df_volcano['mutation'].str.zfill(3)
df_volcano['site_label'] = df_volcano['Amino_Acid'] + df_volcano['mutation_label']

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Get top N enriched and escape
top_n = 5

palette = {
    'Polyclonal_Ab': '#ff7f0e',     # Adjust colors as needed
    'Neutralizing_Ab': '#d62728',    # Adjust colors as needed
    'Mutant_RBD': '#1f77b4',         # Adjust colors as needed
    'wildtype_RBD': '#009688',       # Adjust colors as needed
}
# Ensure you're filtering based on the actual immunization values including the new ones
df_volcano_filtered = df_volcano[df_volcano['immunization'].isin(palette.keys())]

# Plot
plt.figure(figsize=(12, 6))
ax = sns.scatterplot(
    data=df_volcano_filtered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    palette=palette,
    alpha=0.8,
)

plt.axvline(0, linestyle='--', color='gray')
plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")

# Customizing the legend to use the desired labels
handles, labels = ax.get_legend_handles_labels()

# Update the labels to show the correct ones as you prefer
new_labels = ['B cells - B.1.135 RBD vaccine',
              'Polyreactive pAB',
              'B cells - Ancestral Wuhan RBD vaccine',
              'mAB (NEUT)']

# Apply the updated labels
ax.legend(handles, new_labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Enable minor ticks
plt.minorticks_on()

# Customize the minor ticks (smaller ticks)
ax.tick_params(axis='both', which='minor', length=4, width=1, color='gray')

# Annotate top points
texts = []
for _, row in pd.concat([top_enriched, top_escape]).iterrows():
    offset_x = 1 if row['log2_Enrichment'] > 0 else -1
    offset_y = 2
    texts.append(
        plt.text(
            row['log2_Enrichment'] + offset_x,
            row['Sample_Count'] + offset_y,
            row['site_label'],
            fontsize=10,
            weight='bold',
            color='black'
        )
    )
    # Draw line to the actual point
    plt.plot(
        [row['log2_Enrichment'], row['log2_Enrichment'] + offset_x],
        [row['Sample_Count'], row['Sample_Count'] + offset_y],
        linestyle='--',
        color='gray',
        linewidth=0.8
    )

adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))

plt.tight_layout()
plt.show()


In [None]:
print("Columns in df_combined_volcano:")
print(df_combined_volcano.columns.tolist())


In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)

# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['mutation'] = df_volcano['Spike_AS_Position'].astype(str)
df_volcano['mutation_label'] = df_volcano['mutation'].str.zfill(3)
df_volcano['site_label'] = df_volcano['Amino_Acid'] + df_volcano['mutation_label']

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Get top N enriched and escape
top_n = 5

palette = {
    'Polyclonal_Ab': '#ff7f0e',     # Adjust colors as needed
    'Neutralizing_Ab': '#d62728',    # Adjust colors as needed
    'Mutant_RBD': '#1f77b4',         # Adjust colors as needed
    'wildtype_RBD': '#009688',       # Adjust colors as needed
}
# Ensure you're filtering based on the actual immunization values including the new ones
df_volcano_filtered = df_volcano[df_volcano['immunization'].isin(palette.keys())]

# Plot
plt.figure(figsize=(12, 6))
ax = sns.scatterplot(
    data=df_volcano_filtered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    palette=palette,
    alpha=0.8,
)

plt.axvline(0, linestyle='--', color='gray')
plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")

# Customizing the legend to use the desired labels
handles, labels = ax.get_legend_handles_labels()

# Update the labels to show the correct ones as you prefer
new_labels = ['B cells - B.1.135 RBD vaccine',
              'Polyreactive pAB',
              'B cells - Ancestral Wuhan RBD vaccine',
              'mAB (NEUT)']

# Apply the updated labels
ax.legend(handles, new_labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Enable minor ticks
plt.minorticks_on()

# Customize the minor ticks (smaller ticks)
ax.tick_params(axis='both', which='minor', length=4, width=1, color='gray')

# Annotate top points
# Drop duplicate site labels
df_labels = df_sig.drop_duplicates(subset=['site_label'])

texts = []
for _, row in df_labels.iterrows():
    offset_x = 1 if row['log2_Enrichment'] > 0 else -1
    offset_y = 2
    texts.append(
        plt.text(
            row['log2_Enrichment'] + offset_x,
            row['Sample_Count'] + offset_y,
            row['site_label'],
            fontsize=10,
            weight='bold',
            color='black'
        )
    )
    plt.plot(
        [row['log2_Enrichment'], row['log2_Enrichment'] + offset_x],
        [row['Sample_Count'], row['Sample_Count'] + offset_y],
        linestyle='--',
        color='gray',
        linewidth=0.8
    )

adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))


plt.tight_layout()
plt.show()


In [None]:
plt.figure(figsize=(12, 6))
hb = plt.hexbin(
    df_volcano_filtered['log2_Enrichment'],
    df_volcano_filtered['Sample_Count'],
    gridsize=50,
    cmap='viridis',
    bins='log'
)
plt.colorbar(hb, label='Log10(count)')
plt.axvline(0, linestyle='--', color='gray')
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")
plt.title("Hexbin Density of Enrichment vs Sample Count")


In [None]:
sns.stripplot(
    data=df_volcano_filtered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    palette=palette,
    dodge=True,
    jitter=0.25,
    alpha=0.7
)


In [None]:
sns.swarmplot(
    data=df_volcano_filtered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    palette=palette,
    dodge=True
)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

g = sns.FacetGrid(df_volcano_filtered, col='immunization', col_wrap=2, height=4, aspect=1.5)
g.map_dataframe(
    lambda data, color: plt.hexbin(
        data['log2_Enrichment'],
        data['Sample_Count'],
        gridsize=40,
        cmap='viridis',
        bins='log'
    )
)
g.set_titles(col_template="{col_name}")
g.set_axis_labels("log2(Enrichment Ratio)", "Sample Count")
g.fig.subplots_adjust(top=0.9)
g.fig.suptitle("Hexbin Plot of Enrichment vs Sample Count by Immunization")


In [None]:
import umap
import seaborn as sns
import matplotlib.pyplot as plt

from sklearn.preprocessing import StandardScaler

# 1. Select only relevant numeric features
features = df_volcano_filtered[['log2_Enrichment','Sample_Count']].copy()

# 2. Normalize the features
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)

# 3. Run UMAP
reducer = umap.UMAP(random_state=42)
embedding = reducer.fit_transform(features_scaled)

# 4. Add UMAP coordinates to the dataframe
df_volcano_filtered['UMAP1'] = embedding[:, 0]
df_volcano_filtered['UMAP2'] = embedding[:, 1]

# 5. Plot UMAP
plt.figure(figsize=(10, 6))
sns.scatterplot(
    data=df_volcano_filtered,
    x='UMAP1',
    y='UMAP2',
    hue='immunization',
    palette=palette,
    alpha=0.8
)

plt.title("UMAP of Mutation Features")
plt.xlabel("UMAP 1")
plt.ylabel("UMAP 2")
plt.legend(title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:

import pandas as pd
import numpy as np
from Bio import SeqIO, Seq

# --- Load the reference sequence ---
fasta_file = r'/Users/lucaschlotheuber/Desktop/ETH/RBD201_DMS1.fa'
for record in SeqIO.parse(fasta_file, "fasta"):
    wuhan_sequence = str(record.seq)
    break  # Assumes only one sequence in FASTA

# --- Load and filter the DMS dataset ---
file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'

df = pd.read_excel(file_path, usecols=[
    "DMS_RBD_AS_position", "Spike_AS_Position", "Count_of_Base", "Amino_Acid",
    "Type_of_Mutation", "Enrichment_Ratio", "barcode", "immunization",
    "condition", "Total_Reads", "Codon_Change", "Nucleotide_Ref"
])

# Adjust Spike position to match wuhan_sequence indexing
df["Spike_AS_Position"] = df["Spike_AS_Position"] - 5

# Filter out low read counts and NaNs
df = df.dropna(subset=["Enrichment_Ratio", "Amino_Acid"])
df = df[df["Total_Reads"] > 500]

# Filter to non-synonymous mutations and region of interest
df = df[(df['Spike_AS_Position'] > 364) & (df['Type_of_Mutation'] == 'NON-SYNOM')]

# Log2 transform enrichment
df['log2_Enrichment'] = np.log2(df['Enrichment_Ratio'])

# Function to compute reference amino acid
def get_reference_aa(codon_change, nucleotide_ref):
    try:
        original_codon = ''.join(
            [nucleotide_ref if base.isupper() else base for base in codon_change]
        )
        return str(Seq.Seq(original_codon).translate())
    except Exception:
        return np.nan

# Calculate reference amino acid
df['Reference_Amino_Acid'] = df.apply(
    lambda row: get_reference_aa(row['Codon_Change'], row['Nucleotide_Ref'])
    if pd.notna(row['Codon_Change']) and pd.notna(row['Nucleotide_Ref']) else np.nan,
    axis=1
)

# Create site label like N501Y
df['site_label'] = df.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)

# Final cleaned dataset
df_combined_volcano = df.copy()

# Add inverted enrichment ratio (1/x)
df_combined_volcano['Enrichment_Ratio_inverted'] = df_combined_volcano['Enrichment_Ratio'].apply(
    lambda x: 1 / x if x != 0 else x
)

# Label sites of interest for display
sites_to_show = list(map(
    str,
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] +  # RBD-ACE2 interface
    list(range(394, 414)) +  # R21 peptide
    list(range(484, 505))    # R13 peptide
))
df_combined_volcano['show_site'] = df_combined_volcano['Spike_AS_Position'].astype(str).isin(sites_to_show)

# Then define this:
original_columns = [
    'Spike_AS_Position', 'Amino_Acid', 'immunization', 'site_label',
    'Enrichment_Ratio', 'log2_Enrichment', 'barcode', 'condition',
    'Total_Reads', 'Codon_Change', 'Reference_Amino_Acid',
    'Enrichment_Ratio_inverted', 'show_site'
]



In [None]:
# Filter only the immunizations of interest
subset = df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])]

# Group by immunization and calculate the desired statistics
stats = subset.groupby('immunization')['log2_Enrichment'].agg(['min', 'max', 'median'])

print("📊 Log2 Enrichment stats by immunization:")
print(stats)


In [None]:
# Load the processed CSVs from your plot code
poly_df = pd.read_csv("immunization_csv_files/Polyclonal_Ab_data.csv")
neut_df = pd.read_csv("immunization_csv_files/Neutralizing_Ab_data.csv")

# Check the medians of smoothed enrichment
print("Polyclonal (Smoothed Enrichment):", poly_df["Smoothed_Enrichment"].median())
print("Neutralizing (Smoothed Enrichment):", neut_df["Smoothed_Enrichment"].median())
# Load the processed CSVs from your plot code
poly_df = pd.read_csv("immunization_csv_files/Polyclonal_Ab_data.csv")
neut_df = pd.read_csv("immunization_csv_files/Neutralizing_Ab_data.csv")

# Check the medians of smoothed enrichment
print("Polyclonal (Smoothed Enrichment):", poly_df["Smoothed_Enrichment"].median())
print("Neutralizing (Smoothed Enrichment):", neut_df["Smoothed_Enrichment"].median())
# Load the processed CSVs from your plot code
poly_df = pd.read_csv("immunization_csv_files/Polyclonal_Ab_data.csv")
neut_df = pd.read_csv("immunization_csv_files/Neutralizing_Ab_data.csv")

# Check the medians of smoothed enrichment
print("Polyclonal (Smoothed Enrichment):", poly_df["Smoothed_Enrichment"].median())
print("Neutralizing (Smoothed Enrichment):", neut_df["Smoothed_Enrichment"].median())


In [None]:
plt.figure(figsize=(18, 6))
sns.violinplot(
    data=df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    x='barcode', y='Enrichment_Ratio', hue='immunization',
    split=True, inner=None, cut=0
)
sns.stripplot(
    data=df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    x='barcode', y='Enrichment_Ratio', hue='immunization',
    dodge=True, color='black', alpha=0.25, jitter=True
)
plt.title("Enrichment Ratio by Barcode and Immunization")
plt.xticks(rotation=30)
plt.yscale('log')  # Again, helps with wide enrichment ranges
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig("Enrichment_Ratio_by_Barcode_and_Immunization.png", dpi=300)

plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(14, 6))

# Violin plot
sns.violinplot(
    data=df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    x='barcode', y='Enrichment_Ratio', hue='immunization',
    split=True, inner=None, cut=0
)

# Strip plot (overlaid points)
sns.stripplot(
    data=df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    x='barcode', y='Enrichment_Ratio', hue='immunization',
    dodge=False, color='black', alpha=0.3, jitter=True
)

# Add horizontal line at Enrichment Ratio = 1
plt.axhline(y=1, color='red', linestyle='--', linewidth=1.5)

plt.title("Enrichment Ratio by Barcode and Immunization")
plt.yscale('log')
plt.xticks(rotation=45)

# Clean legend (remove duplicate handles)
handles, labels = plt.gca().get_legend_handles_labels()
n = len(set(df_combined_volcano['immunization']))
plt.legend(handles[:n], labels[:n], bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.savefig("enrichment_ratio_by_barcode.png", dpi=300)
plt.show()


In [None]:
plt.figure(figsize=(22, 6))

# Violin plot (with hue split for immunization)
sns.violinplot(
    data=df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    x='barcode', y='Enrichment_Ratio', hue='immunization',
    split=True, inner=None, cut=0
)

# Strip plot (points overlaid, not dodged)
sns.stripplot(
    data=df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    x='barcode', y='Enrichment_Ratio', hue='immunization',
    dodge=False, color='black', alpha=0.3, jitter=True
)

plt.title("Enrichment Ratio by Barcode and Immunization")
plt.yscale('log')
plt.xticks(rotation=45)

# Remove duplicate legends
handles, labels = plt.gca().get_legend_handles_labels()
n = len(set(df_combined_volcano['immunization']))
plt.legend(handles[:n], labels[:n], bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(14, 6))

# Violin plot
sns.violinplot(
    data=df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    x='barcode', y='Enrichment_Ratio', hue='immunization',
    split=True, inner=None, cut=0
)

# Strip plot (overlaid points)
sns.stripplot(
    data=df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
    x='barcode', y='Enrichment_Ratio', hue='immunization',
    dodge=False, color='black', alpha=0.3, jitter=True
)

# Add vertical line at x=1 (2nd barcode group)
plt.axvline(x=1, color='red', linestyle='--', linewidth=1.5)

plt.title("Enrichment Ratio by Barcode and Immunization")
plt.yscale('log')
plt.xticks(rotation=45)

# Fix duplicate legend from stripplot
handles, labels = plt.gca().get_legend_handles_labels()
n = len(set(df_combined_volcano['immunization']))
plt.legend(handles[:n], labels[:n], bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()

plt.savefig("enrichment_ratio_by_barcode.png", dpi=300)
plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.violinplot(data=df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])],
               x='immunization', y='log2_Enrichment')
plt.title("Log2 Enrichment Distribution by Immunization")
plt.show()


In [None]:
# Check available columns
print("Available columns:", poly_df.columns)

# Print summary for Smoothed_Enrichment (raw scale)
print("\n--- Smoothed_Enrichment ---")
print("Polyclonal median:", poly_df["Smoothed_Enrichment"].median())
print("Neutralizing median:", neut_df["Smoothed_Enrichment"].median())

# Add log2 version for comparison
import numpy as np
poly_df['log2_Enrichment'] = np.log2(poly_df["Smoothed_Enrichment"] + 1e-6)
neut_df['log2_Enrichment'] = np.log2(neut_df["Smoothed_Enrichment"] + 1e-6)

print("\n--- log2_Enrichment ---")
print("Polyclonal median (log2):", poly_df['log2_Enrichment'].median())
print("Neutralizing median (log2):", neut_df['log2_Enrichment'].median())

# Print quantiles to visualize distribution skew
print("\nQuantiles (Polyclonal - log2):\n", poly_df['log2_Enrichment'].quantile([0.1, 0.25, 0.5, 0.75, 0.9]))
print("\nQuantiles (Neutralizing - log2):\n", neut_df['log2_Enrichment'].quantile([0.1, 0.25, 0.5, 0.75, 0.9]))

# Plot violin again with same data
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

df_combined = pd.concat([
    poly_df.assign(immunization='Polyclonal_Ab'),
    neut_df.assign(immunization='Neutralizing_Ab')
])

sns.violinplot(data=df_combined, x='immunization', y='log2_Enrichment')
plt.title("Log2 Enrichment Distribution by Immunization")
plt.show()


In [None]:
# Filter the dataset for Polyclonal and Neutralizing Ab groups
df_filtered = df_combined_volcano[df_combined_volcano['immunization'].isin(['Polyclonal_Ab', 'Neutralizing_Ab'])]

# Calculate median log2 enrichment for each group
polyclonal_median = df_filtered[df_filtered['immunization'] == 'Polyclonal_Ab']['log2_Enrichment'].median()
neutralizing_median = df_filtered[df_filtered['immunization'] == 'Neutralizing_Ab']['log2_Enrichment'].median()

# Print out the medians
print(f"--- log2_Enrichment ---")
print(f"Polyclonal median (log2): {polyclonal_median}")
print(f"Neutralizing median (log2): {neutralizing_median}")

# Calculate quantiles for both groups
polyclonal_quantiles = df_filtered[df_filtered['immunization'] == 'Polyclonal_Ab']['log2_Enrichment'].quantile([0.10, 0.25, 0.50, 0.75, 0.90])
neutralizing_quantiles = df_filtered[df_filtered['immunization'] == 'Neutralizing_Ab']['log2_Enrichment'].quantile([0.10, 0.25, 0.50, 0.75, 0.90])

# Print quantiles 
print("\nQuantiles (Polyclonal - log2):")
print(polyclonal_quantiles)

print("\nQuantiles (Neutralizing - log2):")
print(neutralizing_quantiles)


In [None]:
print("✅ Final columns in df_combined_volcano:")
print(df_combined_volcano.columns.tolist())

print("\n🔍 Preview of df_combined_volcano:")
print(df_combined_volcano.head())

In [None]:
# Check if each row has the correct site_label
df_combined_volcano['generated_label'] = df_combined_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}",
    axis=1
)

# Compare the generated label to the existing site_label column
df_combined_volcano['label_correct'] = df_combined_volcano['generated_label'] == df_combined_volcano['site_label']

# Print rows where the label does not match
incorrect_labels = df_combined_volcano[df_combined_volcano['label_correct'] == False]
print(incorrect_labels[['Spike_AS_Position', 'Reference_Amino_Acid', 'Amino_Acid', 'site_label', 'generated_label']])


In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
#FINALCODE

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd
from Bio import SeqIO

df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)

# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['mutation'] = df_volcano['Spike_AS_Position'].astype(str)
df_volcano['mutation_label'] = df_volcano['mutation'].str.zfill(3)
df_volcano['site_label'] = df_volcano['Amino_Acid'] + df_volcano['mutation_label']

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Get top N enriched and escape
top_n = 5

palette = {
    'Polyclonal_Ab': '#ff7f0e',     # Adjust colors as needed
    'Neutralizing_Ab': '#d62728',    # Adjust colors as needed
    'Mutant_RBD': '#1f77b4',         # Adjust colors as needed
    'wildtype_RBD': '#009688',       # Adjust colors as needed
}

# Ensure you're filtering based on the actual immunization values including the new ones
df_volcano_filtered = df_volcano[df_volcano['immunization'].isin(palette.keys())]

# --------------------------
# Define significance thresholds
low_thresh = np.log2(0.5)  # Example threshold for significance (log2(0.5))
high_thresh = np.log2(2)   # Example threshold for significance (log2(2))

# --------------------------
# Create two separate DataFrames for significant and non-significant data
df_non_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] >= low_thresh) & 
                                 (df_volcano_filtered['log2_Enrichment'] <= high_thresh)]

df_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] < low_thresh) | 
                             (df_volcano_filtered['log2_Enrichment'] > high_thresh)]

# --------------------------
# Plot
fig, ax = plt.subplots(figsize=(8, 6))


# Plot non-significant points
sns.scatterplot(
    data=df_non_sig,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',  # Light gray for non-significant points
    s=18,
    alpha=0.4
)

# Plot significant points
sns.scatterplot(
    data=df_sig,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    palette=palette,
    s=20,
    alpha=0.9,
    zorder=2,
    linewidth=0.4
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(0, linestyle='--', color='gray', lw=1)

# Add text labels for thresholds
plt.text(low_thresh, 0.2, 'log2(0.5)', ha='center', va='bottom', fontsize=9)
plt.text(high_thresh, 0.2, 'log2(2)', ha='center', va='bottom', fontsize=9)

plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("← Escape   |   Binding →")
plt.ylabel("Single-Antibody Repertoire")

# Customizing the legend to use the desired labels
handles, labels = plt.gca().get_legend_handles_labels()
new_labels = ['B cells - B.1.135 RBD vaccine',
              'Polyreactive pAB',
              'B cells - Ancestral Wuhan RBD vaccine',
              'mAB (NEUT)']
plt.legend(handles, new_labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')


# Enable minor ticks
plt.minorticks_on()
ax.tick_params(axis='both', which='minor', length=4, width=1, color='gray')

# Annotate top points
texts = []
for _, row in pd.concat([top_enriched, top_escape]).iterrows():
    offset_x = 1 if row['log2_Enrichment'] > 0 else -1
    offset_y = 2
    texts.append(
        plt.text(
            row['log2_Enrichment'] + offset_x,
            row['Sample_Count'] + offset_y,
            row['site_label'],
            fontsize=10,
            weight='bold',
            color='black'
        )
    )
    # Draw line to the actual point
    plt.plot(
        [row['log2_Enrichment'], row['log2_Enrichment'] + offset_x],
        [row['Sample_Count'], row['Sample_Count'] + offset_y],
        linestyle='--',
        color='gray',
        linewidth=0.8
    )

adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))

plt.tight_layout()
plt.show()


In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
#FINALCODE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd
from Bio import SeqIO

df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)

# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['mutation'] = df_volcano['Spike_AS_Position'].astype(str)
df_volcano['mutation_label'] = df_volcano['mutation'].str.zfill(3)
df_volcano['site_label'] = df_volcano['Amino_Acid'] + df_volcano['mutation_label']

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Get top N enriched and escape
top_n = 5

palette = {
    'Polyclonal_Ab': '#ff7f0e',     # Adjust colors as needed
    'Neutralizing_Ab': '#d62728',    # Adjust colors as needed
    'Mutant_RBD': '#1f77b4',         # Adjust colors as needed
    'wildtype_RBD': '#009688',       # Adjust colors as needed
}

# Ensure you're filtering based on the actual immunization values including the new ones
df_volcano_filtered = df_volcano[df_volcano['immunization'].isin(palette.keys())]

# --------------------------
# Define significance thresholds
low_thresh = np.log2(0.5)  # Example threshold for significance (log2(0.5))
high_thresh = np.log2(2)   # Example threshold for significance (log2(2))

# --------------------------
# Create two separate DataFrames for significant and non-significant data
df_non_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] >= low_thresh) & 
                                 (df_volcano_filtered['log2_Enrichment'] <= high_thresh)]

df_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] < low_thresh) | 
                             (df_volcano_filtered['log2_Enrichment'] > high_thresh)]

# --------------------------
# Plot
fig, ax = plt.subplots(figsize=(12, 6))


# Plot non-significant points
sns.scatterplot(
    data=df_non_sig,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',  # Light gray for non-significant points
    s=18,
    alpha=0.4
)

# Plot significant points
sns.scatterplot(
    data=df_sig,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    palette=palette,
    s=20,
    alpha=0.9,
    zorder=2,
    linewidth=0.4
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(0, linestyle='--', color='gray', lw=1)

# Add text labels for thresholds
plt.text(low_thresh, 0.2, 'log2(0.5)', ha='center', va='bottom', fontsize=9)
plt.text(high_thresh, 0.2, 'log2(2)', ha='center', va='bottom', fontsize=9)

plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("← Escape   |   Binding →")
plt.ylabel("Single-Antibody Repertoire")

# Customizing the legend to use the desired labels
handles, labels = plt.gca().get_legend_handles_labels()
new_labels = ['Anti-RBD mAB (NEUT)',
              'Polyreactive pAB',
              'B cells - Ancestral Wuhan RBD vaccine',
              'B cells - B.1.135 RBD vaccine',]
plt.legend(handles, new_labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')


# Enable minor ticks
plt.minorticks_on()
ax.tick_params(axis='both', which='minor', length=4, width=1, color='gray')

from collections import defaultdict

residues_to_label = [417, 439, 440, 452, 476, 477, 484, 493,494,495,496,497,498,499,500, 501, 502, 505]

# Step 1: Filter for log enrichment and sample count
df_labels = df_sig[
    ((df_sig['log2_Enrichment'] > 2) | (df_sig['log2_Enrichment'] < -2)) &
    (df_sig['Sample_Count'] == 3) &
    (df_sig['Spike_AS_Position'].isin(residues_to_label))
].copy()

# Step 2: Deduplicate by site_label within each immunization group
seen = defaultdict(set)
filtered_rows = []

for _, row in df_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 3: Annotate plot
texts = []
for row in filtered_rows:
    offset_x = 1 if row['log2_Enrichment'] > 0 else -1
    offset_y = 2
    texts.append(
        plt.text(
            row['log2_Enrichment'] + offset_x,
            row['Sample_Count'] + offset_y,
            row['site_label'],
            fontsize=10,
            weight='bold',
            color='black'
        )
    )
    plt.plot(
        [row['log2_Enrichment'], row['log2_Enrichment'] + offset_x],
        [row['Sample_Count'], row['Sample_Count'] + offset_y],
        linestyle='--',
        color='gray',
        linewidth=0.8
    )

adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray'))


plt.tight_layout()
plt.show()


In [None]:
# Step 1: Remove duplicates before aggregation and filtering
df_combined_volcano = df_combined_volcano.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'])


In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
#FINALCODE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd
from Bio import SeqIO

import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for saving to a file
import matplotlib.pyplot as plt


df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)

# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['site_label'] = df_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)


# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]



stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Get top N enriched and escape
top_n = 5

palette = {
    'Polyclonal_Ab': '#ff7f0e',     # Adjust colors as needed
    'Neutralizing_Ab': '#d62728',    # Adjust colors as needed
    'Mutant_RBD': '#1f77b4',         # Adjust colors as needed
    'wildtype_RBD': '#009688',       # Adjust colors as needed
}

# Ensure you're filtering based on the actual immunization values including the new ones
df_volcano_filtered = df_volcano[df_volcano['immunization'].isin(palette.keys())]

# --------------------------
# Define significance thresholds
low_thresh = np.log2(0.5)  # Example threshold for significance (log2(0.5))
high_thresh = np.log2(2)   # Example threshold for significance (log2(2))

# --------------------------
# Create two separate DataFrames for significant and non-significant data
df_non_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] >= low_thresh) & 
                                 (df_volcano_filtered['log2_Enrichment'] <= high_thresh)]

df_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] < low_thresh) | 
                             (df_volcano_filtered['log2_Enrichment'] > high_thresh)]

# --------------------------
# Plot
fig, ax = plt.subplots(figsize=(10, 8))


# Set jitter magnitude
y_jitter_std = 0.05  # Adjust as needed

# Add jittered Y values for plotting
df_non_sig['Sample_Count_jitter'] = df_non_sig['Sample_Count'] + np.random.normal(0, y_jitter_std, size=len(df_non_sig))
df_sig['Sample_Count_jitter'] = df_sig['Sample_Count'] + np.random.normal(0, y_jitter_std, size=len(df_sig))

# Plot non-significant points with jitter
sns.scatterplot(
    data=df_non_sig,
    x='log2_Enrichment',
    y='Sample_Count_jitter',
    color='lightgray',
    s=15,
    alpha=0.4
)

# Plot significant points with jitter
sns.scatterplot(
    data=df_sig,
    x='log2_Enrichment',
    y='Sample_Count_jitter',
    hue='immunization',
    palette=palette,
    s=15,
    alpha=0.9,
    zorder=2,
    linewidth=0
)


# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(0, linestyle='--', color='gray', lw=1)

# Add text labels for thresholds
plt.text(low_thresh, 0.2, 'log2(0.5)', ha='center', va='bottom', fontsize=11)
plt.text(high_thresh, 0.2, 'log2(2)', ha='center', va='bottom', fontsize=11)

plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)  # Adjust fontsize as needed
plt.xlabel("← Escape   |   Binding →", fontsize=18)  # Adjust fontsize as needed
plt.ylabel("Single-Antibody (Droplet Barcode)", fontsize=18)  # Adjust fontsize as needed


# Customizing the legend to use the desired labels
handles, labels = plt.gca().get_legend_handles_labels()
new_labels = ['Anti-RBD mAB (NEUT)',
              'Polyreactive pAB',
              'B cells - Ancestral Wuhan RBD vaccine',
              'B cells - B.1.135 RBD vaccine',]
plt.legend(handles, new_labels, title="Immunization", bbox_to_anchor=(0.75, 1), loc='upper left')


# Enable minor ticks
plt.minorticks_on()
ax.tick_params(axis='both', which='minor', length=4, width=1, color='gray')

from collections import defaultdict

residues_to_label = [417, 439, 440, 452, 477, 484, 495, 501, 502, 505]

# Step 1: Filter for log enrichment and sample count
df_labels = df_sig[
    ((df_sig['log2_Enrichment'] > 3) | (df_sig['log2_Enrichment'] < -3)) &
    (df_sig['Sample_Count'] > 2) &
    (df_sig['Spike_AS_Position'].isin(residues_to_label))
].copy()

# Step 2: Deduplicate by site_label within each immunization group
seen = defaultdict(set)
filtered_rows = []

for _, row in df_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 3: Add top 5, bottom 5, and highest Sample_Count to label
top_5 = df_volcano_filtered.nlargest(15, 'log2_Enrichment')
bottom_5 = df_volcano_filtered.nsmallest(15, 'log2_Enrichment')

# Get the top sample count (highest value)
max_sample_count_row = df_volcano_filtered.loc[df_volcano_filtered['Sample_Count'].idxmax()]

# Combine top 5, bottom 5, and the highest sample count
df_combined_labels = pd.concat([top_5, bottom_5, pd.DataFrame([max_sample_count_row])])

# Step 4: Deduplicate the combined labels by site_label within each immunization group
for _, row in df_combined_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 5: Add additional label for data above Sample_Count of 12 and log enrichment of 3
new_label_condition = (df_sig['Sample_Count'] > 12) & (df_sig['log2_Enrichment'] > 3)
new_labels = df_sig[new_label_condition].copy()

# Deduplicate by site_label for new labels
for _, row in new_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)


# Step 6: Annotate plot with unique labels per immunization group
texts = []
for row in filtered_rows:
    offset_x = 0 if row['log2_Enrichment'] > 0 else -1  # Adjust offset based on log2 enrichment
    offset_y = 0.2

    # Add jitter only to the scatter points, not to the labels
    jittered_sample_count = row['Sample_Count'] + np.random.normal(0, y_jitter_std)

    # Add the text annotations for each unique label
    texts.append(
        plt.text(
            row['log2_Enrichment'] + offset_x,
            jittered_sample_count + offset_y,
            row['site_label'],
            fontsize=13,
            color='black'
        )
    )
    # Draw line connecting point to the label (only for jittered points)


# Adjust text labels to prevent overlap
adjust_text(
    texts,
    arrowprops=None,  # No arrows, since you have custom lines
    expand_text=(2, 2),  # Increase repulsion in both x and y direction
    expand_points=(2, 2),  # Increase repulsion for points as well
    force_text=1.0,  # Stronger push away for text to prevent overlap
    force_points=0.3,  # Push points further away for better distribution
    only_move={'points': 'y', 'texts': 'xy'},  # Allow movement in both x and y for texts
    lim=100  # Limit the number of text label adjustments to prevent infinite loops
)

plt.xlim(-7, 9)
plt.tight_layout()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.show()
plt.draw()  # Ensure the plot is drawn before saving
plt.savefig('/Users/lucaschlotheuber/Desktop/Volcano2.jpeg', format='png', dpi=300)


In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
#FINALCODE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd
from Bio import SeqIO

import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for saving to a file
import matplotlib.pyplot as plt


df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)
# Normalize Sample_Count between 0 and 1
# Step 1: Count total unique barcodes per immunization group
barcode_totals = df_combined_volcano.groupby('immunization')['barcode'].nunique().to_dict()

# Step 2: Normalize Sample_Count by total barcodes per immunization
df_counts['Sample_Count'] = df_counts.apply(
    lambda row: row['Sample_Count'] / barcode_totals.get(row['immunization'], 1), axis=1
)


# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['site_label'] = df_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)


# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]



stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Get top N enriched and escape
top_n = 5

palette = {
    'Polyclonal_Ab': '#ff7f0e',     # Adjust colors as needed
    'Neutralizing_Ab': '#d62728',    # Adjust colors as needed
    'Mutant_RBD': '#1f77b4',         # Adjust colors as needed
    'wildtype_RBD': '#009688',       # Adjust colors as needed
}

# Ensure you're filtering based on the actual immunization values including the new ones
df_volcano_filtered = df_volcano[df_volcano['immunization'].isin(palette.keys())]

# --------------------------
# Define significance thresholds
low_thresh = np.log2(0.5)  # Example threshold for significance (log2(0.5))
high_thresh = np.log2(2)   # Example threshold for significance (log2(2))

# --------------------------
# Create two separate DataFrames for significant and non-significant data
df_non_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] >= low_thresh) & 
                                 (df_volcano_filtered['log2_Enrichment'] <= high_thresh)]

df_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] < low_thresh) | 
                             (df_volcano_filtered['log2_Enrichment'] > high_thresh)]

# --------------------------
# Plot
fig, ax = plt.subplots(figsize=(14, 8))


# Set jitter magnitude
y_jitter_std = 0.02 # Adjust as needed

# Add jittered Y values for plotting
df_non_sig['Sample_Count_jitter'] = df_non_sig['Sample_Count'] + np.random.normal(0, y_jitter_std, size=len(df_non_sig))
df_sig['Sample_Count_jitter'] = df_sig['Sample_Count'] + np.random.normal(0, y_jitter_std, size=len(df_sig))

# Plot non-significant points with jitter
sns.scatterplot(
    data=df_non_sig,
    x='log2_Enrichment',
    y='Sample_Count_jitter',
    color='lightgray',
    s=15,
    alpha=0.4
)

# Plot significant points with jitter
sns.scatterplot(
    data=df_sig,
    x='log2_Enrichment',
    y='Sample_Count_jitter',
    hue='immunization',
    palette=palette,
    s=25,
    alpha=0.9,
    zorder=2,
    linewidth=0
)


# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(0, linestyle='--', color='gray', lw=1)

# Add text labels for thresholds
plt.text(low_thresh, -0.05, 'log2(0.5)', ha='center', va='bottom', fontsize=11)
plt.text(high_thresh, -0.05, 'log2(2)', ha='center', va='bottom', fontsize=11)



plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)  # Adjust fontsize as needed
plt.xlabel("← Escape   |   Binding →", fontsize=18)  # Adjust fontsize as needed
plt.ylabel("Frequency of Mutation in Single-Antibody Repertoire", fontsize=18)  # Adjust fontsize as needed


# Customizing the legend to use the desired labels
handles, labels = plt.gca().get_legend_handles_labels()
new_labels = ['Anti-RBD mAB (NEUT)',
              'Polyreactive pAB',
              'B cells - Ancestral Wuhan RBD vaccine',
              'B cells - B.1.135 RBD vaccine',]
plt.legend(
    handles, 
    new_labels, 
    title="Immunization", 
    bbox_to_anchor=(0.8, 1), 
    loc='upper left', 
    fontsize=12,  # Adjust the size of the legend text
    title_fontsize=14  # Adjust the size of the title
)

# Enable minor ticks
plt.minorticks_on()
ax.tick_params(axis='both', which='minor', length=4, width=1, color='gray')

from collections import defaultdict

residues_to_label = [417, 439, 440, 452, 477, 484, 495, 501, 502, 505]

df_labels = df_sig[
    ((df_sig['log2_Enrichment'] > high_thresh) | (df_sig['log2_Enrichment'] < low_thresh)) &  # Significant data
    (df_sig['Sample_Count'] > 2) &  # Filter based on Sample_Count
    (df_sig['Spike_AS_Position'].isin(residues_to_label))  # Only specific residues
].copy()

# Step 2: Deduplicate by site_label within each immunization group
seen = defaultdict(set)
filtered_rows = []

for _, row in df_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 3: Add top 5, bottom 5, and highest Sample_Count to label
top_5 = df_volcano_filtered.nlargest(5, 'log2_Enrichment')
bottom_5 = df_volcano_filtered.nsmallest(25, 'log2_Enrichment')

# Get the top sample count (highest value)
top_y = df_volcano_filtered.nlargest(50, 'Sample_Count')

max_sample_count_row = df_volcano_filtered.loc[df_volcano_filtered['Sample_Count'].idxmax()]


# Combine top 5, bottom 5, and the highest sample count
df_combined_labels = pd.concat([
    top_5,
    bottom_5,
    top_y,  # <-- ADDED
    pd.DataFrame([max_sample_count_row])
])

# --- Debugging block for inspecting label rows ---
# Debugging block to inspect contents of filtered_rows
# --- End of debug block ---


# Step 4: Deduplicate the combined labels by site_label within each immunization group
for _, row in df_combined_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 5: Add additional label for data above Sample_Count of 12 and log enrichment of 3
new_label_condition = (df_sig['Sample_Count'] > 10) & (df_sig['log2_Enrichment'] > 3)
new_labels = df_sig[new_label_condition].copy()

# Deduplicate by site_label for new labels
for _, row in new_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)
        


# Step 6: Annotate plot with unique labels per immunization group
texts = []
for row in filtered_rows:
    # Calculate dynamic offsets based on log2_Enrichment value
    offset_x = 0.01  # Default offset
    offset_y = 0.01  # Default vertical offset

    # Adjust offset based on log2_Enrichment values
    if row['log2_Enrichment'] < 0:
        offset_x = -0.6  # Offset to the left for negative log2_Enrichment
    else:
        offset_x = 0.6  # Offset to the right for positive log2_Enrichment

    # Add jitter only to the scatter points, not to the labels
    jittered_sample_count = row['Sample_Count'] + np.random.normal(0, y_jitter_std)

    # Add the text annotations for each unique label
    texts.append(
        plt.text(
            row['log2_Enrichment'] + offset_x,
            jittered_sample_count + offset_y,
            row['site_label'],
            fontsize=12,
            color='black'
        )
    )

    # Draw line connecting point to the label (only for jittered points)
    plt.plot(
        [row['log2_Enrichment'], row['log2_Enrichment'] + offset_x],
        [row['Sample_Count'], jittered_sample_count + offset_y],
        linestyle='--',
        color='black',
        linewidth=0.8
    )


# Adjust text labels to prevent overlap
adjust_text(
    texts,
    arrowprops=None,  # No arrows, since you have custom lines
    expand_text=(2, 2),  # Increase repulsion in both x and y direction
    expand_points=(2, 2),  # Increase repulsion for points as well
    force_text=1.0,  # Stronger push away for text to prevent overlap
    force_points=0.3,  # Push points further away for better distribution
    only_move={'points': 'y', 'texts': 'xy'},  # Allow movement in both x and y for texts
    lim=100  # Limit the number of text label adjustments to prevent infinite loops
)

plt.xlim(-9, 9)
plt.tight_layout()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.show()
plt.draw()  # Ensure the plot is drawn before saving
plt.savefig('/Users/lucaschlotheuber/Desktop/Volcano3.jpeg', format='png', dpi=300)


In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
#FINALCODE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd
from Bio import SeqIO

import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for saving to a file
import matplotlib.pyplot as plt


df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)
# Normalize Sample_Count between 0 and 1
# Step 1: Count total unique barcodes per immunization group
barcode_totals = df_combined_volcano.groupby('immunization')['barcode'].nunique().to_dict()

# Step 2: Normalize Sample_Count by total barcodes per immunization
df_counts['Sample_Count'] = df_counts.apply(
    lambda row: row['Sample_Count'] / barcode_totals.get(row['immunization'], 1), axis=1
)


# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['site_label'] = df_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)


# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]


stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]
print("Remaining rows after all filters:", len(df_filtered))
# Get top N enriched and escape
top_n = 5

palette = {
    'Polyclonal_Ab': '#ff7f0e',     # Adjust colors as needed
    'Neutralizing_Ab': '#d62728',    # Adjust colors as needed
    'Mutant_RBD': '#009688',         # Adjust colors as needed
    'wildtype_RBD': '#1f77b4',       # Adjust colors as needed  
}
print(df_volcano[['Reference_Amino_Acid', 'Spike_AS_Position', 'Amino_Acid', 'site_label']].drop_duplicates().head())

# Ensure you're filtering based on the actual immunization values including the new ones
df_volcano_filtered = df_volcano[df_volcano['immunization'].isin(palette.keys())]

# --------------------------
# Define significance thresholds
low_thresh = np.log2(0.5)  # Example threshold for significance (log2(0.5))
high_thresh = np.log2(2)   # Example threshold for significance (log2(2))

# --------------------------
# Create two separate DataFrames for significant and non-significant data
df_non_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] >= low_thresh) & 
                                 (df_volcano_filtered['log2_Enrichment'] <= high_thresh)]

df_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] < low_thresh) | 
                             (df_volcano_filtered['log2_Enrichment'] > high_thresh)]

# --------------------------
# Plot
fig, ax = plt.subplots(figsize=(13, 10))


# Set jitter magnitude
y_jitter_std = 0 # Adjust as needed

# Add jittered Y values for plotting
df_non_sig['Sample_Count_jitter'] = df_non_sig['Sample_Count'] + np.random.normal(0, y_jitter_std, size=len(df_non_sig))
df_sig['Sample_Count_jitter'] = df_sig['Sample_Count'] + np.random.normal(0, y_jitter_std, size=len(df_sig))

# Plot non-significant points with jitter
sns.scatterplot(
    data=df_non_sig,
    x='log2_Enrichment',
    y='Sample_Count_jitter',
    color='lightgray',
    s=15,
    alpha=0.4
)

# Plot significant points with jitter
sns.scatterplot(
    data=df_sig,
    x='log2_Enrichment',
    y='Sample_Count_jitter',
    hue='immunization',
    palette=palette,
    s=25,
    alpha=0.9,
    zorder=2,
    linewidth=0
)


# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(0, linestyle='--', color='gray', lw=1)

# Add text labels for thresholds
plt.text(low_thresh, -0.05, 'log2(0.5)', ha='center', va='bottom', fontsize=13)
plt.text(high_thresh, -0.05, 'log2(2)', ha='center', va='bottom', fontsize=13)



plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)  # Adjust fontsize as needed
# Remove default xlabel
plt.xlabel("")

# Add centered xlabel at x=0
ax.text(0, -0.05, "← Escape   |   Binding →", fontsize=21,
        ha='center', va='top', transform=ax.get_xaxis_transform())

plt.ylabel("Frequency of Mutation in Single-Antibody Repertoire", fontsize=21)  # Adjust fontsize as needed


# Customizing the legend to use the desired labels
handles, labels = plt.gca().get_legend_handles_labels()
new_labels = ['Anti-RBD mAB (NEUT)',
              'Polyreactive pAB',
              'B cells - Ancestral Wuhan RBD vaccine',
              'B cells - B.1.135 RBD vaccine',]
plt.legend(
    handles, 
    new_labels, 
    title="Immunization", 
    bbox_to_anchor=(0.8, 0.9), 
    loc='upper left', 
    fontsize=12,  # Adjust the size of the legend text
    title_fontsize=14,  # Adjust the size of the title
    markerscale=1.8  
)

# Enable minor ticks
plt.minorticks_on()
ax.tick_params(axis='both', which='major', labelsize=15)  # Increase font size for major ticks
ax.tick_params(axis='both', which='minor', labelsize=15)

from collections import defaultdict

# Residues to label
residues_to_label = [417, 439, 440, 452, 477, 484, 495, 501, 502, 505]

# Step 1: Filter the significant data
df_labels = df_sig[
    ((df_sig['log2_Enrichment'] > high_thresh) | (df_sig['log2_Enrichment'] < low_thresh)) &  # Significant data
    (df_sig['Sample_Count'] > 2) &  # Filter based on Sample_Count
    (df_sig['Spike_AS_Position'].isin(residues_to_label))  # Only specific residues
].copy()

# Step 2: Deduplicate by site_label within each immunization group
seen = defaultdict(set)
filtered_rows = []

for _, row in df_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 3: Add top 5, bottom 5, and highest Sample_Count to label
top_5 = df_volcano_filtered.nlargest(5, 'log2_Enrichment')
bottom_5 = df_volcano_filtered.nsmallest(5, 'log2_Enrichment')

# Get the top sample count (highest value)
top_y = df_volcano_filtered.nlargest(20, 'Sample_Count')

max_sample_count_row = df_volcano_filtered.loc[df_volcano_filtered['Sample_Count'].idxmax()]

# Combine top 5, bottom 5, and the highest sample count
df_combined_labels = pd.concat([
    top_5,
    bottom_5,
    top_y,  # <-- ADDED
    pd.DataFrame([max_sample_count_row])
])

# Step 4: Deduplicate the combined labels by site_label within each immunization group
for _, row in df_combined_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 5: Add additional label for data above Sample_Count of 12 and log enrichment of 3

is_484 = df_sig['site_label'].str.contains('484', na=False)


new_labels = (
    ((df_sig['Sample_Count'] > 0.6) & (df_sig['log2_Enrichment'] > 2.5)) | 
    ((df_sig['Sample_Count'] > 0.2) & (df_sig['log2_Enrichment'] > 5.5)) | 
    ((df_sig['Sample_Count'] > 0.2) & (df_sig['log2_Enrichment'] < -4.7)) |
    ((df_sig['Sample_Count'] < 0.16) & (df_sig['log2_Enrichment'] > 2))  # 👈 Add this to include all 484-related entries
)

for _, row in df_sig[new_labels].iterrows():  # ✅ use the mask to filter df_sig
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)





# --- Now we are only dealing with significant data, so we will proceed to add labels ---
# Step 6: Annotate plot with unique labels per immunization group
# Step 1: Calculate and add text annotations with initial jitter and offsets
texts = []
offsets = []  # To store the final offsets after jitter and initial offsets
jittered_positions = []  # To store jittered positions for lines

for _, row in df_sig.iterrows():
    if row['Spike_AS_Position'] in residues_to_label:
        texts.append(
            ax.text(
                row['log2_Enrichment'],
                row['Sample_Count_jitter'],
                row['site_label'],
                fontsize=9,
                ha='center',
                va='bottom',
                zorder=5,
                clip_on=True
            )
        )

for row in filtered_rows:
    if -1 <= row['log2_Enrichment'] <= 1:
        continue  # Skip this label if within the range
    # Default offsets for label positioning
    offset_x = 0.01  # Default horizontal offset
    offset_y = 0.01  # Default vertical offset

    # Adjust horizontal offset based on log2_Enrichment values
    if row['log2_Enrichment'] < 0:
        offset_x = -1  # Leftward offset for negative log2_Enrichment
    else:
        offset_x = 1   # Rightward offset for positive log2_Enrichment

    # Add jitter only to the scatter points (not the labels)
    jittered_sample_count = row['Sample_Count'] + np.random.normal(0, y_jitter_std)

    # Store jittered positions for later
    jittered_positions.append((row['log2_Enrichment'], row['Sample_Count']))

    # Add text annotations with the initial position
    text = plt.text(
        row['log2_Enrichment'] + offset_x,
        jittered_sample_count + offset_y,
        row['site_label'],
        fontsize=12,
        color='black'
    )
    texts.append(text)
    offsets.append((row['log2_Enrichment'] + offset_x, jittered_sample_count + offset_y))

# Step 2: Adjust the text labels to avoid overlap
adjust_text(
    texts,
    arrowprops=None,  # No arrows, we will manually handle lines
    expand_text=(2, 2),  # Increase repulsion in both x and y direction
    expand_points=(2, 2),  # Increase repulsion for points as well
    force_text=1.0,  # Stronger push for text to avoid overlap
    force_points=0.3,  # Push points further for better distribution
    only_move={'points': 'y', 'texts': 'xy'},  # Allow movement in both x and y for texts
    lim=100  # Limit the number of text label adjustments to prevent infinite loops
)

# Step 3: Redraw the lines connecting points to adjusted label positions
for i in range(len(texts)):
    adjusted_text_x, adjusted_text_y = texts[i].get_position()
    original_jittered_x, original_jittered_y = jittered_positions[i]

    ax.plot(
        [original_jittered_x, adjusted_text_x],
        [original_jittered_y, adjusted_text_y],
        color='gray',
        linestyle='--',
        linewidth=0.5,
        zorder=1
    )

# Adjust plot
plt.xlim(-7, 9)
plt.tight_layout()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
for spine in ax.spines.values():
    spine.set_linewidth(2)
print(df_labels[['Spike_AS_Position', 'Amino_Acid', 'site_label', 'log2_Enrichment', 'Sample_Count_jitter']])

plt.draw()  # Ensure the plot is drawn before saving
plt.savefig('/Users/lucaschlotheuber/Desktop/Volcano9.jpeg', format='png', dpi=300)

In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
#FINALCODE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd
from Bio import SeqIO

import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for saving to a file
import matplotlib.pyplot as plt


df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)
# Normalize Sample_Count between 0 and 1
# Step 1: Count total unique barcodes per immunization group
barcode_totals = df_combined_volcano.groupby('immunization')['barcode'].nunique().to_dict()

# Step 2: Normalize Sample_Count by total barcodes per immunization
df_counts['Sample_Count'] = df_counts.apply(
    lambda row: row['Sample_Count'] / barcode_totals.get(row['immunization'], 1), axis=1
)


# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['site_label'] = df_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)


# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]


stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]
print("Remaining rows after all filters:", len(df_filtered))
# Get top N enriched and escape
top_n = 5

palette = {
    'Polyclonal_Ab': '#ff7f0e',     # Adjust colors as needed
    'Neutralizing_Ab': '#d62728',    # Adjust colors as needed
    'Mutant_RBD': '#009688',         # Adjust colors as needed
    'wildtype_RBD': '#1f77b4',       # Adjust colors as needed  
}
print(df_volcano[['Reference_Amino_Acid', 'Spike_AS_Position', 'Amino_Acid', 'site_label']].drop_duplicates().head())

# Ensure you're filtering based on the actual immunization values including the new ones
df_volcano_filtered = df_volcano[df_volcano['immunization'].isin(palette.keys())]

# --------------------------
# Define significance thresholds
low_thresh = np.log2(0.5)  # Example threshold for significance (log2(0.5))
high_thresh = np.log2(2)   # Example threshold for significance (log2(2))

# --------------------------
# Create two separate DataFrames for significant and non-significant data
df_non_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] >= low_thresh) & 
                                 (df_volcano_filtered['log2_Enrichment'] <= high_thresh)]

df_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] < low_thresh) | 
                             (df_volcano_filtered['log2_Enrichment'] > high_thresh)]

# --------------------------
# Plot
fig, ax = plt.subplots(figsize=(13, 10))


# Set jitter magnitude
y_jitter_std = 0 # Adjust as needed

# Add jittered Y values for plotting
df_non_sig['Sample_Count_jitter'] = df_non_sig['Sample_Count'] + np.random.normal(0, y_jitter_std, size=len(df_non_sig))
df_sig['Sample_Count_jitter'] = df_sig['Sample_Count'] + np.random.normal(0, y_jitter_std, size=len(df_sig))

# Plot non-significant points with jitter
sns.scatterplot(
    data=df_non_sig,
    x='log2_Enrichment',
    y='Sample_Count_jitter',
    color='lightgray',
    s=15,
    alpha=0.4
)

# Plot significant points with jitter
sns.scatterplot(
    data=df_sig,
    x='log2_Enrichment',
    y='Sample_Count_jitter',
    hue='immunization',
    palette=palette,
    s=25,
    alpha=0.9,
    zorder=2,
    linewidth=0
)


# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(0, linestyle='--', color='gray', lw=1)

# Add text labels for thresholds
plt.text(low_thresh, -0.05, 'log2(0.5)', ha='center', va='bottom', fontsize=13)
plt.text(high_thresh, -0.05, 'log2(2)', ha='center', va='bottom', fontsize=13)



plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)  # Adjust fontsize as needed
# Remove default xlabel
plt.xlabel("")

# Add centered xlabel at x=0
ax.text(0, -0.05, "← Escape   |   Binding →", fontsize=21,
        ha='center', va='top', transform=ax.get_xaxis_transform())

plt.ylabel("Frequency of Mutation in Single-Antibody Repertoire", fontsize=21)  # Adjust fontsize as needed


# Customizing the legend to use the desired labels
handles, labels = plt.gca().get_legend_handles_labels()
new_labels = ['Anti-RBD mAB (NEUT)',
              'Polyreactive pAB',
              'B cells - Ancestral Wuhan RBD vaccine',
              'B cells - B.1.135 RBD vaccine',]
plt.legend(
    handles, 
    new_labels, 
    title="Immunization", 
    bbox_to_anchor=(0.8, 0.9), 
    loc='upper left', 
    fontsize=12,  # Adjust the size of the legend text
    title_fontsize=14,  # Adjust the size of the title
    markerscale=1.8  
)

# Enable minor ticks
plt.minorticks_on()
ax.tick_params(axis='both', which='major', labelsize=15)  # Increase font size for major ticks
ax.tick_params(axis='both', which='minor', labelsize=15)

from collections import defaultdict

# Residues to label
residues_to_label = [417, 439, 440, 452, 477, 484, 495, 501, 502, 505]

# Step 1: Filter the significant data
df_labels = df_sig[
    ((df_sig['log2_Enrichment'] > high_thresh) | (df_sig['log2_Enrichment'] < low_thresh)) &  # Significant data
    (df_sig['Sample_Count'] > 2) &  # Filter based on Sample_Count
    (df_sig['Spike_AS_Position'].isin(residues_to_label))  # Only specific residues
].copy()

# Step 2: Deduplicate by site_label within each immunization group
seen = defaultdict(set)
filtered_rows = []

for _, row in df_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 3: Add top 5, bottom 5, and highest Sample_Count to label
top_5 = df_volcano_filtered.nlargest(5, 'log2_Enrichment')
bottom_5 = df_volcano_filtered.nsmallest(5, 'log2_Enrichment')

# Get the top sample count (highest value)
top_y = df_volcano_filtered.nlargest(20, 'Sample_Count')

max_sample_count_row = df_volcano_filtered.loc[df_volcano_filtered['Sample_Count'].idxmax()]

# Combine top 5, bottom 5, and the highest sample count
df_combined_labels = pd.concat([
    top_5,
    bottom_5,
    top_y,  # <-- ADDED
    pd.DataFrame([max_sample_count_row])
])

# Step 4: Deduplicate the combined labels by site_label within each immunization group
for _, row in df_combined_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 5: Add additional label for data above Sample_Count of 12 and log enrichment of 3

is_484 = df_sig['site_label'].str.contains('484', na=False)


new_labels = (
    ((df_sig['Sample_Count'] > 0.6) & (df_sig['log2_Enrichment'] > 2.5)) | 
    ((df_sig['Sample_Count'] > 0.2) & (df_sig['log2_Enrichment'] > 5.5)) | 
    ((df_sig['Sample_Count'] > 0.2) & (df_sig['log2_Enrichment'] < -4.7)) |
    ((df_sig['Sample_Count'] < 0.16) & (df_sig['log2_Enrichment'] > 2)) |
    ((df_sig['Sample_Count'] < 0.16) & (df_sig['log2_Enrichment'] < 2)) |
    is_484  # 👈 Add this to include all 484-related entries
)
for _, row in df_sig[new_labels].iterrows():  # ✅ use the mask to filter df_sig
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)


highlight_df = df_sig[
    (df_sig['Sample_Count'] > 0) &
    (df_sig['log2_Enrichment'] > -10) &
    (df_sig['Spike_AS_Position'] == 484)
]


# --- Now we are only dealing with significant data, so we will proceed to add labels ---
# Step 6: Annotate plot with unique labels per immunization group
# Step 1: Calculate and add text annotations with initial jitter and offsets
texts = []
offsets = []  # To store the final offsets after jitter and initial offsets
jittered_positions = []  # To store jittered positions for lines

for _, row in df_sig.iterrows():
    if row['Spike_AS_Position'] in residues_to_label:
        texts.append(
            ax.text(
                row['log2_Enrichment'],
                row['Sample_Count_jitter'],
                row['site_label'],
                fontsize=9,
                ha='center',
                va='bottom',
                zorder=5,
                clip_on=True
            )
        )

for _, row in highlight_df.iterrows():
    texts.append(
        plt.text(
            row['log2_Enrichment'],
            row['Sample_Count_jitter'],
            row['site_label'],
            fontsize=10,
            color='black',
            weight='bold',
            ha='center',
            va='center'
        )
    )

        
for row in filtered_rows:
    if -1 <= row['log2_Enrichment'] <= 1:
        continue  # Skip this label if within the range
    # Default offsets for label positioning
    offset_x = 0.01  # Default horizontal offset
    offset_y = 0.01  # Default vertical offset

    # Adjust horizontal offset based on log2_Enrichment values
    if row['log2_Enrichment'] < 0:
        offset_x = -1  # Leftward offset for negative log2_Enrichment
    else:
        offset_x = 1   # Rightward offset for positive log2_Enrichment

    # Add jitter only to the scatter points (not the labels)
    jittered_sample_count = row['Sample_Count'] + np.random.normal(0, y_jitter_std)

    # Store jittered positions for later
    jittered_positions.append((row['log2_Enrichment'], row['Sample_Count']))

    # Add text annotations with the initial position
    text = plt.text(
        row['log2_Enrichment'] + offset_x,
        jittered_sample_count + offset_y,
        row['site_label'],
        fontsize=12,
        color='black'
    )
    texts.append(text)
    offsets.append((row['log2_Enrichment'] + offset_x, jittered_sample_count + offset_y))

# Step 2: Adjust the text labels to avoid overlap
adjust_text(
    texts,
    arrowprops=None,  # No arrows, we will manually handle lines
    expand_text=(2, 2),  # Increase repulsion in both x and y direction
    expand_points=(2, 2),  # Increase repulsion for points as well
    force_text=1.0,  # Stronger push for text to avoid overlap
    force_points=0.3,  # Push points further for better distribution
    only_move={'points': 'y', 'texts': 'xy'},  # Allow movement in both x and y for texts
    lim=100  # Limit the number of text label adjustments to prevent infinite loops
)

# Step 3: Redraw the lines connecting points to adjusted label positions
for i in range(len(texts)):
    adjusted_text_x, adjusted_text_y = texts[i].get_position()
    original_jittered_x, original_jittered_y = jittered_positions[i]

    ax.plot(
        [original_jittered_x, adjusted_text_x],
        [original_jittered_y, adjusted_text_y],
        color='gray',
        linestyle='--',
        linewidth=0.5,
        zorder=1
    )

# Adjust plot
plt.xlim(-7, 9)
plt.tight_layout()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
for spine in ax.spines.values():
    spine.set_linewidth(2)
print(df_labels[['Spike_AS_Position', 'Amino_Acid', 'site_label', 'log2_Enrichment', 'Sample_Count_jitter']])

plt.draw()  # Ensure the plot is drawn before saving
plt.savefig('/Users/lucaschlotheuber/Desktop/Volcano8.jpeg', format='png', dpi=300)

In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
#FINALCODE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd
from Bio import SeqIO

import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for saving to a file
import matplotlib.pyplot as plt

from collections import defaultdict

def stack_points(df, x_col, y_col, spacing=0.002):
    """
    Spreads points in y-axis if they have identical x and y values
    to avoid overlap without jitter.
    """
    coords_counter = defaultdict(int)
    for idx, row in df.iterrows():
        key = (row[x_col], row[y_col])
        count = coords_counter[key]
        df.at[idx, y_col] += count * spacing
        coords_counter[key] += 1
    return df


df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)
# Normalize Sample_Count between 0 and 1
# Step 1: Count total unique barcodes per immunization group
barcode_totals = df_combined_volcano.groupby('immunization')['barcode'].nunique().to_dict()

# Step 2: Normalize Sample_Count by total barcodes per immunization
df_counts['Sample_Count'] = df_counts.apply(
    lambda row: row['Sample_Count'] / barcode_totals.get(row['immunization'], 1), axis=1
)


# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['site_label'] = df_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)


# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]


stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]
print("Remaining rows after all filters:", len(df_filtered))
# Get top N enriched and escape
top_n = 5

palette = {
    'Polyclonal_Ab': '#ff7f0e',     # Adjust colors as needed
    'Neutralizing_Ab': '#d62728',    # Adjust colors as needed
    'Mutant_RBD': '#009688',         # Adjust colors as needed
    'wildtype_RBD': '#1f77b4',       # Adjust colors as needed  
}
print(df_volcano[['Reference_Amino_Acid', 'Spike_AS_Position', 'Amino_Acid', 'site_label']].drop_duplicates().head())

# Ensure you're filtering based on the actual immunization values including the new ones
df_volcano_filtered = df_volcano[df_volcano['immunization'].isin(palette.keys())]

# --------------------------
# Define significance thresholds
low_thresh = np.log2(0.5)  # Example threshold for significance (log2(0.5))
high_thresh = np.log2(2)   # Example threshold for significance (log2(2))

# --------------------------
# Create two separate DataFrames for significant and non-significant data
df_non_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] >= low_thresh) & 
                                 (df_volcano_filtered['log2_Enrichment'] <= high_thresh)]

df_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] < low_thresh) | 
                             (df_volcano_filtered['log2_Enrichment'] > high_thresh)]

df_non_sig = stack_points(df_non_sig, 'log2_Enrichment', 'Sample_Count')
df_sig = stack_points(df_sig, 'log2_Enrichment', 'Sample_Count')

# --------------------------
# Plot
fig, ax = plt.subplots(figsize=(13, 10))



# Plot non-significant points with jitter
sns.scatterplot(
    data=df_non_sig,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    s=15,
    alpha=0.4
)

# Plot significant points with jitter
sns.scatterplot(
    data=df_sig,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    palette=palette,
    s=25,
    alpha=0.9,
    zorder=2,
    linewidth=0
)


# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(0, linestyle='--', color='gray', lw=1)

# Add text labels for thresholds
plt.text(low_thresh, -0.05, 'log2(0.5)', ha='center', va='bottom', fontsize=13)
plt.text(high_thresh, -0.05, 'log2(2)', ha='center', va='bottom', fontsize=13)



plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)  # Adjust fontsize as needed
# Remove default xlabel
plt.xlabel("")

# Add centered xlabel at x=0
ax.text(0, -0.05, "← Escape   |   Binding →", fontsize=21,
        ha='center', va='top', transform=ax.get_xaxis_transform())

plt.ylabel("Frequency of Mutation in Single-Antibody Repertoire", fontsize=21)  # Adjust fontsize as needed


# Customizing the legend to use the desired labels
handles, labels = plt.gca().get_legend_handles_labels()
new_labels = ['Anti-RBD mAB (NEUT)',
              'Polyreactive pAB',
              'B cells - Ancestral Wuhan RBD vaccine',
              'B cells - B.1.135 RBD vaccine',]
plt.legend(
    handles, 
    new_labels, 
    title="Immunization", 
    bbox_to_anchor=(0.8, 0.9), 
    loc='upper left', 
    fontsize=12,  # Adjust the size of the legend text
    title_fontsize=14,  # Adjust the size of the title
    markerscale=1.8  
)

# Enable minor ticks
plt.minorticks_on()
ax.tick_params(axis='both', which='major', labelsize=15)  # Increase font size for major ticks
ax.tick_params(axis='both', which='minor', labelsize=15)

from collections import defaultdict

# Residues to label
residues_to_label = [417, 439, 440, 452, 477, 484, 495, 501, 502, 505]

# Step 1: Filter the significant data
df_labels = df_sig[
    ((df_sig['log2_Enrichment'] > high_thresh) | (df_sig['log2_Enrichment'] < low_thresh)) &  # Significant data
    (df_sig['Sample_Count'] > 2) &  # Filter based on Sample_Count
    (df_sig['Spike_AS_Position'].isin(residues_to_label))  # Only specific residues
].copy()

# Step 2: Deduplicate by site_label within each immunization group
seen = defaultdict(set)
filtered_rows = []

for _, row in df_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 3: Add top 5, bottom 5, and highest Sample_Count to label
top_5 = df_volcano_filtered.nlargest(5, 'log2_Enrichment')
bottom_5 = df_volcano_filtered.nsmallest(5, 'log2_Enrichment')

# Get the top sample count (highest value)
top_y = df_volcano_filtered.nlargest(20, 'Sample_Count')

max_sample_count_row = df_volcano_filtered.loc[df_volcano_filtered['Sample_Count'].idxmax()]

# Combine top 5, bottom 5, and the highest sample count
df_combined_labels = pd.concat([
    top_5,
    bottom_5,
    top_y,  # <-- ADDED
    pd.DataFrame([max_sample_count_row])
])

# Step 4: Deduplicate the combined labels by site_label within each immunization group
for _, row in df_combined_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 5: Add additional label for data above Sample_Count of 12 and log enrichment of 3

new_labels = (
    ((df_sig['Sample_Count'] > 0.6) & (df_sig['log2_Enrichment'] > 2.5)) | 
    ((df_sig['Sample_Count'] > 0.2) & (df_sig['log2_Enrichment'] > 5.5)) | 
    ((df_sig['Sample_Count'] > 0.2) & (df_sig['log2_Enrichment'] < -4.7)) |
    ((df_sig['Sample_Count'] < 0.16) & (df_sig['log2_Enrichment'] > 2 )) |
    ((df_sig['Sample_Count'] < 0.16) & (df_sig['log2_Enrichment'] < 2 ))
)

for _, row in df_sig[new_labels].iterrows():  # ✅ use the mask to filter df_sig
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)


# --- Now we are only dealing with significant data, so we will proceed to add labels ---
# Step 6: Annotate plot with unique labels per immunization group
# Step 1: Calculate and add text annotations with initial jitter and offsets
texts = []
offsets = []  # To store the final offsets after fixed offsets
positions = []  # To store positions for lines

for _, row in df_sig.iterrows():
    if row['Spike_AS_Position'] in residues_to_label:
        texts.append(
            ax.text(
                row['log2_Enrichment'],
                row['Sample_Count'],
                row['site_label'],
                fontsize=9,
                ha='center',
                va='bottom',
                zorder=5,
                clip_on=True
            )
        )

for row in filtered_rows:
    if -1 <= row['log2_Enrichment'] <= 1:
        continue  # Skip this label if within the range
    # Default offsets for label positioning
    offset_x = 0.01  # Default horizontal offset
    offset_y = 0.01  # Default vertical offset

    # Adjust horizontal offset based on log2_Enrichment values
    if row['log2_Enrichment'] < 0:
        offset_x = -1  # Leftward offset for negative log2_Enrichment
    else:
        offset_x = 1   # Rightward offset for positive log2_Enrichment

    # No jitter here — just use original Sample_Count
    sample_count = row['Sample_Count']

    # Store positions for later
    positions.append((row['log2_Enrichment'], sample_count))

    # Add text annotations with the initial position
    text = plt.text(
        row['log2_Enrichment'] + offset_x,
        sample_count + offset_y,
        row['site_label'],
        fontsize=12,
        color='black'
    )
    texts.append(text)
    offsets.append((row['log2_Enrichment'] + offset_x, sample_count + offset_y))

# Step 2: Adjust the text labels to avoid overlap
adjust_text(
    texts,
    arrowprops=None,
    expand_text=(2, 2),
    expand_points=(2, 2),
    force_text=1.0,
    force_points=0.3,
    only_move={'points': 'y', 'texts': 'xy'},
    lim=100
)

# Step 3: Redraw the lines connecting points to adjusted label positions
for i in range(len(texts)):
    adjusted_text_x, adjusted_text_y = texts[i].get_position()
    original_x, original_y = positions[i]

    ax.plot(
        [original_x, adjusted_text_x],
        [original_y, adjusted_text_y],
        color='gray',
        linestyle='--',
        linewidth=0.5,
        zorder=1
        
    )
# Adjust plot
plt.xlim(-7, 9)
plt.tight_layout()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
for spine in ax.spines.values():
    spine.set_linewidth(2)
print(df_labels[['Spike_AS_Position', 'Amino_Acid', 'site_label', 'log2_Enrichment', 'Sample_Count']])


plt.draw()  # Ensure the plot is drawn before saving
plt.savefig('/Users/lucaschlotheuber/Desktop/Volcano9.jpeg', format='png', dpi=300)

In [None]:
print(df_combined_volcano.columns.tolist())

In [None]:
from collections import defaultdict

def stack_points(df, x_col, y_col, stack_direction='vertical', min_distance=0.005):
    """
    Spread overlapping points along y-axis (vertical) or x-axis (horizontal) 
    for better visualization.
    """
    grouped = defaultdict(list)
    adjusted_y = []

    for idx, row in df.iterrows():
        key = (round(row[x_col], 4), round(row[y_col], 4))
        stack_pos = len(grouped[key])
        grouped[key].append(idx)

        if stack_direction == 'vertical':
            new_y = row[y_col] + (stack_pos * min_distance)
            adjusted_y.append(new_y)
        else:
            new_x = row[x_col] + (stack_pos * min_distance)
            df.at[idx, x_col] = new_x
            adjusted_y.append(row[y_col])  # Keep y unchanged

    df['Sample_Count_adjusted'] = adjusted_y
    return df

In [None]:
# Calculate the number of unique barcodes per unique combination of 'Spike_AS_Position' and 'Amino_Acid'
#FINALCODE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd
from Bio import SeqIO



import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for saving to a file
import matplotlib.pyplot as plt


df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .agg(Sample_Count=('barcode', 'nunique'))  # Count unique barcodes
    .reset_index()
)
# Normalize Sample_Count between 0 and 1
# Step 1: Count total unique barcodes per immunization group
barcode_totals = df_combined_volcano.groupby('immunization')['barcode'].nunique().to_dict()

# Step 2: Normalize Sample_Count by total barcodes per immunization
df_counts['Sample_Count'] = df_counts.apply(
    lambda row: row['Sample_Count'] / barcode_totals.get(row['immunization'], 1), axis=1
)


# Calculate the median log2 Enrichment Ratio for each unique combination
df_median_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .median()
    .reset_index(name='log2_Enrichment')  # Use median instead of mean
)

# Merge the counts and median enrichment ratio
df_volcano = pd.merge(df_counts, df_median_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Print out the columns to verify
print("Columns in df_volcano after merging and cleaning:")
print(df_volcano.columns.tolist())

# Assign mutation label like E484K
df_volcano['site_label'] = df_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)


# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]


stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]
print("Remaining rows after all filters:", len(df_filtered))
# Get top N enriched and escape
top_n = 5

palette = {
    'Polyclonal_Ab': '#ff7f0e',     # Adjust colors as needed
    'Neutralizing_Ab': '#d62728',    # Adjust colors as needed
    'Mutant_RBD': '#009688',         # Adjust colors as needed
    'wildtype_RBD': '#1f77b4',       # Adjust colors as needed  
}
print(df_volcano[['Reference_Amino_Acid', 'Spike_AS_Position', 'Amino_Acid', 'site_label']].drop_duplicates().head())

# Ensure you're filtering based on the actual immunization values including the new ones
df_volcano_filtered = df_volcano[df_volcano['immunization'].isin(palette.keys())]

# --------------------------
# Define significance thresholds
low_thresh = np.log2(0.5)  # Example threshold for significance (log2(0.5))
high_thresh = np.log2(2)   # Example threshold for significance (log2(2))

# --------------------------
# Create two separate DataFrames for significant and non-significant data
df_non_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] >= low_thresh) & 
                                 (df_volcano_filtered['log2_Enrichment'] <= high_thresh)]

df_sig = df_volcano_filtered[(df_volcano_filtered['log2_Enrichment'] < low_thresh) | 
                             (df_volcano_filtered['log2_Enrichment'] > high_thresh)]

df_non_sig = stack_points(df_non_sig, 'log2_Enrichment', 'Sample_Count')
df_sig = stack_points(df_sig, 'log2_Enrichment', 'Sample_Count')

# --------------------------
# Plot
fig, ax = plt.subplots(figsize=(13, 10))


# Set jitter magnitude
y_jitter_std = 0 # Adjust as needed

def ensure_no_overlap(df, x_col, y_col, min_distance=0.01):
    seen_y = {}
    for idx, row in df.iterrows():
        y_val = row[y_col]
        x_val = row[x_col]
        # If this y value has been seen before, adjust it
        if y_val in seen_y:
            new_y_val = y_val + seen_y[y_val] * min_distance
            seen_y[y_val] += 1  # Increase the count for this y value
            df.at[idx, y_col] = new_y_val
        else:
            seen_y[y_val] = 1
    return df

# Ensure no overlap for both non-significant and significant data
df_non_sig = ensure_no_overlap(df_non_sig, 'log2_Enrichment', 'Sample_Count')
df_sig = ensure_no_overlap(df_sig, 'log2_Enrichment', 'Sample_Count')

# Add jittered Y values for plotting
df_non_sig['Sample_Count_jitter'] = df_non_sig['Sample_Count'] + np.random.normal(0, y_jitter_std, size=len(df_non_sig))
df_sig['Sample_Count_jitter'] = df_sig['Sample_Count'] + np.random.normal(0, y_jitter_std, size=len(df_sig))

# Plot non-significant points with jitter
sns.scatterplot(
    data=df_non_sig,
    x='log2_Enrichment',
    y='Sample_Count_jitter',
    color='lightgray',
    s=15,
    alpha=0.4
)

# Plot significant points with jitter
sns.scatterplot(
    data=df_sig,
    x='log2_Enrichment',
    y='Sample_Count_jitter',
    hue='immunization',
    palette=palette,
    s=25,
    alpha=0.9,
    zorder=2,
    linewidth=0
)


# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1)
plt.axvline(0, linestyle='--', color='gray', lw=1)

# Add text labels for thresholds
plt.text(low_thresh, -0.05, 'log2(0.5)', ha='center', va='bottom', fontsize=13)
plt.text(high_thresh, -0.05, 'log2(2)', ha='center', va='bottom', fontsize=13)



plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)  # Adjust fontsize as needed
# Remove default xlabel
plt.xlabel("")

# Add centered xlabel at x=0
ax.text(0, -0.05, "← Escape   |   Binding →", fontsize=21,
        ha='center', va='top', transform=ax.get_xaxis_transform())

plt.ylabel("Frequency of Mutation in Single-Antibody Repertoire", fontsize=21)  # Adjust fontsize as needed


# Customizing the legend to use the desired labels
handles, labels = plt.gca().get_legend_handles_labels()
new_labels = ['Anti-RBD mAB (NEUT)',
              'Polyreactive pAB',
              'B cells - Ancestral Wuhan RBD vaccine',
              'B cells - B.1.135 RBD vaccine',]
plt.legend(
    handles, 
    new_labels, 
    title="Immunization", 
    bbox_to_anchor=(0.8, 0.9), 
    loc='upper left', 
    fontsize=12,  # Adjust the size of the legend text
    title_fontsize=14,  # Adjust the size of the title
    markerscale=1.8  
)

# Enable minor ticks
plt.minorticks_on()
ax.tick_params(axis='both', which='major', labelsize=15)  # Increase font size for major ticks
ax.tick_params(axis='both', which='minor', labelsize=15)

from collections import defaultdict

# Residues to label
residues_to_label = [417, 439, 440, 452, 477, 484, 495, 501, 502, 505]

# Step 1: Filter the significant data
df_labels = df_sig[
    ((df_sig['log2_Enrichment'] > high_thresh) | (df_sig['log2_Enrichment'] < low_thresh)) &  # Significant data
    (df_sig['Sample_Count'] > 2) &  # Filter based on Sample_Count
    (df_sig['Spike_AS_Position'].isin(residues_to_label))  # Only specific residues
].copy()

# Step 2: Deduplicate by site_label within each immunization group
seen = defaultdict(set)
filtered_rows = []

for _, row in df_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 3: Add top 5, bottom 5, and highest Sample_Count to label
top_5 = df_volcano_filtered.nlargest(5, 'log2_Enrichment')
bottom_5 = df_volcano_filtered.nsmallest(5, 'log2_Enrichment')

# Get the top sample count (highest value)
top_y = df_volcano_filtered.nlargest(20, 'Sample_Count')

max_sample_count_row = df_volcano_filtered.loc[df_volcano_filtered['Sample_Count'].idxmax()]

# Combine top 5, bottom 5, and the highest sample count
df_combined_labels = pd.concat([
    top_5,
    bottom_5,
    top_y,  # <-- ADDED
    pd.DataFrame([max_sample_count_row])
])

# Step 4: Deduplicate the combined labels by site_label within each immunization group
for _, row in df_combined_labels.iterrows():
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)

# Step 5: Add additional label for data above Sample_Count of 12 and log enrichment of 3

new_labels = (
    ((df_sig['Sample_Count'] > 0.6) & (df_sig['log2_Enrichment'] > 2.5)) | 
    ((df_sig['Sample_Count'] > 0.2) & (df_sig['log2_Enrichment'] > 5.5)) | 
    ((df_sig['Sample_Count'] > 0.2) & (df_sig['log2_Enrichment'] < -4.7)) |
    ((df_sig['Sample_Count'] < 0.16) & (df_sig['log2_Enrichment'] > 2 )) |
    ((df_sig['Sample_Count'] < 0.16) & (df_sig['log2_Enrichment'] < 2 ))
)

for _, row in df_sig[new_labels].iterrows():  # ✅ use the mask to filter df_sig
    label = row['site_label']
    group = row['immunization']
    if "*" in label:
        continue
    if label not in seen[group]:
        seen[group].add(label)
        filtered_rows.append(row)


# --- Now we are only dealing with significant data, so we will proceed to add labels ---
# Step 6: Annotate plot with unique labels per immunization group
# Step 1: Calculate and add text annotations with initial jitter and offsets
texts = []
offsets = []  # To store the final offsets after jitter and initial offsets
jittered_positions = []  # To store jittered positions for lines

for _, row in df_sig.iterrows():
    if row['Spike_AS_Position'] in residues_to_label:
        texts.append(
            ax.text(
                row['log2_Enrichment'],
                row['Sample_Count_jitter'],
                row['site_label'],
                fontsize=9,
                ha='center',
                va='bottom',
                zorder=5,
                clip_on=True
            )
        )

for row in filtered_rows:
    if -1 <= row['log2_Enrichment'] <= 1:
        continue  # Skip this label if within the range
    # Default offsets for label positioning
    offset_x = 0.01  # Default horizontal offset
    offset_y = 0.01  # Default vertical offset

    # Adjust horizontal offset based on log2_Enrichment values
    if row['log2_Enrichment'] < 0:
        offset_x = -1  # Leftward offset for negative log2_Enrichment
    else:
        offset_x = 1   # Rightward offset for positive log2_Enrichment

    # Add jitter only to the scatter points (not the labels)
    jittered_sample_count = row['Sample_Count'] + np.random.normal(0, y_jitter_std)

    # Store jittered positions for later
    jittered_positions.append((row['log2_Enrichment'], row['Sample_Count']))

    # Add text annotations with the initial position
    text = plt.text(
        row['log2_Enrichment'] + offset_x,
        jittered_sample_count + offset_y,
        row['site_label'],
        fontsize=12,
        color='black'
    )
    texts.append(text)
    offsets.append((row['log2_Enrichment'] + offset_x, jittered_sample_count + offset_y))

# Step 2: Adjust the text labels to avoid overlap
adjust_text(
    texts,
    arrowprops=None,  # No arrows, we will manually handle lines
    expand_text=(2, 2),  # Increase repulsion in both x and y direction
    expand_points=(2, 2),  # Increase repulsion for points as well
    force_text=1.0,  # Stronger push for text to avoid overlap
    force_points=0.3,  # Push points further for better distribution
    only_move={'points': 'y', 'texts': 'xy'},  # Allow movement in both x and y for texts
    lim=100  # Limit the number of text label adjustments to prevent infinite loops
)

# Step 3: Redraw the lines connecting points to adjusted label positions
for i in range(len(texts)):
    adjusted_text_x, adjusted_text_y = texts[i].get_position()
    original_jittered_x, original_jittered_y = jittered_positions[i]

    ax.plot(
        [original_jittered_x, adjusted_text_x],
        [original_jittered_y, adjusted_text_y],
        color='gray',
        linestyle='--',
        linewidth=0.5,
        zorder=1
    )

# Adjust plot
plt.xlim(-7, 9)
plt.tight_layout()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
for spine in ax.spines.values():
    spine.set_linewidth(2)
print(df_labels[['Spike_AS_Position', 'Amino_Acid', 'site_label', 'log2_Enrichment', 'Sample_Count_jitter']])


plt.draw()  # Ensure the plot is drawn before saving
plt.savefig('/Users/lucaschlotheuber/Desktop/Volcano9.jpeg', format='png', dpi=300)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text
from Bio.Seq import Seq

# Define a function to calculate the reference amino acid
def get_reference_aa(codon_change, nucleotide_ref):
    # Replace the capitalized nucleotide in the codon change with the nucleotide_ref
    original_codon = ''.join(
        [nucleotide_ref if base.isupper() else base for base in codon_change]
    )
    # Use Bio.Seq to translate the codon to an amino acid
    return Seq(original_codon).translate()

# Combine enriched and escaped data
df_combined_volcano = pd.concat([
    df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1],
    df_escape.assign(Enrichment_Ratio=df_escape['Enrichment_Ratio'])
])

df_combined_volcano = df_combined_volcano[df_combined_volcano['Enrichment_Ratio'] > 0]
df_combined_volcano['log2_Enrichment'] = np.log2(df_combined_volcano['Enrichment_Ratio'])

# Retain all columns from the original data frame
original_columns = df_combined_volcano.columns.tolist()

# Count appearances across barcodes
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

# Get mean log2 enrichment
df_max_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .mean()
    .reset_index(name='log2_Enrichment')
)

# Merge counts and enrichment data
df_volcano = pd.merge(df_counts, df_max_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Assign mutation label like E484K
df_volcano['Reference_Amino_Acid'] = df_volcano.apply(
    lambda row: get_reference_aa(row['Codon_Change'], row['Nucleotide_Ref'])
    if pd.notna(row['Codon_Change']) and pd.notna(row['Nucleotide_Ref']) else np.nan,
    axis=1
)

# Create labels like N501Y
df_volcano['site_label'] = df_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Calculate Reference Amino Acid
df_filtered['Reference_Amino_Acid'] = df_filtered.apply(
    lambda row: get_reference_aa(row['Codon_Change'], row['Nucleotide_Ref']),
    axis=1
)

# Get top N enriched, escaped, and top sample count
top_n = 10

top_enriched = df_filtered[df_filtered['log2_Enrichment'] > 0].nlargest(top_n, ['Sample_Count', 'log2_Enrichment'])
top_escape = df_filtered[df_filtered['log2_Enrichment'] < 0].nsmallest(top_n, ['Sample_Count', 'log2_Enrichment'])
top_sample_count = df_filtered.nlargest(top_n, 'Sample_Count')

# Combine and drop duplicates
top_combined = pd.concat([top_enriched, top_escape, top_sample_count]).drop_duplicates(subset=['site_label'])

plt.figure(figsize=(12, 6))

# Shade non-significant region (0.5–2 enrichment ratio in log2 space)
plt.axvspan(np.log2(0.5), np.log2(2), color='lightgray', alpha=0.3, zorder=0)

# Create masks for significant and non-significant mutations
mask_significant = (df_volcano['Enrichment_Ratio'] <= 0.5) | (df_volcano['Enrichment_Ratio'] >= 2)
mask_nonsig = ~mask_significant

# Plot non-significant mutations in gray (without hue)
sns.scatterplot(
    data=df_volcano[mask_nonsig],
    x='log2_Enrichment',
    y='Sample_Count',
    color='gray',  # Explicitly set color to gray
    alpha=0.4,
    edgecolor='black',
    s=10,
    label='Not significant',
    zorder=1
)

# Plot significant mutations with original coloring (keep hue)
ax = sns.scatterplot(
    data=df_volcano[mask_significant],
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',  # Keep hue for significant mutations
    alpha=0.8,
    edgecolor='black',
    s=10,
    zorder=2
)

plt.axvline(0, linestyle='--', color='gray')
plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")
plt.legend(title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Annotate only the top points with labels
texts = []
for _, row in top_combined.iterrows():
    offset_x = 0.2 if row['log2_Enrichment'] > 0 else -0.2
    offset_y = 0.2
    custom_label = f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    text = plt.text(
        row['log2_Enrichment'] + offset_x,
        row['Sample_Count'] + offset_y,
        custom_label,
        fontsize=8,
        color='black',
        ha='left' if row['log2_Enrichment'] > 0 else 'right',
        va='bottom'
    )
    texts.append(text)

# Adjust label positions to avoid overlap
adjust_text(texts, only_move={'points': 'xy', 'text': 'xy'}, expand_text=(1.1, 1.1), ax=ax)

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text
from Bio.Seq import Seq

# Define a function to calculate the reference amino acid
def get_reference_aa(codon_change, nucleotide_ref):
    # Replace the capitalized nucleotide in the codon change with the nucleotide_ref
    original_codon = ''.join(
        [nucleotide_ref if base.isupper() else base for base in codon_change]
    )
    # Use Bio.Seq to translate the codon to an amino acid
    return Seq(original_codon).translate()

# Combine enriched and escaped data
df_combined_volcano = pd.concat([
    df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1],
    df_escape.assign(Enrichment_Ratio=df_escape['Enrichment_Ratio'])
])

df_combined_volcano = df_combined_volcano[df_combined_volcano['Enrichment_Ratio'] > 0]
df_combined_volcano['log2_Enrichment'] = np.log2(df_combined_volcano['Enrichment_Ratio'])

# Retain all columns from the original data frame
original_columns = df_combined_volcano.columns.tolist()

# Count appearances across barcodes
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

# Get mean log2 enrichment
df_max_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .mean()
    .reset_index(name='log2_Enrichment')
)

# Merge counts and enrichment data
df_volcano = pd.merge(df_counts, df_max_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Assign mutation label like E484K
df_volcano['Reference_Amino_Acid'] = df_volcano.apply(
    lambda row: get_reference_aa(row['Codon_Change'], row['Nucleotide_Ref'])
    if pd.notna(row['Codon_Change']) and pd.notna(row['Nucleotide_Ref']) else np.nan,
    axis=1
)

# Create labels like N501Y
df_volcano['site_label'] = df_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Assign categories to 'Category' column
df_filtered = df_filtered.copy()  # Make sure we are not working on a view
df_filtered['Category'] = pd.NA
df_filtered.loc[df_filtered['immunization'] == 'Neutralizing_Ab', 'Category'] = 'mAB (NEUT)'
df_filtered.loc[df_filtered['immunization'] == 'Polyclonal_Ab', 'Category'] = 'Polyreactive pAB'
df_filtered.loc[df_filtered['immunization'] == 'Mutant_RBD', 'Category'] = 'B cells - B.1.135 RBD vaccine'
df_filtered.loc[df_filtered['immunization'] == 'wildtype_RBD', 'Category'] = 'B cells - Ancestral Wuhan RBD vaccine'

# Create a color palette based on the 'Category' column
palette = {
    'mAB (NEUT)': 'red',
    'Polyreactive pAB': 'orange',
    'B cells - B.1.135 RBD vaccine': 'blue',
    'B cells - Ancestral Wuhan RBD vaccine': 'turquoise'
}

# Create the volcano plot
plt.figure(figsize=(12, 6))

# Shade non-significant region (0.5–2 enrichment ratio in log2 space)
plt.axvspan(np.log2(0.5), np.log2(2), color='lightgray', alpha=0.3, zorder=0)

# Plot the data with custom colors for each category
sns.scatterplot(
    data=df_filtered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,  # Use custom palette
    alpha=0.8,
    edgecolor='black',
    s=30,
    zorder=2
)

# Get top 10 dots for escape (log2 enrichment)
top_escape = df_filtered.nlargest(10, 'log2_Enrichment')

# Get top 10 dots for binding (sample count)
top_binding = df_filtered.nlargest(10, 'Sample_Count')

# Get top 10 dots for sample count
top_sample_count = df_filtered.nlargest(10, 'Sample_Count')

# Add labels to the plot (escape)
texts = []
for i, row in top_escape.iterrows():
    texts.append(plt.text(row['log2_Enrichment'], row['Sample_Count'], row['site_label'], fontsize=10, ha='center', va='center'))

# Add labels to the plot (binding)
for i, row in top_binding.iterrows():
    texts.append(plt.text(row['log2_Enrichment'], row['Sample_Count'], row['site_label'], fontsize=10, ha='center', va='center'))

# Adjust text positions to avoid overlap
adjust_text(texts, arrowprops=dict(arrowstyle='->', color='gray', lw=0.5))

# Plot the line at x=0 for separation
plt.axvline(0, linestyle='--', color='gray')

# Set the title and labels
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)  # Increase title font size
plt.xlabel("← Escape   |   Binding →", fontsize=14)  # Increase x-axis label font size
plt.ylabel("Sample Count", fontsize=14)  # Increase y-axis label font size

# Custom legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Adjust label positions to avoid overlap
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text

# Filter for total reads > 500
df_filtered = df_filtered[df_filtered['Total_Reads'] > 500]

# Define significance thresholds
low_thresh = np.log2(0.5)
high_thresh = np.log2(2)

# Split significant and non-significant
df_non_sig = df_filtered[(df_filtered['log2_Enrichment'] >= low_thresh) & (df_filtered['log2_Enrichment'] <= high_thresh)]
df_sig = df_filtered[(df_filtered['log2_Enrichment'] < low_thresh) | (df_filtered['log2_Enrichment'] > high_thresh)]

# Slight y-jitter function
def add_y_jitter(arr, jitter_strength=10):
    return arr + np.random.normal(0, jitter_strength, size=len(arr))

# Apply y-jitter
df_non_sig_jittered = df_non_sig.copy()
df_sig_jittered = df_sig.copy()

df_non_sig_jittered['Sample_Count'] = add_y_jitter(df_non_sig_jittered['Sample_Count'], jitter_strength=0.2)
df_sig_jittered['Sample_Count'] = add_y_jitter(df_sig_jittered['Sample_Count'], jitter_strength=0.2)

# Darker color palette
palette = {
    'B cells - Ancestral Wuhan RBD vaccine': '#1f77b4',
    'B cells - B.1.135 RBD vaccine': '#ff7f0e',
    'mAB (NEUT)': '#2ca02c',
    'Polyreactive pAB': '#d62728'
}

# Start plotting
plt.figure(figsize=(12, 6))

# Non-significant gray points
sns.scatterplot(
    data=df_non_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    alpha=0.6,
    s=14,
    zorder=2,
    edgecolor=None
)

# Significant colored points
sns.scatterplot(
    data=df_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.9,
    s=10,
    zorder=3,
    edgecolor=None,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=2, zorder=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=2, zorder=1)
plt.axvline(0, linestyle='--', color='gray')

# Threshold labels
plt.text(low_thresh, -5, 'log2(0.5)', ha='center', va='center', fontsize=12, color='black')
plt.text(high_thresh, -5, 'log2(2)', ha='center', va='center', fontsize=12, color='black')

# Label top 10 by escape, binding, and count
top_10_escape = df_filtered.nlargest(10, 'log2_Enrichment')
top_10_binding = df_filtered.nsmallest(10, 'log2_Enrichment')
top_10_sample = df_filtered.nlargest(10, 'Sample_Count')
top_10_labels = pd.concat([top_10_escape, top_10_binding, top_10_sample]).drop_duplicates(subset=['site_label'])

# Thresholds for labeling: top 5% by count and extreme enrichment
count_thresh = df_filtered['Sample_Count'].quantile(0.95)
escape_thresh = df_filtered['log2_Enrichment'].quantile(0.05)  # very low log2 enrichment (escape)
binding_thresh = df_filtered['log2_Enrichment'].quantile(0.95)  # very high log2 enrichment (binding)

df_high_escape_and_count = df_filtered[(df_filtered['Sample_Count'] > count_thresh) & (df_filtered['log2_Enrichment'] < escape_thresh)]
df_high_binding_and_count = df_filtered[(df_filtered['Sample_Count'] > count_thresh) & (df_filtered['log2_Enrichment'] > binding_thresh)]

additional_labels = pd.concat([df_high_escape_and_count, df_high_binding_and_count])

texts = []
for _, row in top_10_labels.iterrows():
    text = plt.text(
        row['log2_Enrichment'], row['Sample_Count'], row['site_label'],
        ha='center', va='center', fontsize=10
    )
    texts.append(text)

# Add custom residue labels
residues_to_label = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
for pos in residues_to_label:
    residue_data = df_filtered[df_filtered['Spike_AS_Position'] == pos]
    if not residue_data.empty:
        text = plt.text(
            residue_data['log2_Enrichment'].values[0],
            residue_data['Sample_Count'].values[0],
            f"{residue_data['site_label'].values[0]}",
            fontsize=12, ha='center', va='center', color='black'
        )
        texts.append(text)

for _, row in additional_labels.iterrows():
    text = plt.text(
        row['log2_Enrichment'], row['Sample_Count'], row['site_label'],
        ha='center', va='center', fontsize=10, color='black'
    )
    texts.append(text)

# Adjust text to avoid overlap
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.2, arrowprops=dict(arrowstyle='-', color='gray'))

# Titles and labels
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)
plt.ylabel("Sample Count", fontsize=14)

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Symmetric x-axis
xlim = max(abs(df_filtered['log2_Enrichment'].min()), abs(df_filtered['log2_Enrichment'].max()))
plt.xlim(-xlim, xlim)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd

# Filter for total reads > 500
df_filtered = df_filtered[df_filtered['Total_Reads'] > 500]

# Define significance thresholds
low_thresh = np.log2(0.5)
high_thresh = np.log2(2)

# Split significant and non-significant
df_non_sig = df_filtered[(df_filtered['log2_Enrichment'] >= low_thresh) & (df_filtered['log2_Enrichment'] <= high_thresh)]
df_sig = df_filtered[(df_filtered['log2_Enrichment'] < low_thresh) | (df_filtered['log2_Enrichment'] > high_thresh)]

# Slight y-jitter function
def add_y_jitter(arr, jitter_strength=10):
    return arr + np.random.normal(0, jitter_strength, size=len(arr))

# Apply y-jitter
df_non_sig_jittered = df_non_sig.copy()
df_sig_jittered = df_sig.copy()

df_non_sig_jittered['Sample_Count'] = add_y_jitter(df_non_sig_jittered['Sample_Count'], jitter_strength=0)
df_sig_jittered['Sample_Count'] = add_y_jitter(df_sig_jittered['Sample_Count'], jitter_strength=0)

# Corrected color palette (darker turquoise)
palette = {
    'B cells - Ancestral Wuhan RBD vaccine': '#1f77b4',  # Blue for Wuhan
    'B cells - B.1.135 RBD vaccine': '#009688',  # Darker turquoise for B.1.135
    'mAB (NEUT)': '#d62728',  # Red for NEUT
    'Polyreactive pAB': '#ff7f0e'  # Orange for Polyreactive pAB
}

# Start plotting
plt.figure(figsize=(12, 6))

# Non-significant gray points
sns.scatterplot(
    data=df_non_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    alpha=0.6,
    s=14,
    zorder=2,
    edgecolor=None
)

# Significant colored points
sns.scatterplot(
    data=df_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.9,
    s=14,
    zorder=3,
    edgecolor=None,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(0, linestyle='--', color='gray')

# Threshold labels
plt.text(low_thresh, -5, 'log2(0.5)', ha='center', va='center', fontsize=10, color='black')
plt.text(high_thresh, -5, 'log2(2)', ha='center', va='center', fontsize=10, color='black')

# Label top 5 by escape, binding, and count
top_5_escape = df_filtered.nlargest(5, 'log2_Enrichment')
top_5_binding = df_filtered.nsmallest(5, 'log2_Enrichment')
top_5_sample = df_filtered.nlargest(5, 'Sample_Count')

# Concatenate and drop duplicates based on 'site_label' to avoid overlap
top_5_labels = pd.concat([top_5_escape, top_5_binding, top_5_sample]).drop_duplicates(subset=['site_label'])

# Thresholds for labeling: sample count > 10 and extreme enrichment
count_thresh = df_filtered['Sample_Count'] > 10
enrichment_thresh = df_filtered['log2_Enrichment'] > 2
escape_thresh = df_filtered['log2_Enrichment'] < -5

df_labels = df_filtered[(count_thresh & (enrichment_thresh | escape_thresh))]

# Combine all labels for plotting
labels_to_plot = pd.concat([top_5_labels, df_labels]).drop_duplicates(subset=['site_label'])

# Define the specific residues to label
# Define the specific residues to label
residues_to_label = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]

# Label the specified residues
texts = []
for _, row in labels_to_plot.iterrows():
    # Ensure site_label is not NaN and check if residue is in the list
    if pd.notna(row['site_label']) and row['Spike_AS_Position'] in residues_to_label:  # Check residue position
        # Determine whether the label should be placed on the left or right
        if row['log2_Enrichment'] < 0:
            x_offset = -1.1  # Move labels to the left
        else:
            x_offset = 1.1  # Move labels to the right

        # Adjust the y position to avoid overlap
        y_offset = 1.1 if row['Sample_Count'] < 50 else 1.1

        # Plot the label with the adjusted position
        text = plt.text(
            row['log2_Enrichment'] + x_offset, row['Sample_Count'] + y_offset, row['site_label'],
            ha='center', va='center', fontsize=10, color='black'
        )
        texts.append(text)
        plt.plot([row['log2_Enrichment'], row['log2_Enrichment'] + x_offset], 
                 [row['Sample_Count'], row['Sample_Count'] + y_offset], color='gray', linestyle='-', linewidth=1)

# Adjust text to avoid overlap
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.2, arrowprops=dict(arrowstyle='-', color='gray'))
# Titles and labels
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=10)  # Adjusted labelpad for the x-axis title
plt.ylabel("Droplet Monoclonal Antibody repertoire", fontsize=14)

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.ylim(bottom=0.5)
# Symmetric x-axis
xlim = max(abs(df_filtered['log2_Enrichment'].min()), abs(df_filtered['log2_Enrichment'].max()))
plt.xlim(-xlim, xlim)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text

# Filter for total reads > 500
df_filtered = df_filtered[df_filtered['Total_Reads'] > 500]

# Define significance thresholds
low_thresh = np.log2(0.5)
high_thresh = np.log2(2)

# Split significant and non-significant
df_non_sig = df_filtered[(df_filtered['log2_Enrichment'] >= low_thresh) & (df_filtered['log2_Enrichment'] <= high_thresh)]
df_sig = df_filtered[(df_filtered['log2_Enrichment'] < low_thresh) | (df_filtered['log2_Enrichment'] > high_thresh)]

# Slight y-jitter function
def add_y_jitter(arr, jitter_strength=10):
    return arr + np.random.normal(0, jitter_strength, size=len(arr))

# Apply y-jitter
df_non_sig_jittered = df_non_sig.copy()
df_sig_jittered = df_sig.copy()

df_non_sig_jittered['Sample_Count'] = add_y_jitter(df_non_sig_jittered['Sample_Count'], jitter_strength=0.1)
df_sig_jittered['Sample_Count'] = add_y_jitter(df_sig_jittered['Sample_Count'], jitter_strength=0.1)

# Darker color palette
palette = {
    'B cells - Ancestral Wuhan RBD vaccine': '#1f77b4',
    'B cells - B.1.135 RBD vaccine': '#ff7f0e',
    'mAB (NEUT)': '#2ca02c',
    'Polyreactive pAB': '#d62728'
}

# Start plotting
plt.figure(figsize=(12, 6))

# Non-significant gray points
sns.scatterplot(
    data=df_non_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    alpha=0.6,
    s=12,
    zorder=2,
    edgecolor=None
)

# Significant colored points
sns.scatterplot(
    data=df_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.9,
    s=12,
    zorder=3,
    edgecolor=None,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(0, linestyle='--', color='gray')

# Threshold labels
plt.text(low_thresh, -5, 'log2(0.5)', ha='center', va='center', fontsize=10, color='black')
plt.text(high_thresh, -5, 'log2(2)', ha='center', va='center', fontsize=10, color='black')

# Label top 5 by escape, binding, and count
top_5_escape = df_filtered.nlargest(5, 'log2_Enrichment')
top_5_binding = df_filtered.nsmallest(5, 'log2_Enrichment')
top_5_sample = df_filtered.nlargest(5, 'Sample_Count')

# Concatenate and drop duplicates based on 'site_label' to avoid overlap
top_5_labels = pd.concat([top_5_escape, top_5_binding, top_5_sample]).drop_duplicates(subset=['site_label'])

# Thresholds for labeling: sample count > 10 and extreme enrichment
count_thresh = df_filtered['Sample_Count'] > 10
enrichment_thresh = df_filtered['log2_Enrichment'] > 2
escape_thresh = df_filtered['log2_Enrichment'] < -4

df_labels = df_filtered[(count_thresh & (enrichment_thresh | escape_thresh))]

# Combine all labels for plotting
labels_to_plot = pd.concat([top_5_labels, df_labels]).drop_duplicates(subset=['site_label'])

texts = []
for _, row in labels_to_plot.iterrows():
    text = plt.text(
        row['log2_Enrichment'], row['Sample_Count'], row['site_label'],
        ha='center', va='center', fontsize=10
    )
    texts.append(text)
    plt.plot([row['log2_Enrichment'], row['log2_Enrichment']], [row['Sample_Count'], row['Sample_Count'] + 0.5], color='gray', linestyle='-', linewidth=1)

# Add custom residue labels
residues_to_label = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
for pos in residues_to_label:
    residue_data = df_filtered[df_filtered['Spike_AS_Position'] == pos]
    if not residue_data.empty:
        text = plt.text(
            residue_data['log2_Enrichment'].values[0],
            residue_data['Sample_Count'].values[0],
            f"{residue_data['site_label'].values[0]}",
            fontsize=10, ha='center', va='center', color='black'
        )
        texts.append(text)

# Adjust text to avoid overlap
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.2, arrowprops=dict(arrowstyle='-', color='gray'))

# Titles and labels
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)
plt.ylabel("Sample Count", fontsize=14)
plt.xlim(-xlim, xlim)
# Legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Symmetric x-axis
xlim = max(abs(df_filtered['log2_Enrichment'].min()), abs(df_filtered['log2_Enrichment'].max()))
plt.xlim(-xlim, xlim)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd

# Filter for total reads > 500
df_filtered = df_filtered[df_filtered['Total_Reads'] > 500]

# Define significance thresholds
low_thresh = np.log2(0.5)
high_thresh = np.log2(2)

# Split significant and non-significant
df_non_sig = df_filtered[(df_filtered['log2_Enrichment'] >= low_thresh) & (df_filtered['log2_Enrichment'] <= high_thresh)]
df_sig = df_filtered[(df_filtered['log2_Enrichment'] < low_thresh) | (df_filtered['log2_Enrichment'] > high_thresh)]

# Slight y-jitter function
def add_y_jitter(arr, jitter_strength=10):
    return arr + np.random.normal(0, jitter_strength, size=len(arr))

# Apply y-jitter
df_non_sig_jittered = df_non_sig.copy()
df_sig_jittered = df_sig.copy()

df_non_sig_jittered['Sample_Count'] = add_y_jitter(df_non_sig_jittered['Sample_Count'], jitter_strength=0.1)
df_sig_jittered['Sample_Count'] = add_y_jitter(df_sig_jittered['Sample_Count'], jitter_strength=0.1)

# Darker color palette
palette = {
    'B cells - Ancestral Wuhan RBD vaccine': '#1f77b4',  # Blue
    'B cells - B.1.135 RBD vaccine': '#40E0D0',  # Turquoise
    'mAB (NEUT)': '#FF0000',  # Red (Neutralizing antibody)
    'Polyreactive pAB': '#FFA500'  # Orange (Polyclonal)
}

# Start plotting
plt.figure(figsize=(12, 6))

# Non-significant gray points
sns.scatterplot(
    data=df_non_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    alpha=0.6,
    s=14,
    zorder=2,
    edgecolor=None
)

# Significant colored points
sns.scatterplot(
    data=df_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.9,
    s=10,
    zorder=3,
    edgecolor=None,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(0, linestyle='--', color='gray')

# Threshold labels
plt.text(low_thresh, -5, 'log2(0.5)', ha='center', va='center', fontsize=10, color='black')
plt.text(high_thresh, -5, 'log2(2)', ha='center', va='center', fontsize=10, color='black')

# Label top 5 by escape, binding, and count
top_5_escape = df_filtered.nlargest(5, 'log2_Enrichment')
top_5_binding = df_filtered.nsmallest(5, 'log2_Enrichment')
top_5_sample = df_filtered.nlargest(5, 'Sample_Count')

# Concatenate and drop duplicates based on 'site_label' to avoid overlap
top_5_labels = pd.concat([top_5_escape, top_5_binding, top_5_sample]).drop_duplicates(subset=['site_label'])

# Thresholds for labeling: sample count > 4 and extreme enrichment (> 2 or < -2)
count_thresh = df_filtered['Sample_Count'] > 4
enrichment_thresh = (df_filtered['log2_Enrichment'] > 2) | (df_filtered['log2_Enrichment'] < -2)

df_labels = df_filtered[count_thresh & enrichment_thresh]

# Combine all labels for plotting
labels_to_plot = pd.concat([top_5_labels, df_labels]).drop_duplicates(subset=['site_label'])

texts = []
for _, row in labels_to_plot.iterrows():
    if row['log2_Enrichment'] < 0:  # Escape labels to the left
        text = plt.text(
            row['log2_Enrichment'] - 0.5,  # Shift to the left
            row['Sample_Count'] + 0.5,  # Shift up
            row['site_label'],
            ha='center', va='center', fontsize=10
        )
        plt.plot([row['log2_Enrichment'], row['log2_Enrichment'] - 0.5],
                 [row['Sample_Count'], row['Sample_Count'] + 0.5], color='gray', linestyle='-', linewidth=1)
    else:  # Binding labels to the right
        text = plt.text(
            row['log2_Enrichment'] + 0.5,  # Shift to the right
            row['Sample_Count'] + 0.5,  # Shift up
            row['site_label'],
            ha='center', va='center', fontsize=10
        )
        plt.plot([row['log2_Enrichment'], row['log2_Enrichment'] + 0.5],
                 [row['Sample_Count'], row['Sample_Count'] + 0.5], color='gray', linestyle='-', linewidth=1)

    texts.append(text)

# Adjust text to avoid overlap
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.2, arrowprops=dict(arrowstyle='-', color='gray'))

# Titles and labels
plt.title("", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)
plt.ylabel("Single-Antibody Repertoire n=1x droplet", fontsize=14)
plt.ylim(bottom=0.5)

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Symmetric x-axis
xlim = 6
plt.xlim(-xlim, xlim)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd

# Filter for total reads > 500
df_filtered = df_filtered[df_filtered['Total_Reads'] > 500]

# Define significance thresholds
low_thresh = np.log2(0.5)
high_thresh = np.log2(2)

# Split significant and non-significant
df_non_sig = df_filtered[(df_filtered['log2_Enrichment'] >= low_thresh) & (df_filtered['log2_Enrichment'] <= high_thresh)]
df_sig = df_filtered[(df_filtered['log2_Enrichment'] < low_thresh) | (df_filtered['log2_Enrichment'] > high_thresh)]

# Slight y-jitter function
def add_y_jitter(arr, jitter_strength=10):
    return arr + np.random.normal(0, jitter_strength, size=len(arr))

# Apply y-jitter
df_non_sig_jittered = df_non_sig.copy()
df_sig_jittered = df_sig.copy()

df_non_sig_jittered['Sample_Count'] = add_y_jitter(df_non_sig_jittered['Sample_Count'], jitter_strength=0.2)
df_sig_jittered['Sample_Count'] = add_y_jitter(df_sig_jittered['Sample_Count'], jitter_strength=0.2)

# Darker color palette
palette = {
    'B cells - Ancestral Wuhan RBD vaccine': '#1f77b4',
    'B cells - B.1.135 RBD vaccine': '#ff7f0e',
    'mAB (NEUT)': '#2ca02c',
    'Polyreactive pAB': '#d62728'
}

# Start plotting
plt.figure(figsize=(12, 6))

# Non-significant gray points
sns.scatterplot(
    data=df_non_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    alpha=0.6,
    s=14,
    zorder=2,
    edgecolor=None
)

# Significant colored points
sns.scatterplot(
    data=df_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.9,
    s=10,
    zorder=3,
    edgecolor=None,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(0, linestyle='--', color='gray')

# Threshold labels
plt.text(low_thresh, -5, 'log2(0.5)', ha='center', va='center', fontsize=10, color='black')
plt.text(high_thresh, -5, 'log2(2)', ha='center', va='center', fontsize=10, color='black')

# Label top 5 by escape, binding, and count
top_5_escape = df_filtered.nlargest(5, 'log2_Enrichment')
top_5_binding = df_filtered.nsmallest(5, 'log2_Enrichment')
top_5_sample = df_filtered.nlargest(5, 'Sample_Count')

# Concatenate and drop duplicates based on 'site_label' to avoid overlap
top_5_labels = pd.concat([top_5_escape, top_5_binding, top_5_sample]).drop_duplicates(subset=['site_label'])

# Thresholds for labeling: sample count > 10 and extreme enrichment
count_thresh = df_filtered['Sample_Count'] > 10
enrichment_thresh = df_filtered['log2_Enrichment'] > 2
escape_thresh = df_filtered['log2_Enrichment'] < -5

df_labels = df_filtered[(count_thresh & (enrichment_thresh | escape_thresh))]

# Combine all labels for plotting
labels_to_plot = pd.concat([top_5_labels, df_labels]).drop_duplicates(subset=['site_label'])

texts = []
for _, row in labels_to_plot.iterrows():
    # Determine whether the label should be placed on the left or right
    if row['log2_Enrichment'] < 0:
        x_offset = -1.2  # Move labels to the left
    else:
        x_offset = 1.2  # Move labels to the right

    # Adjust the y position to avoid overlap
    y_offset = 1.9 if row['Sample_Count'] < 50 else 3

    # Plot the label with the adjusted position
    text = plt.text(
        row['log2_Enrichment'] + x_offset, row['Sample_Count'] + y_offset, row['site_label'],
        ha='center', va='center', fontsize=10, color='black'
    )
    texts.append(text)
    plt.plot([row['log2_Enrichment'], row['log2_Enrichment'] + x_offset], 
             [row['Sample_Count'], row['Sample_Count'] + y_offset], color='gray', linestyle='-', linewidth=1)

# Adjust text to avoid overlap
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.2, arrowprops=dict(arrowstyle='-', color='gray'))

# Titles and labels
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)
plt.ylabel("Monoclonal Antibody repertoire [n = 1x droplet]", fontsize=14)

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.ylim(bottom=0.5)
# Symmetric x-axis
xlim = max(abs(df_filtered['log2_Enrichment'].min()), abs(df_filtered['log2_Enrichment'].max()))
plt.xlim(-xlim, xlim)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd

# Filter for total reads > 500
df_filtered = df_filtered[df_filtered['Total_Reads'] > 500]

# Define significance thresholds
low_thresh = np.log2(0.5)
high_thresh = np.log2(2)

# Split significant and non-significant
df_non_sig = df_filtered[(df_filtered['log2_Enrichment'] >= low_thresh) & (df_filtered['log2_Enrichment'] <= high_thresh)]
df_sig = df_filtered[(df_filtered['log2_Enrichment'] < low_thresh) | (df_filtered['log2_Enrichment'] > high_thresh)]

# Slight y-jitter function
def add_y_jitter(arr, jitter_strength=10):
    return arr + np.random.normal(0, jitter_strength, size=len(arr))

# Apply y-jitter
df_non_sig_jittered = df_non_sig.copy()
df_sig_jittered = df_sig.copy()

df_non_sig_jittered['Sample_Count'] = add_y_jitter(df_non_sig_jittered['Sample_Count'], jitter_strength=0.15)
df_sig_jittered['Sample_Count'] = add_y_jitter(df_sig_jittered['Sample_Count'], jitter_strength=0.15)

# Corrected color palette
palette = {
    'B cells - Ancestral Wuhan RBD vaccine': '#1f77b4',  # Blue for Wuhan
    'B cells - B.1.135 RBD vaccine': '#00bfae',  # Turquoise for B.1.135
    'mAB (NEUT)': '#d62728',  # Red for NEUT
    'Polyreactive pAB': '#ff7f0e'  # Orange for Polyreactive pAB
}

# Start plotting
plt.figure(figsize=(12, 6))

# Non-significant gray points
sns.scatterplot(
    data=df_non_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    alpha=0.6,
    s=10,
    zorder=2,
    edgecolor=None
)

# Significant colored points
sns.scatterplot(
    data=df_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.9,
    s=10,
    zorder=3,
    edgecolor=None,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(0, linestyle='--', color='gray')

# Threshold labels
plt.text(low_thresh, -5, 'log2(0.5)', ha='center', va='center', fontsize=10, color='black')
plt.text(high_thresh, -5, 'log2(2)', ha='center', va='center', fontsize=10, color='black')

# Label top 5 by escape, binding, and count
top_5_escape = df_filtered.nlargest(5, 'log2_Enrichment')
top_5_binding = df_filtered.nsmallest(5, 'log2_Enrichment')
top_5_sample = df_filtered.nlargest(5, 'Sample_Count')

# Concatenate and drop duplicates based on 'site_label' to avoid overlap
top_5_labels = pd.concat([top_5_escape, top_5_binding, top_5_sample]).drop_duplicates(subset=['site_label'])

# Thresholds for labeling: sample count > 10 and extreme enrichment
count_thresh = df_filtered['Sample_Count'] > 10
enrichment_thresh = df_filtered['log2_Enrichment'] > 2
escape_thresh = df_filtered['log2_Enrichment'] < -5

df_labels = df_filtered[(count_thresh & (enrichment_thresh | escape_thresh))]

# Combine all labels for plotting
labels_to_plot = pd.concat([top_5_labels, df_labels]).drop_duplicates(subset=['site_label'])

texts = []
for _, row in labels_to_plot.iterrows():
    # Determine whether the label should be placed on the left or right
    if row['log2_Enrichment'] < 0:
        x_offset = -1.1  # Move labels to the left
    else:
        x_offset = 1.1  # Move labels to the right

    # Adjust the y position to avoid overlap
    y_offset = 1.1 if row['Sample_Count'] < 50 else 1.1

    # Plot the label with the adjusted position
    text = plt.text(
        row['log2_Enrichment'] + x_offset, row['Sample_Count'] + y_offset, row['site_label'],
        ha='center', va='center', fontsize=10, color='black'
    )
    texts.append(text)
    plt.plot([row['log2_Enrichment'], row['log2_Enrichment'] + x_offset], 
             [row['Sample_Count'], row['Sample_Count'] + y_offset], color='gray', linestyle='-', linewidth=1)

# Adjust text to avoid overlap
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.2, arrowprops=dict(arrowstyle='-', color='gray'))

# Titles and labels
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)
plt.ylabel("Droplet Monoclonal Antibody repertoire", fontsize=14)

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.ylim(bottom=0.5)
plt.xlim(-6, 6)
# Symmetric x-axis
xlim = max(abs(df_filtered['log2_Enrichment'].min()), abs(df_filtered['log2_Enrichment'].max()))
plt.xlim(-xlim, xlim)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# ------------------ Data Preparation ------------------

# Assuming `df_volcano` is already prepared and contains the necessary columns
# Filter out synonymous mutations or stop codons
df_volcano = df_volcano[~df_volcano['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]

# Top 10 enriched, escaped, and sample count mutations
df_volcano['label_this'] = False

# Top 10 enriched
top_enriched = df_volcano.nlargest(10, 'log2_Enrichment')

# Top 10 escaped (most negative enrichment)
top_escaped = df_volcano.nsmallest(10, 'log2_Enrichment')

# Top 10 by sample count
top_sample = df_volcano.nlargest(10, 'Sample_Count')

# Combine top 10 from each
label_indices = pd.concat([top_enriched, top_escaped, top_sample]).index.unique()
df_volcano.loc[label_indices, 'label_this'] = True

# ------------------ Plot ------------------

plt.figure(figsize=(12, 6))
sns.scatterplot(
    data=df_volcano,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    alpha=0.8,
    edgecolor='black'
)

# Add directional labels with arrows
for _, row in df_volcano[df_volcano['label_this']].iterrows():
    if row['log2_Enrichment'] >= 0:
        offset = (10, 5)  # closer to point to keep within plot
        ha = 'left'
    else:
        offset = (-10, 5)
        ha = 'right'
    
    plt.annotate(
        row['site_label'],
        xy=(row['log2_Enrichment'], row['Sample_Count']),
        xytext=offset,
        textcoords='offset points',
        ha=ha,
        va='bottom',
        fontsize=8,
        arrowprops=dict(arrowstyle='-', lw=0.6, color='gray'),
        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
        clip_on=True  # Keep inside axes limits
    )

# Vertical line for the center (log2(1) = 0)
plt.axvline(0, linestyle='--', color='gray')

# Title and labels
plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")
plt.legend(title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Adjust layout and display plot
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from Bio.Seq import Seq

import seaborn as sns


fasta_file = r'/Users/lucaschlotheuber/Desktop/ETH/RBD201_DMS1.fa'

# Read the FASTA file
for record in SeqIO.parse(fasta_file, "fasta"):
    wuhan_sequence = str(record.seq)
    break  # Assuming there's only one sequence in the FASTA file


file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'

df_total = pd.read_excel(
    file_path,
    usecols=[
        "DMS_RBD_AS_position", "Spike_AS_Position", "Count_of_Base", "Amino_Acid",
        "Type_of_Mutation", "Enrichment_Ratio", "barcode", "immunization",
        "condition", "Total_Reads", "Nucleotide_Ref", "Codon_Change"
    ]
)

# Subtract 5 to every position of Spike_AS_Position. (In excel sheets this starts at 336, instead of 331) I could also just have fixed the excel sheets. 
df_total["Spike_AS_Position"] = df_total["Spike_AS_Position"] - 5

#Removing ~5000 datapoints with "inf" values in the Enrichment_Ratio column and discarding reads with less than 10.000 total reads
df_total = df_total.dropna(subset=['Enrichment_Ratio','Amino_Acid'])

df_total = df_total[df_total["Total_Reads"] > 1000]

immunization = "Wuhan_Sequence"
barcode = "Wuhan_Barcode"

data_wuhan = []
for position, amino_acid in enumerate(wuhan_sequence, start=1):
    data_wuhan.append({
        'DMS_RBD_AS_position': position,
        'Spike_AS_Position': position + 330,
        'Amino_Acid': amino_acid,
        'immunization': immunization,
        'barcode': barcode,
        'Enrichment_Ratio': 1,# Assuming an enrichment ratio of 1 for simplicity
    })

# Filter escape mutations (Enrichment_Ratio < 1 and not zero)
print("Columns in total:")
print(df_total.columns)
df_escape = df_total[(df_total['Enrichment_Ratio'] < 1) & (df_total['Enrichment_Ratio'] != 0)].copy()

# Invert and log-transform Enrichment_Ratio
df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x)
df_escape['Enrichment_Ratio_log2'] = np.log2(df_escape['Enrichment_Ratio_inverted'])

# Extract relevant info (avoid duplication of Nucleotide_Ref, Codon_Change per group)
df_info = df_escape[['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization', 'Nucleotide_Ref', 'Codon_Change']].drop_duplicates()

# Aggregate
df_escape_agg = df_escape.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio_log2': 'mean'})

# Merge the info back in
df_escape_agg = pd.merge(df_escape_agg, df_info, on=['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], how='left')

# Merge necessary columns from df_total into df_logo_agg before aggregation
df_logo_agg = df_total[['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization', 'Enrichment_Ratio', 'Nucleotide_Ref', 'Codon_Change']]

# Perform the aggregation
df_logo_agg = df_logo_agg.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'mean',
    'Nucleotide_Ref': 'first',  # Assuming 'Nucleotide_Ref' is the same for all rows in each group
    'Codon_Change': 'first'  # Assuming 'Codon_Change' is the same for all rows in each group
})

print("Columns in df_combined_volcano before processing:")
print(df_combined_volcano.columns)


# Combine enriched and escaped data
df_combined_volcano = pd.concat([
    df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1],
    df_escape.assign(Enrichment_Ratio=df_escape['Enrichment_Ratio'])  # keep column naming consistent
])


print("Columns in df_combined_volcano after combining:")
print(df_combined_volcano.columns)

# Check the column names of df_logo_agg and df_escape
print("Columns in df_logo_agg:")
print(df_logo_agg.columns)

print("\nColumns in df_escape:")
print(df_escape.columns)

# Check if the necessary columns exist
necessary_columns = ['Nucleotide_Ref', 'Codon_Change']

for column in necessary_columns:
    print(f"\nChecking if '{column}' is present in df_logo_agg:")
    print(column in df_logo_agg.columns)

    print(f"Checking if '{column}' is present in df_escape:")
    print(column in df_escape.columns)

# After combining the datasets, ensure you're working with df_combined_volcano
df_combined_volcano = df_combined_volcano[df_combined_volcano['Enrichment_Ratio'] > 0]

print(df_combined_volcano.columns)

# Add log2 enrichment to df_combined_volcano
df_combined_volcano['log2_Enrichment'] = np.log2(df_combined_volcano['Enrichment_Ratio'])


print("Columns after adding log2_Enrichment:", df_combined_volcano.columns)

# Count how often each mutation appears across barcodes (like in df_combined)
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

print("Columns after counting mutations:", df_combined_volcano.columns)

# Get max log2 enrichment ratio for coloring (optional aesthetic)
df_max_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .mean()
    .reset_index(name='log2_Enrichment')
)

# Rename the 'log2_Enrichment' column in df_max_enrichment to avoid conflicts during the merge
df_max_enrichment = df_max_enrichment.rename(columns={'log2_Enrichment': 'log2_Enrichment_max'})

# Merge count and enrichment
df_combined_volcano = pd.merge(
    df_combined_volcano,  # Keep all columns from original df_combined_volcano
    df_counts[['Spike_AS_Position', 'Amino_Acid', 'immunization', 'Sample_Count']],  # Keep only relevant columns
    on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], 
    how='left'
)

# Merge log2 enrichment (ensuring no conflict)
df_combined_volcano = pd.merge(
    df_combined_volcano,  # Keep all columns from previous merge
    df_max_enrichment, 
    on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], 
    how='left'
)

# Rename the new column if necessary to match the naming convention
df_combined_volcano = df_combined_volcano.rename(columns={'log2_Enrichment_max': 'log2_Enrichment'})


print("Columns after both merges:", df_combined_volcano.columns)

# Assign label for volcano plot
df_combined_volcano['site_label'] = df_combined_volcano['Amino_Acid'] + "_" + df_combined_volcano['Spike_AS_Position'].astype(str)

print(df_combined_volcano.columns)

# Filter to positions of interest (optional)
df_combined_volcano = df_combined_volcano[df_combined_volcano['Spike_AS_Position'] > 365]

print(df_combined_volcano.columns)

print(df_combined_volcano.dtypes)


# Function to get amino acid from codon
def get_amino_acid(codon):
    """ Translate a codon to the corresponding amino acid """
    try:
        return str(Seq(codon).translate())
    except:
        return '?'  # Return '?' if translation fails (e.g., invalid codon)

# ------------------ Data Preparation ------------------

# Assuming df_combined_volcano is already prepared and contains the necessary columns
# Filter out synonymous mutations or stop codons
df_combined_volcano = df_combined_volcano[~df_combined_volcano['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]

# Compute reference codon and amino acid
df_combined_volcano['Reference_Codon'] = df_combined_volcano['Nucleotide_Ref'] + df_combined_volcano['Codon_Change']
df_combined_volcano['Reference_Amino_Acid'] = df_combined_volcano['Reference_Codon'].apply(get_amino_acid)

# Generate mutation label in the format "ReferenceAminoAcidPositionMutatedAminoAcid"
df_combined_volcano['mutation_label'] = df_combined_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if row['Amino_Acid'] != row['Reference_Amino_Acid'] else "", 
    axis=1
)

# Top 10 enriched, escaped, and sample count mutations
df_combined_volcano['label_this'] = False

# Top 10 enriched
top_enriched = df_combined_volcano.nlargest(10, 'log2_Enrichment')

# Top 10 escaped (most negative enrichment)
top_escaped = df_combined_volcano.nsmallest(10, 'log2_Enrichment')

# Top 10 by sample count
top_sample = df_combined_volcano.nlargest(10, 'Sample_Count')

# Combine top 10 from each
label_indices = pd.concat([top_enriched, top_escaped, top_sample]).index.unique()
df_combined_volcano.loc[label_indices, 'label_this'] = True

# ------------------ Plot ------------------

plt.figure(figsize=(12, 6))
sns.scatterplot(
    data=df_combined_volcano,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    alpha=0.8,
    edgecolor='black'
)

# Add directional labels with arrows
for _, row in df_combined_volcano[df_combined_volcano['label_this']].iterrows():
    if row['log2_Enrichment'] >= 0:
        offset = (10, 5)  # closer to point to keep within plot
        ha = 'left'
    else:
        offset = (-10, 5)
        ha = 'right'
    
    # Use the mutation label for annotation
    plt.annotate(
        row['mutation_label'],  # Use the mutation label for annotation
        xy=(row['log2_Enrichment'], row['Sample_Count']),
        xytext=offset,
        textcoords='offset points',
        ha=ha,
        va='bottom',
        fontsize=8,
        arrowprops=dict(arrowstyle='-', lw=0.6, color='gray'),
        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
        clip_on=True  # Keep inside axes limits
    )

# Vertical line for the center (log2(1) = 0)
plt.axvline(0, linestyle='--', color='gray')

# Title and labels
plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")
plt.legend(title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Adjust layout and display plot
plt.tight_layout()
plt.show()



In [None]:
#Volcanoplot

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Combine enriched and escaped data
df_combined_volcano = pd.concat([
    df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1],
    df_escape.assign(Enrichment_Ratio=df_escape['Enrichment_Ratio'])  # keep column naming consistent
])

# Drop zero or invalid enrichment ratios
df_combined_volcano = df_combined_volcano[df_combined_volcano['Enrichment_Ratio'] > 0]

# Add log2 enrichment
df_combined_volcano['log2_Enrichment'] = np.log2(df_combined_volcano['Enrichment_Ratio'])

# Count how often each mutation appears across barcodes (like in df_combined)
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

# Get max log2 enrichment ratio for coloring (optional aesthetic)
df_max_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .mean()
    .reset_index(name='log2_Enrichment')
)

# Merge count and enrichment
df_volcano = pd.merge(df_counts, df_max_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Filter to positions of interest (optional)
df_volcano = df_volcano[df_volcano['Spike_AS_Position'] > 365]

# ------- Add E484K-style mutation labels -------
# ------- Add E484K-style mutation labels for specific positions -------

# Define ancestral (Wuhan) residues at key positions
wuhan_strain_aa = {
    417: 'K',
    452: 'L',
    484: 'E',
    501: 'N',
    346: 'R',
    # Add more if needed
}

# Generate mutation label like E484K
df_volcano['mutation_label'] = df_volcano.apply(
    lambda row: (
        f"{wuhan_strain_aa.get(row['Spike_AS_Position'], '?')}"
        f"{row['Spike_AS_Position']}"
        f"{row['Amino_Acid']}"
    ) if row['Spike_AS_Position'] in wuhan_strain_aa else "",
    axis=1
)

# Label only mutations at selected positions
positions_to_label = set(wuhan_strain_aa.keys())
df_volcano['label_this'] = df_volcano['Spike_AS_Position'].isin(positions_to_label)


# ------------------ PLOT ------------------

plt.figure(figsize=(12, 6))
sns.scatterplot(
    data=df_volcano,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    alpha=0.8,
    edgecolor='black'
)

# Add text labels for selected mutations
for _, row in df_volcano[df_volcano['label_this']].iterrows():
    plt.text(
        row['log2_Enrichment'],
        row['Sample_Count'] + 0.5,  # offset to avoid overlap
        row['mutation_label'],
        fontsize=12,
        ha='center',
        va='bottom'
    )

plt.axvline(0, linestyle='--', color='gray')  # Center at log2(1) = 0
plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")
plt.legend(title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text
from Bio.Seq import Seq

# Define a function to calculate the reference amino acid
def get_reference_aa(codon_change, nucleotide_ref):
    original_codon = ''.join(
        [nucleotide_ref if base.isupper() else base for base in codon_change]
    )
    return Seq(original_codon).translate()

# Combine enriched and escaped data
df_combined_volcano = pd.concat([
    df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1],
    df_escape.assign(Enrichment_Ratio=df_escape['Enrichment_Ratio'])
])

df_combined_volcano = df_combined_volcano[df_combined_volcano['Enrichment_Ratio'] > 0]
df_combined_volcano['log2_Enrichment'] = np.log2(df_combined_volcano['Enrichment_Ratio'])

# Retain all columns from the original data frame
original_columns = df_combined_volcano.columns.tolist()

# Count appearances across barcodes
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

# Get mean log2 enrichment
df_max_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .mean()
    .reset_index(name='log2_Enrichment')
)

# Merge counts and enrichment data
df_volcano = pd.merge(df_counts, df_max_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Assign mutation label like E484K
df_volcano['Reference_Amino_Acid'] = df_volcano.apply(
    lambda row: get_reference_aa(row['Codon_Change'], row['Nucleotide_Ref'])
    if pd.notna(row['Codon_Change']) and pd.notna(row['Nucleotide_Ref']) else np.nan,
    axis=1
)

# Create labels like N501Y
df_volcano['site_label'] = df_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Assign categories to 'Category' column
df_filtered = df_filtered.copy()
df_filtered['Category'] = pd.NA
df_filtered.loc[df_filtered['immunization'] == 'Neutralizing_Ab', 'Category'] = 'mAB (NEUT)'
df_filtered.loc[df_filtered['immunization'] == 'Polyclonal_Ab', 'Category'] = 'Polyreactive pAB'
df_filtered.loc[df_filtered['immunization'] == 'Mutant_RBD', 'Category'] = 'B cells - B.1.135 RBD vaccine'
df_filtered.loc[df_filtered['immunization'] == 'wildtype_RBD', 'Category'] = 'B cells - Ancestral Wuhan RBD vaccine'

# Create a color palette based on the 'Category' column
palette = {
    'mAB (NEUT)': 'red',
    'Polyreactive pAB': 'orange',
    'B cells - B.1.135 RBD vaccine': 'blue',
    'B cells - Ancestral Wuhan RBD vaccine': 'turquoise'
}

# Create the volcano plot
plt.figure(figsize=(12, 6))

# Shade non-significant region (0.5–2 enrichment ratio in log2 space) with 50% transparency
plt.axvspan(np.log2(0.5), np.log2(2), color='gray', alpha=0.5, zorder=2)

# Add black dotted lines for the non-significant region borders
plt.plot([np.log2(0.5), np.log2(0.5)], [0, df_filtered['Sample_Count'].max()], 'k--', lw=2, zorder=3)
plt.plot([np.log2(2), np.log2(2)], [0, df_filtered['Sample_Count'].max()], 'k--', lw=2, zorder=3)

# Label the intersection of the non-significant area with the x-axis
plt.text(np.log2(0.5), -5, 'log2(0.5)', ha='center', va='center', fontsize=12, color='black')
plt.text(np.log2(2), -5, 'log2(2)', ha='center', va='center', fontsize=12, color='black')

# Plot the data with custom colors for each category
sns.scatterplot(
    data=df_filtered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.8,
    edgecolor='black',
    s=30,
    zorder=4,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Ensure we label each dot only once
top_10_escape = df_filtered.nlargest(10, 'log2_Enrichment')
top_10_binding = df_filtered.nsmallest(10, 'log2_Enrichment')
top_10_sample = df_filtered.nlargest(10, 'Sample_Count')

# Concatenate the dataframes and drop duplicates to ensure unique labels
top_10_labels = pd.concat([top_10_escape, top_10_binding, top_10_sample]).drop_duplicates(subset=['site_label'])

# Add labels
texts = []
for _, row in top_10_labels.iterrows():
    text = plt.text(
        row['log2_Enrichment'], row['Sample_Count'], row['site_label'],
        ha='center', va='center', fontsize=10
    )
    texts.append(text)

# Add labels for specific residues
residues_to_label = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
for pos in residues_to_label:
    residue_data = df_filtered[df_filtered['Spike_AS_Position'] == pos]
    if not residue_data.empty:
        plt.text(
            residue_data['log2_Enrichment'].values[0],
            residue_data['Sample_Count'].values[0],
            f"{residue_data['site_label'].values[0]}",
            fontsize=12, ha='center', va='center', color='black'
        )

# Use adjustText to avoid overlap
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.1, army_strength=2.0)

# Adjust x-axis title to avoid overlap with shaded area
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)  # Adjusted x-axis label position
plt.ylabel("Sample Count", fontsize=14)

# Custom legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Adjust layout
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text
from Bio.Seq import Seq

# Define a function to calculate the reference amino acid
def get_reference_aa(codon_change, nucleotide_ref):
    original_codon = ''.join(
        [nucleotide_ref if base.isupper() else base for base in codon_change]
    )
    return Seq(original_codon).translate()

# Combine enriched and escaped data
df_combined_volcano = pd.concat([
    df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1],
    df_escape.assign(Enrichment_Ratio=df_escape['Enrichment_Ratio'])
])

df_combined_volcano = df_combined_volcano[df_combined_volcano['Enrichment_Ratio'] > 0]
df_combined_volcano['log2_Enrichment'] = np.log2(df_combined_volcano['Enrichment_Ratio'])

# Retain all columns from the original data frame
original_columns = df_combined_volcano.columns.tolist()

# Count appearances across barcodes
df_counts = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

# Get mean log2 enrichment
df_max_enrichment = (
    df_combined_volcano
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['log2_Enrichment']
    .mean()
    .reset_index(name='log2_Enrichment')
)

# Merge counts and enrichment data
df_volcano = pd.merge(df_counts, df_max_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Merge with the original data to retain all columns
df_volcano = pd.merge(df_volcano, df_combined_volcano[original_columns], on=['Spike_AS_Position', 'Amino_Acid', 'immunization'], how='left')

# Rename the 'log2_Enrichment' columns to keep only one
df_volcano['log2_Enrichment'] = df_volcano['log2_Enrichment_x']
df_volcano.drop(columns=['log2_Enrichment_x', 'log2_Enrichment_y'], inplace=True)

# Assign mutation label like E484K
df_volcano['Reference_Amino_Acid'] = df_volcano.apply(
    lambda row: get_reference_aa(row['Codon_Change'], row['Nucleotide_Ref'])
    if pd.notna(row['Codon_Change']) and pd.notna(row['Nucleotide_Ref']) else np.nan,
    axis=1
)

# Create labels like N501Y
df_volcano['site_label'] = df_volcano.apply(
    lambda row: f"{row['Reference_Amino_Acid']}{row['Spike_AS_Position']}{row['Amino_Acid']}"
    if pd.notna(row['Reference_Amino_Acid']) and pd.notna(row['Amino_Acid']) else '',
    axis=1
)

# Exclude synonymous mutations and stop codons
df_volcano = df_volcano[df_volcano['Amino_Acid'].notna()]
df_volcano['is_synonymous'] = df_volcano['site_label'].str[0] == df_volcano['site_label'].str[-1]
stop_codons = ['*', 'Stop', 'X']
df_filtered = df_volcano[
    (~df_volcano['Amino_Acid'].isin(stop_codons)) &
    (~df_volcano['is_synonymous'])
]

# Assign categories to 'Category' column
df_filtered = df_filtered.copy()
df_filtered['Category'] = pd.NA
df_filtered.loc[df_filtered['immunization'] == 'Neutralizing_Ab', 'Category'] = 'mAB (NEUT)'
df_filtered.loc[df_filtered['immunization'] == 'Polyclonal_Ab', 'Category'] = 'Polyreactive pAB'
df_filtered.loc[df_filtered['immunization'] == 'Mutant_RBD', 'Category'] = 'B cells - B.1.135 RBD vaccine'
df_filtered.loc[df_filtered['immunization'] == 'wildtype_RBD', 'Category'] = 'B cells - Ancestral Wuhan RBD vaccine'

# Create a color palette based on the 'Category' column
palette = {
    'mAB (NEUT)': 'red',
    'Polyreactive pAB': 'orange',
    'B cells - B.1.135 RBD vaccine': 'blue',
    'B cells - Ancestral Wuhan RBD vaccine': 'turquoise'
}

# Create the volcano plot
plt.figure(figsize=(12, 6))

# Shade non-significant region (0.5–2 enrichment ratio in log2 space) with 50% transparency
plt.axvspan(np.log2(0.5), np.log2(2), color='gray', alpha=0.5, zorder=2)

# Add black dotted lines for the non-significant region borders
plt.plot([np.log2(0.5), np.log2(0.5)], [0, df_filtered['Sample_Count'].max()], 'k--', lw=2, zorder=3)
plt.plot([np.log2(2), np.log2(2)], [0, df_filtered['Sample_Count'].max()], 'k--', lw=2, zorder=3)

# Label the intersection of the non-significant area with the x-axis
plt.text(np.log2(0.5), -5, 'log2(0.5)', ha='center', va='center', fontsize=12, color='black')
plt.text(np.log2(2), -5, 'log2(2)', ha='center', va='center', fontsize=12, color='black')

# Plot the data with custom colors for each category
sns.scatterplot(
    data=df_filtered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.8,
    edgecolor='black',
    s=15,
    zorder=4,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Ensure we label each dot only once
top_10_escape = df_filtered.nlargest(10, 'log2_Enrichment')
top_10_binding = df_filtered.nsmallest(10, 'log2_Enrichment')
top_10_sample = df_filtered.nlargest(10, 'Sample_Count')

# Concatenate the dataframes and drop duplicates to ensure unique labels
top_10_labels = pd.concat([top_10_escape, top_10_binding, top_10_sample]).drop_duplicates(subset=['site_label'])

# Add labels
texts = []
for _, row in top_10_labels.iterrows():
    text = plt.text(
        row['log2_Enrichment'], row['Sample_Count'], row['site_label'],
        ha='center', va='center', fontsize=10
    )
    texts.append(text)

# Add labels for specific residues
residues_to_label = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
for pos in residues_to_label:
    residue_data = df_filtered[df_filtered['Spike_AS_Position'] == pos]
    if not residue_data.empty:
        plt.text(
            residue_data['log2_Enrichment'].values[0],
            residue_data['Sample_Count'].values[0],
            f"{residue_data['site_label'].values[0]}",
            fontsize=12, ha='center', va='center', color='black'
        )

# Use adjustText to avoid overlap
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.1, army_strength=2.0)

# Adjust x-axis title to avoid overlap with shaded area
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)  # Adjusted x-axis label position
plt.ylabel("Sample Count", fontsize=14)

# Custom legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Adjust layout
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text

# Filter for total reads > 500
df_filtered = df_filtered[df_filtered['Total_Reads'] > 500]

# Define significance thresholds
low_thresh = np.log2(0.5)
high_thresh = np.log2(2)

# Split significant and non-significant
df_non_sig = df_filtered[(df_filtered['log2_Enrichment'] >= low_thresh) & (df_filtered['log2_Enrichment'] <= high_thresh)]
df_sig = df_filtered[(df_filtered['log2_Enrichment'] < low_thresh) | (df_filtered['log2_Enrichment'] > high_thresh)]

# Slight y-jitter function
def add_y_jitter(arr, jitter_strength=0.1):
    return arr + np.random.normal(0, jitter_strength, size=len(arr))

# Apply y-jitter
df_non_sig_jittered = df_non_sig.copy()
df_sig_jittered = df_sig.copy()

df_non_sig_jittered['Sample_Count'] = add_y_jitter(df_non_sig_jittered['Sample_Count'], jitter_strength=10)
df_sig_jittered['Sample_Count'] = add_y_jitter(df_sig_jittered['Sample_Count'], jitter_strength=10)

# Darker color palette
palette = {
    'B cells - Ancestral Wuhan RBD vaccine': '#1f77b4',
    'B cells - B.1.135 RBD vaccine': '#ff7f0e',
    'mAB (NEUT)': '#2ca02c',
    'Polyreactive pAB': '#d62728'
}

# Start plotting
plt.figure(figsize=(12, 6))

# Non-significant gray points
sns.scatterplot(
    data=df_non_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    alpha=0.6,
    s=14,
    zorder=2,
    edgecolor=None
)

# Significant colored points
sns.scatterplot(
    data=df_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.9,
    s=14,
    zorder=3,
    edgecolor=None,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=2, zorder=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=2, zorder=1)
plt.axvline(0, linestyle='--', color='gray')

# Threshold labels
plt.text(low_thresh, -5, 'log2(0.5)', ha='center', va='center', fontsize=12, color='black')
plt.text(high_thresh, -5, 'log2(2)', ha='center', va='center', fontsize=12, color='black')

# Label top 10 by escape, binding, and count
top_10_escape = df_filtered.nlargest(10, 'log2_Enrichment')
top_10_binding = df_filtered.nsmallest(10, 'log2_Enrichment')
top_10_sample = df_filtered.nlargest(10, 'Sample_Count')
top_10_labels = pd.concat([top_10_escape, top_10_binding, top_10_sample]).drop_duplicates(subset=['site_label'])

texts = []
for _, row in top_10_labels.iterrows():
    text = plt.text(
        row['log2_Enrichment'], row['Sample_Count'], row['site_label'],
        ha='center', va='center', fontsize=10
    )
    texts.append(text)

# Add custom residue labels
residues_to_label = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
for pos in residues_to_label:
    residue_data = df_filtered[df_filtered['Spike_AS_Position'] == pos]
    if not residue_data.empty:
        text = plt.text(
            residue_data['log2_Enrichment'].values[0],
            residue_data['Sample_Count'].values[0],
            f"{residue_data['site_label'].values[0]}",
            fontsize=12, ha='center', va='center', color='black'
        )
        texts.append(text)

# Adjust text to avoid overlap
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.2, arrowprops=dict(arrowstyle='-', color='gray'))

# Titles and labels
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)
plt.ylabel("Sample Count", fontsize=14)

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Symmetric x-axis
xlim = max(abs(df_filtered['log2_Enrichment'].min()), abs(df_filtered['log2_Enrichment'].max()))
plt.xlim(-xlim, xlim)

plt.tight_layout()
plt.show()


In [None]:
# Filter for total reads > 500
df_filtered = df_filtered[df_filtered['Total_Reads'] > 500]

# Create the volcano plot
plt.figure(figsize=(12, 6))

# Plot the data with custom colors for each category
sns.scatterplot(
    data=df_filtered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.8,
    edgecolor='black',
    s=12,
    zorder=4,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot shaded non-significant region
plt.axvspan(np.log2(0.5), np.log2(2), color='gray', alpha=0.5, zorder=5)

# Plot non-significant region border lines
plt.plot([np.log2(0.5), np.log2(0.5)], [0, df_filtered['Sample_Count'].max()], 'k--', lw=2, zorder=3)
plt.plot([np.log2(2), np.log2(2)], [0, df_filtered['Sample_Count'].max()], 'k--', lw=2, zorder=3)

# Label axis thresholds
plt.text(np.log2(0.5), -5, 'log2(0.5)', ha='center', va='center', fontsize=12, color='black')
plt.text(np.log2(2), -5, 'log2(2)', ha='center', va='center', fontsize=12, color='black')

# Label top escape, binding, and sample mutations
top_10_escape = df_filtered.nlargest(10, 'log2_Enrichment')
top_10_binding = df_filtered.nsmallest(10, 'log2_Enrichment')
top_10_sample = df_filtered.nlargest(10, 'Sample_Count')
top_10_labels = pd.concat([top_10_escape, top_10_binding, top_10_sample]).drop_duplicates(subset=['site_label'])

texts = []
for _, row in top_10_labels.iterrows():
    text = plt.text(
        row['log2_Enrichment'], row['Sample_Count'], row['site_label'],
        ha='center', va='center', fontsize=10
    )
    texts.append(text)

# Label specific residues of interest
residues_to_label = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
for pos in residues_to_label:
    residue_data = df_filtered[df_filtered['Spike_AS_Position'] == pos]
    if not residue_data.empty:
        plt.text(
            residue_data['log2_Enrichment'].values[0],
            residue_data['Sample_Count'].values[0],
            f"{residue_data['site_label'].values[0]}",
            fontsize=12, ha='center', va='center', color='black'
        )

# Beautify plot
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.1, army_strength=2.0)

plt.axvline(0, linestyle='--', color='gray')
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)

# Set x-axis label with extra padding
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)

plt.ylabel("Sample Count", fontsize=14)

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
plt.show()


In [None]:
#Scatterplotting

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from adjustText import adjust_text
import pandas as pd

# Filter for total reads > 500
df_filtered = df_filtered[df_filtered['Total_Reads'] > 500]

# Define significance thresholds
low_thresh = np.log2(0.5)
high_thresh = np.log2(2)

# Split significant and non-significant
df_non_sig = df_filtered[(df_filtered['log2_Enrichment'] >= low_thresh) & (df_filtered['log2_Enrichment'] <= high_thresh)]
df_sig = df_filtered[(df_filtered['log2_Enrichment'] < low_thresh) | (df_filtered['log2_Enrichment'] > high_thresh)]

# Slight y-jitter function
def add_y_jitter(arr, jitter_strength=10):
    return arr + np.random.normal(0, jitter_strength, size=len(arr))

# Apply y-jitter
df_non_sig_jittered = df_non_sig.copy()
df_sig_jittered = df_sig.copy()

df_non_sig_jittered['Sample_Count'] = add_y_jitter(df_non_sig_jittered['Sample_Count'], jitter_strength=0.15)
df_sig_jittered['Sample_Count'] = add_y_jitter(df_sig_jittered['Sample_Count'], jitter_strength=0.15)

# Corrected color palette (darker turquoise)
palette = {
    'B cells - Ancestral Wuhan RBD vaccine': '#1f77b4',  # Blue for Wuhan
    'B cells - B.1.135 RBD vaccine': '#009688',  # Darker turquoise for B.1.135
    'mAB (NEUT)': '#d62728',  # Red for NEUT
    'Polyreactive pAB': '#ff7f0e'  # Orange for Polyreactive pAB
}

# Start plotting
plt.figure(figsize=(12, 6))

# Non-significant gray points
sns.scatterplot(
    data=df_non_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    alpha=0.6,
    s=10,
    zorder=2,
    edgecolor=None
)

# Significant colored points
sns.scatterplot(
    data=df_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.9,
    s=10,
    zorder=3,
    edgecolor=None,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(0, linestyle='--', color='gray')

# Threshold labels
plt.text(low_thresh, -5, 'log2(0.5)', ha='center', va='center', fontsize=10, color='black')
plt.text(high_thresh, -5, 'log2(2)', ha='center', va='center', fontsize=10, color='black')

# Label top 5 by escape, binding, and count
top_5_escape = df_filtered.nlargest(5, 'log2_Enrichment')
top_5_binding = df_filtered.nsmallest(5, 'log2_Enrichment')
top_5_sample = df_filtered.nlargest(5, 'Sample_Count')

# Concatenate and drop duplicates based on 'site_label' to avoid overlap
top_5_labels = pd.concat([top_5_escape, top_5_binding, top_5_sample]).drop_duplicates(subset=['site_label'])

# Thresholds for labeling: sample count > 10 and extreme enrichment
count_thresh = df_filtered['Sample_Count'] > 10
enrichment_thresh = df_filtered['log2_Enrichment'] > 2
escape_thresh = df_filtered['log2_Enrichment'] < -5

df_labels = df_filtered[(count_thresh & (enrichment_thresh | escape_thresh))]

# Combine all labels for plotting
labels_to_plot = pd.concat([top_5_labels, df_labels]).drop_duplicates(subset=['site_label'])

# Define the specific residues to label
residues_to_label = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]

# Label the specified residues
texts = []
for _, row in labels_to_plot.iterrows():
    if row['site_label'] in residues_to_label:  # Check if residue is in the list
        # Determine whether the label should be placed on the left or right
        if row['log2_Enrichment'] < 0:
            x_offset = -1.1  # Move labels to the left
        else:
            x_offset = 1.1  # Move labels to the right

        # Adjust the y position to avoid overlap
        y_offset = 1.1 if row['Sample_Count'] < 50 else 1.1

        # Plot the label with the adjusted position
        text = plt.text(
            row['log2_Enrichment'] + x_offset, row['Sample_Count'] + y_offset, row['site_label'],
            ha='center', va='center', fontsize=10, color='black'
        )
        texts.append(text)
        plt.plot([row['log2_Enrichment'], row['log2_Enrichment'] + x_offset], 
                 [row['Sample_Count'], row['Sample_Count'] + y_offset], color='gray', linestyle='-', linewidth=1)

# Adjust text to avoid overlap
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.2, arrowprops=dict(arrowstyle='-', color='gray'))

# Titles and labels
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=10)  # Adjusted labelpad for the x-axis title
plt.ylabel("Droplet Monoclonal Antibody repertoire", fontsize=14)

# Adding ticks in between
# Set x and y axis ticks manually
x_ticks = np.arange(-6, 6.1, 0.5)  # Adding ticks between -6 and 6 with a step of 0.5
y_ticks = np.arange(0, df_filtered['Sample_Count'].max(), 5)  # You can adjust the range as needed

# Apply the ticks to both axes
plt.xticks(x_ticks)
plt.yticks(y_ticks)

# Adding labels for ticks at intersections
for tick in x_ticks:
    plt.text(tick, 0.5, f'{tick:.1f}', ha='center', va='bottom', fontsize=10, color='black')

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.ylim(bottom=0.5)
plt.xlim(-6, 6)
# Symmetric x-axis
xlim = max(abs(df_filtered['log2_Enrichment'].min()), abs(df_filtered['log2_Enrichment'].max()))
plt.xlim(-xlim, xlim)

plt.tight_layout()
plt.show()


In [None]:
#Volcano plotting

In [None]:
# Start plotting
plt.figure(figsize=(12, 6))

# Non-significant gray points
sns.scatterplot(
    data=df_non_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    alpha=0.6,
    s=14,
    zorder=2,
    edgecolor=None
)

# Significant colored points
sns.scatterplot(
    data=df_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.9,
    s=10,
    zorder=3,
    edgecolor=None,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=1, zorder=1)
plt.axvline(0, linestyle='--', color='gray')

# Threshold labels
plt.text(low_thresh, -5, 'log2(0.5)', ha='center', va='center', fontsize=10, color='black')
plt.text(high_thresh, -5, 'log2(2)', ha='center', va='center', fontsize=10, color='black')

# Symmetric x-axis limit from -6 to 6
plt.xlim(-6, 6)

# Titles and labels
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)
plt.ylabel("Monoclonal Antibody repertoire [n = 1x droplet]", fontsize=14)

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

# Adjust layout
plt.tight_layout()
plt.show()


In [None]:
# Filter for total reads > 500
df_filtered = df_filtered[df_filtered['Total_Reads'] > 500]

# Define significance threshold
low_thresh = np.log2(0.5)
high_thresh = np.log2(2)

# Split into significant and non-significant
df_non_sig = df_filtered[(df_filtered['log2_Enrichment'] >= low_thresh) & (df_filtered['log2_Enrichment'] <= high_thresh)]
df_sig = df_filtered[(df_filtered['log2_Enrichment'] < low_thresh) | (df_filtered['log2_Enrichment'] > high_thresh)]

# Create the volcano plot
plt.figure(figsize=(12, 6))

import seaborn as sns
import numpy as np

# Slight jitter function
def add_jitter(arr, jitter_strength=0):
    return arr + np.random.normal(0, jitter_strength, size=len(arr))

# Apply jitter to copies of the data
df_non_sig_jittered = df_non_sig.copy()
df_sig_jittered = df_sig.copy()

df_non_sig_jittered['log2_Enrichment'] = add_jitter(df_non_sig_jittered['log2_Enrichment'], 0.1)
df_sig_jittered['log2_Enrichment'] = add_jitter(df_sig_jittered['log2_Enrichment'], 0.1)

# Plot non-significant points in light gray
sns.scatterplot(
    data=df_non_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    color='lightgray',
    alpha=0.6,
    s=14,
    zorder=2,
    edgecolor=None
)

# Plot significant points with category colors
sns.scatterplot(
    data=df_sig_jittered,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='Category',
    palette=palette,
    alpha=0.9,
    s=14,
    zorder=3,
    edgecolor=None,
    hue_order=[
        'B cells - Ancestral Wuhan RBD vaccine',
        'B cells - B.1.135 RBD vaccine',
        'mAB (NEUT)',
        'Polyreactive pAB'
    ]
)

# Plot threshold lines
plt.axvline(low_thresh, linestyle='--', color='gray', lw=2, zorder=1)
plt.axvline(high_thresh, linestyle='--', color='gray', lw=2, zorder=1)

# Label thresholds
plt.text(low_thresh, -5, 'log2(0.5)', ha='center', va='center', fontsize=12, color='black')
plt.text(high_thresh, -5, 'log2(2)', ha='center', va='center', fontsize=12, color='black')

# Label top escape, binding, and sample mutations
top_10_escape = df_filtered.nlargest(10, 'log2_Enrichment')
top_10_binding = df_filtered.nsmallest(10, 'log2_Enrichment')
top_10_sample = df_filtered.nlargest(10, 'Sample_Count')
top_10_labels = pd.concat([top_10_escape, top_10_binding, top_10_sample]).drop_duplicates(subset=['site_label'])

texts = []
for _, row in top_10_labels.iterrows():
    text = plt.text(
        row['log2_Enrichment'], row['Sample_Count'], row['site_label'],
        ha='center', va='center', fontsize=10
    )
    texts.append(text)

# Label specific residues of interest
residues_to_label = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
for pos in residues_to_label:
    residue_data = df_filtered[df_filtered['Spike_AS_Position'] == pos]
    if not residue_data.empty:
        plt.text(
            residue_data['log2_Enrichment'].values[0],
            residue_data['Sample_Count'].values[0],
            f"{residue_data['site_label'].values[0]}",
            fontsize=12, ha='center', va='center', color='black'
        )

# Beautify
adjust_text(texts, only_move={'points': 'y', 'texts': 'xy'}, force_text=0.1, army_strength=2.0)

plt.axvline(0, linestyle='--', color='gray')
plt.title("Volcano Plot: Escape vs Binding Mutations", fontsize=16)

# Set x-axis label with extra padding
plt.xlabel("← Escape   |   Binding →", fontsize=14, labelpad=20)
plt.ylabel("Sample Count", fontsize=14)

# Legend for significant categories only
handles, labels = plt.gca().get_legend_handles_labels()
plt.legend(handles, labels, title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')

plt.tight_layout()
# Make x-axis symmetric around 0
xlim = max(abs(df_filtered['log2_Enrichment'].min()), abs(df_filtered['log2_Enrichment'].max()))
plt.xlim(-xlim, xlim)
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(12, 6))
sns.scatterplot(
    data=df_volcano,
    x='log2_Enrichment',
    y='Sample_Count',
    hue='immunization',
    alpha=0.8,
    edgecolor='black'
)

# Add text labels with arrows pointing to the dots
for _, row in df_volcano[df_volcano['label_this']].iterrows():
    plt.annotate(
        row['mutation_label'],
        xy=(row['log2_Enrichment'], row['Sample_Count']),           # Point to the data point
        xytext=(-40, 20),                                           # Offset label (x,y) in points
        textcoords='offset points',
        ha='right',
        va='bottom',
        fontsize=9,
        arrowprops=dict(arrowstyle='-', lw=0.8, color='gray'),     # Line pointing to the dot
        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7)  # Optional: background
    )

plt.axvline(0, linestyle='--', color='gray')  # log2(1)
plt.title("Volcano Plot: Escape vs Binding Mutations")
plt.xlabel("log2(Enrichment Ratio) (← Escape   |   Binding →)")
plt.ylabel("Sample Count")
plt.legend(title="Immunization", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

plt.ion()

# Define the dictionary for the specific amino acids at certain positions
wuhan_strain_aa = {
    417: 'K', 439: 'N', 440: 'N', 452: 'L', 476: 'G', 477: 'S', 484: 'E',
    493: 'Q', 501: 'N', 502: 'G', 505: 'Y'
}

# Aggregate data
df_logo_agg = df_total.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio': 'sum'})

# Filter for Enrichment_Ratio > 1
df_logo_agg = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]

# Define sites to show (as strings for consistency)
sites_to_show = map(str, [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505])

# Add site labels and determine which sites to show
df_logo_agg = df_logo_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Ensure amino acids are uppercase and exclude specific characters
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[~df_logo_agg['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]

# Debugging: Check the filtered positions and amino acids
print("Filtered data with relevant sites and amino acids:")
print(df_logo_agg[['Spike_AS_Position', 'Amino_Acid']].drop_duplicates())

# Generate logo plots for each unique barcode
for barcode in df_logo_agg['barcode'].unique():
    print(f"Generating plot for barcode: {barcode}")
    filtered_data = df_logo_agg.query(f'barcode == "{barcode}"').query("show_site")

    # Debugging: Check positions for the current barcode
    print(f"Positions and amino acids for {barcode}:")
    print(filtered_data[['Spike_AS_Position', 'Amino_Acid']])

    # Specifically print out rows for position 505 in the original dataframe (df_total)
    print("\nRows for position 505 in the original dataframe:")
    rows_505 = df_total[df_total['Spike_AS_Position'] == 505][['Spike_AS_Position', 'Amino_Acid', 'Enrichment_Ratio', 'DMS_RBD_AS_position']]
    print(rows_505)

    # Check if there is any data to plot
    if not filtered_data.empty:
        fig, ax = dmslogo.draw_logo(
            filtered_data,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio",
            title=barcode + ' logoplot',
            addbreaks=True
        )

        ax.set_ylabel("Antigen Binding")  # Set the y-axis label to Binding Ratio
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  # Set the x-axis label

        # Save the figure
        file_path = os.path.join(
            r"/Users/lucaschlotheuber/Desktop/ETH/immunization_csv_files",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with tight bounding box

        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memory
        plt.close(fig)
    else:
        print(f"No data to plot for barcode: {barcode}")


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import os

# Enable interactive plotting
plt.ion()

# Define the dictionary for the specific amino acids at certain positions
wuhan_strain_aa = {
    417: 'K', 439: 'N', 440: 'N', 452: 'L', 476: 'G', 477: 'S', 484: 'E',
    493: 'Q', 501: 'N', 502: 'G', 505: 'L'
}

# Aggregate data
df_logo_agg = df_total.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio': 'mean'})

# Filter for Enrichment_Ratio > 1
df_logo_agg = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]

# Define sites to show (as strings for consistency)
sites_to_show = map(str, [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505])

# Add site labels and determine which sites to show
df_logo_agg = df_logo_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Ensure amino acids are uppercase and exclude specific characters
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[~df_logo_agg['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]


# Generate logo plots for each unique barcode
for barcode in df_logo_agg['barcode'].unique():
    print(f"Generating plot for barcode: {barcode}")
    filtered_data = df_logo_agg.query(f'barcode == "{barcode}"').query("show_site")

    # Debugging: Check positions for the current barcode
    print(f"Positions and amino acids for {barcode}:")
    print(filtered_data[['Spike_AS_Position', 'Amino_Acid']])

    # Specifically print out rows for position 505 in the original dataframe (df_total)
    print("\nRows for position 505 in the original dataframe:")
    rows_505 = df_total[df_total['Spike_AS_Position'] == 505][['Spike_AS_Position', 'Amino_Acid', 'Enrichment_Ratio', 'DMS_RBD_AS_position']]
    print(rows_505)

    # Check if there is any data to plot
    if not filtered_data.empty:
        # Plot logo using your drawing function
        fig, ax = dmslogo.draw_logo(
            filtered_data,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio",
            title=barcode + ' logoplot',
            addbreaks=True
        )

        ax.set_ylabel("Antigen Binding")  # Set the y-axis label to Binding Ratio
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  # Set the x-axis label

        # Save the figure
        file_path = os.path.join(
            r"/Users/lucaschlotheuber/Desktop/ETH/immunization_csv_files",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with tight bounding box

        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memory
        plt.close(fig)
    else:
        print(f"No data to plot for barcode: {barcode}")


In [None]:
import matplotlib.pyplot as plt

plt.ion()

# Filter out rows where Enrichment_Ratio is zero and Enrichment_Ratio < 1 for escape
df_escape = df_total[(df_total['Enrichment_Ratio'] < 1) & (df_total['Enrichment_Ratio'] != 0)]

# Apply the inverse to the Enrichment_Ratio
df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

# Aggregate the escape data by position, amino acid, barcode, and immunization
df_escape_agg = df_escape.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'],
    as_index=False
).agg({'Enrichment_Ratio_inverted': 'sum'})

# Apply log2 transformation
df_escape_agg['Enrichment_Ratio_log2'] = df_escape_agg['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

# Now filter out any amino acids that are both enriched and escaped in different samples
# Here, we create a set of amino acids that appear in both categories (Enriched and Escaped) for the same position
enriched_aa = df_total[df_total['Enrichment_Ratio'] > 1]
escaped_aa = df_escape

# Extract amino acids that appear in both
common_aa = set(enriched_aa['Amino_Acid']).intersection(set(escaped_aa['Amino_Acid']))

# Filter both enriched and escaped data to exclude common amino acids
df_escape_agg = df_escape_agg[~df_escape_agg['Amino_Acid'].isin(common_aa)]

# Define the sites to show
sites_to_show = map(str, [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505])

df_escape_agg = df_escape_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Loop through each barcode to generate the plots
for barcode in df_escape_agg['barcode'].unique():
    print(barcode)
    
    # Filter based on barcode and selected sites
    df_filtered = df_escape_agg.query(f'barcode == "{barcode}"').query("show_site")
    
    if not df_filtered.empty:
        # Exclude stop codons before plotting
        df_filtered = df_filtered[~df_filtered['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]
        
        # Create the plot
        fig, ax = dmslogo.draw_logo(
            df_filtered,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log2",
            title=barcode + ' logoplot',
            addbreaks=True,
        )
        
        ax.set_ylabel("Antibody Escape")  # Set the y-axis label
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  

        # Save the figure
        file_path = os.path.join(
            r"C:\Users\lschlotheube\Desktop\ETH/Thesis\LogoEscape3",
            f"{barcode}_logoplots.png"
        )
        plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with tight bounding box
        
        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memory
        plt.close(fig)


In [None]:
# Code block for generating logoplots of enriched positions grouped by immunization (sum of all barcodes)
# Aggregating by immunization, Spike_AS_Position, and Amino_Acid, and summing Enrichment_Ratio

df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

sites_to_show = map(
    str,
    [417, 439, 440, 452, 476, 477, 484, 493, 501, 505]
)

# Adding site labels and filtering for sites to show
df_logo_agg = df_logo_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

# Now iterate over unique immunization groups instead of barcodes
for immunization in df_logo_agg['immunization'].unique():
    print(immunization)
    fig, ax = dmslogo.draw_logo(
        df_logo_agg.query(f'immunization == "{immunization}"').query("show_site"),
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Enrichment_Ratio",
        title=immunization + ' logoplot',
        addbreaks=True
    )
    # Save the figure (uncomment to save the plot)
    #file_path = os.path.join(r"C:\Users\au649453\OneDrive - Aarhus universitet\PhD\Luca\DMS_plots\Enriched_and_targeted_positions", f"{immunization}_logoplots.png")
    #plt.savefig(file_path, dpi = 300, bbox_inches = 'tight')
    #plt.close(fig)


In [None]:
# Code block for generating logoplots of enriched positions grouped by immunization (sum of all barcodes)
# Aggregating by immunization, Spike_AS_Position, and Amino_Acid, and summing Enrichment_Ratio

df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

sites_to_show = map(
    str,
    [417, 439, 440, 452, 476, 477, 484, 493, 501, 505]
)

# Adding site labels and filtering for sites to show
df_logo_agg = df_logo_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

# Now iterate over unique immunization groups instead of barcodes
for immunization in df_logo_agg['immunization'].unique():
    print(immunization)
    fig, ax = dmslogo.draw_logo(
        df_logo_agg.query(f'immunization == "{immunization}"').query("show_site"),
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Enrichment_Ratio",
        title=immunization + ' logoplot',
        addbreaks=True
    )
    # Save the figure (uncomment to save the plot)
    #file_path = os.path.join(r"C:\Users\au649453\OneDrive - Aarhus universitet\PhD\Luca\DMS_plots\Enriched_and_targeted_positions", f"{immunization}_logoplots.png")
    #plt.savefig(file_path, dpi = 300, bbox_inches = 'tight')
    #plt.close(fig)


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Function to assign enrichment-based colors (black = low, bright red = high)
def enrichment_color(enrichment):
    norm_value = (enrichment - 3) / (df_logo_agg['Enrichment_Ratio'].max() - 3)  # Normalize to range [0,1]
    norm_value = np.clip(norm_value, 0, 1)  # Ensure values stay in range
    return plt.cm.Reds(norm_value)  # Use 'Reds' colormap

# Select sites to display
sites_to_show = [str(x) for x in [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]]

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

# Count how often an amino acid appears across barcodes at each site
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 3]  # Filter enrichment > 3

df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()  # Count occurrences across barcodes
    .reset_index(name='Sample_Count')  # Rename count column
)

# Assign colors based on enrichment ratio
df_combined['Max_Enrichment'] = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['Enrichment_Ratio'].max().reset_index(drop=True)
df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)

df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Generate logo plots
for immunization in df_combined['immunization'].unique():
    print(immunization)
    fig, ax = dmslogo.draw_logo(
        df_combined.query(f'immunization == "{immunization}"').query("show_site"),
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Sample_Count",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization} logoplot (Occurrence-based)",
        addbreaks=True
    )


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Function to assign enrichment-based colors (black = low, bright red = high)
def enrichment_color(enrichment):
    norm_value = (enrichment - 3) / (df_logo_agg['Enrichment_Ratio'].max() - 3)  # Normalize to range [0,1]
    norm_value = np.clip(norm_value, 0, 1)  # Ensure values stay in range
    return plt.cm.Reds(norm_value)  # Use 'Reds' colormap

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

# Filter: Only keep amino acids with Enrichment_Ratio > 3
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 3]

# Count how often an amino acid appears across barcodes at each position
df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()  # Count occurrences across barcodes
    .reset_index(name='Sample_Count')  # Rename count column
)

# Assign colors based on max enrichment ratio per amino acid at each site
df_combined['Max_Enrichment'] = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['Enrichment_Ratio'].max().reset_index(drop=True)
df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)

df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# Filter the sites to only include positions larger than 365
df_combined = df_combined[df_combined['Spike_AS_Position'] > 365]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots
for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")
    
    fig, ax = dmslogo.draw_logo(
        df_combined.query(f'immunization == "{immunization}"'),  # No site filtering (already done above)
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Sample_Count",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization} logoplot (Occurrence-based)",
        addbreaks=True
    )
    
    # Save the plot as a PNG file
    plot_filename = os.path.join(output_dir, f"{immunization}_logoplot.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    plt.close(fig)  # Close the figure to prevent displaying
    print(f"Plot saved as {plot_filename}")
    plt.show()



In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Function to assign enrichment-based colors (black = low, bright red = high)
def enrichment_color(enrichment):
    # Normalize to range [0, 1], ensuring max enrichment gets the brightest red
    max_enrichment = df_logo_agg['Enrichment_Ratio'].max()  # Max value for normalization
    norm_value = (enrichment - 3) / (max_enrichment - 3)  # Normalize to range [0, 1]
    norm_value = np.clip(norm_value, 0, 1)  # Ensure values stay in range
    return plt.cm.Reds(norm_value)  # Use 'Reds' colormap

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

# Filter: Only keep amino acids with Enrichment_Ratio > 3
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 3]

# Count how often an amino acid appears across barcodes at each position
df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()  # Count occurrences across barcodes
    .reset_index(name='Sample_Count')  # Rename count column
)

# Assign colors based on max enrichment ratio per amino acid at each site
df_combined['Max_Enrichment'] = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['Enrichment_Ratio'].max().reset_index(drop=True)
df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)

df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# Filter the sites to only include positions larger than 365
df_combined = df_combined[df_combined['Spike_AS_Position'] > 365]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots
for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")
    
    fig, ax = dmslogo.draw_logo(
        df_combined.query(f'immunization == "{immunization}"'),  # No site filtering (already done above)
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Sample_Count",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization} logoplot (Occurrence-based)",
        addbreaks=True
    )
    
    # Show the plot in the notebook
    plt.show()

    # Save the plot as a PNG file
    plot_filename = os.path.join(output_dir, f"{immunization}_logoplot.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Function to assign enrichment-based colors (black = low, bright red = high) using log scale
def enrichment_color(enrichment):
    # Set the max enrichment value for normalization
    max_enrichment = 3000  # Adjust to the maximum value you want for color scaling
    # Apply a logarithmic scale to the enrichment values
    log_enrichment = np.log10(enrichment + 1)  # log(x+1) to avoid issues with zero or small values
    max_log_enrichment = np.log10(max_enrichment + 1)
    norm_value = log_enrichment / max_log_enrichment  # Normalize to range [0, 1]
    norm_value = np.clip(norm_value, 0, 1)  # Ensure values stay in range
    return plt.cm.Reds(norm_value)  # Use 'Reds' colormap

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

# Filter: Only keep amino acids with Enrichment_Ratio > 3
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 3]

# Count how often an amino acid appears across barcodes at each position
df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()  # Count occurrences across barcodes
    .reset_index(name='Sample_Count')  # Rename count column
)

# Assign colors based on max enrichment ratio per amino acid at each site
df_combined['Max_Enrichment'] = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['Enrichment_Ratio'].max().reset_index(drop=True)
df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)

df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# Filter the sites to only include positions larger than 365
df_combined = df_combined[df_combined['Spike_AS_Position'] > 365]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)



# Generate and save logo plots
for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")
    
    fig, ax = dmslogo.draw_logo(
        df_combined.query(f'immunization == "{immunization}"'),  # No site filtering (already done above)
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Sample_Count",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization} logoplot (Occurrence-based)",
        addbreaks=True
    )

    # Add colorbar to the plot, now using the axes 'ax' for colorbar placement
    sm = plt.cm.ScalarMappable(cmap="Reds", norm=plt.Normalize(vmin=np.log10(3 + 1), vmax=np.log10(3000 + 1)))  # Log scale for max
    sm.set_array([])  # Required for the colorbar to work
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')  # Attach colorbar to the same axes as the plot
    cbar.set_label('Log(Enrichment Ratio)', rotation=270, labelpad=15)  # Label the colorbar

    # Show the plot in the notebook
    plt.show()

    # Save the plot as a PNG file
    plot_filename = os.path.join(output_dir, f"{immunization}_logoplot.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Function to assign enrichment-based colors (black = low, bright red = high) using log scale
def enrichment_color(enrichment):
    max_enrichment = 3000
    log_enrichment = np.log10(enrichment + 1)
    max_log_enrichment = np.log10(max_enrichment + 1)
    norm_value = np.clip(log_enrichment / max_log_enrichment, 0, 1)
    return plt.cm.Reds(norm_value)

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 3]

# Create a directory to save the PNG files
output_dir = "logo_plots_by_barcode"
os.makedirs(output_dir, exist_ok=True)

# Loop through each barcode
for barcode in df_filtered['barcode'].unique():
    df_barcode = df_filtered[df_filtered['barcode'] == barcode].copy()

    # Extract the associated immunization for the barcode
    immunization = df_barcode['immunization'].iloc[0]

    # Count occurrences
    df_combined = (
        df_barcode.groupby(['Spike_AS_Position', 'Amino_Acid'])
        .size()
        .reset_index(name='Sample_Count')
    )

    # Get max enrichment per position/amino acid for coloring
    df_combined['Max_Enrichment'] = df_barcode.groupby(['Spike_AS_Position', 'Amino_Acid'])['Enrichment_Ratio'].max().reset_index(drop=True)
    df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)

    df_combined = df_combined.assign(
        site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
    )

    # Filter for sites > 365
    df_combined = df_combined[df_combined['Spike_AS_Position'] > 365]

    if df_combined.empty:
        continue  # Skip barcodes with no data after filtering

    print(f"Generating plot for barcode {barcode} (immunization: {immunization})")

    fig, ax = dmslogo.draw_logo(
        df_combined,
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Sample_Count",
        color_col="color",
        title=f"Barcode: {barcode}\nImmunization: {immunization}",
        addbreaks=True
    )

    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap="Reds", norm=plt.Normalize(vmin=np.log10(3 + 1), vmax=np.log10(3000 + 1)))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Log(Enrichment Ratio)', rotation=270, labelpad=15)

    plt.show()

    # Save the figure
    plot_filename = os.path.join(output_dir, f"{barcode}_logoplot.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')
    print(f"Plot saved as {plot_filename}")


In [None]:
from collections import defaultdict
import pandas as pd

# Dictionary to store all position-wise matrices
immunization_position_matrices = defaultdict(dict)

# Filter for Enrichment_Ratio > 3 and Spike_AS_Position > 365
df_filtered = df_total[(df_total['Enrichment_Ratio'] > 3) & (df_total['Spike_AS_Position'] > 365)].copy()
df_filtered['Amino_Acid'] = df_filtered['Amino_Acid'].str.upper()

# Loop through immunization and position
for immunization in df_filtered['immunization'].unique():
    df_imm = df_filtered[df_filtered['immunization'] == immunization]
    
    for position in df_imm['Spike_AS_Position'].unique():
        df_pos = df_imm[df_imm['Spike_AS_Position'] == position]

        # Pivot table: rows = barcode, columns = Amino_Acid, values = enrichment
        pivot = df_pos.pivot_table(index='barcode', columns='Amino_Acid', values='Enrichment_Ratio', fill_value=0)
        
        # Store
        immunization_position_matrices[immunization][position] = pivot


In [None]:
from scipy.spatial.distance import pdist, squareform

reproducibility_stats = []

for immunization, positions in immunization_position_matrices.items():
    for position, matrix in positions.items():
        if matrix.shape[0] < 2:
            continue  # Need at least two barcodes to compute reproducibility

        # Compute pairwise Pearson correlation
        correlation_matrix = matrix.T.corr()  # Transpose so barcodes are compared across amino acids
        
        # Take the average pairwise correlation (excluding diagonal)
        tril_values = correlation_matrix.where(~np.eye(correlation_matrix.shape[0], dtype=bool)).values
        tril_values = tril_values[~np.isnan(tril_values)]
        mean_corr = tril_values.mean()

        reproducibility_stats.append({
            'immunization': immunization,
            'Spike_AS_Position': position,
            'mean_pairwise_correlation': mean_corr,
            'n_barcodes': matrix.shape[0]
        })

df_reproducibility = pd.DataFrame(reproducibility_stats)


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Function to assign enrichment-based colors (black = low, bright red = high) using log scale
def enrichment_color(enrichment):
    # Set the max enrichment value for normalization
    max_enrichment = 3000  # Adjust to the maximum value you want for color scaling
    # Apply a logarithmic scale to the enrichment values
    log_enrichment = np.log10(enrichment + 1)  # log(x+1) to avoid issues with zero or small values
    max_log_enrichment = np.log10(max_enrichment + 1)
    norm_value = log_enrichment / max_log_enrichment  # Normalize to range [0, 1]
    norm_value = np.clip(norm_value, 0, 1)  # Ensure values stay in range
    return plt.cm.Reds(norm_value)  # Use 'Reds' colormap

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

# Filter: Only keep amino acids with Enrichment_Ratio > 3
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]




# Count how often an amino acid appears across barcodes at each position
df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()  # Count occurrences across barcodes
    .reset_index(name='Sample_Count')  # Rename count column
)

# Assign colors based on max enrichment ratio per amino acid at each site
df_combined['Max_Enrichment'] = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['Enrichment_Ratio'].max().reset_index(drop=True)
df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)

df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# Filter the sites to only include positions larger than 365
df_combined = df_combined[df_combined['Spike_AS_Position'] > 365]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)



# Generate and save logo plots
for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    
    
    fig, ax = dmslogo.draw_logo(
        df_combined.query(f'immunization == "{immunization}"'),  # No site filtering (already done above)
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Sample_Count",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization} logoplot (Occurrence-based)",
        addbreaks=True
    )

    fig.set_size_inches(45, 4)

    # Add colorbar to the plot, now using the axes 'ax' for colorbar placement
    sm = plt.cm.ScalarMappable(cmap="Reds", norm=plt.Normalize(vmin=np.log10(3 + 1), vmax=np.log10(3000 + 1)))  # Log scale for max
    sm.set_array([])  # Required for the colorbar to work
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')  # Attach colorbar to the same axes as the plot
    cbar.set_label('Log(Enrichment Ratio)', rotation=270, labelpad=15)  # Label the colorbar

    # Show the plot in the notebook
    plt.show()

    # Save the plot as a PNG file
    plot_filename = os.path.join(output_dir, f"{immunization}_logoplot.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")


In [None]:
!pip install scikit-learn


In [None]:
#Similarity heatmap test with ER >1

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import pdist, squareform
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os

# Use only data where Enrichment_Ratio > 1, excluding stop codons (*) for analysis
df_logo_filtered = df_logo_agg[
    (df_logo_agg['Amino_Acid'] != "*") & 
    (df_logo_agg['Enrichment_Ratio'] > 1)
]

# --- Quick check: confirm all values have ER > 1 ---
num_total = len(df_logo_filtered)
num_below_or_equal_1 = (df_logo_filtered['Enrichment_Ratio'] <= 1).sum()
print(f"\nQuick Check: {num_total} entries total; {num_below_or_equal_1} have ER ≤ 1 (should be 0)")
# ---------------------------------------------------

# Step 1: Create a pivot table per barcode
pivot = df_logo_filtered.pivot_table(
    index=['immunization', 'barcode'],
    columns=['Spike_AS_Position', 'Amino_Acid'],
    values='Enrichment_Ratio',
    fill_value=0
)
title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-CoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-CoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS-CoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS-CoV-2\nB.1.135',
    'Library_ctrl': 'Library-2',
    'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}
# Step 2: Analyze within each immunization
similarity_stats = []

for immunization, sub_df in pivot.groupby(level=0):
    data = sub_df.values  # shape (num_barcodes, num_sites)
    barcodes = sub_df.index.get_level_values('barcode')

    # Pairwise cosine similarities
    sim_matrix = cosine_similarity(data)

    # Extract upper triangle values (excluding self-comparisons)
    pairwise_sims = sim_matrix[np.triu_indices_from(sim_matrix, k=1)]

    # Store results
    similarity_stats.append({
        'immunization': immunization,
        'mean_similarity': np.mean(pairwise_sims),
        'std_similarity': np.std(pairwise_sims),
        'num_barcodes': len(barcodes)
    })

    # Optional: visualize similarity matrix as heatmap with mapped title
    plt.figure(figsize=(6, 5))
    sns.heatmap(sim_matrix, xticklabels=barcodes, yticklabels=barcodes, cmap='viridis', vmin=0, vmax=1)
    # Map immunization to pretty name for title, fallback to original if not found
    pretty_name = title_map.get(immunization, immunization)
    plt.title(f'Cosine Similarity Between Barcodes\n{pretty_name}')
    plt.tight_layout()
    filename = os.path.join(output_dir, f"similarity_heatmap_{immunization.replace(' ', '_')}.png")
    plt.savefig(filename, dpi=300)
    plt.show()

# Convert to DataFrame
df_sim_stats = pd.DataFrame(similarity_stats)

# Step 3: Plot overall similarity per immunization using mapped names on x-axis
plt.figure(figsize=(8, 5))
ax = sns.barplot(data=df_sim_stats, x='immunization', y='mean_similarity')
plt.ylabel('Mean Cosine Similarity Between Barcodes')
plt.title('Within-Immunization Barcode Similarity')

# Replace x-axis tick labels with mapped names (keep order from df_sim_stats)
xticks = ax.get_xticks()
xticklabels = [title_map.get(name, name) for name in df_sim_stats['immunization']]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, rotation=45)

plt.tight_layout()
barplot_file = os.path.join(output_dir, "summary_barplot_similarity.png")
plt.savefig(barplot_file, dpi=300)
plt.show()


In [None]:
#Similarity heatmap test with all ER

In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import pdist, squareform
import seaborn as sns

# Use all data, but exclude stop codons (*) for analysis
df_logo_cleaned = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

# Step 1: Create a pivot table per barcode
pivot = df_logo_cleaned.pivot_table(
    index=['immunization', 'barcode'],
    columns=['Spike_AS_Position', 'Amino_Acid'],
    values='Enrichment_Ratio',
    fill_value=0
)

# --- Quick check: confirm that ER < 1 values are still included ---
num_total = len(df_logo_cleaned)
num_below_1 = (df_logo_cleaned['Enrichment_Ratio'] < 1).sum()
print(f"\nQuick Check: {num_below_1} out of {num_total} entries have Enrichment_Ratio < 1")
# ------------------
title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\n2',
    'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}
# Step 2: Analyze within each immunization
similarity_stats = []

for immunization, sub_df in pivot.groupby(level=0):
    data = sub_df.values  # shape (num_barcodes, num_sites)
    barcodes = sub_df.index.get_level_values('barcode')

    # Pairwise cosine similarities
    sim_matrix = cosine_similarity(data)

    # Extract upper triangle values (excluding self-comparisons)
    pairwise_sims = sim_matrix[np.triu_indices_from(sim_matrix, k=1)]

    # Store results
    similarity_stats.append({
        'immunization': immunization,
        'mean_similarity': np.mean(pairwise_sims),
        'std_similarity': np.std(pairwise_sims),
        'num_barcodes': len(barcodes)
    })

    # Optional: visualize similarity matrix as heatmap with mapped title
    plt.figure(figsize=(6, 5))
    sns.heatmap(sim_matrix, xticklabels=barcodes, yticklabels=barcodes, cmap='viridis', vmin=0, vmax=1)
    # Map immunization to pretty name for title, fallback to original if not found
    pretty_name = title_map.get(immunization, immunization)
    plt.title(f'Cosine Similarity Between Barcodes\n{pretty_name}')
    plt.tight_layout()
    filename = os.path.join(output_dir, f"similarity_heatmap_{immunization.replace(' ', '_')}.png")
    plt.savefig(filename, dpi=300)
    plt.show()

# Convert to DataFrame
df_sim_stats = pd.DataFrame(similarity_stats)

# Step 3: Plot overall similarity per immunization using mapped names on x-axis
plt.figure(figsize=(8, 5))
ax = sns.barplot(data=df_sim_stats, x='immunization', y='mean_similarity')
plt.ylabel('Mean Cosine Similarity Between Barcodes')
plt.title('Within-Immunization Barcode Similarity')

# Replace x-axis tick labels with mapped names (keep order from df_sim_stats)
xticks = ax.get_xticks()
xticklabels = [title_map.get(name, name) for name in df_sim_stats['immunization']]
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, rotation=45)

plt.tight_layout()
barplot_file = os.path.join(output_dir, "summary_barplot_similarity.png")
plt.savefig(barplot_file, dpi=300)
plt.show()


In [None]:
Cosine Similarity ER ALL

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import pdist, squareform
from scipy.stats import ttest_ind, f_oneway
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import combinations
df_logo_cleaned = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]
# Step 1: Create pivot table per barcode
pivot = df_logo_cleaned.pivot_table(
    index=['immunization', 'barcode'],
    columns=['Spike_AS_Position', 'Amino_Acid'],
    values='Enrichment_Ratio',
    fill_value=0
)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\n2',
    'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}

print("\n=== Enrichment Ratio Summary Per Immunization ===")
for name, group in df_filtered.groupby('immunization'):
    values = group['Enrichment_Ratio'].dropna()
    print(f"\nImmunization: {name}")
    print(f"  Count: {len(values)}")
    print(f"  Min:   {values.min():.4f}")
    print(f"  Max:   {values.max():.4f}")
    print(f"  Mean:  {values.mean():.4f}")
    print(f"  Std:   {values.std():.4f}")
    print(f"  Median:{values.median():.4f}")

# Step 2: Analyze within each immunization
similarity_stats = []
all_pairwise = {}  # to store individual similarity values for stats

for immunization, sub_df in pivot.groupby(level=0):
    data = sub_df.values
    barcodes = sub_df.index.get_level_values('barcode')
    
    sim_matrix = cosine_similarity(data)
    pairwise_sims = sim_matrix[np.triu_indices_from(sim_matrix, k=1)]

    # Save stats
    similarity_stats.append({
        'immunization': immunization,
        'mean_similarity': np.mean(pairwise_sims),
        'std_similarity': np.std(pairwise_sims),
        'num_barcodes': len(barcodes)
    })

    all_pairwise[immunization] = pairwise_sims  # save for comparison
    
    # Optional: show heatmap
    plt.figure(figsize=(6, 5))
    sns.heatmap(sim_matrix, xticklabels=barcodes, yticklabels=barcodes, cmap='viridis', vmin=0, vmax=1)
    plt.title(f'Cosine Similarity Between Barcodes\n{immunization}')
    plt.tight_layout()

    plt.close()
    plt.show()

# Step 3: Create summary DataFrame
df_sim_stats = pd.DataFrame(similarity_stats)
print("\nSummary Statistics:")
print(df_sim_stats)

# Step 4: Perform ANOVA and pairwise t-tests
print("\nStatistical Tests:")
immunizations = list(all_pairwise.keys())
groups = [all_pairwise[imm] for imm in immunizations]

# Always run ANOVA for multiple groups
if len(immunizations) > 2:
    f_stat, p_val = f_oneway(*groups)
    print(f"ANOVA: F = {f_stat:.4f}, p = {p_val:.4e}")

# Pairwise t-tests
pairwise_pvals = []
for a, b in combinations(immunizations, 2):
    stat, p_val = ttest_ind(all_pairwise[a], all_pairwise[b])
    print(f"{a} vs {b}: t = {stat:.4f}, p = {p_val:.4e}")
    pairwise_pvals.append((a, b, p_val))

# Step 5: Plot with error bars and significance lines
plt.figure(figsize=(8, 5))
ax = sns.barplot(data=df_sim_stats, x='immunization', y='mean_similarity', ci=None, palette='Set2')
plt.errorbar(x=np.arange(len(df_sim_stats)),
             y=df_sim_stats['mean_similarity'],
             yerr=df_sim_stats['std_similarity'],
             fmt='none', c='black', capsize=5)
plt.ylabel('Mean Cosine Similarity')

# Update x-axis labels using title_map
xticks = ax.get_xticks()
xticklabels = df_sim_stats['immunization'].map(title_map)
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, rotation=0)

# Add significance brackets
def add_sig_bracket(ax, x1, x2, y, h, p_val, fontsize=12):
    barx = [x1, x1, x2, x2]
    bary = [y, y+h, y+h, y]
    ax.plot(barx, bary, c='black')
    if p_val < 0.001:
        stars = '***'
    elif p_val < 0.01:
        stars = '**'
    elif p_val < 0.05:
        stars = '*'
    else:
        stars = 'ns'
    ax.text((x1 + x2) / 2, y + h + 0.01, stars, ha='center', va='bottom', fontsize=fontsize)

# Add brackets above bars
y_max = df_sim_stats['mean_similarity'].max() + df_sim_stats['std_similarity'].max()
h = 0.1  # height of brackets
offset = 0
for a, b, p in pairwise_pvals:
    if p >= 0.05 or np.isnan(p):
        continue 
    x1 = df_sim_stats[df_sim_stats['immunization'] == a].index[0]
    x2 = df_sim_stats[df_sim_stats['immunization'] == b].index[0]
    y = y_max + h * offset
    add_sig_bracket(ax, x1, x2, y, h, p)
    offset += 3

plt.tight_layout()
filename = os.path.join(output_dir, f"cosine_heatmap_{immunization}.png")
plt.savefig(filename, dpi=300)
plt.show()


In [None]:
#cosine plotting with ER>1


In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from scipy.spatial.distance import pdist, squareform
from scipy.stats import ttest_ind, f_oneway
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import combinations
df_logo_cleaned = df_logo_agg[(df_logo_agg['Amino_Acid'] != "*") & (df_logo_agg['Enrichment_Ratio'] > 1)]

# Step 1: Create pivot table per barcode
pivot = df_logo_cleaned.pivot_table(
    index=['immunization', 'barcode'],
    columns=['Spike_AS_Position', 'Amino_Acid'],
    values='Enrichment_Ratio',
    fill_value=0
)

title_map = {
    'Polyclonal_Ab': 'Anti-\nSARS-\nCoV-2\npAB',
    'Neutralizing_Ab': 'Anti-\nSARS-\nCoV-2\nnAb',
    'wildtype_RBD': 'ASCs\nSARS\nCoV-2\nWuhan',
    'Mutant_RBD': 'ASCs\nSARS\nCoV-2\nB.1.135',
    'Library_ctrl': 'Lib\n2',
    'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}

print("\n=== Enrichment Ratio Summary Per Immunization ===")
for name, group in df_filtered.groupby('immunization'):
    values = group['Enrichment_Ratio'].dropna()
    print(f"\nImmunization: {name}")
    print(f"  Count: {len(values)}")
    print(f"  Min:   {values.min():.4f}")
    print(f"  Max:   {values.max():.4f}")
    print(f"  Mean:  {values.mean():.4f}")
    print(f"  Std:   {values.std():.4f}")
    print(f"  Median:{values.median():.4f}")

# Step 2: Analyze within each immunization
similarity_stats = []
all_pairwise = {}  # to store individual similarity values for stats

for immunization, sub_df in pivot.groupby(level=0):
    data = sub_df.values
    barcodes = sub_df.index.get_level_values('barcode')
    
    sim_matrix = cosine_similarity(data)
    pairwise_sims = sim_matrix[np.triu_indices_from(sim_matrix, k=1)]

    # Save stats
    similarity_stats.append({
        'immunization': immunization,
        'mean_similarity': np.mean(pairwise_sims),
        'std_similarity': np.std(pairwise_sims),
        'num_barcodes': len(barcodes)
    })

    all_pairwise[immunization] = pairwise_sims  # save for comparison
    
    # Optional: show heatmap
    plt.figure(figsize=(6, 5))
    sns.heatmap(sim_matrix, xticklabels=barcodes, yticklabels=barcodes, cmap='viridis', vmin=0, vmax=1)
    plt.title(f'Cosine Similarity Between Barcodes\n{immunization}')
    plt.tight_layout()

    plt.close()
    plt.show()

# Step 3: Create summary DataFrame
df_sim_stats = pd.DataFrame(similarity_stats)
print("\nSummary Statistics:")
print(df_sim_stats)

# Step 4: Perform ANOVA and pairwise t-tests
print("\nStatistical Tests:")
immunizations = list(all_pairwise.keys())
groups = [all_pairwise[imm] for imm in immunizations]

# Always run ANOVA for multiple groups
if len(immunizations) > 2:
    f_stat, p_val = f_oneway(*groups)
    print(f"ANOVA: F = {f_stat:.4f}, p = {p_val:.4e}")

# Pairwise t-tests
pairwise_pvals = []
for a, b in combinations(immunizations, 2):
    stat, p_val = ttest_ind(all_pairwise[a], all_pairwise[b])
    print(f"{a} vs {b}: t = {stat:.4f}, p = {p_val:.4e}")
    pairwise_pvals.append((a, b, p_val))

# Step 5: Plot with error bars and significance lines
plt.figure(figsize=(8, 5))
ax = sns.barplot(data=df_sim_stats, x='immunization', y='mean_similarity', ci=None, palette='Set2')
plt.errorbar(x=np.arange(len(df_sim_stats)),
             y=df_sim_stats['mean_similarity'],
             yerr=df_sim_stats['std_similarity'],
             fmt='none', c='black', capsize=5)
plt.ylabel('Mean Cosine Similarity')

# Update x-axis labels using title_map
xticks = ax.get_xticks()
xticklabels = df_sim_stats['immunization'].map(title_map)
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels, rotation=0)

# Add significance brackets
def add_sig_bracket(ax, x1, x2, y, h, p_val, fontsize=12):
    barx = [x1, x1, x2, x2]
    bary = [y, y+h, y+h, y]
    ax.plot(barx, bary, c='black')
    if p_val < 0.001:
        stars = '***'
    elif p_val < 0.01:
        stars = '**'
    elif p_val < 0.05:
        stars = '*'
    else:
        stars = 'ns'
    ax.text((x1 + x2) / 2, y + h + 0.01, stars, ha='center', va='bottom', fontsize=fontsize)

# Add brackets above bars
y_max = df_sim_stats['mean_similarity'].max() + df_sim_stats['std_similarity'].max()
h = 0.05 # height of brackets
offset = 0
for a, b, p in pairwise_pvals:
    if p >= 0.05 or np.isnan(p):
        continue 
    x1 = df_sim_stats[df_sim_stats['immunization'] == a].index[0]
    x2 = df_sim_stats[df_sim_stats['immunization'] == b].index[0]
    y = y_max + h * offset
    add_sig_bracket(ax, x1, x2, y, h, p)
    offset += 3

plt.tight_layout()
filename = os.path.join(output_dir, f"cosine_heatmap_{immunization}.png")
plt.savefig(filename, dpi=300)
plt.show()


In [None]:
!pip install pingouin


In [None]:
import pingouin as pg

# For each immunization group:
for imm, sub_df in pivot.groupby(level=0):
    df_icc = sub_df.reset_index(level='barcode')
    df_icc['immunization'] = imm
    
    # Flatten columns from MultiIndex to simple strings
    df_icc.columns = ['_'.join(map(str, col)).strip('_') for col in df_icc.columns.values]
    
    print(df_icc.columns)  # check column names
    
    # Now melt with simple column names
    df_long = df_icc.melt(id_vars=['immunization', 'barcode'], var_name='feature', value_name='value')
    
    import pingouin as pg
    icc_res = pg.intraclass_corr(data=df_long, targets='barcode', raters='feature', ratings='value')
    icc_21 = icc_res.loc[icc_res['Type'] == 'ICC2', :]
    print(f"ICC results for {imm}:\n", icc_21[['ICC', 'F', 'df1', 'df2', 'pval']])





In [None]:
!pip install umap-learn


In [None]:
from sklearn.decomposition import PCA
import umap
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Pivot the data
pivot = df_filtered.pivot_table(
    index=['immunization', 'barcode'],
    columns=['Spike_AS_Position', 'Amino_Acid'],
    values='Enrichment_Ratio',
    fill_value=0
)

print("\nPivot Table Columns (multi-index columns collapsed after reset):")
print(pivot.columns)

# Reset index for metadata extraction
pivot_data = pivot.reset_index()
print("\nPivot Data (after reset_index):")
print(pivot_data.head())

# Extract feature matrix and metadata
X = pivot_data.drop(columns=['immunization', 'barcode']).values
meta = pivot_data[['immunization', 'barcode']]
print("\nMetadata Columns:")
print(meta.columns)

# --- PCA ---
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)

pca_df = pd.DataFrame(X_pca, columns=['PC1', 'PC2'])
pca_df = pd.concat([pca_df.reset_index(drop=True), meta.reset_index(drop=True)], axis=1)

print("\nPCA DataFrame Columns:")
print(pca_df.columns)
print(pca_df.head())

# Flatten multi-index column names
pca_df.columns = [col if isinstance(col, str) else col[0] for col in pca_df.columns]


# Plot PCA
plt.figure(figsize=(7, 6))
sns.scatterplot(data=pca_df, x='PC1', y='PC2', hue='immunization', style='immunization', s=100)
plt.title('PCA of Enrichment Profiles Across Barcodes')
plt.xlim(-75, 50)
plt.ylim(-65, 0)
plt.tight_layout()
plt.show()

# Print explained variance
print("\nPCA Explained Variance Ratio:")
print(f"PC1: {pca.explained_variance_ratio_[0]:.2%}, PC2: {pca.explained_variance_ratio_[1]:.2%}")

# --- UMAP ---
reducer = umap.UMAP(n_components=2, random_state=42)
X_umap = reducer.fit_transform(X)

umap_df = pd.DataFrame(X_umap, columns=['UMAP1', 'UMAP2'])
umap_df = pd.concat([umap_df.reset_index(drop=True), meta.reset_index(drop=True)], axis=1)

print("\nUMAP DataFrame Columns:")
print(umap_df.columns)
print(umap_df.head())

umap_df = umap_df.rename(columns={('immunization', ''): 'immunization', ('barcode', ''): 'barcode'})


# Plot UMAP
plt.figure(figsize=(7, 6))
sns.scatterplot(data=umap_df, x='UMAP1', y='UMAP2', hue='immunization', style='immunization', s=100)
plt.title('UMAP of Enrichment Profiles Across Barcodes')
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import umap

# Pivot data: each row is a barcode, columns are (position, AA), values are enrichment
pivot = df_filtered.pivot_table(
    index=['immunization', 'barcode'],
    columns=['Spike_AS_Position', 'Amino_Acid'],
    values='Enrichment_Ratio',
    fill_value=0
)

# Reset index to keep immunization/barcode as columns
pivot_reset = pivot.reset_index()

# Separate features (X) and metadata
X = pivot_reset.drop(columns=['immunization', 'barcode']).values
meta = pivot_reset[['immunization', 'barcode']]

# --- PCA ---
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)

pca_df = pd.DataFrame(X_pca, columns=['PC1', 'PC2'])
pca_df['immunization'] = meta['immunization'].values
pca_df['barcode'] = meta['barcode'].values

# Plot PCA
plt.figure(figsize=(7, 6))
sns.scatterplot(data=pca_df, x='PC1', y='PC2', hue='immunization', style='immunization', s=100)
plt.title('PCA of Enrichment Profiles (Each Point = 1 Barcode)')
plt.tight_layout()
plt.show()

print(f"PCA explained variance: PC1 = {pca.explained_variance_ratio_[0]:.2%}, PC2 = {pca.explained_variance_ratio_[1]:.2%}")

# --- UMAP ---
reducer = umap.UMAP(n_components=2, random_state=42)
X_umap = reducer.fit_transform(X)

umap_df = pd.DataFrame(X_umap, columns=['UMAP1', 'UMAP2'])
umap_df['immunization'] = meta['immunization'].values
umap_df['barcode'] = meta['barcode'].values

# Plot UMAP
plt.figure(figsize=(7, 6))
sns.scatterplot(data=umap_df, x='UMAP1', y='UMAP2', hue='immunization', style='immunization', s=100)
plt.title('UMAP of Enrichment Profiles (Each Point = 1 Barcode)')
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import umap

# Filter out 'library_ctrl'
df_filtered_no_ctrl = df_filtered[df_filtered['immunization'] != 'Library_ctrl']

# Pivot data: each row is a barcode, columns are (position, AA), values are enrichment
pivot = df_filtered_no_ctrl.pivot_table(
    index=['immunization', 'barcode'],
    columns=['Spike_AS_Position', 'Amino_Acid'],
    values='Enrichment_Ratio',
    fill_value=0
)

# Reset index to keep immunization/barcode as columns
pivot_reset = pivot.reset_index()

# Separate features (X) and metadata
X = pivot_reset.drop(columns=['immunization', 'barcode']).values
meta = pivot_reset[['immunization', 'barcode']]

# --- PCA ---
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)

pca_df = pd.DataFrame(X_pca, columns=['PC1', 'PC2'])
pca_df['immunization'] = meta['immunization'].values
pca_df['barcode'] = meta['barcode'].values

# Plot PCA with strong colors and circle markers (no style)
plt.figure(figsize=(7, 6))
sns.scatterplot(
    data=pca_df,
    x='PC1', y='PC2',
    hue='immunization',
    palette="bright",
    s=100,
    marker='o',
    edgecolor='black',
    linewidth=0.8
)
plt.title('PCA of Enrichment Profiles (Each Point = 1 Barcode)')
plt.legend(title='Immunization', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

print(f"PCA explained variance: PC1 = {pca.explained_variance_ratio_[0]:.2%}, PC2 = {pca.explained_variance_ratio_[1]:.2%}")

# --- UMAP ---
reducer = umap.UMAP(n_components=2, random_state=42)
X_umap = reducer.fit_transform(X)

umap_df = pd.DataFrame(X_umap, columns=['UMAP1', 'UMAP2'])
umap_df['immunization'] = meta['immunization'].values
umap_df['barcode'] = meta['barcode'].values

# Plot UMAP with strong colors and circle markers (no style)
plt.figure(figsize=(7, 6))
sns.scatterplot(
    data=umap_df,
    x='UMAP1', y='UMAP2',
    hue='immunization',
    palette="bright",
    s=100,
    marker='o',
    edgecolor='black',
    linewidth=0.8
)
plt.title('UMAP of Enrichment Profiles (Each Point = 1 Barcode)')
plt.legend(title='Immunization', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig('PCA_UMAP_Clusters.png', dpi=300)
plt.show()


In [None]:
from sklearn.preprocessing import StandardScaler
import numpy as np

# Optional log transform if data skewed:
X_log = np.log10(X + 1e-3)

# Standard scale features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_log)

pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

pca_df = pd.DataFrame(X_pca, columns=['PC1', 'PC2'])
pca_df['immunization'] = meta['immunization'].values
pca_df['barcode'] = meta['barcode'].values

plt.figure(figsize=(7, 6))
sns.scatterplot(
    data=pca_df,
    x='PC1', y='PC2',
    hue='immunization',
    palette="bright",
    s=40,
    marker='o',
    edgecolor='black',
    linewidth=0.8
)
plt.title('PCA of Enrichment Profiles (Each Point = 1 Barcode)')
plt.legend(title='Immunization', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()

print(f"PCA explained variance: PC1 = {pca.explained_variance_ratio_[0]:.2%}, PC2 = {pca.explained_variance_ratio_[1]:.2%}")


In [None]:
color_map = {
    'Polyclonal_Ab': 'darkorange',
    'Neutralizing_Ab': 'red',
    'wildtype_RBD': 'green',
    'Mutant_RBD': 'darkblue'
}

title_map = {
    'Polyclonal_Ab': 'Anti-SARS-CoV-2 pAB',
    'Neutralizing_Ab': 'Anti-SARS-CoV-2 nAb',
    'wildtype_RBD': 'IgG secreting cells\n SARS-CoV-2 Wuhan',
    'Mutant_RBD': 'IgG secreting cells \n SARS-CoV-2 B.1.135',
    'Library_ctrl': 'Lib\n2',
    'Un-enrich. Libr': 'Un-enrich.\nLibrary'
}

# Map the immunization short names to pretty labels
pca_df['immunization_pretty'] = pca_df['immunization'].map(title_map).fillna(pca_df['immunization'])

# Build palette dict for seaborn keyed by pretty labels
color_map_pretty = {
    title_map[k]: color_map[k] for k in title_map if k in color_map
}
# Add jitter to PC1 and PC2 (adjust magnitude as needed)
jitter_strength = 0.1
pca_df['PC1_jitter'] = pca_df['PC1'] + np.random.normal(0, jitter_strength, size=len(pca_df))
pca_df['PC2_jitter'] = pca_df['PC2'] + np.random.normal(0, jitter_strength, size=len(pca_df))

plt.figure(figsize=(8 , 6))
sns.scatterplot(
    data=pca_df,
    x='PC1', y='PC2',
    hue='immunization_pretty',
    palette=color_map_pretty,
    s=60,
    alpha=0.7,
    marker='o',
    edgecolor='black',
    linewidth=0.8
)
plt.title('')
plt.legend(title='Immunization', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
filename = os.path.join(output_dir, f"PCA_Immuni.png")
plt.savefig(filename, dpi=300)
plt.show()


In [None]:
from sklearn.cluster import DBSCAN
import numpy as np

# Run DBSCAN on PCA components
dbscan = DBSCAN(eps=2, min_samples=5)  # Tune eps and min_samples as needed
pca_df['cluster'] = dbscan.fit_predict(X_pca)

# Map cluster -1 (noise) to a label
pca_df['cluster_label'] = pca_df['cluster'].apply(lambda x: 'noise' if x == -1 else f'cluster {x}')

plt.figure(figsize=(8, 6))
sns.scatterplot(
    data=pca_df,
    x='PC1', y='PC2',
    hue='immunization_pretty',
    style='cluster_label',      # Different markers for clusters
    palette=color_map_pretty,
    s=60,
    alpha=0.8,
    edgecolor='black',
    linewidth=0.8
)
plt.title('PCA of Enrichment Profiles with DBSCAN Clusters')
plt.legend(title='Immunization / Cluster', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
from sklearn.cluster import DBSCAN
from scipy.spatial import ConvexHull
import numpy as np


In [None]:
# --- Clustering using DBSCAN ---
clustering = DBSCAN(eps=0.5, min_samples=5).fit(X_umap)
cluster_labels = clustering.labels_

# Add to UMAP DataFrame
umap_df['Cluster'] = cluster_labels

# Plot UMAP with clusters instead of immunization
plt.figure(figsize=(8, 7))
sns.scatterplot(
    data=umap_df,
    x='UMAP1', y='UMAP2',
    hue='Cluster',
    palette='tab10',
    s=100,
    marker='o',
    edgecolor='black',
    linewidth=0.8,
    legend='full'
)
plt.title('UMAP with DBSCAN Clusters')
plt.legend(title='Cluster', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
plt.figure(figsize=(8, 7))
sns.scatterplot(
    data=umap_df,
    x='UMAP1', y='UMAP2',
    hue='Cluster',
    palette='tab10',
    s=100,
    edgecolor='black',
    linewidth=0.8,
    legend='full'
)

# Draw convex hulls around each cluster
for cluster_id in sorted(umap_df['Cluster'].unique()):
    if cluster_id == -1:  # Skip noise if using DBSCAN
        continue
    cluster_points = umap_df[umap_df['Cluster'] == cluster_id][['UMAP1', 'UMAP2']].values
    if len(cluster_points) >= 3:
        hull = ConvexHull(cluster_points)
        hull_points = cluster_points[hull.vertices]
        plt.fill(hull_points[:, 0], hull_points[:, 1], alpha=0.2, label=f'Cluster {cluster_id} boundary')

plt.title('UMAP with Cluster Boundaries (DBSCAN)')
plt.legend(title='Cluster', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
def plot_umap_with_clusters(df_subset, title_suffix):
    print(f"\n--- Processing subset: {title_suffix} ---")
    print(f"Original subset size: {df_subset.shape}")

    if df_subset.empty:
        print(f"Subset '{title_suffix}' is empty. Skipping plot.")
        return

    pivot = df_subset.pivot_table(
        index=['immunization', 'barcode'],
        columns=['Spike_AS_Position', 'Amino_Acid'],
        values='Enrichment_Ratio',
        fill_value=0
    )
    print("Enrichment_Ratio values in this subset:")
    print(df_subset['Enrichment_Ratio'].values)
    print(f"Pivot table shape: {pivot.shape}")

    pivot_reset = pivot.reset_index()
    print(f"Pivot reset shape: {pivot_reset.shape}")
    print(f"Columns: {pivot_reset.columns}")

    X = pivot_reset.drop(columns=['immunization', 'barcode']).values
    print(f"Feature matrix shape (X): {X.shape}")

    if X.shape[0] == 0:
        print(f"No samples in feature matrix after dropping metadata. Skipping.")
        return
    if X.shape[1] == 0:
        print(f"No features after dropping metadata. Skipping.")
        return

    # Run UMAP
    reducer = umap.UMAP(random_state=42)
    embedding = reducer.fit_transform(X)
    print(f"UMAP embedding shape: {embedding.shape}")

    clustering = DBSCAN(eps=0.5, min_samples=5)
    cluster_labels = clustering.fit_predict(embedding)
    print(f"Cluster labels assigned: {set(cluster_labels)}")

    umap_df = pd.DataFrame({
        'UMAP1': embedding[:, 0],
        'UMAP2': embedding[:, 1],
        'Cluster': cluster_labels,
        'immunization': pivot_reset['immunization'],
        'barcode': pivot_reset['barcode']
    })

    plt.figure(figsize=(8, 7))
    sns.scatterplot(
        data=umap_df,
        x='UMAP1', y='UMAP2',
        hue='Cluster',
        palette='tab10',
        s=100,
        edgecolor='black',
        linewidth=0.8,
        legend='full'
    )

    for cluster_id in sorted(umap_df['Cluster'].unique()):
        if cluster_id == -1:
            continue
        cluster_points = umap_df[umap_df['Cluster'] == cluster_id][['UMAP1', 'UMAP2']].values
        if len(cluster_points) >= 3:
            hull = ConvexHull(cluster_points)
            hull_points = cluster_points[hull.vertices]
            plt.fill(hull_points[:, 0], hull_points[:, 1], alpha=0.2, label=f'Cluster {cluster_id} boundary')

    plt.title(f'UMAP with Cluster Boundaries (DBSCAN) — {title_suffix}')
    plt.legend(title='Cluster', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()


# Filter data (check sizes!)
df_filtered_no_ctrl = df_filtered[df_filtered['immunization'] != 'Library_ctrl']
print(f"Data after removing 'Library_ctrl': {df_filtered_no_ctrl.shape}")

df_er_gt1 = df_filtered_no_ctrl[df_filtered_no_ctrl['Enrichment_Ratio'] > 1]
print(f"Subset ER > 1 size: {df_er_gt1.shape}")

df_er_0to1 = df_filtered_no_ctrl[
    (df_filtered_no_ctrl['Enrichment_Ratio'] > 0) &
    (df_filtered_no_ctrl['Enrichment_Ratio'] < 1)
]
print(f"Subset 0 < ER < 1 size: {df_er_0to1.shape}")

plot_umap_with_clusters(df_er_gt1, 'ER > 1')
plot_umap_with_clusters(df_er_0to1, '0 < ER < 1')


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN
from scipy.spatial import ConvexHull
import numpy as np
import umap

# Filter out 'library_ctrl'
df_filtered_no_ctrl = df_filtered[df_filtered['immunization'] != 'Library_ctrl']

# Pivot data: each row is a barcode, columns are (position, AA), values are enrichment
pivot = df_filtered_no_ctrl.pivot_table(
    index=['immunization', 'barcode'],
    columns=['Spike_AS_Position', 'Amino_Acid'],
    values='Enrichment_Ratio',
    fill_value=0
)

# Reset index to keep immunization/barcode as columns
pivot_reset = pivot.reset_index()

# Separate features (X) and metadata
X = pivot_reset.drop(columns=['immunization', 'barcode']).values
meta = pivot_reset[['immunization', 'barcode']]

# --- PCA ---
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)

pca_df = pd.DataFrame(X_pca, columns=['PC1', 'PC2'])
pca_df['immunization'] = meta['immunization'].values
pca_df['barcode'] = meta['barcode'].values

# Plot and save PCA
plt.figure(figsize=(7, 6))
sns.scatterplot(
    data=pca_df,
    x='PC1', y='PC2',
    hue='immunization',
    palette="bright",
    s=100,
    marker='o',
    edgecolor='black',
    linewidth=0.8
)
plt.title('PCA of Enrichment Profiles (Each Point = 1 Barcode)')
plt.legend(title='Immunization', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig('PCA_Enrichment_Profiles.png', dpi=300)
plt.show()

print(f"PCA explained variance: PC1 = {pca.explained_variance_ratio_[0]:.2%}, PC2 = {pca.explained_variance_ratio_[1]:.2%}")

# --- UMAP ---
reducer = umap.UMAP(n_components=2, random_state=42)
X_umap = reducer.fit_transform(X)

umap_df = pd.DataFrame(X_umap, columns=['UMAP1', 'UMAP2'])
umap_df['immunization'] = meta['immunization'].values
umap_df['barcode'] = meta['barcode'].values

# --- Clustering using DBSCAN ---
clustering = DBSCAN(eps=0.5, min_samples=5).fit(X_umap)
cluster_labels = clustering.labels_
umap_df['Cluster'] = cluster_labels

# Save cluster table
umap_df[['barcode', 'immunization', 'Cluster']].to_csv('UMAP_Clusters.csv', index=False)

# Plot UMAP with clusters
plt.figure(figsize=(8, 7))
sns.scatterplot(
    data=umap_df,
    x='UMAP1', y='UMAP2',
    hue='Cluster',
    palette='tab10',
    s=100,
    marker='o',
    edgecolor='black',
    linewidth=0.8,
    legend='full'
)

# Draw convex hulls around clusters
for cluster_id in sorted(umap_df['Cluster'].unique()):
    if cluster_id == -1:
        continue
    cluster_points = umap_df[umap_df['Cluster'] == cluster_id][['UMAP1', 'UMAP2']].values
    if len(cluster_points) >= 3:
        hull = ConvexHull(cluster_points)
        hull_points = cluster_points[hull.vertices]
        plt.fill(hull_points[:, 0], hull_points[:, 1], alpha=0.2, label=f'Cluster {cluster_id} boundary')

plt.title('UMAP with Cluster Boundaries (DBSCAN)')
plt.legend(title='Cluster', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.savefig('UMAP_DBSCAN_Clusters.png', dpi=300)
plt.show()


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Function to assign enrichment-based colors (black = low, bright red = high) using log scale
def enrichment_color(enrichment):
    max_enrichment = 3000
    log_enrichment = np.log10(enrichment + 1)
    max_log_enrichment = np.log10(max_enrichment + 1)
    norm_value = log_enrichment / max_log_enrichment
    norm_value = np.clip(norm_value, 0, 1)
    return plt.cm.Reds(norm_value)

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]

# Compute Sample_Count and Max_Enrichment for coloring
df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'])
    .agg(Sample_Count=('Enrichment_Ratio', 'count'),
         Max_Enrichment=('Enrichment_Ratio', 'max'))
    .reset_index()
)

# Assign colors
df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)
df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# Filter for positions > 365
df_combined = df_combined[df_combined['Spike_AS_Position'] > 365]

# Create output directory
output_dir = "barcode_logoplots"
os.makedirs(output_dir, exist_ok=True)

# Loop by barcode (each plot includes immunization label)
for barcode in df_combined['barcode'].unique():
    df_barcode = df_combined[df_combined['barcode'] == barcode]
    if df_barcode.empty:
        continue
    
    immunization = df_barcode['immunization'].iloc[0]

    print(f"Generating plot for barcode {barcode} ({immunization})...")

    fig, ax = dmslogo.draw_logo(
        df_barcode,
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Sample_Count",
        color_col="color",
        title=f"{immunization} - {barcode} logoplot (Occurrence-based)",
        addbreaks=True
    )

    fig.set_size_inches(45, 4)

    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap="Reds", norm=plt.Normalize(vmin=np.log10(3 + 1), vmax=np.log10(3000 + 1)))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Log(Enrichment Ratio)', rotation=270, labelpad=15)

    plt.show()

    # Save plot
    safe_barcode = barcode.replace("/", "_").replace("\\", "_")  # Sanitize filename
    plot_filename = os.path.join(output_dir, f"{safe_barcode}_logoplot.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')
    print(f"Plot saved as {plot_filename}")


## Occurance based logo plotting calculates the frequency a certain variant appears across the single droplet/ single antibody repertoire where letter height represents the number of antibodies binding this variants

In [None]:
#Colored by Enrichment ratio, Y axis is how many samples support it

import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'mean'
})

# Function to assign enrichment-based colors (black = low, bright red = high) using log scale
def enrichment_color(enrichment):
    # Set the max enrichment value for normalization
    max_enrichment = 3000  # Adjust to the maximum value you want for color scaling
    # Apply a logarithmic scale to the enrichment values
    log_enrichment = np.log10(enrichment + 1)  # log(x+1) to avoid issues with zero or small values
    max_log_enrichment = np.log10(max_enrichment + 1)
    norm_value = log_enrichment / max_log_enrichment  # Normalize to range [0, 1]
    norm_value = np.clip(norm_value, 0, 1)  # Ensure values stay in range
    return plt.cm.Reds(norm_value)  # Use 'Reds' colormap

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

# Filter: Only keep amino acids with Enrichment_Ratio > 3
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]

# Count how often an amino acid appears across barcodes at each position
df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()  # Count occurrences across barcodes
    .reset_index(name='Sample_Count')  # Rename count column
)

# Assign colors based on max enrichment ratio per amino acid at each site
df_combined['Max_Enrichment'] = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['Enrichment_Ratio'].max().reset_index(drop=True)
df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)

df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)


df_agg = df_total.copy()
df_agg['Amino_Acid'] = df_agg['Amino_Acid'].str.upper()
df_agg = df_agg[df_agg['Amino_Acid'] != "*"]  # remove stop codons

# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots
for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')

    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Sample_Count",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization} logoplot (Occurrence-based)",
        addbreaks=True
    )

    # Add colorbar to the plot
    sm = plt.cm.ScalarMappable(cmap="Reds", norm=plt.Normalize(vmin=np.log10(3 + 1), vmax=np.log10(3000 + 1)))  # Log scale for max
    sm.set_array([])  # Required for the colorbar to work
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')  # Attach colorbar to the same axes as the plot
    cbar.set_label('Log(Enrichment Ratio)', rotation=270, labelpad=15)  # Label the colorbar

    # Set the y-axis label to 'Single Droplet'
    ax.set_ylabel('Single Droplet')


    # Ensure the output directory exists
    output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
    os.makedirs(output_dir, exist_ok=True)  # Creates the directory if it doesn't exist
    
    # Set y-axis label
    ax.set_ylabel('IgG Secreting cell [n]')
    
    # Define file paths
    file_path = os.path.join(output_dir, f"{barcode}_logoplots.png")
    plot_filename = os.path.join(output_dir, f"{immunization}_logop.png")
    
    # Save the plot before showing it
    plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with high resolution
    print(f"Plot saved as {file_path}")
    
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")
    
    # Show the plot in the notebook
    plt.show()


In [None]:
#Colored by Enrichment ratio, Y axis is how many samples support it
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.ticker as ticker
import pandas as pd
import seaborn as sns


# -----------------------------
# HEIGHT FUNCTION
# -----------------------------

def enrichment_height(enrichment, epsilon=1e-8, max_cap=1e8):
    """
    Safe symmetric log10 transformation of enrichment values.
    
    Args:
        enrichment (float or array-like): Enrichment value(s).
        epsilon (float): Minimum enrichment value to avoid log10(0).
        max_cap (float): Maximum value to cap extremely high enrichments.
    
    Returns:
        float or np.array: Transformed enrichment.
    """
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)


# Count barcodes:
# Total number of unique barcodes per immunization
barcode_counts = (
    df_logo_agg.groupby('immunization')['barcode']
    .nunique()
    .reset_index(name='Total_Barcodes')
)


# Convert to numeric in case of strings; coerce errors to NaN
df_total['Enrichment_Ratio'] = pd.to_numeric(df_total['Enrichment_Ratio'], errors='coerce')

# Filter out rows with Enrichment_Ratio that are NaN, 0, negative, or infinite
df_total = df_total[
    df_total['Enrichment_Ratio'].notna() &  # not NaN
    np.isfinite(df_total['Enrichment_Ratio']) &  # not inf or -inf
    (df_total['Enrichment_Ratio'] > 0)  # positive only
]

df_agg = df_total.copy()
# Aggregate enrichment ratio per barcode
print(df_logo_agg.columns.tolist())
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'mean'
})
print(df_logo_agg.columns.tolist())

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

print(df_logo_agg.columns.tolist())
# Count how often an amino acid appears across barcodes at each position

# -----------------------------
# Add color and letter height
# -----------------------------
# Compute sample count separately
sample_counts = (
    df_logo_agg.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)
print(df_logo_agg.columns.tolist())
# Merge with df_combined (which contains enrichment and letter height)
df_combined = df_agg.copy()

print(df_combined.columns.tolist())
# Merge in sample count
df_combined = pd.merge(
    df_combined,
    sample_counts,
    on=['Spike_AS_Position', 'Amino_Acid', 'immunization'],
    how='left'  # Use 'left' to preserve all rows, or 'inner' if you only want those with sample counts
)
print(df_combined.columns.tolist())
# Drop duplicates if needed
df_combined = df_combined.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization'])
df_combined = df_combined[df_combined['Enrichment_Ratio'] > 0]

df_combined = pd.merge(
    df_combined,
    barcode_counts,
    on='immunization',
    how='left'
)

df_combined['Sample_Fraction'] = (
    100 * df_combined['Sample_Count'] / df_combined['Total_Barcodes']
)


df_combined['letter_height'] = df_combined['Enrichment_Ratio'].apply(enrichment_height)
#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.Reds(x / 100))
df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: orange_cmap(x / 100))


print(df_combined.columns.tolist())
# Add site label
df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)
print(df_combined.columns.tolist())

# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots

for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height', 'Sample_Count', 'Sample_Fraction']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter_heights")
    
    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization}",
        addbreaks=True
    )

    # Set the y-axis label to 'Single Droplet'
    ax.set_ylabel('Single Droplet')


    # Ensure the output directory exists
    output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
    os.makedirs(output_dir, exist_ok=True)  # Creates the directory if it doesn't exist
    
    # Set y-axis label
    ax.set_ylabel('Log10 AB Escape - Binding', fontsize=13)

  

    # Set major ticks (e.g., every 1 unit)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(2))
    
    # Set minor ticks (e.g., every 0.25 unit)
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))
    
    # Define file paths
    file_path = os.path.join(output_dir, f"{barcode}_logoplots.png")
    plot_filename = os.path.join(output_dir, f"{immunization}_logop.png")

    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap="orange_cmap", norm=plt.Normalize(vmin=0, vmax=100))
    #sm = plt.cm.ScalarMappable(cmap="Reds", norm=plt.Normalize(vmin=0, vmax=100))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
    
    # Save the plot before showing it
    plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with high resolution
    print(f"Plot saved as {file_path}")
    ax.set_xlabel('')
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")
    
    # Show the plot in the notebook
    plt.show()


In [None]:
#Colored by Enrichment ratio, Y axis is how many samples support it
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# -----------------------------
# HEIGHT FUNCTION
# -----------------------------
orange_cmap = plt.cm.get_cmap('YlOrBr') 

def enrichment_height(enrichment):
    # If enrichment >= 1, use log2 transformation
    if enrichment >= 1:
        return np.log2(enrichment)
    # If enrichment < 1, use -log2(1/enrichment)
    else:
        return -np.log2(1 / enrichment)


# Count barcodes:
# Total number of unique barcodes per immunization
barcode_counts = (
    df_logo_agg.groupby('immunization')['barcode']
    .nunique()
    .reset_index(name='Total_Barcodes')
)


# Convert to numeric in case of strings; coerce errors to NaN
df_total['Enrichment_Ratio'] = pd.to_numeric(df_total['Enrichment_Ratio'], errors='coerce')

# Filter out rows with Enrichment_Ratio that are NaN, 0, negative, or infinite
df_total = df_total[
    df_total['Enrichment_Ratio'].notna() &  # not NaN
    np.isfinite(df_total['Enrichment_Ratio']) &  # not inf or -inf
    (df_total['Enrichment_Ratio'] > 0)  # positive only
]
df_total = df_total[df_total['Amino_Acid'] != "*"]
df_agg = df_total.copy()
# Aggregate enrichment ratio per barcode
print(df_logo_agg.columns.tolist())
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'mean'
})
print(df_logo_agg.columns.tolist())

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

print(df_logo_agg.columns.tolist())
# Count how often an amino acid appears across barcodes at each position

# -----------------------------
# Add color and letter height
# -----------------------------
# Compute sample count separately
sample_counts = (
    df_logo_agg.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)
print(df_logo_agg.columns.tolist())
# Merge with df_combined (which contains enrichment and letter height)
df_combined = df_agg.copy()

print(df_combined.columns.tolist())
# Merge in sample count
df_combined = pd.merge(
    df_combined,
    sample_counts,
    on=['Spike_AS_Position', 'Amino_Acid', 'immunization'],
    how='left'  # Use 'left' to preserve all rows, or 'inner' if you only want those with sample counts
)
print(df_combined.columns.tolist())
# Drop duplicates if needed
df_combined = df_combined.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization'])
df_combined = df_combined[df_combined['Enrichment_Ratio'] > 0]

df_combined = pd.merge(
    df_combined,
    barcode_counts,
    on='immunization',
    how='left'
)

df_combined['Sample_Fraction'] = (
    100 * df_combined['Sample_Count'] / df_combined['Total_Barcodes']
)


df_combined['letter_height'] = df_combined['Enrichment_Ratio'].apply(enrichment_height)
df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: orange_cmap(x / 100))

#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.Greens(x / 100))


print(df_combined.columns.tolist())
# Add site label
df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)
print(df_combined.columns.tolist())

# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots

for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height', 'Sample_Count', 'Sample_Fraction']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter_heights")
    
    positions = df_immunization['Spike_AS_Position']   # pandas Series
    amino_acids = df_immunization['Amino_Acid']       # pandas Series
    heights = df_immunization['letter_height']        # pandas Series
    print(f"Positions range from {positions.min()} to {positions.max()}")
    print(f"Unique amino acids: {sorted(amino_acids.unique())}")
    print("Counts per position:")
    print(df_immunization.groupby('Spike_AS_Position')['Amino_Acid'].count())
    print(f"Letter height range: min {heights.min():.3f}, max {heights.max():.3f}")
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter heights")
    
        # Separate positive and negative letter heights for correct stacking
    def separate_stacked_heights(df):
        df_sorted = df.sort_values(
            by=['Spike_AS_Position', 'letter_height'], ascending=[True, False]
        )
    
        # Separate positive and negative heights
        df_positive = df_sorted[df_sorted['letter_height'] >= 0].copy()
        df_negative = df_sorted[df_sorted['letter_height'] < 0].copy()
    
        # Cumulative heights per position
        df_positive['stack_bottom'] = df_positive.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_positive['letter_height']
        df_negative['stack_bottom'] = df_negative.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_negative['letter_height']
    
        # Recombine
        return pd.concat([df_positive, df_negative], axis=0).sort_values(
            by=['Spike_AS_Position', 'stack_bottom']
        )
    
    # Apply stacking fix per immunization before plotting
    df_immunization = separate_stacked_heights(df_immunization)
    
    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization}",
        addbreaks=True
    )
    ax.set_title('')
    # Set the y-axis label to 'Single Droplet'
    ax.set_ylabel('Single Droplet')


    # Ensure the output directory exists
    output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
    os.makedirs(output_dir, exist_ok=True)  # Creates the directory if it doesn't exist
    
    # Set y-axis label
    ax.set_xlabel('')
    ax.set_ylabel('AB Escape - Binding')
    ax.yaxis.set_major_locator(ticker.MultipleLocator(4))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(1))
    
    # Define file paths
    file_path = os.path.join(output_dir, f"{barcode}_logoplots.png")
    plot_filename = os.path.join(output_dir, f"{immunization}_logop1.png")

    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap=orange_cmap, norm=plt.Normalize(vmin=0, vmax=100))
    #sm = plt.cm.ScalarMappable(cmap="Greens", norm=plt.Normalize(vmin=0, vmax=100))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
    
    # Save the plot before showing it
    plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with high resolution
    print(f"Plot saved as {file_path}")
    
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")
    # Calculate and print total (stacked) height per position
    stacked_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print("\nTotal stacked heights per position for immunization:", immunization)
    print(stacked_heights.to_string(index=False))
    # Compute stacked heights per position with individual contributions
    position_summary = (
        df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']]
        .sort_values(by=['Spike_AS_Position', 'letter_height'], ascending=[True, False])
    )
    
    # Add cumulative stacked height for plotting order visualization (optional)
    position_summary['Cumulative_Height'] = position_summary.groupby('Spike_AS_Position')['letter_height'].cumsum()
    
    # Total stacked height per position
    total_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print(f"\n====== Immunization: {immunization} ======")
    print("Breakdown of letter heights at each position:")
    print(position_summary.to_string(index=False))
    
    print("\nTotal stacked height per position:")
    print(total_heights.to_string(index=False))


    # Show the plot in the notebook
    plt.show()


In [None]:
#Colored by Enrichment ratio, Y axis is how many samples support it
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.ticker as ticker
import pandas as pd
import seaborn as sns


# -----------------------------
# HEIGHT FUNCTION
# -----------------------------

def enrichment_height(enrichment, epsilon=1e-8, max_cap=1e8):
    """
    Safe symmetric log10 transformation of enrichment values.
    
    Args:
        enrichment (float or array-like): Enrichment value(s).
        epsilon (float): Minimum enrichment value to avoid log10(0).
        max_cap (float): Maximum value to cap extremely high enrichments.
    
    Returns:
        float or np.array: Transformed enrichment.
    """
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)


# Count barcodes:
# Total number of unique barcodes per immunization
barcode_counts = (
    df_logo_agg.groupby('immunization')['barcode']
    .nunique()
    .reset_index(name='Total_Barcodes')
)


# Convert to numeric in case of strings; coerce errors to NaN
df_total['Enrichment_Ratio'] = pd.to_numeric(df_total['Enrichment_Ratio'], errors='coerce')

# Filter out rows with Enrichment_Ratio that are NaN, 0, negative, or infinite
df_total = df_total[
    df_total['Enrichment_Ratio'].notna() &  # not NaN
    np.isfinite(df_total['Enrichment_Ratio']) &  # not inf or -inf
    (df_total['Enrichment_Ratio'] > 0)  # positive only
]
df_total = df_total[df_total['Amino_Acid'] != "*"]
df_agg = df_total.copy()
# Aggregate enrichment ratio per barcode
print(df_logo_agg.columns.tolist())
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'mean'
})
print(df_logo_agg.columns.tolist())

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

print(df_logo_agg.columns.tolist())
# Count how often an amino acid appears across barcodes at each position

# -----------------------------
# Add color and letter height
# -----------------------------
# Compute sample count separately
sample_counts = (
    df_logo_agg.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)
print(df_logo_agg.columns.tolist())
# Merge with df_combined (which contains enrichment and letter height)
df_combined = df_agg.copy()

print(df_combined.columns.tolist())
# Merge in sample count
df_combined = pd.merge(
    df_combined,
    sample_counts,
    on=['Spike_AS_Position', 'Amino_Acid', 'immunization'],
    how='left'  # Use 'left' to preserve all rows, or 'inner' if you only want those with sample counts
)
print(df_combined.columns.tolist())
# Drop duplicates if needed
df_combined = df_combined.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization'])
df_combined = df_combined[df_combined['Enrichment_Ratio'] > 0]

df_combined = pd.merge(
    df_combined,
    barcode_counts,
    on='immunization',
    how='left'
)

df_combined['Sample_Fraction'] = (
    100 * df_combined['Sample_Count'] / df_combined['Total_Barcodes']
)


df_combined['letter_height'] = df_combined['Enrichment_Ratio'].apply(enrichment_height)
#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: orange_cmap(x / 100))

df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.Greens(x / 100))


print(df_combined.columns.tolist())
# Add site label
df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)
print(df_combined.columns.tolist())

# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots

for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height', 'Sample_Count', 'Sample_Fraction']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter_heights")
    
    positions = df_immunization['Spike_AS_Position']   # pandas Series
    amino_acids = df_immunization['Amino_Acid']       # pandas Series
    heights = df_immunization['letter_height']        # pandas Series
    print(f"Positions range from {positions.min()} to {positions.max()}")
    print(f"Unique amino acids: {sorted(amino_acids.unique())}")
    print("Counts per position:")
    print(df_immunization.groupby('Spike_AS_Position')['Amino_Acid'].count())
    print(f"Letter height range: min {heights.min():.3f}, max {heights.max():.3f}")
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter heights")
    
        # Separate positive and negative letter heights for correct stacking
    def separate_stacked_heights(df):
        df_sorted = df.sort_values(
            by=['Spike_AS_Position', 'letter_height'], ascending=[True, False]
        )
    
        # Separate positive and negative heights
        df_positive = df_sorted[df_sorted['letter_height'] >= 0].copy()
        df_negative = df_sorted[df_sorted['letter_height'] < 0].copy()
    
        # Cumulative heights per position
        df_positive['stack_bottom'] = df_positive.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_positive['letter_height']
        df_negative['stack_bottom'] = df_negative.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_negative['letter_height']
    
        # Recombine
        return pd.concat([df_positive, df_negative], axis=0).sort_values(
            by=['Spike_AS_Position', 'stack_bottom']
        )
    
    # Apply stacking fix per immunization before plotting
    df_immunization = separate_stacked_heights(df_immunization)
    
    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization}",
        addbreaks=True
    )
    ax.set_title('')
    # Set the y-axis label to 'Single Droplet'
    ax.set_ylabel('Single Droplet')


    # Ensure the output directory exists
    output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
    os.makedirs(output_dir, exist_ok=True)  # Creates the directory if it doesn't exist
    
    # Set y-axis label
    ax.set_xlabel('')
    ax.set_ylabel('Log10 AB Escape - Binding', fontsize=13)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(0.5))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.25))
    
    # Define file paths
    file_path = os.path.join(output_dir, f"{barcode}_logoplots.png")
    plot_filename = os.path.join(output_dir, f"{immunization}_logop1.png")

    # Add colorbar
    #sm = plt.cm.ScalarMappable(cmap=orange_cmap, norm=plt.Normalize(vmin=0, vmax=100))
    sm = plt.cm.ScalarMappable(cmap="Greens", norm=plt.Normalize(vmin=0, vmax=100))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
    
    # Save the plot before showing it
    plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with high resolution
    print(f"Plot saved as {file_path}")
    
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")
    # Calculate and print total (stacked) height per position
    stacked_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print("\nTotal stacked heights per position for immunization:", immunization)
    print(stacked_heights.to_string(index=False))
    # Compute stacked heights per position with individual contributions
    position_summary = (
        df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']]
        .sort_values(by=['Spike_AS_Position', 'letter_height'], ascending=[True, False])
    )
    
    # Add cumulative stacked height for plotting order visualization (optional)
    position_summary['Cumulative_Height'] = position_summary.groupby('Spike_AS_Position')['letter_height'].cumsum()
    
    # Total stacked height per position
    total_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print(f"\n====== Immunization: {immunization} ======")
    print("Breakdown of letter heights at each position:")
    print(position_summary.to_string(index=False))
    
    print("\nTotal stacked height per position:")
    print(total_heights.to_string(index=False))


    # Show the plot in the notebook
    plt.show()


In [None]:
#FINAL PLOT with Inversed coloring

In [None]:
#Plots For single barcodes in SI

In [None]:
#Colored by Enrichment ratio, Y axis is how many samples support it
import random
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.ticker as ticker
import pandas as pd
from Bio import SeqIO
import altair as alt

# Load and clean the Excel data
file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'
df_total = pd.read_excel(file_path, usecols=[
    "DMS_RBD_AS_position", "Spike_AS_Position", "Count_of_Base",
    "Amino_Acid", "Type_of_Mutation", "Enrichment_Ratio",
    "barcode", "immunization", "condition", "Total_Reads"
])
df_total["Spike_AS_Position"] -= 5  # Adjust for 336 -> 331

# Clean up: remove NaNs, low reads, and stop codons
df_total = df_total.dropna(subset=['Enrichment_Ratio','Amino_Acid'])
df_total = df_total[df_total["Total_Reads"] > 100]
df_total = df_total[df_total["Type_of_Mutation"] != 'SYNOM']
df_total = df_total[df_total["Enrichment_Ratio"] > 0]
df_total = df_total[df_total["Amino_Acid"] != '*']  # Exclude stop codons

# -----------------------------
# HEIGHT FUNCTION
def enrichment_height(enrichment, epsilon=1e-6, max_cap=1e6):
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)

df_logo_agg = df_total.copy()
df_logo_agg = df_logo_agg.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'])
df_logo_agg = df_logo_agg[df_logo_agg['Spike_AS_Position'] > 364]
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

df_combined = df_logo_agg.copy()
df_combined['letter_height'] = df_combined['Enrichment_Ratio'].apply(enrichment_height)

# -----------------------------
# Color amino acids by identity
# Color amino acids by identity
aa_colors = {
    'A':'#1f77b4','R':'#ff7f0e','N':'#2ca02c','D':'#d62728',
    'C':'#9467bd','Q':'#8c564b','E':'#e377c2','G':'#7f7f7f',
    'H':'#bcbd22','I':'#17becf','L':'#aec7e8','K':'#ffbb78',
    'M':'#98df8a','F':'#ff9896','P':'#c5b0d5','S':'#c49c94',
    'T':'#f7b6d2','W':'#dbdb8d','Y':'#9edae5','V':'#393b79'
}
df_combined['color'] = df_combined['Amino_Acid'].map(aa_colors).fillna('#000000')



# Add site label
df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)


# Create output directory
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# -----------------------------
# Stacking function
def separate_stacked_heights(df):
    # Sort once by position and height descending
    df_sorted = df.sort_values(by=['Spike_AS_Position', 'letter_height'], ascending=[True, False])
    df_positive = df_sorted[df_sorted['letter_height'] >= 0].copy()
    df_negative = df_sorted[df_sorted['letter_height'] < 0].copy()

    df_positive['stack_bottom'] = df_positive.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_positive['letter_height']
    df_negative['stack_bottom'] = df_negative.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_negative['letter_height']

    return pd.concat([df_positive, df_negative], axis=0).sort_values(by=['Spike_AS_Position', 'stack_bottom'])

# -----------------------------
# Plot per barcode
for immunization in df_combined['immunization'].unique():
    df_immunization = df_combined.query(f'immunization == "{immunization}"')

    eligible_barcodes = [
        bc for bc, sub_df in df_immunization.groupby('barcode')
        if sub_df['Spike_AS_Position'].nunique() > 10
    ]

    if not eligible_barcodes:
        continue

    sampled_barcodes = random.sample(eligible_barcodes, min(12, len(eligible_barcodes)))

    for barcode in sampled_barcodes:
        df_barcode = df_immunization.query(f'barcode == "{barcode}"').copy()

        df_barcode = df_barcode[(df_barcode['Spike_AS_Position'] >= 420) &
                                (df_barcode['Spike_AS_Position'] <= 520)]

        if df_barcode.empty:
            continue  # skip if nothing in this range

        aa_list = sorted(df_barcode['Amino_Acid'].unique())
        gray_values = np.linspace(0.0, 0.9, len(aa_list))  # from black (0.0) to almost white (0.9)
        aa_gray_colors = {aa: str(gray_values[i]) for i, aa in enumerate(aa_list)}
        df_barcode['color'] = df_barcode['Amino_Acid'].map(aa_gray_colors)

        
        # Assign colors per barcode (optional: adjust colormap if needed)
        #norm = colors.Normalize(vmin=df_barcode['letter_height'].min(), vmax=df_barcode['letter_height'].max())
        #df_barcode['color'] = df_barcode['letter_height'].apply(lambda x: plt.cm.Blues_r(norm(x)))
        # Apply stacking
        df_barcode = separate_stacked_heights(df_barcode)
        plot_title = f"{immunization} - {barcode}"
        print(f"Generating plot: {plot_title}")
        # Plot logo
        fig, ax = dmslogo.draw_logo(
            df_barcode,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="letter_height",
            color_col="color",
            title=f"",
            addbreaks=True
        )
        fig.set_size_inches(fig.get_size_inches()[0], fig.get_size_inches()[1] + 1)
        ax.set_ylabel(
            "Log10 AB binding (median)\n$\\mathbf{\\Leftarrow}$ Enrichment $\\mathbf{\\Rightarrow}$",
            rotation=90, labelpad=15, ha='right', fontsize=14
        )
        ax.yaxis.set_label_coords(-0.025, 1)
        # Y-axis ticks: fewer and larger
        ax.yaxis.set_major_locator(ticker.MultipleLocator(2))   # every 2 units
        ax.yaxis.set_minor_locator(ticker.MultipleLocator(1))   # minor every 1 unit
        ax.tick_params(axis='y', which='major', length=8, labelsize=14)  # major ticks longer & labels bigger
        ax.tick_params(axis='y', which='minor', length=4, labelsize=10)  # minor ticks smaller
        

        fig_path = os.path.join(output_dir, f"logo_{immunization}_{barcode}_RandomAll.png")
        fig.savefig(fig_path, bbox_inches='tight', dpi=300)
        plt.show()


In [None]:
# Colored by Enrichment ratio, Y axis is how many samples support it
import random
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.ticker as ticker
import pandas as pd
from Bio import SeqIO

# Load and clean the Excel data
file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'
df_total = pd.read_excel(file_path, usecols=[
    "DMS_RBD_AS_position", "Spike_AS_Position", "Count_of_Base",
    "Amino_Acid", "Type_of_Mutation", "Enrichment_Ratio",
    "barcode", "immunization", "condition", "Total_Reads"
])
df_total["Spike_AS_Position"] -= 5  # Adjust for 336 -> 331

# Clean up: remove NaNs, low reads, stop codons, and synom mutations
df_total = df_total.dropna(subset=['Enrichment_Ratio','Amino_Acid'])
df_total = df_total[df_total["Total_Reads"] > 100]
df_total = df_total[df_total["Type_of_Mutation"] != 'SYNOM']
df_total = df_total[df_total["Enrichment_Ratio"] > 0]
df_total = df_total[df_total["Amino_Acid"] != '*']

# -----------------------------
# HEIGHT FUNCTION
def enrichment_height(enrichment, epsilon=1e-6, max_cap=1e6):
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)

df_logo_agg = df_total.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'])
df_logo_agg = df_logo_agg[df_logo_agg['Spike_AS_Position'] > 364]
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg['letter_height'] = df_logo_agg['Enrichment_Ratio'].apply(enrichment_height)

# -----------------------------
# Color amino acids by identity
aa_colors = {
    'A':'#1f77b4','R':'#ff7f0e','N':'#2ca02c','D':'#d62728',
    'C':'#9467bd','Q':'#8c564b','E':'#e377c2','G':'#7f7f7f',
    'H':'#bcbd22','I':'#17becf','L':'#aec7e8','K':'#ffbb78',
    'M':'#98df8a','F':'#ff9896','P':'#c5b0d5','S':'#c49c94',
    'T':'#f7b6d2','W':'#dbdb8d','Y':'#9edae5','V':'#393b79'
}
df_logo_agg['color'] = df_logo_agg['Amino_Acid'].map(aa_colors).fillna('#000000')

# Add site label
df_logo_agg = df_logo_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# -----------------------------
# Stacking function (DMS-style)
def separate_stacked_heights(df):
    df_sorted = df.sort_values(by=['Spike_AS_Position', 'letter_height'], ascending=[True, False])
    df_positive = df_sorted[df_sorted['letter_height'] >= 0].copy()
    df_negative = df_sorted[df_sorted['letter_height'] < 0].copy()

    df_positive['stack_bottom'] = df_positive.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_positive['letter_height']
    df_negative['stack_bottom'] = df_negative.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_negative['letter_height']

    return pd.concat([df_positive, df_negative], axis=0).sort_values(by=['Spike_AS_Position', 'stack_bottom'])

# Create output directory
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# -----------------------------
# Plot per barcode
for immunization in df_logo_agg['immunization'].unique():
    df_immunization = df_logo_agg.query(f'immunization == "{immunization}"')

    eligible_barcodes = [
        bc for bc, sub_df in df_immunization.groupby('barcode')
        if sub_df['Spike_AS_Position'].nunique() > 10
    ]

    if not eligible_barcodes:
        continue

    sampled_barcodes = random.sample(eligible_barcodes, min(12, len(eligible_barcodes)))

    for barcode in sampled_barcodes:
        df_barcode = df_immunization.query(f'barcode == "{barcode}"').copy()
        df_barcode = df_barcode[(df_barcode['Spike_AS_Position'] >= 420) &
                                (df_barcode['Spike_AS_Position'] <= 520)]
        if df_barcode.empty:
            continue

        # Apply stacking
        df_barcode = separate_stacked_heights(df_barcode)

        # Plot
        fig, ax = dmslogo.draw_logo(
            df_barcode,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="letter_height",
            color_col="color",
            title=f"",
            addbreaks=True
        )
        #fig.set_size_inches(fig.get_size_inches()[0], fig.get_size_inches()[1] + 1)
        ax.set_ylabel(
            "Log10 AB binding (median)\n$\\mathbf{\\Leftarrow}$ Enrichment $\\mathbf{\\Rightarrow}$",
            rotation=90, labelpad=15, ha='right', fontsize=14
        )
        ax.yaxis.set_label_coords(-0.025, 1)
        ax.yaxis.set_major_locator(ticker.MultipleLocator(2))
        ax.yaxis.set_minor_locator(ticker.MultipleLocator(1))
        ax.tick_params(axis='y', which='major', length=8, labelsize=15)
        ax.tick_params(axis='y', which='minor', length=4, labelsize=15)

        fig_path = os.path.join(output_dir, f"logo_{immunization}_{barcode}New.png")
        fig.savefig(fig_path, bbox_inches='tight', dpi=300)
        plt.show()


In [None]:
#single droplet letter.

In [None]:
#Logoplots for Figure 1. Log10 median ER (Logoplots are separated)

In [None]:
#Colored by Enrichment ratio, Y axis is how many samples support it
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.ticker as ticker
import pandas as pd
import seaborn as sns
import matplotlib.colors as colors


# -----------------------------
# HEIGHT FUNCTION
# -----------------------------

def enrichment_height(enrichment, epsilon=1e-8, max_cap=1e8):
    """
    Safe symmetric log10 transformation of enrichment values.
    
    Args:
        enrichment (float or array-like): Enrichment value(s).
        epsilon (float): Minimum enrichment value to avoid log10(0).
        max_cap (float): Maximum value to cap extremely high enrichments.
    
    Returns:
        float or np.array: Transformed enrichment.
    """
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)


# Count barcodes:
# Total number of unique barcodes per immunization
barcode_counts = (
    df_logo_agg.groupby('immunization')['barcode']
    .nunique()
    .reset_index(name='Total_Barcodes')
)


# Convert to numeric in case of strings; coerce errors to NaN
df_total['Enrichment_Ratio'] = pd.to_numeric(df_total['Enrichment_Ratio'], errors='coerce')
df_total = df_total[df_total["Count_of_Base"] > 3]

# Filter out rows with Enrichment_Ratio that are NaN, 0, negative, or infinite
df_total = df_total[
    df_total['Enrichment_Ratio'].notna() &  # not NaN
    np.isfinite(df_total['Enrichment_Ratio']) &  # not inf or -inf
    (df_total['Enrichment_Ratio'] > 0)  # positive only
]
df_total = df_total[df_total['Amino_Acid'] != "*"]
df_agg = df_total.copy()
# Aggregate enrichment ratio per barcode
print(df_logo_agg.columns.tolist())
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'mean'
})
print(df_logo_agg.columns.tolist())

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

print(df_logo_agg.columns.tolist())
# Count how often an amino acid appears across barcodes at each position

# -----------------------------
# Add color and letter height
# -----------------------------
# Compute sample count separately
sample_counts = (
    df_logo_agg.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)
print(df_logo_agg.columns.tolist())
# Merge with df_combined (which contains enrichment and letter height)
df_combined = df_agg.copy()

print(df_combined.columns.tolist())
# Merge in sample count
df_combined = pd.merge(
    df_combined,
    sample_counts,
    on=['Spike_AS_Position', 'Amino_Acid', 'immunization'],
    how='left'  # Use 'left' to preserve all rows, or 'inner' if you only want those with sample counts
)
print(df_combined.columns.tolist())
# Drop duplicates if needed
df_combined = df_combined.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization'])
df_combined = df_combined[df_combined['Enrichment_Ratio'] > 0]

df_combined = pd.merge(
    df_combined,
    barcode_counts,
    on='immunization',
    how='left'
)

df_combined['Sample_Fraction'] = (
    100 * df_combined['Sample_Count'] / df_combined['Total_Barcodes']
)


df_combined['letter_height'] = df_combined['Enrichment_Ratio'].apply(enrichment_height)
#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: orange_cmap(x / 100))
norm = colors.Normalize(vmin=0, vmax=100)
df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.Blues(1 - norm(x)))


#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.Blues(x / 100))


print(df_combined.columns.tolist())
# Add site label
df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)
print(df_combined.columns.tolist())

# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [417,418, 439,400, 440,441,442, 452,453,454, 476, 477, 483,484,485, 493, 499,500,501, 502, 505,506]
# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots

for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")
    print(f"Sample_Fraction range: min={df_combined['Sample_Fraction'].min()}, max={df_combined['Sample_Fraction'].max()}")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height', 'Sample_Count', 'Sample_Fraction']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter_heights")
    
    positions = df_immunization['Spike_AS_Position']   # pandas Series
    amino_acids = df_immunization['Amino_Acid']       # pandas Series
    heights = df_immunization['letter_height']        # pandas Series
    print(f"Positions range from {positions.min()} to {positions.max()}")
    print(f"Unique amino acids: {sorted(amino_acids.unique())}")
    print("Counts per position:")
    print(df_immunization.groupby('Spike_AS_Position')['Amino_Acid'].count())
    print(f"Letter height range: min {heights.min():.3f}, max {heights.max():.3f}")
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter heights")
    
        # Separate positive and negative letter heights for correct stacking
    def separate_stacked_heights(df):
        df_sorted = df.sort_values(
            by=['Spike_AS_Position', 'letter_height'], ascending=[True, False]
        )
    
        # Separate positive and negative heights
        df_positive = df_sorted[df_sorted['letter_height'] >= 0].copy()
        df_negative = df_sorted[df_sorted['letter_height'] < 0].copy()
    
        # Cumulative heights per position
        df_positive['stack_bottom'] = df_positive.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_positive['letter_height']
        df_negative['stack_bottom'] = df_negative.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_negative['letter_height']
    
        # Recombine
        return pd.concat([df_positive, df_negative], axis=0).sort_values(
            by=['Spike_AS_Position', 'stack_bottom']
        )
    
    # Apply stacking fix per immunization before plotting
    df_immunization = separate_stacked_heights(df_immunization)
    
    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization}",
        addbreaks=True
    )
    ax.set_title('')
    # Set the y-axis label to 'Single Droplet'
    ax.set_ylabel('Single Droplet')


    # Ensure the output directory exists
    output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
    os.makedirs(output_dir, exist_ok=True)  # Creates the directory if it doesn't exist
    
    # Set y-axis label
    ax.set_xlabel('')
    ax.set_ylabel('Log10 AB Escape - Binding', fontsize=13)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))
    
    # Define file paths
    file_path = os.path.join(output_dir, f"{barcode}_logoplots.png")
    plot_filename = os.path.join(output_dir, f"{immunization}_logop1.png")

    # Add colorbar
    #sm = plt.cm.ScalarMappable(cmap=orange_cmap, norm=plt.Normalize(vmin=0, vmax=100))

    cmap = plt.cm.Blues_r  # reversed Blues colormap
    norm = colors.Normalize(vmin=0, vmax=100)
    
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
        
    #sm = plt.cm.ScalarMappable(cmap="Blues", norm=plt.Normalize(vmin=0, vmax=100))
    #sm.set_array([])
    #cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    #cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
    
    # Save the plot before showing it
    plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with high resolution
    print(f"Plot saved as {file_path}")
    
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")
    # Calculate and print total (stacked) height per position
    stacked_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print("\nTotal stacked heights per position for immunization:", immunization)
    print(stacked_heights.to_string(index=False))
    # Compute stacked heights per position with individual contributions
    position_summary = (
        df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']]
        .sort_values(by=['Spike_AS_Position', 'letter_height'], ascending=[True, False])
    )
    
    # Add cumulative stacked height for plotting order visualization (optional)
    position_summary['Cumulative_Height'] = position_summary.groupby('Spike_AS_Position')['letter_height'].cumsum()
    
    # Total stacked height per position
    total_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print(f"\n====== Immunization: {immunization} ======")
    print("Breakdown of letter heights at each position:")
    print(position_summary.to_string(index=False))
    
    print("\nTotal stacked height per position:")
    print(total_heights.to_string(index=False))


    # Show the plot in the notebook
    plt.show()


In [None]:
#Colored by Enrichment ratio, Y axis is how many samples support it
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.ticker as ticker
import pandas as pd
import seaborn as sns
import matplotlib.colors as colors


# -----------------------------
# HEIGHT FUNCTION
# --------#luca3
def enrichment_height(enrichment, epsilon=1e-8, max_cap=1e8):
    """
    Safe symmetric log10 transformation of enrichment values.
    
    Args:
        enrichment (float or array-like): Enrichment value(s).
        epsilon (float): Minimum enrichment value to avoid log10(0).
        max_cap (float): Maximum value to cap extremely high enrichments.
    
    Returns:
        float or np.array: Transformed enrichment.
    """
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)


# Count barcodes:
# Total number of unique barcodes per immunization
barcode_counts = (
    df_logo_agg.groupby('immunization')['barcode']
    .nunique()
    .reset_index(name='Total_Barcodes')
)


# Convert to numeric in case of strings; coerce errors to NaN
df_total['Enrichment_Ratio'] = pd.to_numeric(df_total['Enrichment_Ratio'], errors='coerce')

# Filter out rows with Enrichment_Ratio that are NaN, 0, negative, or infinite
df_total = df_total[
    df_total['Enrichment_Ratio'].notna() &  # not NaN
    np.isfinite(df_total['Enrichment_Ratio']) &  # not inf or -inf
    (df_total['Enrichment_Ratio'] > 0)  # positive only
]
df_total = df_total[df_total['Amino_Acid'] != "*"]
df_agg = df_total.copy()
# Aggregate enrichment ratio per barcode
print(df_logo_agg.columns.tolist())
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'mean'
})
print(df_logo_agg.columns.tolist())

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

print(df_logo_agg.columns.tolist())
# Count how often an amino acid appears across barcodes at each position

# -----------------------------
# Add color and letter height
# -----------------------------
# Compute sample count separately
sample_counts = (
    df_logo_agg.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)
print(df_logo_agg.columns.tolist())
# Merge with df_combined (which contains enrichment and letter height)
df_combined = df_agg.copy()

print(df_combined.columns.tolist())
# Merge in sample count
df_combined = pd.merge(
    df_combined,
    sample_counts,
    on=['Spike_AS_Position', 'Amino_Acid', 'immunization'],
    how='left'  # Use 'left' to preserve all rows, or 'inner' if you only want those with sample counts
)
print(df_combined.columns.tolist())
# Drop duplicates if needed
df_combined = df_combined.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization'])
df_combined = df_combined[df_combined['Enrichment_Ratio'] > 0]

df_combined = pd.merge(
    df_combined,
    barcode_counts,
    on='immunization',
    how='left'
)

df_combined['Sample_Fraction'] = (
    100 * df_combined['Sample_Count'] / df_combined['Total_Barcodes']
)


df_combined['letter_height'] = df_combined['Enrichment_Ratio'].apply(enrichment_height)
#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: orange_cmap(x / 100))
norm = colors.Normalize(vmin=0, vmax=100)
df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.Blues(1 - norm(x)))


#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.Blues(x / 100))


print(df_combined.columns.tolist())
# Add site label
df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)
print(df_combined.columns.tolist())

# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots

for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")
    print(f"Sample_Fraction range: min={df_combined['Sample_Fraction'].min()}, max={df_combined['Sample_Fraction'].max()}")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height', 'Sample_Count', 'Sample_Fraction']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter_heights")
    
    positions = df_immunization['Spike_AS_Position']   # pandas Series
    amino_acids = df_immunization['Amino_Acid']       # pandas Series
    heights = df_immunization['letter_height']        # pandas Series
    print(f"Positions range from {positions.min()} to {positions.max()}")
    print(f"Unique amino acids: {sorted(amino_acids.unique())}")
    print("Counts per position:")
    print(df_immunization.groupby('Spike_AS_Position')['Amino_Acid'].count())
    print(f"Letter height range: min {heights.min():.3f}, max {heights.max():.3f}")
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter heights")
    
        # Separate positive and negative letter heights for correct stacking
    def separate_stacked_heights(df):
        df_sorted = df.sort_values(
            by=['Spike_AS_Position', 'letter_height'], ascending=[True, False]
        )
    
        # Separate positive and negative heights
        df_positive = df_sorted[df_sorted['letter_height'] >= 0].copy()
        df_negative = df_sorted[df_sorted['letter_height'] < 0].copy()
    
        # Cumulative heights per position
        df_positive['stack_bottom'] = df_positive.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_positive['letter_height']
        df_negative['stack_bottom'] = df_negative.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_negative['letter_height']
    
        # Recombine
        return pd.concat([df_positive, df_negative], axis=0).sort_values(
            by=['Spike_AS_Position', 'stack_bottom']
        )
    
    # Apply stacking fix per immunization before plotting
    df_immunization = separate_stacked_heights(df_immunization)
    
    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization}",
        addbreaks=True
    )
    ax.set_title('')
    # Set the y-axis label to 'Single Droplet'
    ax.set_ylabel('Single Droplet')


    # Ensure the output directory exists
    output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
    os.makedirs(output_dir, exist_ok=True)  # Creates the directory if it doesn't exist
    
    # Set y-axis label
    ax.set_xlabel('')
    ax.set_ylabel("Log10 AB binding (median)\n$\\mathbf{\\Leftarrow}$ Enrichment $\\mathbf{\\Rightarrow}$", 
                  rotation=90, labelpad=20, ha='right', fontsize=14)
    ax.yaxis.set_label_coords(-0.08, 0.97)
    #ax.set_ylabel('Log10 AB Escape - Binding', fontsize=13)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))
    
    # Define file paths
    file_path = os.path.join(output_dir, f"{barcode}_logoplots.png")
    plot_filename = os.path.join(output_dir, f"{immunization}_logop1.png")

    # Add colorbar
    #sm = plt.cm.ScalarMappable(cmap=orange_cmap, norm=plt.Normalize(vmin=0, vmax=100))

    norm = colors.Normalize(vmin=0, vmax=100)

    # 2) Grab the reversed colormap and truncate it so we skip the very top (white) 10%
    orig_cmap = plt.cm.Blues_r
    #    here we take only the values from 0.0→0.9 of the original colormap
    trunc_cmap = colors.LinearSegmentedColormap.from_list(
        'truncBlues_r',
        orig_cmap(np.linspace(0.0, 0.75, 256))
    )

    df_combined['color'] = df_combined['Sample_Fraction'].apply(
        lambda x: trunc_cmap(norm(x))
    )
    
    # 3) When you assign colors, still invert the norm so 0%→dark, 100%→light:
    #df_combined['color'] = df_combined['Sample_Fraction'].apply(
       # lambda x: trunc_cmap(1 - norm(x))
    #)
    
    # …later, when you draw your logo, the letters will get these colors…
    
    # 4) And for the colorbar:
    sm = plt.cm.ScalarMappable(cmap=trunc_cmap, norm=norm)
    sm.set_array([])
    
    #sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    #sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
        
    #sm = plt.cm.ScalarMappable(cmap="Blues", norm=plt.Normalize(vmin=0, vmax=100))
    #sm.set_array([])
    #cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    #cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
    
    # Save the plot before showing it
    plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with high resolution
    print(f"Plot saved as {file_path}")
    
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")
    # Calculate and print total (stacked) height per position
    stacked_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print("\nTotal stacked heights per position for immunization:", immunization)
    print(stacked_heights.to_string(index=False))
    # Compute stacked heights per position with individual contributions
    position_summary = (
        df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']]
        .sort_values(by=['Spike_AS_Position', 'letter_height'], ascending=[True, False])
    )
    
    # Add cumulative stacked height for plotting order visualization (optional)
    position_summary['Cumulative_Height'] = position_summary.groupby('Spike_AS_Position')['letter_height'].cumsum()
    
    # Total stacked height per position
    total_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print(f"\n====== Immunization: {immunization} ======")
    print("Breakdown of letter heights at each position:")
    print(position_summary.to_string(index=False))
    
    print("\nTotal stacked height per position:")
    print(total_heights.to_string(index=False))


    # Show the plot in the notebook
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from matplotlib.gridspec import GridSpec

def create_truncated_cmap(base_cmap_name, vmin=0, vmax=100, max_fraction=0.75):
    base_cmap = plt.cm.get_cmap(base_cmap_name)
    trunc_cmap = colors.LinearSegmentedColormap.from_list(
        f'{base_cmap_name}_trunc',
        base_cmap(np.linspace(0.0, max_fraction, 256))
    )
    norm = colors.Normalize(vmin=vmin, vmax=vmax)
    return trunc_cmap, norm

fig = plt.figure(figsize=(1, 3))

# All reversed
orange_cmap, norm_orange = create_truncated_cmap('YlOrBr_r', max_fraction=0.75)
green_cmap,  norm_green  = create_truncated_cmap('Greens_r', max_fraction=0.75)
blue_cmap,   norm_blue   = create_truncated_cmap('Blues_r',  max_fraction=0.75)

gs = GridSpec(1, 3, figure=fig, wspace=0)

axs = [fig.add_subplot(gs[0, i]) for i in range(3)]

colorbars = [
    (axs[0], orange_cmap, norm_orange, 'Fraction %\nSingle-droplet repertoire'),
    (axs[1], green_cmap, norm_green, ''),
    (axs[2], blue_cmap, norm_blue, '')
]

# Custom ticks and labels for blue bar:
custom_ticks = [0, 25, 50, 75, 100]
custom_labels = ['Present in \n 1x droplet', '25%', '50%', '75%', 'Present \n in all \n droplets']

for i, (ax, cmap, norm, label) in enumerate(colorbars):
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, cax=ax, orientation='vertical')
    cbar.set_label(label, rotation=270, labelpad=35, fontsize=14)

    if i == 2:  # Blue bar gets custom text ticks
        cbar.set_ticks(custom_ticks)
        cbar.set_ticklabels(custom_labels)
        cbar.ax.tick_params(labelsize=9)
    else:
        cbar.set_ticks([])
        cbar.ax.tick_params(labelsize=0)

fig.suptitle('Color Legends for Enrichment Ratio Plot', fontsize=14, y=1.02)
plt.savefig('ColorBarDMS.png', dpi=300, bbox_inches='tight')
plt.show()



In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from matplotlib.gridspec import GridSpec

def create_truncated_cmap(base_cmap_name, vmin=0, vmax=100, min_fraction=0.0, max_fraction=1.0):
    base_cmap = plt.cm.get_cmap(base_cmap_name)
    trunc_cmap = colors.LinearSegmentedColormap.from_list(
        f'{base_cmap_name}_trunc',
        base_cmap(np.linspace(min_fraction, max_fraction, 256))
    )
    norm = colors.Normalize(vmin=vmin, vmax=vmax)
    return trunc_cmap, norm

fig = plt.figure(figsize=(2, 5))
gs = GridSpec(1, 3, figure=fig, wspace=0)

axs = [fig.add_subplot(gs[0, i]) for i in range(3)]

# Use non-reversed colormaps: dark at 0, bright (white) at 100
orange_cmap, norm_orange = create_truncated_cmap('YlOrBr', 0, 100, 0.0, 1.0)
green_cmap, norm_green = create_truncated_cmap('Greens', 0, 100, 0.0, 1.0)
blue_cmap, norm_blue = create_truncated_cmap('Blues', 0, 100, 0.0, 1.0)

colorbars = [
    (axs[0], orange_cmap, norm_orange, 'Fraction %\nSingle-droplet repertoire'),
    (axs[1], green_cmap, norm_green, ''),
    (axs[2], blue_cmap, norm_blue, '')
]

for i, (ax, cmap, norm, label) in enumerate(colorbars):
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, cax=ax, orientation='vertical')
    cbar.set_label(label, rotation=270, labelpad=35, fontsize=12)
    ax.yaxis.set_label_coords(5, 0.5)
    
    if i == 2:
        cbar.set_ticks(np.linspace(0, 100, 6))
        cbar.ax.tick_params(labelsize=9)
    else:
        cbar.set_ticks([])
        cbar.ax.tick_params(labelsize=0)
    
    ax.set_xticks([])
    ax.set_xlabel('')

fig.suptitle('Color Legends for Enrichment Ratio Plot', fontsize=14, y=1.02)

plt.show()


In [None]:
#Colored by Enrichment ratio, Y axis is how many samples support it
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.ticker as ticker
import pandas as pd
import seaborn as sns
import matplotlib.colors as colors


import os
import pandas as pd
import numpy as np
from Bio import SeqIO
import altair as alt

# Load FASTA sequence (Wuhan reference)
fasta_file = r'/Users/lucaschlotheuber/Desktop/ETH/RBD201_DMS1.fa'
for record in SeqIO.parse(fasta_file, "fasta"):
    wuhan_sequence = str(record.seq)
    break

# Load and clean the Excel data
file_path = r'/Users/lucaschlotheuber/Desktop/ETH/summary_DMS_cleaned.xlsx'
df_total = pd.read_excel(file_path, usecols=[
    "DMS_RBD_AS_position", "Spike_AS_Position", "Count_of_Base",
    "Amino_Acid", "Type_of_Mutation", "Enrichment_Ratio",
    "barcode", "immunization", "condition", "Total_Reads"
])
df_total["Spike_AS_Position"] -= 5  # Adjust 336 -> 331

# Clean up
df_total = df_total.dropna(subset=['Enrichment_Ratio', 'Amino_Acid'])
df_total = df_total[df_total["Total_Reads"] > 1000]
df_total = df_total[df_total["Amino_Acid"] != '*']  # Exclude stop codons

# Add Wuhan reference
immunization = "Wuhan_Sequence"
barcode = "Wuhan_Barcode"
data_wuhan = [{
    'DMS_RBD_AS_position': pos,
    'Spike_AS_Position': pos + 330,
    'Amino_Acid': aa,
    'immunization': immunization,
    'barcode': barcode,
    'Enrichment_Ratio': 1,
} for pos, aa in enumerate(wuhan_sequence, start=1) if aa != '*']
df_wuhan = pd.DataFrame(data_wuhan)
df_total = pd.concat([df_total, df_wuhan], ignore_index=True)


# -----------------------------
# HEIGHT FUNCTION
# -----------------------------

def enrichment_height(enrichment, epsilon=1e-8, max_cap=1e8):
    """
    Safe symmetric log10 transformation of enrichment values.
    
    Args:
        enrichment (float or array-like): Enrichment value(s).
        epsilon (float): Minimum enrichment value to avoid log10(0).
        max_cap (float): Maximum value to cap extremely high enrichments.
    
    Returns:
        float or np.array: Transformed enrichment.
    """
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)


# Count barcodes:
# Total number of unique barcodes per immunization
barcode_counts = (
    df_logo_agg.groupby('immunization')['barcode']
    .nunique()
    .reset_index(name='Total_Barcodes')
)


# Convert to numeric in case of strings; coerce errors to NaN
df_total['Enrichment_Ratio'] = pd.to_numeric(df_total['Enrichment_Ratio'], errors='coerce')

# Filter out rows with Enrichment_Ratio that are NaN, 0, negative, or infinite
df_total = df_total[
    df_total['Enrichment_Ratio'].notna() &  # not NaN
    np.isfinite(df_total['Enrichment_Ratio']) &  # not inf or -inf
    (df_total['Enrichment_Ratio'] > 0)  # positive only
]
df_total = df_total[df_total['Amino_Acid'] != "*"]
df_agg = df_total.copy()
# Aggregate enrichment ratio per barcode
print(df_logo_agg.columns.tolist())
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'mean'
})
print(df_logo_agg.columns.tolist())

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

print(df_logo_agg.columns.tolist())
# Count how often an amino acid appears across barcodes at each position

# -----------------------------
# Add color and letter height
# -----------------------------
# Compute sample count separately
sample_counts = (
    df_logo_agg.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)
print(df_logo_agg.columns.tolist())
# Merge with df_combined (which contains enrichment and letter height)
df_combined = df_agg.copy()

print(df_combined.columns.tolist())
# Merge in sample count
df_combined = pd.merge(
    df_combined,
    sample_counts,
    on=['Spike_AS_Position', 'Amino_Acid', 'immunization'],
    how='left'  # Use 'left' to preserve all rows, or 'inner' if you only want those with sample counts
)
print(df_combined.columns.tolist())
# Drop duplicates if needed
df_combined = df_combined.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization'])
df_combined = df_combined[df_combined['Enrichment_Ratio'] > 0]

df_combined = pd.merge(
    df_combined,
    barcode_counts,
    on='immunization',
    how='left'
)

df_combined['Sample_Fraction'] = (
    100 * df_combined['Sample_Count'] / df_combined['Total_Barcodes']
)


df_combined['letter_height'] = df_combined['Enrichment_Ratio'].apply(enrichment_height)
#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: orange_cmap(x / 100))
norm = colors.Normalize(vmin=0, vmax=100)
df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.YlOrBr_r(1 - norm(x)))


#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.Blues(x / 100))


print(df_combined.columns.tolist())
# Add site label
df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)
print(df_combined.columns.tolist())

# ✅ DEBUG: Check if any M is present at position 452
print("=== DEBUG: Amino acids at position 452 ===")
print(df_combined[df_combined['Spike_AS_Position'] == 452][['Amino_Acid', 'immunization', 'barcode', 'Enrichment_Ratio']])


# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [417, 439, 440, 451,452,453,454,455,456, 476,483,484485,499,500,501,502,503,504,505]
# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots

for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")
    print(f"Sample_Fraction range: min={df_combined['Sample_Fraction'].min()}, max={df_combined['Sample_Fraction'].max()}")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height', 'Sample_Count', 'Sample_Fraction']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter_heights")
    
    positions = df_immunization['Spike_AS_Position']   # pandas Series
    amino_acids = df_immunization['Amino_Acid']       # pandas Series
    heights = df_immunization['letter_height']        # pandas Series
    print(f"Positions range from {positions.min()} to {positions.max()}")
    print(f"Unique amino acids: {sorted(amino_acids.unique())}")
    print("Counts per position:")
    print(df_immunization.groupby('Spike_AS_Position')['Amino_Acid'].count())
    print(f"Letter height range: min {heights.min():.3f}, max {heights.max():.3f}")
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter heights")
    
        # Separate positive and negative letter heights for correct stacking
    def separate_stacked_heights(df):
        df_sorted = df.sort_values(
            by=['Spike_AS_Position', 'letter_height'], ascending=[True, False]
        )
    
        # Separate positive and negative heights
        df_positive = df_sorted[df_sorted['letter_height'] >= 0].copy()
        df_negative = df_sorted[df_sorted['letter_height'] < 0].copy()
    
        # Cumulative heights per position
        df_positive['stack_bottom'] = df_positive.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_positive['letter_height']
        df_negative['stack_bottom'] = df_negative.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_negative['letter_height']
    
        # Recombine
        return pd.concat([df_positive, df_negative], axis=0).sort_values(
            by=['Spike_AS_Position', 'stack_bottom']
        )
    
    # Apply stacking fix per immunization before plotting
    df_immunization = separate_stacked_heights(df_immunization)

    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization}",
        addbreaks=True
    )
    ax.set_title('')
    # Set the y-axis label to 'Single Droplet'
    ax.set_ylabel('Single Droplet')


    # Ensure the output directory exists
    output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
    os.makedirs(output_dir, exist_ok=True)  # Creates the directory if it doesn't exist
    
    # Set y-axis label
    ax.set_xlabel('')
    ax.set_ylabel('Log10 AB Escape - Binding', fontsize=13)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))
    
    # Define file paths
    file_path = os.path.join(output_dir, f"{barcode}_logoplots.png")
    plot_filename = os.path.join(output_dir, f"{immunization}_logop1.png")

    # Add colorbar
    #sm = plt.cm.ScalarMappable(cmap=orange_cmap, norm=plt.Normalize(vmin=0, vmax=100))

    norm = colors.Normalize(vmin=0, vmax=100)

    # 2) Grab the reversed colormap and truncate it so we skip the very top (white) 10%
    #orig_cmap = plt.cm.orange_cmap_r
    orig_cmap   = plt.cm.YlOrBr_r
    #    here we take only the values from 0.0→0.9 of the original colormap
    trunc_cmap = colors.LinearSegmentedColormap.from_list(
        'YlOrBr_r',
        orig_cmap(np.linspace(0.0, 0.75, 256))
    )

    df_combined['color'] = df_combined['Sample_Fraction'].apply(
        lambda x: trunc_cmap(norm(x))
    )
    
    # 3) When you assign colors, still invert the norm so 0%→dark, 100%→light:
    #df_combined['color'] = df_combined['Sample_Fraction'].apply(
       # lambda x: trunc_cmap(1 - norm(x))
    #)
    
    # …later, when you draw your logo, the letters will get these colors…
    
    # 4) And for the colorbar:
    sm = plt.cm.ScalarMappable(cmap=trunc_cmap, norm=norm)
    sm.set_array([])
    
    #sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    #sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
        
    #sm = plt.cm.ScalarMappable(cmap="Blues", norm=plt.Normalize(vmin=0, vmax=100))
    #sm.set_array([])
    #cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    #cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
    
    # Save the plot before showing it
    plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with high resolution
    print(f"Plot saved as {file_path}")
    
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")
    # Calculate and print total (stacked) height per position
    stacked_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print("\nTotal stacked heights per position for immunization:", immunization)
    print(stacked_heights.to_string(index=False))
    # Compute stacked heights per position with individual contributions
    position_summary = (
        df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']]
        .sort_values(by=['Spike_AS_Position', 'letter_height'], ascending=[True, False])
    )
    
    # Add cumulative stacked height for plotting order visualization (optional)
    position_summary['Cumulative_Height'] = position_summary.groupby('Spike_AS_Position')['letter_height'].cumsum()
    
    # Total stacked height per position
    total_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print(f"\n====== Immunization: {immunization} ======")
    print("Breakdown of letter heights at each position:")
    print(position_summary.to_string(index=False))
    
    print("\nTotal stacked height per position:")
    print(total_heights.to_string(index=False))


    # Show the plot in the notebook
    plt.show()


In [None]:
#Colored by Enrichment ratio, Y axis is how many samples support it
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.ticker as ticker
import pandas as pd
import seaborn as sns
import matplotlib.colors as colors


# -----------------------------
# HEIGHT FUNCTION
# -----------------------------

def enrichment_height(enrichment, epsilon=1e-8, max_cap=1e8):
    """
    Safe symmetric log10 transformation of enrichment values.
    
    Args:
        enrichment (float or array-like): Enrichment value(s).
        epsilon (float): Minimum enrichment value to avoid log10(0).
        max_cap (float): Maximum value to cap extremely high enrichments.
    
    Returns:
        float or np.array: Transformed enrichment.
    """
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)


# Count barcodes:
# Total number of unique barcodes per immunization
barcode_counts = (
    df_logo_agg.groupby('immunization')['barcode']
    .nunique()
    .reset_index(name='Total_Barcodes')
)


# Convert to numeric in case of strings; coerce errors to NaN
df_total['Enrichment_Ratio'] = pd.to_numeric(df_total['Enrichment_Ratio'], errors='coerce')

# Filter out rows with Enrichment_Ratio that are NaN, 0, negative, or infinite
df_total = df_total[
    df_total['Enrichment_Ratio'].notna() &  # not NaN
    np.isfinite(df_total['Enrichment_Ratio']) &  # not inf or -inf
    (df_total['Enrichment_Ratio'] > 0)  # positive only
]
df_total = df_total[df_total['Amino_Acid'] != "*"]
df_agg = df_total.copy()
# Aggregate enrichment ratio per barcode
print(df_logo_agg.columns.tolist())
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'mean'
})
print(df_logo_agg.columns.tolist())

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

print(df_logo_agg.columns.tolist())
# Count how often an amino acid appears across barcodes at each position

# -----------------------------
# Add color and letter height
# -----------------------------
# Compute sample count separately
sample_counts = (
    df_logo_agg.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)
print(df_logo_agg.columns.tolist())
# Merge with df_combined (which contains enrichment and letter height)
df_combined = df_agg.copy()

print(df_combined.columns.tolist())
# Merge in sample count
df_combined = pd.merge(
    df_combined,
    sample_counts,
    on=['Spike_AS_Position', 'Amino_Acid', 'immunization'],
    how='left'  # Use 'left' to preserve all rows, or 'inner' if you only want those with sample counts
)
print(df_combined.columns.tolist())
# Drop duplicates if needed
df_combined = df_combined.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization'])
df_combined = df_combined[df_combined['Enrichment_Ratio'] > 0]

df_combined = pd.merge(
    df_combined,
    barcode_counts,
    on='immunization',
    how='left'
)

df_combined['Sample_Fraction'] = (
    100 * df_combined['Sample_Count'] / df_combined['Total_Barcodes']
)


df_combined['letter_height'] = df_combined['Enrichment_Ratio'].apply(enrichment_height)
#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: orange_cmap(x / 100))
norm = colors.Normalize(vmin=0, vmax=100)
df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.YlOrBr_r(1 - norm(x)))


#df_combined['color'] = df_combined['Sample_Fraction'].apply(lambda x: plt.cm.Blues(x / 100))


print(df_combined.columns.tolist())
# Add site label
df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)
print(df_combined.columns.tolist())

# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [415, 422, 440, 452, 466, 447]
# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots

for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")
    print(f"Sample_Fraction range: min={df_combined['Sample_Fraction'].min()}, max={df_combined['Sample_Fraction'].max()}")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height', 'Sample_Count', 'Sample_Fraction']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter_heights")
    
    positions = df_immunization['Spike_AS_Position']   # pandas Series
    amino_acids = df_immunization['Amino_Acid']       # pandas Series
    heights = df_immunization['letter_height']        # pandas Series
    print(f"Positions range from {positions.min()} to {positions.max()}")
    print(f"Unique amino acids: {sorted(amino_acids.unique())}")
    print("Counts per position:")
    print(df_immunization.groupby('Spike_AS_Position')['Amino_Acid'].count())
    print(f"Letter height range: min {heights.min():.3f}, max {heights.max():.3f}")
    print(df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']])
    print(df_immunization['letter_height'].isna().sum(), "NaN letter heights")
    
        # Separate positive and negative letter heights for correct stacking
    def separate_stacked_heights(df):
        df_sorted = df.sort_values(
            by=['Spike_AS_Position', 'letter_height'], ascending=[True, False]
        )
    
        # Separate positive and negative heights
        df_positive = df_sorted[df_sorted['letter_height'] >= 0].copy()
        df_negative = df_sorted[df_sorted['letter_height'] < 0].copy()
    
        # Cumulative heights per position
        df_positive['stack_bottom'] = df_positive.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_positive['letter_height']
        df_negative['stack_bottom'] = df_negative.groupby('Spike_AS_Position')['letter_height'].cumsum() - df_negative['letter_height']
    
        # Recombine
        return pd.concat([df_positive, df_negative], axis=0).sort_values(
            by=['Spike_AS_Position', 'stack_bottom']
        )
    
    # Apply stacking fix per immunization before plotting
    df_immunization = separate_stacked_heights(df_immunization)
    
    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization}",
        addbreaks=True
    )
    ax.set_title('')
    # Set the y-axis label to 'Single Droplet'
    ax.set_ylabel('Single Droplet')


    # Ensure the output directory exists
    output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
    os.makedirs(output_dir, exist_ok=True)  # Creates the directory if it doesn't exist
    
    # Set y-axis label
    ax.set_xlabel('')
    ax.set_ylabel('Log10 AB Escape - Binding', fontsize=13)
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))
    
    # Define file paths
    file_path = os.path.join(output_dir, f"{barcode}_logoplots.png")
    plot_filename = os.path.join(output_dir, f"{immunization}_logop1.png")

    # Add colorbar
    #sm = plt.cm.ScalarMappable(cmap=orange_cmap, norm=plt.Normalize(vmin=0, vmax=100))

    norm = colors.Normalize(vmin=0, vmax=100)

    # 2) Grab the reversed colormap and truncate it so we skip the very top (white) 10%
    #orig_cmap = plt.cm.orange_cmap_r
    orig_cmap   = plt.cm.YlOrBr_r
    #    here we take only the values from 0.0→0.9 of the original colormap
    trunc_cmap = colors.LinearSegmentedColormap.from_list(
        'YlOrBr_r',
        orig_cmap(np.linspace(0.0, 0.75, 256))
    )

    df_combined['color'] = df_combined['Sample_Fraction'].apply(
        lambda x: trunc_cmap(norm(x))
    )
    
    # 3) When you assign colors, still invert the norm so 0%→dark, 100%→light:
    #df_combined['color'] = df_combined['Sample_Fraction'].apply(
       # lambda x: trunc_cmap(1 - norm(x))
    #)
    
    # …later, when you draw your logo, the letters will get these colors…
    
    # 4) And for the colorbar:
    sm = plt.cm.ScalarMappable(cmap=trunc_cmap, norm=norm)
    sm.set_array([])
    
    #sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    #sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
        
    #sm = plt.cm.ScalarMappable(cmap="Blues", norm=plt.Normalize(vmin=0, vmax=100))
    #sm.set_array([])
    #cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    #cbar.set_label('Fraction % of \n Single-droplet repertoire', rotation=270, labelpad=35)
    
    # Save the plot before showing it
    plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with high resolution
    print(f"Plot saved as {file_path}")
    
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")
    # Calculate and print total (stacked) height per position
    stacked_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print("\nTotal stacked heights per position for immunization:", immunization)
    print(stacked_heights.to_string(index=False))
    # Compute stacked heights per position with individual contributions
    position_summary = (
        df_immunization[['Spike_AS_Position', 'Amino_Acid', 'letter_height']]
        .sort_values(by=['Spike_AS_Position', 'letter_height'], ascending=[True, False])
    )
    
    # Add cumulative stacked height for plotting order visualization (optional)
    position_summary['Cumulative_Height'] = position_summary.groupby('Spike_AS_Position')['letter_height'].cumsum()
    
    # Total stacked height per position
    total_heights = (
        df_immunization.groupby('Spike_AS_Position')['letter_height']
        .sum()
        .reset_index(name='Total_Stacked_Height')
    )
    
    print(f"\n====== Immunization: {immunization} ======")
    print("Breakdown of letter heights at each position:")
    print(position_summary.to_string(index=False))
    
    print("\nTotal stacked height per position:")
    print(total_heights.to_string(index=False))


    # Show the plot in the notebook
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

fractions = np.linspace(0, 100, 11)
colors = [plt.cm.Blues(1 - x/100) for x in fractions]

plt.figure(figsize=(8,1))
for i, color in enumerate(colors):
    plt.bar(i, 1, color=color)
plt.xticks(range(11), [f"{int(x)}%" for x in fractions])
plt.yticks([])
plt.title("Color mapping test: low % → strong blue, high % → faint blue")
plt.show()


In [None]:
#End of publication logo plots

In [None]:
#Statistics

In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

# Filter: Only keep amino acids with Enrichment_Ratio > 1 (adjust threshold as needed)
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]

# Count how often an amino acid appears across barcodes at each position
df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()  # Count occurrences (number of barcodes supporting AA at site)
    .reset_index(name='Sample_Count')
)

# Compute average enrichment ratio per amino acid at each site and immunization
avg_enrichment = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['Enrichment_Ratio'].mean().reset_index(name='Avg_Enrichment')

# Merge average enrichment into df_combined
df_combined = df_combined.merge(avg_enrichment, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Function to assign occurrence-based colors (low count = light blue, high count = dark blue)
def occurrence_color(count):
    max_count = df_combined['Sample_Count'].max()
    norm_value = count / max_count
    norm_value = np.clip(norm_value, 0, 1)
    return plt.cm.Blues(norm_value)

df_combined['color'] = df_combined['Sample_Count'].apply(occurrence_color)

df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# Sites to show (adjust as needed)
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create output directory for saving plots
output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots per immunization
for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    # Filter data for this immunization
    df_immunization = df_combined.query(f'immunization == "{immunization}"')

    # Draw logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Avg_Enrichment",  # height = average enrichment
        color_col="color",                    # color = occurrence
        title=f"{immunization}",
        addbreaks=True
    )

    # Add colorbar for occurrence count
    sm = plt.cm.ScalarMappable(cmap=orange_cmap, norm=plt.Normalize(vmin=0, vmax=100))
    #sm = plt.cm.ScalarMappable(cmap="Blues", norm=plt.Normalize(vmin=0, vmax=df_combined['Sample_Count'].max()))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('Occurrence (Number of Supporting Samples)', rotation=270, labelpad=15)

    # Set y-axis label
    ax.set_ylabel('Antibody-Binding (Mean)')

    # Save plots (you may want to define a barcode or rename accordingly)
    # Here I use immunization name for file name
    plot_filename = os.path.join(output_dir, f"{immunization}_logoplot.png")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Plot saved as {plot_filename}")

    plt.show()


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import dmslogo  # make sure this is installed: pip install dmslogo

# -----------------------------
# PREP: Aggregation and Setup
# -----------------------------
# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False
).agg({'Enrichment_Ratio': 'sum'})

# Remove stop codons
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

# -----------------------------
# COLOR FUNCTION
# -----------------------------
def enrichment_color(enrichment):
    # Apply log2 scale symmetrically, centered at 1
    log_enrichment = np.log2(enrichment + 1e-3)  # Add small value to avoid log(0)
    norm = plt.Normalize(vmin=-5, vmax=5)
    return plt.cm.coolwarm(norm(log_enrichment))

# -----------------------------
# HEIGHT FUNCTION
# -----------------------------
def enrichment_height(enrichment):
    if enrichment >= 1:
        return np.log2(enrichment)
    else:
        return -np.log2(1 / enrichment)  # Or: -(1 - enrichment)

# -----------------------------
# Add color and letter height
# -----------------------------
df_logo_agg['letter_height'] = df_logo_agg['Enrichment_Ratio'].apply(enrichment_height)
df_logo_agg['color'] = df_logo_agg['Enrichment_Ratio'].apply(enrichment_color)

# -----------------------------
# Filter specific Spike positions
# -----------------------------
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
df_logo_agg = df_logo_agg[df_logo_agg['Spike_AS_Position'].isin(sites_to_show)]

# -----------------------------
# Identify and print duplicates
# -----------------------------
dup_cols = ['Spike_AS_Position', 'Amino_Acid', 'immunization']
duplicates = df_logo_agg[df_logo_agg.duplicated(subset=dup_cols, keep=False)]

if not duplicates.empty:
    print("\n⚠️ Duplicate entries found (before filtering):")
    print(duplicates.sort_values(dup_cols).to_string(index=False))

# -----------------------------
# Keep only the entry with the highest Enrichment_Ratio per group
# -----------------------------
df_logo_agg = df_logo_agg.sort_values('Enrichment_Ratio', ascending=False)
df_logo_agg = df_logo_agg.drop_duplicates(subset=dup_cols, keep='first')

# Filter specific Spike positions
# -----------------------------
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
df_logo_agg = df_logo_agg[df_logo_agg['Spike_AS_Position'].isin(sites_to_show)]

# -----------------------------
# Output directory
# -----------------------------
output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# -----------------------------
# PREP: Aggregation and Setup
# -----------------------------
# Use median Enrichment_Ratio per (Spike_AS_Position, Amino_Acid, immunization)
df_agg = df_total.copy()
df_agg['Amino_Acid'] = df_agg['Amino_Acid'].str.upper()
df_agg = df_agg[df_agg['Amino_Acid'] != "*"]  # remove stop codons

# Aggregate by median across barcodes
df_logo_agg = df_agg.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization'], as_index=False
).agg({'Enrichment_Ratio': 'median'})

# -----------------------------
# COLOR FUNCTION
# -----------------------------
def enrichment_color(enrichment):
    log_enrichment = np.log2(enrichment + 1e-3)  # avoid log(0)
    norm = plt.Normalize(vmin=-5, vmax=5)
    return plt.cm.coolwarm(norm(log_enrichment))

# -----------------------------
# HEIGHT FUNCTION
# -----------------------------
def enrichment_height(enrichment):
    return np.log2(enrichment) if enrichment >= 1 else -np.log2(1 / enrichment)

# -----------------------------
# Add color and letter height
# -----------------------------
df_logo_agg['letter_height'] = df_logo_agg['Enrichment_Ratio'].apply(enrichment_height)
df_logo_agg['color'] = df_logo_agg['Enrichment_Ratio'].apply(enrichment_color)

# -----------------------------
# Filter specific Spike positions
# -----------------------------
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
df_logo_agg = df_logo_agg[df_logo_agg['Spike_AS_Position'].isin(sites_to_show)]

# -----------------------------
# Output directory
# -----------------------------
output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)


# -----------------------------
# Generate plots
# -----------------------------
for immunization in df_logo_agg['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    df_immunization = df_logo_agg[df_logo_agg['immunization'] == immunization]

    fig, ax = dmslogo.draw_logo(
        df_immunization,
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",
        color_col="color",
        title=f"{immunization} Dual Binding Escape Logo Plot",
        addbreaks=True
    )

    # Add zero line and labels
    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_ylabel('log₂(Enrichment Ratio)\n(Positive = binding, Negative = escape)')

    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=plt.Normalize(vmin=-5, vmax=5))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('log₂(Enrichment)', rotation=270, labelpad=15)

    # Save figure
    filename = os.path.join(output_dir, f"{immunization}_dual_escape_logo.png")
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Plot saved as {filename}")
    plt.show()


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import dmslogo  # make sure this is installed: pip install dmslogo

# -----------------------------
# PREP: Aggregation and Setup
# -----------------------------
# Use median Enrichment_Ratio per (Spike_AS_Position, Amino_Acid, immunization)
df_agg = df_total.copy()
df_agg['Amino_Acid'] = df_agg['Amino_Acid'].str.upper()
df_agg = df_agg[df_agg['Amino_Acid'] != "*"]  # remove stop codons

# Aggregate by median across barcodes
df_logo_agg = df_agg.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization'], as_index=False
).agg({'Enrichment_Ratio': 'median'})

# -----------------------------
# COLOR FUNCTION
# -----------------------------
def enrichment_color(enrichment):
    log_enrichment = np.log2(enrichment + 1e-3)  # avoid log(0)
    norm = plt.Normalize(vmin=-5, vmax=5)
    return plt.cm.coolwarm(norm(log_enrichment))

# -----------------------------
# HEIGHT FUNCTION
# -----------------------------
def enrichment_height(enrichment):
    return np.log2(enrichment) if enrichment >= 1 else -np.log2(1 / enrichment)

# -----------------------------
# Add color and letter height
# -----------------------------
df_logo_agg['letter_height'] = df_logo_agg['Enrichment_Ratio'].apply(enrichment_height)
df_logo_agg['color'] = df_logo_agg['Enrichment_Ratio'].apply(enrichment_color)

# -----------------------------
# Filter specific Spike positions
# -----------------------------
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
df_logo_agg = df_logo_agg[df_logo_agg['Spike_AS_Position'].isin(sites_to_show)]

# -----------------------------
# Output directory
# -----------------------------
output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# -----------------------------
# Generate plots
# -----------------------------
for immunization in df_logo_agg['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    df_immunization = df_logo_agg[df_logo_agg['immunization'] == immunization]

    fig, ax = dmslogo.draw_logo(
        df_immunization,
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",
        color_col="color",
        title=f"{immunization}",
        addbreaks=True
    )

    # Add zero line and labels
    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_ylabel('Antibody-binding (Log2) \n AA frequency')

    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=plt.Normalize(vmin=-5, vmax=5))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('log₂(Enrichment)', rotation=270, labelpad=15)

    # Save figure
    filename = os.path.join(output_dir, f"{immunization}_dual_escape_logo.png")
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Plot saved as {filename}")
    plt.show()


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import dmslogo

# -----------------------------
# PREP: Filter and uppercase, remove stop codons
# -----------------------------
df_logo_agg = df_total.copy()
df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
df_logo_agg = df_logo_agg[df_logo_agg['Spike_AS_Position'].isin(sites_to_show)]

# -----------------------------
# For each (Position, Amino Acid, immunization), keep row with max Total_Reads
# -----------------------------
df_filtered = (
    df_logo_agg.sort_values('Total_Reads', ascending=False)
    .drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization'], keep='first')
).copy()

# -----------------------------
# Separate positive (>1) and negative (<=1 but >0) enrichment groups
# -----------------------------
df_pos = df_filtered[df_filtered['Enrichment_Ratio'] > 1].copy()
df_neg = df_filtered[(df_filtered['Enrichment_Ratio'] <= 1) & (df_filtered['Enrichment_Ratio'] > 0)].copy()

# -----------------------------
# For positive: letter_height = 1 (or counts?), direction 'up'
# Color: Reds scaled by log10 enrichment (clipped)
# -----------------------------
df_pos['letter_height'] = 1  # or use counts if you want, here 1 means presence
df_pos['direction'] = 'up'
df_pos['color'] = df_pos['Enrichment_Ratio'].apply(
    lambda x: plt.cm.Reds(np.clip(np.log10(x + 1) / np.log10(3000 + 1), 0, 1))
)

# -----------------------------
# For negative: letter_height = -1 (inverted for plotting), direction 'down'
# Color: Blues_r scaled by enrichment ratio 0-1
# -----------------------------
def reversed_blue_color(enrichment):
    norm = plt.Normalize(vmin=0, vmax=1)
    return plt.cm.Blues_r(norm(enrichment))

df_neg['letter_height'] = -1
df_neg['direction'] = 'down'
df_neg['color'] = df_neg['Enrichment_Ratio'].apply(reversed_blue_color)

# -----------------------------
# Combine both
# -----------------------------
df_plot = pd.concat([df_pos, df_neg], ignore_index=True)

# -----------------------------
# Plotting loop by immunization
# -----------------------------
output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

for immunization in df_plot['immunization'].unique():
    print(f"Plotting {immunization}...")

    df_immunization = df_plot[df_plot['immunization'] == immunization]

    # IMPORTANT: For each position and amino acid, only one row must remain
    # Check duplicates just in case:
    duplicates = df_immunization.duplicated(subset=['Spike_AS_Position', 'Amino_Acid'])
    if duplicates.any():
        print(f"Warning: duplicates found for {immunization} at these rows:")
        print(df_immunization[duplicates])

    # Aggregate letter_height per group, color take first (all same anyway)
    df_agg = df_immunization.groupby(['Spike_AS_Position', 'Amino_Acid', 'direction']).agg({
        'letter_height': 'mean',
        'color': 'first'
    }).reset_index()

    fig, ax = dmslogo.draw_logo(
        df_agg,
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",
        color_col="color",
        title=f"{immunization} Escape vs. Binding Logo",
        addbreaks=True
    )
    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_ylabel('IgG Secreting Cell [n]')

    # Colorbars
    sm_red = plt.cm.ScalarMappable(cmap="Reds", norm=plt.Normalize(vmin=np.log10(3), vmax=np.log10(3000)))
    sm_blue = plt.cm.ScalarMappable(cmap="Blues_r", norm=plt.Normalize(vmin=0, vmax=1))
    sm_red.set_array([])
    sm_blue.set_array([])

    cbar_red = fig.colorbar(sm_red, ax=ax, orientation='vertical', fraction=0.05, pad=0.04)
    cbar_red.set_label('log₁₀(Enrichment > 1)', rotation=270, labelpad=15)

    cbar_blue = fig.colorbar(sm_blue, ax=ax, orientation='vertical', fraction=0.05, pad=0.10)
    cbar_blue.set_label('Enrichment ≤ 1 (Escape)', rotation=270, labelpad=15)

    filename = os.path.join(output_dir, f"{immunization}_escape_binding_logo.png")
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Saved plot to {filename}")
    plt.show()


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import dmslogo  # make sure this is installed: pip install dmslogo

# -----------------------------
# PREP: Aggregation and Setup
# -----------------------------
# Use median Enrichment_Ratio per (Spike_AS_Position, Amino_Acid, immunization)
df_agg = df_total.copy()
df_agg['Amino_Acid'] = df_agg['Amino_Acid'].str.upper()
df_agg = df_agg[df_agg['Amino_Acid'] != "*"]  # remove stop codons

# Replace or clip zeros to a small positive value to avoid issues in log scale
df_agg['Enrichment_Ratio'] = df_agg['Enrichment_Ratio'].clip(lower=1e-5)

# Aggregate by median across barcodes
df_logo_agg = df_agg.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization'], as_index=False
).agg({'Enrichment_Ratio': 'median'})

# -----------------------------
# COLOR FUNCTION
# -----------------------------
def enrichment_color(enrichment):
    enrichment = max(enrichment, 1e-5)
    log_enrichment = np.log2(enrichment)
    # Symmetric color scale between -10 and +10
    norm = plt.Normalize(vmin=-10, vmax=10)
    return plt.cm.coolwarm(norm(log_enrichment))

# -----------------------------
# HEIGHT FUNCTION
# -----------------------------
def enrichment_height(enrichment):
    enrichment = max(enrichment, 1e-5)
    return np.log2(enrichment)

# -----------------------------
# Add color and letter height
# -----------------------------
df_logo_agg['letter_height'] = df_logo_agg['Enrichment_Ratio'].apply(enrichment_height)
df_logo_agg['color'] = df_logo_agg['Enrichment_Ratio'].apply(enrichment_color)

# -----------------------------
# Filter specific Spike positions
# -----------------------------
sites_to_show = [453,484,501]
df_logo_agg = df_logo_agg[df_logo_agg['Spike_AS_Position'].isin(sites_to_show)]

# -----------------------------
# Output directory
# -----------------------------
output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

# -----------------------------
# Generate plots
# -----------------------------
for immunization in df_logo_agg['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    df_immunization = df_logo_agg[df_logo_agg['immunization'] == immunization]

    fig, ax = dmslogo.draw_logo(
        df_immunization,
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",
        color_col="color",
        title=f"{immunization}",
        addbreaks=True
    )

    # Add zero line and labels
    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_ylabel('Antibody-binding (Log2)')
    ax.set_title(ax.get_title(), pad=20) 
    # Add colorbar
    sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=plt.Normalize(vmin=-10, vmax=10))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('log2 (Enrichment)', rotation=270, labelpad=15)

    # Save figure
    filename = os.path.join(output_dir, f"{immunization}_dual_escape_logo.png")
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Plot saved as {filename}")
    plt.show()


In [None]:
# --- Step 1: Preprocess Data ---
df = df_total.copy()
df['Amino_Acid'] = df['Amino_Acid'].str.upper()
df = df[df['Amino_Acid'] != "*"]
df['Enrichment_Ratio'] = df['Enrichment_Ratio'].clip(lower=1e-5)

# Count barcodes and get max enrichment per (position, AA, immunization)

# Grouping
df_grouped = df.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Barcode count and median enrichment
df_counts = df_grouped.size().reset_index(name='Sample_Count')
df_median_enrich = df_grouped['Enrichment_Ratio'].median().reset_index(name='Median_Enrichment')

# Merge
df_logo = pd.merge(df_counts, df_median_enrich, on=['Spike_AS_Position', 'Amino_Acid', 'immunization'])

# Assign color using median enrichment
df_logo['color'] = df_logo['Median_Enrichment'].apply(enrichment_color)

# Use barcode count for height (as in your new version)
df_logo['letter_height'] = df_logo.apply(
    lambda row: row['Sample_Count'] if row['Median_Enrichment'] > 1 else -row['Sample_Count'],
    axis=1
)


# Assign color by enrichment
def enrichment_color(enrichment):
    enrichment = max(enrichment, 1e-5)
    log_enrichment = np.log2(enrichment)
    norm = plt.Normalize(vmin=-10, vmax=10)
    return plt.cm.coolwarm(norm(log_enrichment))


# Filter specific Spike positions
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]
df_logo = df_logo[df_logo['Spike_AS_Position'].isin(sites_to_show)]

# --- Step 2: Plot ---
output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)

for immunization in df_logo['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    df_immunization = df_logo[df_logo['immunization'] == immunization]

    fig, ax = dmslogo.draw_logo(
        df_immunization,
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="letter_height",
        color_col="color",
        title=f"{immunization} Occurrence-based Enrichment Logo",
        addbreaks=True
    )

    ax.axhline(0, color='black', linewidth=0.8)
    ax.set_ylabel('IgG Secreting Cells [n]')

    sm = plt.cm.ScalarMappable(cmap="coolwarm", norm=plt.Normalize(vmin=-10, vmax=10))
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
    cbar.set_label('log2(Enrichment Ratio)', rotation=270, labelpad=15)

    filename = os.path.join(output_dir, f"{immunization}_combined_logoplot.png")
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Plot saved as {filename}")
    plt.show()


In [None]:
# In the next block we want to test weather A) occurances and B) Enrichment ratios FOR the selected SPIKE protein positions 
# are different between the immunization conditions. 

In [None]:
sites_to_show = [416,417, 439, 440,441, 451, 452,453, 475,476, 477,478, 484, 492,493,494,495,500,501, 502, 503,505]

In [None]:
from itertools import combinations
from scipy.stats import fisher_exact

# Step 1: Filter for NON-SYNOM, selected sites, and Enrichment_Ratio > 3
non_syn_df = df_total[
    (df_total['Type_of_Mutation'] == 'NON-SYNOM') &
    (df_total['Spike_AS_Position'].isin(sites_to_show)) &
    (df_total['Enrichment_Ratio'] > 3)
]

# Step 2: Count AA occurrences across barcodes
df_non_syn_agg = (
    non_syn_df
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'])
    .agg({'Enrichment_Ratio': 'sum'})
    .reset_index()
)

# Step 3: Sample counts (how many barcodes had this AA at this site per immunization)
df_sample_counts = (
    df_non_syn_agg
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

# Step 4: Barcode totals per immunization
barcode_counts = df_total.groupby('immunization')['barcode'].nunique().to_dict()

# Step 5: Run Fisher's exact test for all immunization pairs
freq_stats = []

# Loop over all relevant AA-site combinations
for (pos, aa), subdf in df_sample_counts.groupby(['Spike_AS_Position', 'Amino_Acid']):
    immunizations = subdf['immunization'].unique()
    if len(immunizations) < 2:
        continue

    # Compare each pair of immunizations
    for imm1, imm2 in combinations(immunizations, 2):
        count1 = subdf.loc[subdf['immunization'] == imm1, 'Sample_Count'].values
        count2 = subdf.loc[subdf['immunization'] == imm2, 'Sample_Count'].values

        # Handle missing values as 0
        n_pos_1 = count1[0] if len(count1) > 0 else 0
        n_pos_2 = count2[0] if len(count2) > 0 else 0

        total1 = barcode_counts.get(imm1, 0)
        total2 = barcode_counts.get(imm2, 0)

        n_neg_1 = total1 - n_pos_1
        n_neg_2 = total2 - n_pos_2

        # 2x2 contingency table
        table = [[n_pos_1, n_neg_1],
                 [n_pos_2, n_neg_2]]

        # Skip if total is too small
        if total1 == 0 or total2 == 0:
            continue

        stat, p_val = fisher_exact(table)
        freq_stats.append({
            'Spike_AS_Position': pos,
            'Amino_Acid': aa,
            'Immunization_1': imm1,
            'Immunization_2': imm2,
            'n_pos_1': n_pos_1,
            'n_neg_1': n_neg_1,
            'n_pos_2': n_pos_2,
            'n_neg_2': n_neg_2,
            'p_value': p_val
        })

# Step 6: Create DataFrame with results
df_freq_stats = pd.DataFrame(freq_stats)

# Optional: Filter for significant values
significant = df_freq_stats[df_freq_stats['p_value'] < 0.05]

# Save or display results
print(df_freq_stats.head())


In [None]:
from itertools import combinations
from scipy.stats import fisher_exact

# Step 1: Filter relevant data
non_syn_df = df_total[
    (df_total['Type_of_Mutation'] == 'NON-SYNOM') &
    (df_total['Spike_AS_Position'].isin(sites_to_show)) &
    (df_total['Enrichment_Ratio'] > 1)
]

# Step 2: Aggregate counts
df_non_syn_agg = (
    non_syn_df
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'])
    .agg({'Enrichment_Ratio': 'sum'})
    .reset_index()
)

df_sample_counts = (
    df_non_syn_agg
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

barcode_counts = df_total.groupby('immunization')['barcode'].nunique().to_dict()

# Step 3: Fisher tests and print results
freq_stats = []

for (pos, aa), subdf in df_sample_counts.groupby(['Spike_AS_Position', 'Amino_Acid']):
    immunizations = subdf['immunization'].unique()
    if len(immunizations) < 2:
        continue

    for imm1, imm2 in combinations(immunizations, 2):
        count1 = subdf.loc[subdf['immunization'] == imm1, 'Sample_Count'].values
        count2 = subdf.loc[subdf['immunization'] == imm2, 'Sample_Count'].values

        n_pos_1 = count1[0] if len(count1) > 0 else 0
        n_pos_2 = count2[0] if len(count2) > 0 else 0

        total1 = barcode_counts.get(imm1, 0)
        total2 = barcode_counts.get(imm2, 0)

        n_neg_1 = total1 - n_pos_1
        n_neg_2 = total2 - n_pos_2

        if total1 == 0 or total2 == 0:
            continue

        table = [[n_pos_1, n_neg_1], [n_pos_2, n_neg_2]]
        stat, p_val = fisher_exact(table)

        # Store results
        result = {
            'Spike_AS_Position': pos,
            'Amino_Acid': aa,
            'Immunization_1': imm1,
            'Immunization_2': imm2,
            'n_pos_1': n_pos_1,
            'n_neg_1': n_neg_1,
            'n_pos_2': n_pos_2,
            'n_neg_2': n_neg_2,
            'p_value': p_val
        }
        freq_stats.append(result)

        # Print formatted summary
        print(f"\n[Fisher's Exact Test] Site {pos}, AA {aa}")
        print(f"Comparison: {imm1} vs {imm2}")
        print(f"Table: [[{n_pos_1}, {n_neg_1}], [{n_pos_2}, {n_neg_2}]]")
        print(f"P-value: {p_val:.4g}")

# Convert to DataFrame
df_freq_stats = pd.DataFrame(freq_stats)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Define abbreviations for the immunization conditions
abbr_map = {
    'Library_ctrl': 'Lib',
    'Polyclonal_Ab': 'pAB',
    'Mutant_RBD': 'B.1.135',
    'Neutralizing_Ab': 'mAB Neut',
    'wildtype_RBD': 'WT',
    # Add other mappings as needed
}

# Filter for significant results (p < 0.05)
significant_results = df_freq_stats[df_freq_stats['p_value'] < 0.05].copy()

if significant_results.empty:
    print("No significant results found.")
else:
    # Replace long names with abbreviations
    significant_results['Immunization_1_abbr'] = significant_results['Immunization_1'].map(abbr_map).fillna(significant_results['Immunization_1'])
    significant_results['Immunization_2_abbr'] = significant_results['Immunization_2'].map(abbr_map).fillna(significant_results['Immunization_2'])

    # Create combined label with abbreviations
    significant_results['Comparison_Label'] = (
        significant_results['Spike_AS_Position'].astype(str) + "\n" +
        significant_results['Amino_Acid'] + "\n" +
        significant_results['Immunization_1_abbr'] + "\n" +
        "vs." + "\n" +
        significant_results['Immunization_2_abbr']
    )

    plt.figure(figsize=(20, 6))
    sns.barplot(
        data=significant_results,
        x='Comparison_Label',
        y='p_value',
        palette='viridis'
    )
    plt.axhline(0.05, color='red', linestyle='--', label='Significance Threshold (0.05)')
    plt.xticks(rotation=0, ha='center', fontsize=8)
    plt.ylabel('P-value')
    plt.title('Significant Fisher\'s Exact Test Results')
    plt.legend()
    plt.tight_layout()
    output_path = os.path.join(output_dir, "pValue_FisherExactHist_.png")
    plt.savefig(output_path, format='png', dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
from itertools import combinations
from scipy.stats import fisher_exact

# Step 1: Filter relevant data
non_syn_df = df_total[
    (df_total['Type_of_Mutation'] == 'NON-SYNOM') &
    (df_total['Spike_AS_Position'].isin(sites_to_show)) &
    (df_total['Enrichment_Ratio'] > 1)
]

# Step 2: Aggregate counts
df_non_syn_agg = (
    non_syn_df
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'])
    .agg({'Enrichment_Ratio': 'sum'})
    .reset_index()
)

df_sample_counts = (
    df_non_syn_agg
    .groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()
    .reset_index(name='Sample_Count')
)

barcode_counts = df_total.groupby('immunization')['barcode'].nunique().to_dict()

# Step 3: Fisher tests and print results
freq_stats = []

for (pos, aa), subdf in df_sample_counts.groupby(['Spike_AS_Position', 'Amino_Acid']):
    immunizations = subdf['immunization'].unique()
    if len(immunizations) < 2:
        continue

    for imm1, imm2 in combinations(immunizations, 2):
        count1 = subdf.loc[subdf['immunization'] == imm1, 'Sample_Count'].values
        count2 = subdf.loc[subdf['immunization'] == imm2, 'Sample_Count'].values

        n_pos_1 = count1[0] if len(count1) > 0 else 0
        n_pos_2 = count2[0] if len(count2) > 0 else 0

        total1 = barcode_counts.get(imm1, 0)
        total2 = barcode_counts.get(imm2, 0)

        n_neg_1 = total1 - n_pos_1
        n_neg_2 = total2 - n_pos_2

        if total1 == 0 or total2 == 0:
            continue

        table = [[n_pos_1, n_neg_1], [n_pos_2, n_neg_2]]
        odds_ratio, p_val = fisher_exact(table)

        # Determine direction
        if odds_ratio > 1:
            direction = f"Higher in {imm1}"
        elif odds_ratio < 1:
            direction = f"Higher in {imm2}"
        else:
            direction = "No difference"

        # Store results
        result = {
            'Spike_AS_Position': pos,
            'Amino_Acid': aa,
            'Immunization_1': imm1,
            'Immunization_2': imm2,
            'n_pos_1': n_pos_1,
            'n_neg_1': n_neg_1,
            'n_pos_2': n_pos_2,
            'n_neg_2': n_neg_2,
            'odds_ratio': odds_ratio,
            'p_value': p_val,
            'direction': direction
        }
        freq_stats.append(result)

        # Print formatted summary
        print(f"\n[Fisher's Exact Test] Site {pos}, AA {aa}")
        print(f"Comparison: {imm1} vs {imm2}")
        print(f"Table: [[{n_pos_1}, {n_neg_1}], [{n_pos_2}, {n_neg_2}]]")
        print(f"Odds ratio: {odds_ratio:.4g}")
        print(f"P-value: {p_val:.4g}")
        print(f"Direction: {direction}")

# Convert to DataFrame
df_freq_stats = pd.DataFrame(freq_stats)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Define abbreviations for the immunization conditions
abbr_map = {
    'Library_ctrl': 'Lib',
    'Polyclonal_Ab': 'pAB',
    'Mutant_RBD': 'B.1.135',
    'Neutralizing_Ab': 'mAB',
    'wildtype_RBD': 'WT',
    # Add other mappings as needed
}

# Filter for significant results (p < 0.05)


significant_results = df_freq_stats[df_freq_stats['p_value'] < 0.05].copy()

# Remove any comparisons involving 'Library_ctrl' (or its abbreviation 'Lib')
significant_results = significant_results[
    ~(significant_results['Immunization_1'].isin(['Library_ctrl']) |
      significant_results['Immunization_2'].isin(['Library_ctrl']))
]

significant_results = significant_results[significant_results['Amino_Acid'] != '*']

if significant_results.empty:
    print("No significant results found after filtering Library_ctrl.")
else:
    # Replace long names with abbreviations
    significant_results['Immunization_1_abbr'] = significant_results['Immunization_1'].map(abbr_map).fillna(significant_results['Immunization_1'])
    significant_results['Immunization_2_abbr'] = significant_results['Immunization_2'].map(abbr_map).fillna(significant_results['Immunization_2'])

    # Reorder rows so that immunization with higher occurrence is always first
    def reorder_row(row):
        # Higher occurrence means direction says higher in Immunization_1 (odds_ratio > 1)
        if row['odds_ratio'] < 1:
            # swap immunizations and counts
            return pd.Series({
                'Spike_AS_Position': row['Spike_AS_Position'],
                'Amino_Acid': row['Amino_Acid'],
                'Immunization_1': row['Immunization_2'],
                'Immunization_2': row['Immunization_1'],
                'Immunization_1_abbr': row['Immunization_2_abbr'],
                'Immunization_2_abbr': row['Immunization_1_abbr'],
                'p_value': row['p_value']
            })
        else:
            return pd.Series({
                'Spike_AS_Position': row['Spike_AS_Position'],
                'Amino_Acid': row['Amino_Acid'],
                'Immunization_1': row['Immunization_1'],
                'Immunization_2': row['Immunization_2'],
                'Immunization_1_abbr': row['Immunization_1_abbr'],
                'Immunization_2_abbr': row['Immunization_2_abbr'],
                'p_value': row['p_value']
            })

    reordered = significant_results.apply(reorder_row, axis=1)

    # Create combined label with abbreviations (higher occurrence first)
    reordered['Comparison_Label'] = (
        reordered['Spike_AS_Position'].astype(str) + "\n" +
        reordered['Amino_Acid'] + "\n" +
        reordered['Immunization_1_abbr'] + "\n" +
        "vs." + "\n" +
        reordered['Immunization_2_abbr']
    )

    plt.figure(figsize=(14, 5))
    sns.barplot(
        data=reordered,
        x='Comparison_Label',
        y='p_value',
        palette='viridis'
    )
    plt.axhline(0.05, color='red', linestyle='--', label='Significance Threshold (0.05)')
    plt.xticks(rotation=0, ha='center', fontsize=8)
    plt.ylabel('P-value')
    plt.title('Significant Fisher\'s Exact Test Results (Higher occurrence first)')
    plt.legend()
    plt.tight_layout()
    output_path = os.path.join(output_dir, "pValue_Occurence.png")
    plt.savefig(output_path, format='png', dpi=300, bbox_inches='tight')
    plt.show()


In [None]:
from itertools import combinations
from scipy.stats import mannwhitneyu

# Filter relevant data: Non-syn mutations, selected sites, enrichment ratio >1
non_syn_df = df_total[
    (df_total['Type_of_Mutation'] == 'NON-SYNOM') &
    (df_total['Spike_AS_Position'].isin(sites_to_show)) &
    (df_total['Enrichment_Ratio'] > 1)
]

results = []

# Loop over each position and amino acid
for (pos, aa), group_df in non_syn_df.groupby(['Spike_AS_Position', 'Amino_Acid']):
    immunizations = group_df['immunization'].unique()
    if len(immunizations) < 2:
        continue
    
    # Compare all pairs of immunizations
    for imm1, imm2 in combinations(immunizations, 2):
        enr_1 = group_df.loc[group_df['immunization'] == imm1, 'Enrichment_Ratio']
        enr_2 = group_df.loc[group_df['immunization'] == imm2, 'Enrichment_Ratio']

        # Skip if either group has no data
        if len(enr_1) == 0 or len(enr_2) == 0:
            continue
        
        # Mann-Whitney U test (two-sided)
        stat, p_val = mannwhitneyu(enr_1, enr_2, alternative='two-sided')
        
        median_1 = enr_1.median()
        median_2 = enr_2.median()
        
        # Determine direction based on median
        if median_1 > median_2:
            direction = f"Higher median in {imm1}"
        elif median_2 > median_1:
            direction = f"Higher median in {imm2}"
        else:
            direction = "No median difference"
        
        # Store results
        result = {
            'Spike_AS_Position': pos,
            'Amino_Acid': aa,
            'Immunization_1': imm1,
            'Immunization_2': imm2,
            'Median_Enrichment_1': median_1,
            'Median_Enrichment_2': median_2,
            'MannWhitneyU_stat': stat,
            'p_value': p_val,
            'Direction': direction
        }
        results.append(result)

        # Print summary
        print(f"\n[Median Enrichment Ratio Comparison] Site {pos}, AA {aa}")
        print(f"Comparison: {imm1} vs {imm2}")
        print(f"Median enrichment ratio: {median_1:.4f} vs {median_2:.4f}")
        print(f"Mann-Whitney U stat: {stat:.4g}")
        print(f"P-value: {p_val:.4g}")
        print(f"Direction: {direction}")

# Convert to DataFrame
df_enrichment_stats = pd.DataFrame(results)


In [None]:
sites_to_show = [416,417, 418,419,420,421,422,439, 440,441, 450,451, 452,453, 475,476, 477,478, 484, 492,493,494,495,500,501, 502, 503,504,505]

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

abbr_map = {
    'Library_ctrl': 'Lib',
    'Polyclonal_Ab': 'pAB',
    'Mutant_RBD': 'B.1.135',
    'Neutralizing_Ab': 'mAB',
    'wildtype_RBD': 'WT',
}

# Filter significant results and exclude library and stop codons (*)
significant_enrich = df_enrichment_stats[
    (df_enrichment_stats['p_value'] < 0.05) &
    (~df_enrichment_stats['Immunization_1'].isin(['Library_ctrl'])) &
    (~df_enrichment_stats['Immunization_2'].isin(['Library_ctrl'])) &
    (df_enrichment_stats['Amino_Acid'] != '*')
].copy()

if significant_enrich.empty:
    print("No significant median enrichment ratio differences found (after filtering).")
else:
    significant_enrich['Imm1_abbr'] = significant_enrich['Immunization_1'].map(abbr_map).fillna(significant_enrich['Immunization_1'])
    significant_enrich['Imm2_abbr'] = significant_enrich['Immunization_2'].map(abbr_map).fillna(significant_enrich['Immunization_2'])

    def order_imm(row):
        if row['Median_Enrichment_1'] >= row['Median_Enrichment_2']:
            return (row['Imm1_abbr'], row['Imm2_abbr'], row['Median_Enrichment_1'], row['Median_Enrichment_2'])
        else:
            return (row['Imm2_abbr'], row['Imm1_abbr'], row['Median_Enrichment_2'], row['Median_Enrichment_1'])

    ordered = significant_enrich.apply(order_imm, axis=1, result_type='expand')
    significant_enrich['Imm_high'] = ordered[0]
    significant_enrich['Imm_low'] = ordered[1]
    significant_enrich['Median_high'] = ordered[2]
    significant_enrich['Median_low'] = ordered[3]

    significant_enrich['Comparison_Label'] = (
        significant_enrich['Spike_AS_Position'].astype(str) + "\n" +
        significant_enrich['Amino_Acid'] + "\n" +
        significant_enrich['Imm_high'] + "\n" +
        "vs.\n" +
        significant_enrich['Imm_low']
    )

    plt.figure(figsize=(7, 5))
    sns.barplot(
        data=significant_enrich,
        x='Comparison_Label',
        y='p_value',
        palette='mako_r'
    )
    plt.axhline(0.05, color='red', linestyle='--', label='Significance Threshold (0.05)')
    plt.xticks(rotation=0, ha='right', fontsize=10)
    plt.ylabel('P-value (Mann-Whitney U test)')
    plt.title('Significant Median Enrichment Ratio Comparisons (Excluding Library and Stop Codons)')
    plt.legend()
    plt.tight_layout()
    output_path = os.path.join(output_dir, "pValue_ER_Ratio.png")
    plt.savefig(output_path, format='png', dpi=300, bbox_inches='tight')
    plt.show()



In [None]:
from scipy.stats import fisher_exact

# Total barcodes per immunization
barcode_counts = df_logo_agg.groupby('immunization')['barcode'].nunique().to_dict()

# Store results
freq_stats = []

for (pos, aa), subdf in df_combined.groupby(['Spike_AS_Position', 'Amino_Acid']):
    table = []
    for imm in subdf['immunization'].unique():
        n_pos = subdf[(subdf['immunization'] == imm)]['Sample_Count'].values[0]  # Count with that AA
        total = barcode_counts.get(imm, 0)
        n_neg = total - n_pos
        table.append([n_pos, n_neg])
    
    # Only compare 2 groups (you can extend to Chi2 if needed)
    if len(table) == 2:
        stat, p_val = fisher_exact(table)
        freq_stats.append({'Spike_AS_Position': pos, 'Amino_Acid': aa, 'Test': 'Fisher', 'p_value': p_val})

df_freq_stats = pd.DataFrame(freq_stats)


In [None]:
print(df_logo_agg.groupby('immunization')['barcode'].nunique())


In [None]:
print("Significant enrichment differences:")
print(sig_enrich)


In [None]:
print(f"Total tests: {len(df_enrichment_stats)}")
print(f"Significant (p < 0.05): {len(sig_enrich)}")


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.colors as mcolors

# Aggregate enrichment ratio per barcode, amino acid, and position
df_logo_agg = df_total.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False
).agg({
    'Enrichment_Ratio': 'sum'  # or 'mean' or 'max' depending on your preference
})

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]

# Filter for positions > 365
df_filtered = df_filtered[df_filtered['Spike_AS_Position'] > 365]

# Assign a fixed color for each Amino Acid (e.g., a color map keyed by AA letter)
aa_list = sorted(df_filtered['Amino_Acid'].unique())
# Use a categorical colormap or assign colors manually:
colors = plt.cm.tab20.colors  # a palette with 20 distinct colors
aa_colors = {aa: colors[i % len(colors)] for i, aa in enumerate(aa_list)}

df_filtered['color'] = df_filtered['Amino_Acid'].map(aa_colors)

# Prepare labels for logo plot
df_filtered = df_filtered.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# Create output directory
output_dir = "barcode_logoplots"
os.makedirs(output_dir, exist_ok=True)

# Loop by barcode (each plot includes immunization label)
for barcode in df_filtered['barcode'].unique():
    df_barcode = df_filtered[df_filtered['barcode'] == barcode]
    if df_barcode.empty:
        continue

    immunization = df_barcode['immunization'].iloc[0]

    print(f"Generating plot for barcode {barcode} ({immunization})...")

    fig, ax = dmslogo.draw_logo(
        df_barcode,
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Enrichment_Ratio",  # Use enrichment ratio for height
        color_col="color",
        title=f"{immunization} - {barcode} logoplot (Enrichment Ratio-based)",
        addbreaks=True
    )

    fig.set_size_inches(45, 4)

    # Optionally: remove colorbar or create a legend for amino acid colors
    # Here's a legend for amino acids
    from matplotlib.patches import Patch
    legend_elements = [Patch(facecolor=col, label=aa) for aa, col in aa_colors.items()]
    ax.legend(handles=legend_elements, title='Amino Acid', bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.show()

    # Save plot
    safe_barcode = barcode.replace("/", "_").replace("\\", "_")  # Sanitize filename
    plot_filename = os.path.join(output_dir, f"{safe_barcode}_logoplot.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')
    print(f"Plot saved as {plot_filename}")


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import matplotlib.ticker as ticker

# Aggregate enrichment ratio per barcode, amino acid, and position
df_logo_agg = df_total.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False
).agg({
    'Enrichment_Ratio': 'median'  # or 'mean' or 'max'
})

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]
df_logo_agg = df_logo_agg.replace([np.inf, -np.inf], np.nan)
df_logo_agg = df_logo_agg.dropna(subset=['Enrichment_Ratio'])

df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 1]
df_filtered = df_filtered[df_filtered['Spike_AS_Position'] > 365]

print(df_filtered['Enrichment_Ratio'].min())  # Should be > 1

# Add a new column with log10 of Enrichment_Ratio (safe transformation)
def safe_log10(x):
    return np.log10(x) if x > 0 else np.nan

df_filtered['Enrichment_Ratio_log'] = df_filtered['Enrichment_Ratio'].apply(safe_log10)
df_filtered = df_filtered.dropna(subset=['Enrichment_Ratio_log'])

# Color assignment
aa_list = sorted(df_filtered['Amino_Acid'].unique())
colors = plt.cm.tab20.colors
aa_colors = {aa: colors[i % len(colors)] for i, aa in enumerate(aa_list)}
df_filtered['color'] = df_filtered['Amino_Acid'].map(aa_colors)

# Create output directory
output_dir = "barcode_logoplots_panels"
os.makedirs(output_dir, exist_ok=True)

# Helper function to format y-axis ticks as 10^x
def log_ticks(ax):
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda y, _: f"$10^{{{int(y)}}}$"))

# Loop over each immunization and create vertical panel of top 4 barcodes
for immunization in df_filtered['immunization'].unique():
    df_imm = df_filtered[df_filtered['immunization'] == immunization]
    barcodes = df_imm['barcode'].unique()[:4]  # Top 4 barcodes only

    if len(barcodes) == 0:
        continue

    print(f"Generating panel for immunization: {immunization}")

    fig, axs = plt.subplots(len(barcodes), 1, figsize=(45, 4 * len(barcodes)))

    if len(barcodes) == 1:
        axs = [axs]

    for i, barcode in enumerate(barcodes):
        df_barcode = df_imm[(df_imm['barcode'] == barcode) & (df_imm['Enrichment_Ratio_log'].notna())]
    
        if df_barcode.empty:
            continue
    
        fig_sub, ax = dmslogo.draw_logo(
            df_barcode,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log",  # Use log-transformed values
            color_col="color",
            title=f"{immunization} - {barcode} logoplot (log10 Enrichment Ratio)",
            addbreaks=True,
            ax=axs[i]
        )
    
        axs[i].set_ylabel('log10(Enrichment Ratio)')
        log_ticks(axs[i])
    
        legend_elements = [Patch(facecolor=col, label=aa) for aa, col in aa_colors.items()]
        axs[i].legend(handles=legend_elements, title='Amino Acid', bbox_to_anchor=(1.01, 1), loc='upper left')

    plt.tight_layout()
    safe_immunization = immunization.replace("/", "_").replace("\\", "_")
    plot_filename = os.path.join(output_dir, f"{safe_immunization}_top4_logoplots.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')
    plt.close(fig)
    print(f"Panel saved as {plot_filename}")


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import matplotlib.ticker as ticker
import pandas as pd  # added for concat/copy

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]
df_logo_agg = df_logo_agg.replace([np.inf, -np.inf], np.nan)
df_logo_agg = df_logo_agg.dropna(subset=['Enrichment_Ratio'])

df_filtered = df_logo_agg.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'])

# Filter positions > 365
df_filtered = df_filtered[df_filtered['Spike_AS_Position'] > 365]

print(df_filtered['Enrichment_Ratio'].min())  # Should be > 1

duplicates = df_barcode[df_barcode.duplicated(subset=['Spike_AS_Position', 'Amino_Acid'], keep=False)]
print(duplicates)

# Add a new column with log10 of Enrichment_Ratio (safe transformation)
def safe_log10(enrichment, epsilon=1e-8, max_cap=1e8):
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)

df_filtered['Enrichment_Ratio_log'] = df_filtered['Enrichment_Ratio'].apply(safe_log10)
df_filtered = df_filtered.dropna(subset=['Enrichment_Ratio_log'])

# Color assignment
aa_list = sorted(df_filtered['Amino_Acid'].unique())
colors = plt.cm.tab20.colors
aa_colors = {aa: colors[i % len(colors)] for i, aa in enumerate(aa_list)}
df_filtered['color'] = df_filtered['Amino_Acid'].map(aa_colors)

# Output directory
output_dir = "barcode_logoplots_panels"
os.makedirs(output_dir, exist_ok=True)

# Helper function for plain integer log ticks
def log_ticks_plain(ax):
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda y, _: f"{int(y)}"))

# --- NEW: explicit positive/negative stacking like in your 2nd code ---
def separate_stacked_heights(df, x_col='Spike_AS_Position', h_col='Enrichment_Ratio_log'):
    # Sort primarily by position, then by height (largest first) for stable order
    df_sorted = df.sort_values(by=[x_col, h_col], ascending=[True, False]).copy()

    # Split by sign
    df_positive = df_sorted[df_sorted[h_col] >= 0].copy()
    df_negative = df_sorted[df_sorted[h_col] < 0].copy()

    # Cumulative within each position gives where each letter should start
    df_positive['stack_bottom'] = df_positive.groupby(x_col)[h_col].cumsum() - df_positive[h_col]
    df_negative['stack_bottom'] = df_negative.groupby(x_col)[h_col].cumsum() - df_negative[h_col]

    # Recombine; sort by baseline so draw order follows bottom→top within position
    out = pd.concat([df_positive, df_negative], axis=0).sort_values(by=[x_col, 'stack_bottom'])
    return out

# Plot loop
for immunization in df_filtered['immunization'].unique():
    df_imm = df_filtered[df_filtered['immunization'] == immunization]
    barcodes = df_imm['barcode'].unique()[:4]

    if len(barcodes) == 0:
        continue

    print(f"Generating panel for immunization: {immunization}")

    fig, axs = plt.subplots(len(barcodes), 1, figsize=(45, 4 * len(barcodes)),
                             gridspec_kw={'hspace': 0.5})

    if len(barcodes) == 1:
        axs = [axs]

    for i, barcode in enumerate(barcodes):
        df_barcode = df_imm[(df_imm['barcode'] == barcode) &
                            (df_imm['Enrichment_Ratio_log'].notna())]

        if df_barcode.empty:
            continue

        # >>> apply explicit stacking so letters don't overlap
        df_barcode = separate_stacked_heights(
            df_barcode,
            x_col="Spike_AS_Position",
            h_col="Enrichment_Ratio_log"
        )

        fig_sub, ax = dmslogo.draw_logo(
            df_barcode,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log",
            color_col="color",
            title=f"{immunization} - {barcode} logoplot (log10 Enrichment Ratio)",
            addbreaks=True,
            ax=axs[i]
        )

        axs[i].set_title(f"{immunization} - {barcode} logoplot (log10 Enrichment Ratio)", pad=25)
        axs[i].set_ylabel('log10 AB binding')
        axs[i].yaxis.set_label_coords(-2, 0.5)
        axs[i].set_ylim(-10, 6)
        axs[i].yaxis.set_major_locator(ticker.MultipleLocator(2))
        log_ticks_plain(axs[i])

    legend_elements = [Patch(facecolor=col, label=aa) for aa, col in aa_colors.items()]
    fig.subplots_adjust(right=0.85)
    fig.legend(handles=legend_elements, title='Amino Acid', loc='center right', borderaxespad=0.1)

    plt.tight_layout(rect=[0, 0, 0.85, 2])
    safe_immunization = immunization.replace("/", "_").replace("\\", "_")
    plot_filename = os.path.join(output_dir, f"{safe_immunization}_top4_logoplots.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')
    plt.show()
    print(f"Panel saved as {plot_filename}")


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import matplotlib.ticker as ticker

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]
df_logo_agg = df_logo_agg.replace([np.inf, -np.inf], np.nan)
df_logo_agg = df_logo_agg.dropna(subset=['Enrichment_Ratio'])

df_filtered = df_logo_agg.drop_duplicates(subset=['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'])

# Filter positions > 365
df_filtered = df_filtered[df_filtered['Spike_AS_Position'] > 365]

print(df_filtered['Enrichment_Ratio'].min())  # Should be > 1

duplicates = df_barcode[df_barcode.duplicated(subset=['Spike_AS_Position', 'Amino_Acid'], keep=False)]
print(duplicates)

# Add a new column with log10 of Enrichment_Ratio (safe transformation)
def safe_log10(enrichment, epsilon=1e-8, max_cap=1e8):
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)

df_filtered['Enrichment_Ratio_log'] = df_filtered['Enrichment_Ratio'].apply(safe_log10)
df_filtered = df_filtered.dropna(subset=['Enrichment_Ratio_log'])

# Color assignment
aa_list = sorted(df_filtered['Amino_Acid'].unique())
colors = plt.cm.tab20.colors
aa_colors = {aa: colors[i % len(colors)] for i, aa in enumerate(aa_list)}
df_filtered['color'] = df_filtered['Amino_Acid'].map(aa_colors)

# Output directory
output_dir = "barcode_logoplots_panels"
os.makedirs(output_dir, exist_ok=True)

# Helper function for plain integer log ticks
def log_ticks_plain(ax):
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda y, _: f"{int(y)}"))

# Improved stacking function — letters sorted by height before drawing
def improved_draw_logo(df, **kwargs):
    # Sort within each x position so letters with greater height are on top
    df_sorted = df.sort_values(
        by=[kwargs['x_col'], kwargs['letter_height_col']],
        ascending=[True, True]
    )
    return dmslogo.draw_logo(df_sorted, **kwargs)

# Plot loop
for immunization in df_filtered['immunization'].unique():
    df_imm = df_filtered[df_filtered['immunization'] == immunization]
    barcodes = df_imm['barcode'].unique()[:4]

    if len(barcodes) == 0:
        continue

    print(f"Generating panel for immunization: {immunization}")

    fig, axs = plt.subplots(len(barcodes), 1, figsize=(45, 12 * len(barcodes)),
                             gridspec_kw={'hspace': 2.5})

    if len(barcodes) == 1:
        axs = [axs]

    for i, barcode in enumerate(barcodes):
        df_barcode = df_imm[(df_imm['barcode'] == barcode) &
                            (df_imm['Enrichment_Ratio_log'].notna())]

        if df_barcode.empty:
            continue

        fig_sub, ax = improved_draw_logo(
            df_barcode,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log",
            color_col="color",
            title=f"{immunization} - {barcode} logoplot (log10 Enrichment Ratio)",
            addbreaks=True,
            ax=axs[i]
        )

        axs[i].set_title(f"{immunization} - {barcode} logoplot (log10 Enrichment Ratio)", pad=25)
        axs[i].set_ylabel('Log10 AB binding (median) \u2190 Enrichment \u2192', fontsize=10)
        axs[i].yaxis.set_label_coords(-0.01, 0.5)
        axs[i].set_ylim(-10, 6)
        axs[i].yaxis.set_major_locator(ticker.MultipleLocator(4))
        log_ticks_plain(axs[i])

    legend_elements = [Patch(facecolor=col, label=aa) for aa, col in aa_colors.items()]
    fig.subplots_adjust(right=0.85)
    fig.legend(handles=legend_elements, title='Amino Acid', loc='center right', borderaxespad=0.1)

    plt.tight_layout(rect=[0, 0, 0.85, 2])
    safe_immunization = immunization.replace("/", "_").replace("\\", "_")
    plot_filename = os.path.join(output_dir, f"{safe_immunization}_top4_logoplots.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')
    plt.show()
    print(f"Panel saved as {plot_filename}")


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import matplotlib.ticker as ticker
import matplotlib.colors as colors
from matplotlib.ticker import ScalarFormatter

# -----------------------------
# HEIGHT FUNCTION
# -----------------------------
def enrichment_height(enrichment, epsilon=1e-8, max_cap=1e8):
    """
    Safe log10 transformation of enrichment values.
    Args:
        enrichment (float or array-like): Enrichment value(s).
        epsilon (float): Minimum enrichment value to avoid log10(0).
        max_cap (float): Maximum value to cap extremely high enrichments.
    Returns:
        float or np.array: Transformed enrichment.
    """
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)

# Aggregate enrichment ratio per barcode, amino acid, and position
df_logo_agg = df_total.groupby(
    ['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False
).agg({'Enrichment_Ratio': 'sum'})  # or 'mean' or 'max'

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]
df_logo_agg = df_logo_agg.replace([np.inf, -np.inf], np.nan)
df_logo_agg = df_logo_agg.dropna(subset=['Enrichment_Ratio'])

# Include all positive values (0 excluded)
df_filtered = df_logo_agg[df_logo_agg['Spike_AS_Position'] > 365]

# Add letter height using safe log10 transformation
df_filtered['letter_height'] = df_filtered['Enrichment_Ratio'].apply(enrichment_height)
df_filtered = df_filtered.dropna(subset=['letter_height'])

# Color assignment
aa_list = sorted(df_filtered['Amino_Acid'].unique())
colors_list = plt.cm.tab20.colors
aa_colors = {aa: colors_list[i % len(colors_list)] for i, aa in enumerate(aa_list)}
df_filtered['color'] = df_filtered['Amino_Acid'].map(aa_colors)

# Create output directory
output_dir = "barcode_logoplots_panels"
os.makedirs(output_dir, exist_ok=True)

# Helper function to format y-axis ticks with log10 inverse style
def log_ticks(ax):
    """Set y-axis to log10 scale with ticks as 10^x."""
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda y, _: f"$10^{{{int(y)}}}$"))

# Loop over each immunization and create vertical panel of 4 random barcodes
for immunization in df_filtered['immunization'].unique():
    df_imm = df_filtered[df_filtered['immunization'] == immunization]
    barcodes_all = df_imm['barcode'].unique()
    
    if len(barcodes_all) == 0:
        continue

    # Randomly select up to 4 barcodes
    barcodes = random.sample(list(barcodes_all), min(4, len(barcodes_all)))

    print(f"Generating panel for immunization: {immunization} with barcodes: {barcodes}")

    fig, axs = plt.subplots(
        len(barcodes), 1, figsize=(45, 4 * len(barcodes)),
        gridspec_kw={'hspace': 0.5}
    )

    if len(barcodes) == 1:
        axs = [axs]

    for i, barcode in enumerate(barcodes):
        df_barcode = df_imm[(df_imm['barcode'] == barcode) & (df_imm['letter_height'].notna())]

        if df_barcode.empty:
            continue

        fig_sub, ax = dmslogo.draw_logo(
            df_barcode,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="letter_height",
            color_col="color",
            title=f"{immunization} - {barcode} logoplot (log10 Enrichment Ratio)",
            addbreaks=True,
            ax=axs[i]
        )

        axs[i].set_ylabel('Enrichment Ratio (log10)')
        axs[i].set_ylim(-10, 6)
        axs[i].yaxis.set_major_locator(ticker.MultipleLocator(2))
        log_ticks(axs[i])

    # Add legend once
    legend_elements = [Patch(facecolor=col, label=aa) for aa, col in aa_colors.items()]
    fig.subplots_adjust(right=0.85)
    fig.legend(handles=legend_elements, title='Amino Acid', loc='center right', borderaxespad=0.1)

    plt.tight_layout(rect=[0, 0, 0.85, 1])
    safe_immunization = immunization.replace("/", "_").replace("\\", "_")
    plot_filename = os.path.join(output_dir, f"{safe_immunization}_top4_random_logoplots.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')
    plt.close(fig)
    print(f"Panel saved as {plot_filename}")


In [None]:
import scipy
print(scipy.__version__)

from scipy.stats import binomtest
help(binomtest)



In [None]:
import pandas as pd
import numpy as np
from scipy.stats import binomtest
from statsmodels.stats.multitest import multipletests
import matplotlib.pyplot as plt
import seaborn as sns

# Start with df_total containing: ['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode']

# Filter out stop codons and optionally positions <= 365 (adjust as you want)
df_filtered = df_total[(df_total['Amino_Acid'] != '*') & (df_total['Spike_AS_Position'] > 365)]

results = []

for imm in df_filtered['immunization'].unique():
    df_imm = df_filtered[df_filtered['immunization'] == imm]
    barcodes = df_imm['barcode'].unique()
    n_barcodes = len(barcodes)

    # Build presence matrix for each (pos, aa) across barcodes
    presence_dict = {}
    for pos in df_imm['Spike_AS_Position'].unique():
        df_pos = df_imm[df_imm['Spike_AS_Position'] == pos]
        for aa in df_pos['Amino_Acid'].unique():
            presence = df_pos[df_pos['Amino_Acid'] == aa].groupby('barcode').size()
            presence = presence.apply(lambda x: 1 if x > 0 else 0)
            presence = presence.reindex(barcodes, fill_value=0)
            presence_dict[(pos, aa)] = presence.values

    presence_df = pd.DataFrame.from_dict(presence_dict, orient='index', columns=barcodes)

    # Calculate frequency
    freq = presence_df.sum(axis=1) / n_barcodes

    # Binomial test for enrichment with null hypothesis p0
    p0 = 0.05  # Adjust this null frequency as appropriate

    pvals = freq.apply(
        lambda f: binomtest(int(f * n_barcodes), n_barcodes, p=p0, alternative='greater').pvalue
    )

    df_res = pd.DataFrame({
        'immunization': imm,
        'Spike_AS_Position': [pos for pos, aa in freq.index],
        'Amino_Acid': [aa for pos, aa in freq.index],
        'frequency': freq.values,
        'p_value': pvals.values
    })

    results.append(df_res)

# Combine all immunizations
df_reproducibility = pd.concat(results)

# Adjust p-values globally
df_reproducibility['p_adj'] = multipletests(df_reproducibility['p_value'], method='fdr_bh')[1]

# Merge adjusted p-values and frequencies back to df_filtered for plotting
df_plot = pd.merge(
    df_filtered,
    df_reproducibility,
    on=['immunization', 'Spike_AS_Position', 'Amino_Acid'],
    how='left'
)

# Calculate -log10 adjusted p-values for plotting, handling zeros safely
df_plot['log_p_adj'] = -np.log10(df_plot['p_adj'].replace(0, np.nextafter(0, 1)))

# Plot combined scatterplot for all immunizations faceted by immunization
sns.set(style="whitegrid")
g = sns.FacetGrid(df_plot, col="immunization", col_wrap=2, height=5, sharey=False)
g.map_dataframe(
    sns.scatterplot,
    x='Spike_AS_Position',
    y='log_p_adj',
    hue='Amino_Acid',
    palette='tab20',
    edgecolor='black',
    linewidth=0.5,
    s=70
)
g.set_axis_labels("Spike Amino Acid Position", "-log10 Adjusted p-value")
g.add_legend(title='Amino Acid')
for ax in g.axes.flat:
    ax.axhline(-np.log10(0.05), color='red', linestyle='--')  # FDR 0.05 threshold
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set plotting style
sns.set(style="whitegrid")

plt.figure(figsize=(14, 7))

# Plot each immunization separately with a different color
immunizations = df_reproducibility['immunization'].unique()

for imm in immunizations:
    df_sub = df_reproducibility[df_reproducibility['immunization'] == imm]

    # Aggregate by position: take minimum p_adj (or mean, or median) across amino acids at the same position
    # Using min here to highlight the strongest signal per position
    pvals_by_pos = df_sub.groupby('Spike_AS_Position')['p_adj'].min()

    plt.plot(
        pvals_by_pos.index,
        pvals_by_pos.values,
        marker='o',
        linestyle='-',
        label=imm
    )

plt.axhline(0.05, color='red', linestyle='--', label='Significance threshold (0.05)')
plt.yscale('log')
plt.xlabel('Spike Amino Acid Position')
plt.ylabel('Adjusted p-value (log scale)')
plt.title('Adjusted p-values by Spike Position for each Immunization')
plt.legend(title='Immunization')
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set style
sns.set(style="whitegrid")

# Loop over each immunization and plot separately
for imm in df_reproducibility['immunization'].unique():
    df_imm = df_reproducibility[df_reproducibility['immunization'] == imm].copy()

    # Compute -log10(p_adj) for plotting
    df_imm['log_p_adj'] = -np.log10(df_imm['p_adj'])

    plt.figure(figsize=(16, 6))
    scatter = sns.scatterplot(
        data=df_imm,
        x='Spike_AS_Position',
        y='log_p_adj',
        hue='Amino_Acid',
        palette='tab20',
        edgecolor='black',
        linewidth=0.5,
        s=80
    )

    # Threshold line at p = 0.05 (adjusted)
    plt.axhline(-np.log10(0.05), color='red', linestyle='--', label='FDR = 0.05')

    plt.title(f'-log10 Adjusted p-values by Position and Amino Acid: {imm}')
    plt.xlabel('Spike Amino Acid Position')
    plt.ylabel('-log10(p_adj)')
    plt.legend(title='Amino Acid', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

    # Define bin edges for -log10(p_adj)
bins = np.arange(0, 11, 1)  # bins: 0–1, 1–2, ..., 9–10
bin_labels = [f"{i}-{i+1}" for i in bins[:-1]]

# Loop over each immunization
for imm in df_reproducibility['immunization'].unique():
    df_imm = df_reproducibility[df_reproducibility['immunization'] == imm].copy()
    df_imm['log_p_adj'] = -np.log10(df_imm['p_adj'])

    # Bin the values
    df_imm['log_p_bin'] = pd.cut(df_imm['log_p_adj'], bins=bins, labels=bin_labels, right=False)

    # Count and percent
    bin_counts = df_imm['log_p_bin'].value_counts(sort=False).fillna(0).astype(int)
    bin_percent = (bin_counts / len(df_imm) * 100).round(2)

    # Print summary
    print(f"\nImmunization: {imm}")
    print("Bin Range | Count | Percent")
    for b in bin_labels:
        print(f"{b:>8} | {bin_counts[b]:>5} | {bin_percent[b]:>6}%")



In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set style
sns.set(style="whitegrid")

# Loop over each immunization and plot separately
for imm in df_reproducibility['immunization'].unique():
    df_imm = df_reproducibility[df_reproducibility['immunization'] == imm].copy()

    # Compute -log10(p_adj) for plotting
    df_imm['log_p_adj'] = -np.log10(df_imm['p_adj'])

    plt.figure(figsize=(16, 6))
    scatter = sns.scatterplot(
        data=df_imm,
        x='Spike_AS_Position',
        y='log_p_adj',
        hue='Amino_Acid',
        palette='tab20',
        edgecolor='black',
        linewidth=0.5,
        s=80
    )

    # Threshold line at p = 0.05 (adjusted)
    plt.axhline(-np.log10(0.05), color='red', linestyle='--', label='FDR = 0.05')

    plt.title(f'-log10 Adjusted p-values by Position and Amino Acid: {imm}')
    plt.xlabel('Spike Amino Acid Position')
    plt.ylabel('-log10(p_adj)')
    plt.legend(title='Amino Acid', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()

 # Define bin edges and labels
bins = np.arange(0, 11, 1)
bin_labels = [f"{i}-{i+1}" for i in bins[:-1]]

# Loop over each immunization
for imm in df_reproducibility['immunization'].unique():
    df_imm = df_reproducibility[df_reproducibility['immunization'] == imm].copy()
    df_imm['log_p_adj'] = -np.log10(df_imm['p_adj'])

    # Drop rows with NaN values (these would not be plotted)
    df_imm_valid = df_imm.dropna(subset=['log_p_adj'])

    # Bin only valid values
    df_imm_valid['log_p_bin'] = pd.cut(df_imm_valid['log_p_adj'], bins=bins, labels=bin_labels, right=False)

    # Count and percent
    bin_counts = df_imm_valid['log_p_bin'].value_counts(sort=False).fillna(0).astype(int)
    bin_percent = (bin_counts / len(df_imm_valid) * 100).round(2)

    # Print summary
    print(f"\nImmunization: {imm}")
    print("Bin Range | Count | Percent")
    for b in bin_labels:
        print(f"{b:>8} | {bin_counts[b]:>5} | {bin_percent[b]:>6}%")
    
    print(f"Total points plotted: {len(df_imm_valid)} — Sum of %: {bin_percent.sum()}%")


In [None]:
print("Columns in df_reproducibility:")
print(df_reproducibility.columns.tolist())

print("\nSample data (first 5 rows):")
print(df_reproducibility.head())


In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# Avoid SettingWithCopyWarning by creating a new column safely
df_bootstrap = bootstrap_results.copy()


df_bootstrap = pd.DataFrame(bootstrap_results)
df_bootstrap['Pos'] = df_bootstrap['Spike_AS_Position'].astype(str)
df_bootstrap = df_bootstrap[df_bootstrap['immunization'] != 'Library_ctrl']


# Plot settings
sns.set(style="whitegrid")
g = sns.FacetGrid(
    df_bootstrap,
    col="immunization",
    col_wrap=3,
    height=5,
    sharey=False
)

# Custom barplot function with error bars
def plot_with_errorbars(data, color, **kwargs):
    ax = plt.gca()
    # Sort bars to keep order consistent
    data = data.sort_values(by=['Spike_AS_Position', 'Amino_Acid'])
    
    sns.barplot(
        data=data,
        x='Pos',
        y='freq_mean',
        hue='Amino_Acid',
        errorbar=None,
        palette='tab20',
        ax=ax,
        **kwargs
    )
    # Add error bars
    for i, row in data.iterrows():
        xpos = list(data['Pos'].unique()).index(row['Pos'])
        ax.errorbar(
            x=xpos,
            y=row['freq_mean'],
            yerr=row['freq_se'],
            fmt='none',
            ecolor='gray',
            elinewidth=1,
            capsize=2
        )

# Apply plotting function
g.map_dataframe(plot_with_errorbars)

# Axis and legend cleanup
g.set_axis_labels("Spike Amino Acid Position", "Mean Frequency")
g.set_titles(col_template="{col_name}")
for ax in g.axes.flat:
    ax.tick_params(axis='x', rotation=90)


g.set_axis_labels("Spike Amino Acid Position", "Mean Frequency")
g.set_titles(col_template="{col_name}")

for ax in g.axes.flat:
    ax.tick_params(axis='x', rotation=90)

    # Ensure x-axis ticks every 10th position
    positions = sorted(df_bootstrap['Spike_AS_Position'].unique())
    positions_int = sorted([int(p) for p in positions])
    tick_positions = [i for i, p in enumerate(positions_int) if p % 10 == 0]
    tick_labels = [str(positions_int[i]) for i in tick_positions]
    ax.set_xticks(tick_positions)
    ax.set_xticklabels(tick_labels)

g.add_legend(title='Amino Acid', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Sort immunizations and bins for consistent order
immunizations = sorted(df_bootstrap['immunization'].unique())
enrich_bins = sorted(df_bootstrap['enrichment_bin'].unique(), key=lambda x: enrichment_labels.index(x))

# Prepare figure height (e.g., 5 inches per plot)
plot_height_per = 5
total_plots = len(immunizations) * len(enrich_bins)
fig_height = plot_height_per * total_plots
fig_width = 16  # Wide to fit notebook width nicely

fig, axes = plt.subplots(total_plots, 1, figsize=(fig_width, fig_height), squeeze=False)

# Flatten axes for easy indexing
axes = axes.flatten()

plot_i = 0
for imm in immunizations:
    for enrich_bin in enrich_bins:
        ax = axes[plot_i]
        data = df_bootstrap[
            (df_bootstrap['immunization'] == imm) & 
            (df_bootstrap['enrichment_bin'] == enrich_bin)
        ]

        if data.empty:
            ax.set_visible(False)
            plot_i += 1
            continue

        # Sort data for nicer plotting
        data = data.sort_values(['Spike_AS_Position', 'Amino_Acid'])

        sns.barplot(
            data=data,
            x='Pos',
            y='freq_mean',
            hue='Amino_Acid',
            palette='tab20',
            ax=ax
        )
        
        # Add error bars manually
        for i, row in data.iterrows():
            xpos = list(data['Pos'].unique()).index(row['Pos'])
            ax.errorbar(
                x=xpos,
                y=row['freq_mean'],
                yerr=row['freq_se'],
                fmt='none',
                ecolor='gray',
                elinewidth=1,
                capsize=2
            )
        
        ax.set_title(f"Immunization: {imm} | Enrichment Bin: {enrich_bin}", fontsize=14)
        ax.set_xlabel("Spike Amino Acid Position")
        ax.set_ylabel("Mean Frequency")
        ax.tick_params(axis='x', rotation=90)
        
        # Only show legend on first plot to avoid clutter
        if plot_i == 0:
            ax.legend(title='Amino Acid', bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            ax.get_legend().remove()
        
        plot_i += 1

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Find max y across all data for consistent y-axis scaling
y_max = df_bootstrap['freq_mean'].max() * 1.05  # 5% padding on top

immunizations = sorted(df_bootstrap['immunization'].unique())
enrich_bins = sorted(df_bootstrap['enrichment_bin'].unique(), key=lambda x: enrichment_labels.index(x))

plot_height_per = 5
total_plots = len(immunizations) * len(enrich_bins)
fig_height = plot_height_per * total_plots
fig_width = 16

fig, axes = plt.subplots(total_plots, 1, figsize=(fig_width, fig_height), squeeze=False)
axes = axes.flatten()

plot_i = 0
for imm in immunizations:
    for enrich_bin in enrich_bins:
        ax = axes[plot_i]
        data = df_bootstrap[
            (df_bootstrap['immunization'] == imm) & 
            (df_bootstrap['enrichment_bin'] == enrich_bin)
        ]

        if data.empty:
            ax.set_visible(False)
            plot_i += 1
            continue

        data = data.sort_values(['Spike_AS_Position', 'Amino_Acid'])
        
        sns.barplot(
            data=data,
            x='Pos',
            y='freq_mean',
            hue='Amino_Acid',
            palette='tab20',
            ax=ax
        )
        
        # Add error bars manually
        for i, row in data.iterrows():
            xpos = list(data['Pos'].unique()).index(row['Pos'])
            ax.errorbar(
                x=xpos,
                y=row['freq_mean'],
                yerr=row['freq_se'],
                fmt='none',
                ecolor='gray',
                elinewidth=1,
                capsize=2
            )

        # Fix y axis limits (same for all)
        ax.set_ylim(0, y_max)
        # Remove gaps on x axis (tight margins)
        ax.margins(x=0)
        
        ax.set_title(f"Immunization: {imm} | Enrichment Bin: {enrich_bin}", fontsize=14)
        ax.set_xlabel("Spike Amino Acid Position")
        ax.set_ylabel("Bootstrap Mean Frequency (AA Mutation per Binding-Ratio bin")
        ax.tick_params(axis='x', rotation=90)
        
        # Legend only on first plot
        if plot_i == 0:
            ax.legend(title='Amino Acid', bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            ax.get_legend().remove()
        
        plot_i += 1

plt.tight_layout()
plt.show()


In [None]:
import numpy as np

# Number of synthetic samples per data point for violin shape
n_synthetic = 100

fig, axes = plt.subplots(total_plots, 1, figsize=(fig_width, fig_height), squeeze=False)
axes = axes.flatten()

plot_i = 0
for imm in immunizations:
    for enrich_bin in enrich_bins:
        ax = axes[plot_i]
        data = df_bootstrap[
            (df_bootstrap['immunization'] == imm) & 
            (df_bootstrap['enrichment_bin'] == enrich_bin)
        ]

        if data.empty:
            ax.set_visible(False)
            plot_i += 1
            continue

        data = data.sort_values(['Spike_AS_Position', 'Amino_Acid'])

        # Generate synthetic bootstrap samples for violin-like distribution
        # For each row (mean, se), simulate n_synthetic samples
        synthetic_samples = []
        for _, row in data.iterrows():
            # Use normal distribution, clip at 0 since freq can't be negative
            samples = np.random.normal(loc=row['freq_mean'], scale=row['freq_se'], size=n_synthetic)
            samples = np.clip(samples, a_min=0, a_max=None)
            for s in samples:
                synthetic_samples.append({
                    'Pos': row['Pos'],
                    'Amino_Acid': row['Amino_Acid'],
                    'freq': s
                })
        df_synth = pd.DataFrame(synthetic_samples)

        # Plot violin behind bars
        sns.violinplot(
            data=df_synth,
            x='Pos',
            y='freq',
            hue='Amino_Acid',
            palette='tab20',
            cut=0,
            scale='width',
            inner=None,
            linewidth=0,
            ax=ax,
            dodge=True
        )

        # Then plot your original barplot on top
        sns.barplot(
            data=data,
            x='Pos',
            y='freq_mean',
            hue='Amino_Acid',
            palette='tab20',
            ax=ax,
            edgecolor='black',
            alpha=0.8,
            dodge=True
        )

        # Add error bars manually
        for i, row in data.iterrows():
            xpos = list(data['Pos'].unique()).index(row['Pos'])
            ax.errorbar(
                x=xpos,
                y=row['freq_mean'],
                yerr=row['freq_se'],
                fmt='none',
                ecolor='gray',
                elinewidth=1,
                capsize=2
            )

        # Fix y axis limits (same for all)
        ax.set_ylim(0, y_max)
        ax.margins(x=0)

        ax.set_title(f"Immunization: {imm} | Enrichment Bin: {enrich_bin}", fontsize=14)
        ax.set_xlabel("Spike Amino Acid Position")
        ax.set_ylabel("Mean Frequency")
        ax.tick_params(axis='x', rotation=90)

        # Legend only on first plot
        if plot_i == 0:
            ax.legend(title='Amino Acid', bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            ax.get_legend().remove()

        plot_i += 1

plt.tight_layout()
plt.show()


In [None]:
# Here are trying to compute bootstrapping sampling and at the same time compute p values to test if  when sampling a given amino-acid mutation 
#within a given enrichment bin (0-1, 1-10, 10-100, 100-500, >500) is purely random / by chance or if that is expected as it is represented 
#well in the repertoire.

In [None]:
# Compute p-values by comparing observed vs bootstrapped frequencies
p_values = []

# Merge observed data (all barcodes per bin) with bootstrapped results
for i, row in df_bootstrap.iterrows():
    imm = row['immunization']
    enrich_bin = row['enrichment_bin']
    pos = row['Spike_AS_Position']
    aa = row['Amino_Acid']

    # Get all barcodes for this bin
    df_imm = df_filtered[df_filtered['immunization'] == imm]
    df_bin = df_imm[df_imm['enrichment_bin'] == enrich_bin]
    barcodes = df_bin['barcode'].unique()

    if len(barcodes) == 0:
        p_values.append(np.nan)
        continue

    # Observed frequency
    total_count = df_bin[
        (df_bin['Spike_AS_Position'] == pos) &
        (df_bin['Amino_Acid'] == aa)
    ]['barcode'].nunique()

    obs_freq = total_count / len(barcodes)

    # Get bootstrap freq distribution
    freq_array = df_bootstrap[
        (df_bootstrap['immunization'] == imm) &
        (df_bootstrap['enrichment_bin'] == enrich_bin) &
        (df_bootstrap['Spike_AS_Position'] == pos) &
        (df_bootstrap['Amino_Acid'] == aa)
    ]['freq_mean'].values

    # p-value: proportion of bootstrap freq >= observed freq
    p_val = np.mean(freq_array >= obs_freq)
    p_values.append(p_val)

df_bootstrap['p_value'] = p_values
df_bootstrap['log10_p'] = -np.log10(df_bootstrap['p_value'].replace(0, 1e-10))  # avoid -inf


In [None]:
#Code to produce heatmaps of p-valeus from bootstrapping for each amino acid

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Loop over immunizations and enrichment bins
for imm in df_bootstrap['immunization'].unique():
    for enrich_bin in enrichment_labels:
        df_plot = df_bootstrap[
            (df_bootstrap['immunization'] == imm) &
            (df_bootstrap['enrichment_bin'] == enrich_bin)
        ]
        if df_plot.empty:
            continue

        # Heatmap: mean frequency
        pivot_freq = df_plot.pivot_table(
            index='Amino_Acid',
            columns='Spike_AS_Position',
            values='freq_mean'
        )

        plt.figure(figsize=(20, 5))
        sns.heatmap(pivot_freq, cmap='magma', cbar_kws={'label': 'Mean AA Frequency'}, linewidths=0.1)
        plt.title(f'{imm} | Bin: {enrich_bin} — Mean Amino Acid Frequency')
        plt.xlabel('Spike Position')
        plt.ylabel('Amino Acid')
        plt.tight_layout()
        plt.show()

        # Heatmap: log10 p-value
        pivot_logp = df_plot.pivot_table(
            index='Amino_Acid',
            columns='Spike_AS_Position',
            values='log10_p'
        )

        plt.figure(figsize=(20, 5))
        sns.heatmap(pivot_logp, cmap='viridis', cbar_kws={'label': '-log10(p-value)'}, linewidths=0.1)
        plt.title(f'{imm} | Bin: {enrich_bin} — Significance of AA Enrichment')
        plt.xlabel('Spike Position')
        plt.ylabel('Amino Acid')
        plt.tight_layout()
        plt.show()


In [None]:
from collections import defaultdict
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

# Parameters
n_iter = 1000
enrichment_bins = [0, 1, 10, 100, 500, np.inf]
enrichment_labels = ['0–1', '1–10', '10–100', '100–500', '>500']

df_filtered = df_total[
    (df_total['Amino_Acid'] != '*') & (df_total['Spike_AS_Position'] > 365)
].copy()

# Bin ERs
df_filtered['enrichment_bin'] = pd.cut(
    df_filtered['Enrichment_Ratio'],
    bins=enrichment_bins,
    labels=enrichment_labels,
    right=False
)

# Precompute null distributions per (immunization, position)
null_distributions = {}

for (imm, pos), df_pool in tqdm(df_filtered.groupby(['immunization', 'Spike_AS_Position']), desc='Precomputing nulls'):
    er_values = df_pool['Enrichment_Ratio'].values
    if len(er_values) < 2:
        continue

    # Bootstrapped bin counts
    counts_array = np.zeros((n_iter, len(enrichment_labels)), dtype=np.float32)

    counts_array = np.zeros((n_iter, len(enrichment_labels)))

    for i in range(n_iter):
        sample = np.random.choice(er_values, size=len(er_values), replace=True)
        sample_bins = pd.cut(sample, bins=enrichment_bins, labels=enrichment_labels, right=False)
        counts = pd.Series(sample_bins).value_counts().reindex(enrichment_labels, fill_value=0).values
        counts_array[i] = counts

    null_distributions[(imm, pos)] = counts_array

# Now test observed counts against nulls
results = []

for (imm, pos, aa), df_group in tqdm(df_filtered.groupby(['immunization', 'Spike_AS_Position', 'Amino_Acid']), desc="Testing enrichment"):
    if df_group.empty or (imm, pos) not in null_distributions:
        continue

    obs_bins = pd.cut(df_group['Enrichment_Ratio'], bins=enrichment_bins, labels=enrichment_labels, right=False)
    obs_counts = pd.Series(obs_bins).value_counts().reindex(enrichment_labels, fill_value=0).values

    null_counts = null_distributions[(imm, pos)]

    for i, label in enumerate(enrichment_labels):
        observed = obs_counts[i]
        null = null_counts[:, i]
        p_val = np.mean(null >= observed)
        log10_p = -np.log10(p_val) if p_val > 0 else 10

        results.append({
            'immunization': imm,
            'Spike_AS_Position': pos,
            'Amino_Acid': aa,
            'enrichment_bin': label,
            'observed_count': observed,
            'p_value': p_val,
            'log10_p': log10_p
        })

df_results = pd.DataFrame(results)


In [None]:
from collections import defaultdict
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

# Parameters
n_iter = 1000
chunk_size = 500  # tune based on available memory
enrichment_bins = [0, 1, 10, 100, 500, np.inf]
enrichment_labels = ['0–1', '1–10', '10–100', '100–500', '>500']

df_filtered = df_total[
    (df_total['Amino_Acid'] != '*') & (df_total['Spike_AS_Position'] > 365)
].copy()

# Bin ERs
df_filtered['enrichment_bin'] = pd.cut(
    df_filtered['Enrichment_Ratio'],
    bins=enrichment_bins,
    labels=enrichment_labels,
    right=False
)

# Step 1: Get group keys
group_keys = list(df_filtered.groupby(['immunization', 'Spike_AS_Position']).groups.keys())

# Step 2: Process null distributions in chunks
null_distributions = {}

for i in tqdm(range(0, len(group_keys), chunk_size), desc='Processing nulls in chunks'):
    chunk_keys = group_keys[i:i+chunk_size]
    for (imm, pos) in chunk_keys:
        df_pool = df_filtered[(df_filtered['immunization'] == imm) & (df_filtered['Spike_AS_Position'] == pos)]
        er_values = df_pool['Enrichment_Ratio'].values
        if len(er_values) < 2:
            continue

        counts_array = np.zeros((n_iter, len(enrichment_labels)), dtype=np.float32)

        for j in range(n_iter):
            sample = np.random.choice(er_values, size=len(er_values), replace=True)
            sample_bins = pd.cut(sample, bins=enrichment_bins, labels=enrichment_labels, right=False)
            counts = pd.Series(sample_bins).value_counts().reindex(enrichment_labels, fill_value=0).values
            counts_array[j] = counts

        null_distributions[(imm, pos)] = counts_array

# Step 3: Compute observed test results
results = []

group_keys_obs = list(df_filtered.groupby(['immunization', 'Spike_AS_Position', 'Amino_Acid']).groups.keys())

for i in tqdm(range(0, len(group_keys_obs), chunk_size), desc='Testing observed groups'):
    chunk_keys_obs = group_keys_obs[i:i+chunk_size]
    for (imm, pos, aa) in chunk_keys_obs:
        if (imm, pos) not in null_distributions:
            continue

        df_group = df_filtered[
            (df_filtered['immunization'] == imm) &
            (df_filtered['Spike_AS_Position'] == pos) &
            (df_filtered['Amino_Acid'] == aa)
        ]

        if df_group.empty:
            continue

        obs_bins = pd.cut(df_group['Enrichment_Ratio'], bins=enrichment_bins, labels=enrichment_labels, right=False)
        obs_counts = pd.Series(obs_bins).value_counts().reindex(enrichment_labels, fill_value=0).values

        null_counts = null_distributions[(imm, pos)]

        for k, label in enumerate(enrichment_labels):
            observed = obs_counts[k]
            null = null_counts[:, k]
            p_val = np.mean(null >= observed)
            log10_p = -np.log10(p_val) if p_val > 0 else 10

            results.append({
                'immunization': imm,
                'Spike_AS_Position': pos,
                'Amino_Acid': aa,
                'enrichment_bin': label,
                'observed_count': observed,
                'p_value': p_val,
                'log10_p': log10_p
            })

df_results = pd.DataFrame(results)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))
sns.violinplot(data=df_results, x='enrichment_bin', y='log10_p', hue='immunization', inner=None, cut=0)
sns.stripplot(data=df_results, x='enrichment_bin', y='log10_p', hue='immunization', dodge=True, jitter=True, alpha=0.5)

plt.ylabel('-log10(p-value)')
plt.title('Significance of Amino Acid Enrichment Ratios (per Immunization)')
plt.axhline(-np.log10(0.005), linestyle='--', color='red', label='p = 0.005')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
# Compute stats + empirical p-value
for (pos, aa), freqs in position_aa_freqs.items():
    freq_array = np.array(freqs)

    if len(barcodes) == 0:
        continue  # Skip division by zero

    # Compute observed frequency using all barcodes
    total_count = df_bin[
        (df_bin['Spike_AS_Position'] == pos) &
        (df_bin['Amino_Acid'] == aa)
    ]['barcode'].nunique()
    obs_freq = total_count / len(barcodes)

    # p-value: fraction of bootstraps where freq ≥ observed freq
    p_value = np.sum(freq_array >= obs_freq) / n_iter

    bootstrap_results.append({
        'immunization': imm,
        'Spike_AS_Position': pos,
        'Amino_Acid': aa,
        'enrichment_bin': enrich_bin,
        'freq_mean': np.mean(freq_array),
        'freq_std': np.std(freq_array),
        'freq_se': np.std(freq_array) / np.sqrt(n_iter),
        'obs_freq': obs_freq,
        'p_value': p_value
    })


In [None]:
# Set p-value threshold
p_thresh = 0.005

# Count total and significant amino acids for each (immunization, enrichment_bin)
summary = (
    df_bootstrap
    .groupby(['immunization', 'enrichment_bin'])
    .apply(lambda g: pd.Series({
        'n_total': len(g),
        'n_significant': (g['p_value'] < p_thresh).sum()
    }))
    .reset_index()
)

# Calculate percentage
summary['percent_significant'] = 100 * summary['n_significant'] / summary['n_total']


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Loop through all immunizations and enrichment bins
for imm in df_bootstrap['immunization'].unique():
    for enrich_bin in enrichment_labels:
        df_plot = df_bootstrap[
            (df_bootstrap['immunization'] == imm) &
            (df_bootstrap['enrichment_bin'] == enrich_bin)
        ]

        if df_plot.empty:
            continue  # Skip if there's no data for this combination

        pivot = df_plot.pivot_table(
            index='Amino_Acid',
            columns='Spike_AS_Position',
            values='log10_p',
            aggfunc='mean'
        )

        plt.figure(figsize=(20, 5))
        sns.heatmap(
            pivot,
            cmap='viridis',
            cbar_kws={'label': '-log10(p-value)'},
            linewidths=0.1,
            linecolor='gray'
        )
        plt.title(f'{imm} | Bin: {enrich_bin} — Amino Acid Reproducibility')
        plt.xlabel('Spike Position')
        plt.ylabel('Amino Acid')
        plt.tight_layout()
        plt.show()


In [None]:
# Lets check any correlation between chemical features. (charge of AA) and enrichment correlation

In [None]:
# Example AA properties dictionary (Kyte-Doolittle hydrophobicity, size, charge)
aa_properties = {
    'A': {'hydrophobicity': 1.8,  'size':  89, 'charge': 0},
    'R': {'hydrophobicity': -4.5, 'size': 174, 'charge': 1},
    'N': {'hydrophobicity': -3.5, 'size': 132, 'charge': 0},
    'D': {'hydrophobicity': -3.5, 'size': 133, 'charge': -1},
    'C': {'hydrophobicity': 2.5,  'size': 121, 'charge': 0},
    'Q': {'hydrophobicity': -3.5, 'size': 146, 'charge': 0},
    'E': {'hydrophobicity': -3.5, 'size': 147, 'charge': -1},
    'G': {'hydrophobicity': -0.4, 'size':  75, 'charge': 0},
    'H': {'hydrophobicity': -3.2, 'size': 155, 'charge': 0.1},  # Histidine partial charge
    'I': {'hydrophobicity': 4.5,  'size': 131, 'charge': 0},
    'L': {'hydrophobicity': 3.8,  'size': 131, 'charge': 0},
    'K': {'hydrophobicity': -3.9, 'size': 146, 'charge': 1},
    'M': {'hydrophobicity': 1.9,  'size': 149, 'charge': 0},
    'F': {'hydrophobicity': 2.8,  'size': 165, 'charge': 0},
    'P': {'hydrophobicity': -1.6, 'size': 115, 'charge': 0},
    'S': {'hydrophobicity': -0.8, 'size': 105, 'charge': 0},
    'T': {'hydrophobicity': -0.7, 'size': 119, 'charge': 0},
    'W': {'hydrophobicity': -0.9, 'size': 204, 'charge': 0},
    'Y': {'hydrophobicity': -1.3, 'size': 181, 'charge': 0},
    'V': {'hydrophobicity': 4.2,  'size': 117, 'charge': 0},
}


In [None]:
aa_features = {
    'A': {'size': 'Small', 'charge': 'Neutral', 'hydro': 'Hydrophobic', 'polarity': 'Nonpolar'},
    'R': {'size': 'Large', 'charge': 'Positive', 'hydro': 'Hydrophilic', 'polarity': 'Polar'},
    'N': {'size': 'Medium', 'charge': 'Neutral', 'hydro': 'Hydrophilic', 'polarity': 'Polar'},
    'D': {'size': 'Small', 'charge': 'Negative', 'hydro': 'Hydrophilic', 'polarity': 'Polar'},
    'C': {'size': 'Small', 'charge': 'Neutral', 'hydro': 'Hydrophobic', 'polarity': 'Polar'},
    'E': {'size': 'Medium', 'charge': 'Negative', 'hydro': 'Hydrophilic', 'polarity': 'Polar'},
    'Q': {'size': 'Medium', 'charge': 'Neutral', 'hydro': 'Hydrophilic', 'polarity': 'Polar'},
    'G': {'size': 'Small', 'charge': 'Neutral', 'hydro': 'Hydrophilic', 'polarity': 'Nonpolar'},
    'H': {'size': 'Large', 'charge': 'Positive', 'hydro': 'Hydrophilic', 'polarity': 'Polar'},
    'I': {'size': 'Medium', 'charge': 'Neutral', 'hydro': 'Hydrophobic', 'polarity': 'Nonpolar'},
    'L': {'size': 'Medium', 'charge': 'Neutral', 'hydro': 'Hydrophobic', 'polarity': 'Nonpolar'},
    'K': {'size': 'Large', 'charge': 'Positive', 'hydro': 'Hydrophilic', 'polarity': 'Polar'},
    'M': {'size': 'Medium', 'charge': 'Neutral', 'hydro': 'Hydrophobic', 'polarity': 'Nonpolar'},
    'F': {'size': 'Large', 'charge': 'Neutral', 'hydro': 'Hydrophobic', 'polarity': 'Nonpolar'},
    'P': {'size': 'Small', 'charge': 'Neutral', 'hydro': 'Hydrophobic', 'polarity': 'Nonpolar'},
    'S': {'size': 'Small', 'charge': 'Neutral', 'hydro': 'Hydrophilic', 'polarity': 'Polar'},
    'T': {'size': 'Small', 'charge': 'Neutral', 'hydro': 'Hydrophilic', 'polarity': 'Polar'},
    'W': {'size': 'Large', 'charge': 'Neutral', 'hydro': 'Hydrophobic', 'polarity': 'Nonpolar'},
    'Y': {'size': 'Large', 'charge': 'Neutral', 'hydro': 'Hydrophilic', 'polarity': 'Polar'},
    'V': {'size': 'Small', 'charge': 'Neutral', 'hydro': 'Hydrophobic', 'polarity': 'Nonpolar'}
}

# Map to df_bootstrap
for feature in ['charge', 'size', 'hydro', 'polarity']:
    df_bootstrap[f'{feature}_group'] = df_bootstrap['Amino_Acid'].map(lambda aa: aa_features.get(aa, {}).get(feature, np.nan))


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

grouping_features = ['charge_group', 'size_group', 'hydro_group', 'polarity_group']
enrichment_bins_ordered = ['0–1', '1–10', '10–100', '100–500', '>500']

for group_feature in grouping_features:
    for enrich_bin in enrichment_bins_ordered:
        df_plot = df_bootstrap[df_bootstrap['enrichment_bin'] == enrich_bin]
        
        plt.figure(figsize=(10, 6))
        sns.boxplot(
            data=df_plot,
            x=group_feature,
            y='freq_mean',
            hue='immunization'
        )
        plt.title(f'Bootstrap freq_mean by {group_feature} | Enrichment bin {enrich_bin}')
        plt.ylabel('Bootstrap Frequency Mean')
        plt.xlabel(group_feature.replace('_group', '').capitalize())
        plt.legend(title='Immunization', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

grouping_features = ['charge_group', 'size_group', 'hydro_group', 'polarity_group']
enrichment_bins_ordered = ['0–1', '1–10', '10–100', '100–500', '>500']

for group_feature in grouping_features:
    for enrich_bin in enrichment_bins_ordered:
        df_plot = df_bootstrap[df_bootstrap['enrichment_bin'] == enrich_bin].copy()

        # Ensure the group values are treated as categorical (for color consistency)
        df_plot[group_feature] = pd.Categorical(df_plot[group_feature])
        immunization_order = sorted(df_plot['immunization'].unique())
        group_order = sorted(df_plot[group_feature].unique())

        plt.figure(figsize=(10, 6))
        sns.boxplot(
            data=df_plot,
            x='immunization',
            y='freq_mean',
            hue=group_feature,
            order=immunization_order,
            hue_order=group_order
        )
        plt.title(f'{group_feature.replace("_group", "").capitalize()} groups by Immunization | Enrichment bin {enrich_bin}')
        plt.ylabel('Bootstrap Frequency Mean')
        plt.xlabel('Immunization')
        plt.legend(title=group_feature.replace('_group', '').capitalize(), bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.show()


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

grouping_features = ['charge_group', 'size_group', 'hydro_group', 'polarity_group']
enrichment_bins_ordered = ['0–1', '1–10', '10–100', '100–500', '>500']

for group_feature in grouping_features:
    for enrich_bin in enrichment_bins_ordered:
        df_plot = df_bootstrap[df_bootstrap['enrichment_bin'] == enrich_bin].copy()

        # Ensure categorical types for consistency
        df_plot[group_feature] = pd.Categorical(df_plot[group_feature])
        immunization_order = sorted(df_plot['immunization'].unique())
        group_order = sorted(df_plot[group_feature].unique())

        plt.figure(figsize=(10, 6))

        # Boxplot (grouped by immunization, colored by feature)
        sns.boxplot(
            data=df_plot,
            x='immunization',
            y='freq_mean',
            hue=group_feature,
            order=immunization_order,
            hue_order=group_order,
            showfliers=False  # optional: hides outliers to reduce clutter
        )

        # Stripplot to show individual data points
        sns.stripplot(
            data=df_plot,
            x='immunization',
            y='freq_mean',
            hue=group_feature,
            order=immunization_order,
            hue_order=group_order,
            dodge=True,
            jitter=True,
            marker='o',
            size=3,
            color='black',
            alpha=0.5
        )

        # Clean legend (avoid duplicate entries from both plots)
        handles, labels = plt.gca().get_legend_handles_labels()
        n_groups = len(group_order)
        plt.legend(handles[:n_groups], labels[:n_groups], title=group_feature.replace('_group', '').capitalize(), bbox_to_anchor=(1.05, 1), loc='upper left')

        plt.title(f'{group_feature.replace("_group", "").capitalize()} groups by Immunization | Enrichment bin {enrich_bin}')
        plt.ylabel('Bootstrap Frequency Mean')
        plt.xlabel('Immunization')
        plt.tight_layout()
        plt.show()


In [None]:
!pip install statannotations


In [None]:
from statannotations.Annotator import Annotator
import itertools
import seaborn as sns
import matplotlib.pyplot as plt

grouping_features = ['charge_group', 'size_group', 'hydro_group', 'polarity_group']
enrichment_bins_ordered = ['0–1', '1–10', '10–100', '100–500', '>500']

for group_feature in grouping_features:
    for enrich_bin in enrichment_bins_ordered:
        df_plot = df_bootstrap[df_bootstrap['enrichment_bin'] == enrich_bin].copy()

        # Drop NaN in y-value or group_feature
        df_plot['freq_mean'] = pd.to_numeric(df_plot['freq_mean'], errors='coerce')
        df_plot = df_plot.dropna(subset=['freq_mean', group_feature, 'immunization'])

        if df_plot.empty:
            continue

        plt.figure(figsize=(12, 6))
        ax = sns.boxplot(
            data=df_plot,
            x=group_feature,
            y='freq_mean',
            hue='immunization'
        )
        plt.title(f'Bootstrap freq_mean by {group_feature} | Enrichment bin {enrich_bin}')
        plt.ylabel('Bootstrap Frequency Mean')
        plt.xlabel(group_feature.replace('_group', '').capitalize())

        # Build valid group pairs
        pairs = []
        for immun in df_plot['immunization'].unique():
            sub_df = df_plot[df_plot['immunization'] == immun]
            levels = sub_df[group_feature].dropna().unique()
            if len(levels) >= 2:
                combos = list(itertools.combinations(sorted(levels), 2))
                pairs += [((a, immun), (b, immun)) for a, b in combos]

        # Validate pairs exist in the actual data
        valid_groups = set(tuple(row) for row in df_plot[[group_feature, 'immunization']].dropna().values)
        filtered_pairs = [pair for pair in pairs if pair[0] in valid_groups and pair[1] in valid_groups]

        if filtered_pairs:
            annotator = Annotator(ax, filtered_pairs, data=df_plot,
                                  x=group_feature, y='freq_mean', hue='immunization')
            annotator.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=0)
            annotator.apply_and_annotate()
        else:
            print(f"Skipped annotation for {group_feature} | {enrich_bin} (no valid pairs)")

        plt.legend(title='Immunization', bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.tight_layout()
        plt.show()


In [None]:
import itertools
from statannotations.Annotator import Annotator

# Ensure freq_mean is numeric and drop NaNs
df_plot['freq_mean'] = pd.to_numeric(df_plot['freq_mean'], errors='coerce')
df_plot = df_plot.dropna(subset=['freq_mean'])

# Generate all possible pairs within each immunization group
pairs = []
for immun in df_plot['immunization'].unique():
    sub_df = df_plot[df_plot['immunization'] == immun]
    group_levels = sub_df[group_feature].unique()
    if len(group_levels) >= 2:
        # Get all pairwise combinations
        pair_combos = list(itertools.combinations(group_levels, 2))
        # Add immunization context to the pairs
        pairs += [((immun, a), (immun, b)) for a, b in pair_combos]

# Filter pairs to only keep those actually present in the data
valid_groups = list(df_plot.groupby(['immunization', group_feature]).groups.keys())
filtered_pairs = [pair for pair in pairs if pair[0] in valid_groups and pair[1] in valid_groups]

# Add annotation only if there are valid pairs
if filtered_pairs:
    annotator = Annotator(ax, filtered_pairs, data=df_plot,
                          x='immunization', y='freq_mean', hue=group_feature)
    annotator.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=0)
    annotator.apply_and_annotate()
else:
    print("No valid group pairs found for annotation.")


In [None]:
import itertools
from statannotations.Annotator import Annotator

# Ensure freq_mean is numeric and drop NaNs
df_plot['freq_mean'] = pd.to_numeric(df_plot['freq_mean'], errors='coerce')
df_plot = df_plot.dropna(subset=['freq_mean'])

# Generate all possible pairs within each immunization group
pairs = []
for immun in df_plot['immunization'].unique():
    sub_df = df_plot[df_plot['immunization'] == immun]
    group_levels = sub_df[group_feature].unique()
    if len(group_levels) >= 2:
        # Get all pairwise combinations
        pair_combos = list(itertools.combinations(group_levels, 2))
        # Add immunization context to the pairs
        pairs += [((immun, a), (immun, b)) for a, b in pair_combos]

# Filter pairs to only keep those actually present in the data
valid_groups = list(df_plot.groupby(['immunization', group_feature]).groups.keys())
filtered_pairs = [pair for pair in pairs if pair[0] in valid_groups and pair[1] in valid_groups]

# Add annotation only if there are valid pairs
if filtered_pairs:
    annotator = Annotator(ax, filtered_pairs, data=df_plot,
                          x='immunization', y='freq_mean', hue=group_feature)
    annotator.configure(test='Mann-Whitney', text_format='star', loc='outside', verbose=0)
    annotator.apply_and_annotate()
else:
    print("No valid group pairs found for annotation.")


In [None]:
df_bootstrap['hydrophobicity'] = df_bootstrap['Amino_Acid'].map(lambda aa: aa_properties.get(aa, {}).get('hydrophobicity', np.nan))
df_bootstrap['size'] = df_bootstrap['Amino_Acid'].map(lambda aa: aa_properties.get(aa, {}).get('size', np.nan))
df_bootstrap['charge'] = df_bootstrap['Amino_Acid'].map(lambda aa: aa_properties.get(aa, {}).get('charge', np.nan))


In [None]:
from sklearn.decomposition import PCA
import umap

# For dimensionality reduction, build a feature matrix
features = df_bootstrap[['freq_mean', 'freq_std', 'hydrophobicity', 'size', 'charge']].dropna()

# PCA example:
pca = PCA(n_components=2)
pca_result = pca.fit_transform(features)

df_bootstrap.loc[features.index, 'PCA1'] = pca_result[:,0]
df_bootstrap.loc[features.index, 'PCA2'] = pca_result[:,1]

# Or UMAP:
# reducer = umap.UMAP(n_components=2, random_state=42)
# umap_result = reducer.fit_transform(features)
# df_bootstrap.loc[features.index, 'UMAP1'] = umap_result[:,0]
# df_bootstrap.loc[features.index, 'UMAP2'] = umap_result[:,1]


In [None]:
# Add biochemical properties as before
df_bootstrap['hydrophobicity'] = df_bootstrap['Amino_Acid'].map(lambda aa: aa_properties.get(aa, {}).get('hydrophobicity', np.nan))
df_bootstrap['charge'] = df_bootstrap['Amino_Acid'].map(lambda aa: aa_properties.get(aa, {}).get('charge', np.nan))

# Categorize charge groups
def charge_group(c):
    if c > 0:
        return 'Positive'
    elif c < 0:
        return 'Negative'
    else:
        return 'Neutral'
df_bootstrap['charge_group'] = df_bootstrap['charge'].apply(charge_group)

# Similarly, you can bin hydrophobicity, e.g.:
df_bootstrap['hydro_group'] = pd.cut(
    df_bootstrap['hydrophobicity'],
    bins=[-5, -0.5, 0.5, 5],
    labels=['Hydrophilic', 'Neutral', 'Hydrophobic']
)


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))
sns.boxplot(
    data=df_bootstrap,
    x='charge_group',
    y='freq_mean',
    hue='immunization'
)
plt.title('Bootstrap frequency mean by amino acid charge group and immunization')
plt.ylabel('Bootstrap Frequency Mean')
plt.xlabel('Charge Group')
plt.legend(title='Immunization', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


In [None]:
from scipy.stats import kruskal

for imm in df_bootstrap['immunization'].unique():
    subset = df_bootstrap[df_bootstrap['immunization'] == imm]
    
    groups = [group['freq_mean'].values for name, group in subset.groupby('charge_group')]
    
    stat, p = kruskal(*groups)
    print(f"Immunization {imm}: Kruskal-Wallis H-test across charge groups p-value = {p:.4f}")


In [None]:
import statsmodels.formula.api as smf

# Drop rows with missing values
df_reg = df_bootstrap.dropna(subset=['freq_mean', 'hydrophobicity', 'charge', 'immunization'])

# Encode immunization as categorical variable
df_reg['immunization'] = df_reg['immunization'].astype('category')

# Fit linear model predicting freq_mean by hydrophobicity and immunization
model = smf.ols('freq_mean ~ hydrophobicity + C(immunization)', data=df_reg).fit()

print(model.summary())


In [None]:
# Visualizing bootstrapping mean frequency as a logo plot. color by amino acid
#this plot is not valdiated

In [None]:
print(df_bootstrap['immunization'].unique())


In [None]:
#Viusalizing bootstrap sampling as pie charts: We want to know what proportion of the amino-acids/positions 
#are within each bootrapping mean frequency bin (0-0.1, 0.1-0.2, 0.2-0.3. and so on)
# High proportion of mean frequencies > 0.8 could indicate that the sampling of aminoacids (mutations) is very robust in that sample.

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Define bins from 0 to 1 in 0.1 increments
bins = np.arange(0, 1.1, 0.1)
bin_labels = [f"{bins[i]:.1f}-{bins[i+1]:.1f}" for i in range(len(bins)-1)]

immunizations = sorted(df_bootstrap['immunization'].unique())
enrich_bins = sorted(df_bootstrap['enrichment_bin'].unique(), key=lambda x: enrichment_labels.index(x))

fig_width = 3 * len(enrich_bins)  # 3 inches per enrichment bin
fig_height = 4 * len(immunizations)  # 4 inches per immunization

fig, axes = plt.subplots(len(immunizations), len(enrich_bins), figsize=(fig_width, fig_height))

if len(immunizations) == 1 and len(enrich_bins) == 1:
    axes = np.array([[axes]])
elif len(immunizations) == 1:
    axes = axes.reshape(1, -1)
elif len(enrich_bins) == 1:
    axes = axes.reshape(-1, 1)

for i, imm in enumerate(immunizations):
    for j, enrich_bin in enumerate(enrich_bins):
        ax = axes[i, j]
        
        data = df_bootstrap[
            (df_bootstrap['immunization'] == imm) & 
            (df_bootstrap['enrichment_bin'] == enrich_bin)
        ]

        if data.empty:
            ax.axis('off')
            continue
        
        # Bin freq_mean values
        counts, _ = np.histogram(data['freq_mean'], bins=bins)
        proportions = counts / counts.sum()
        
        # Remove zero-prop bins for cleaner pie
        nonzero_idx = proportions > 0
        props_nonzero = proportions[nonzero_idx]
        labels_nonzero = np.array(bin_labels)[nonzero_idx]
        
        # Pie chart with % labels
        wedges, texts, autotexts = ax.pie(
            props_nonzero,
            labels=labels_nonzero,
            autopct='%1.1f%%',
            startangle=90,
            textprops={'fontsize': 8}
        )
        
        ax.set_title(f"Imm: {imm}\nEnrich bin: {enrich_bin}", fontsize=10)
        ax.axis('equal')

# Global figure title
plt.suptitle("Distribution of Bootstrap Mean Frequencies by Immunization and Enrichment Bin", fontsize=16, y=1.02)
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Define bins from 0 to 1 in 0.1 increments
bins = np.arange(0, 1.1, 0.1)
bin_labels = [f"{bins[i]:.1f}-{bins[i+1]:.1f}" for i in range(len(bins)-1)]

immunizations = sorted(df_bootstrap['immunization'].unique())
enrich_bins = sorted(df_bootstrap['enrichment_bin'].unique(), key=lambda x: enrichment_labels.index(x))

fig_width = 3 * len(enrich_bins)  # 3 inches per enrichment bin
fig_height = 4 * len(immunizations)  # 4 inches per immunization

fig, axes = plt.subplots(len(immunizations), len(enrich_bins), figsize=(fig_width, fig_height))

if len(immunizations) == 1 and len(enrich_bins) == 1:
    axes = np.array([[axes]])
elif len(immunizations) == 1:
    axes = axes.reshape(1, -1)
elif len(enrich_bins) == 1:
    axes = axes.reshape(-1, 1)

# Create a fixed color map for bins
# Using a seaborn or matplotlib colormap with 10 distinct colors
import seaborn as sns
palette = sns.color_palette("Spectral", n_colors=len(bin_labels))
color_map = dict(zip(bin_labels, palette))

for i, imm in enumerate(immunizations):
    for j, enrich_bin in enumerate(enrich_bins):
        ax = axes[i, j]
        
        data = df_bootstrap[
            (df_bootstrap['immunization'] == imm) & 
            (df_bootstrap['enrichment_bin'] == enrich_bin)
        ]

        if data.empty:
            ax.axis('off')
            continue
        
        # Bin freq_mean values
        counts, _ = np.histogram(data['freq_mean'], bins=bins)
        proportions = counts / counts.sum()
        
        # Remove zero-prop bins for cleaner pie
        nonzero_idx = proportions > 0
        props_nonzero = proportions[nonzero_idx]
        labels_nonzero = np.array(bin_labels)[nonzero_idx]

        # Pick colors for the nonzero bins consistently from color_map
        colors_nonzero = [color_map[label] for label in labels_nonzero]
        
        wedges, texts, autotexts = ax.pie(
            props_nonzero,
            labels=labels_nonzero,
            autopct='%1.1f%%',
            startangle=90,
            colors=colors_nonzero,
            textprops={'fontsize': 12},   # bin label font size
            pctdistance=0.75,              # keeps percentages closer to center
            labeldistance=1.05              # keeps labels close for big pies
        )
        
        # Make percentage text bigger
        for t in autotexts:
            t.set_fontsize(12)
        
        ax.set_title(f"Imm: {imm}\nEnrich bin: {enrich_bin}", fontsize=14)
        ax.axis('equal')

plt.suptitle(
    "Distribution of Bootstrap Mean Frequencies by Immunization and Enrichment Bin",
    fontsize=20, y=1.02
)
plt.tight_layout()
# Save the figure before displaying
output_path = "bootstrap_frequency_distribution.png"  # Change filename if needed
plt.savefig(output_path, format='png', bbox_inches='tight', dpi=300)  # High-res save
print(f"Figure saved as {output_path}")

plt.show()



In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from matplotlib.patches import Patch

# Filter out 'library_ctrl'
df_plot = df_bootstrap[df_bootstrap['immunization'] != 'library_ctrl']

# Define bins from 0 to 1 in 0.1 increments
bins = np.arange(0, 1.1, 0.1)
bin_labels = [f"{bins[i]:.1f}-{bins[i+1]:.1f}" for i in range(len(bins)-1)]

immunizations = sorted(df_plot['immunization'].unique())
enrich_bins = sorted(df_plot['enrichment_bin'].unique(), key=lambda x: enrichment_labels.index(x))

fig_width = 3 * len(enrich_bins) + 2  # extra space for legend
fig_height = 4 * len(immunizations)

fig, axes = plt.subplots(len(immunizations), len(enrich_bins), figsize=(fig_width, fig_height))

if len(immunizations) == 1 and len(enrich_bins) == 1:
    axes = np.array([[axes]])
elif len(immunizations) == 1:
    axes = axes.reshape(1, -1)
elif len(enrich_bins) == 1:
    axes = axes.reshape(-1, 1)

# Color map for bins
palette = sns.color_palette("Spectral", n_colors=len(bin_labels))
color_map = dict(zip(bin_labels, palette))

for i, imm in enumerate(immunizations):
    for j, enrich_bin in enumerate(enrich_bins):
        ax = axes[i, j]
        
        data = df_plot[
            (df_plot['immunization'] == imm) & 
            (df_plot['enrichment_bin'] == enrich_bin)
        ]

        if data.empty:
            ax.axis('off')
            continue
        
        # Bin freq_mean values
        counts, _ = np.histogram(data['freq_mean'], bins=bins)
        proportions = counts / counts.sum()
        
        # Remove zero-prop bins
        nonzero_idx = proportions > 0
        props_nonzero = proportions[nonzero_idx]
        labels_nonzero = np.array(bin_labels)[nonzero_idx]

        # Colors
        colors_nonzero = [color_map[label] for label in labels_nonzero]
        
        wedges, _, autotexts = ax.pie(
            props_nonzero,
            labels=None,                 # no enrichment bin labels
            autopct='%1.1f%%',
            startangle=90,
            colors=colors_nonzero,
            textprops={'fontsize': 12},
            pctdistance=0.75
        )
        
        for t in autotexts:
            t.set_fontsize(12)
        
        ax.set_title(f"Imm: {imm}", fontsize=14)
        ax.axis('equal')

# Create legend for enrichment bins
legend_elements = [Patch(facecolor=color_map[label], label=label) for label in bin_labels]
fig.legend(
    handles=legend_elements,
    loc='center right',
    title="Bootstrap Mean Bins",
    fontsize=12,
    title_fontsize=14
)

plt.subplots_adjust(right=0.85)  # make space for legend

plt.suptitle(
    "Distribution of Bootstrap Mean Frequencies by Immunization and Enrichment Bin",
    fontsize=20, y=1.02
)
plt.tight_layout(rect=[0, 0, 0.85, 0.96])

output_path = "bootstrap_frequency_distribution.png"
plt.savefig(output_path, format='png', bbox_inches='tight', dpi=300)
print(f"Figure saved as {output_path}")

plt.show()


In [None]:
!pip install logomaker

In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import matplotlib.ticker as ticker

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()
df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]
df_logo_agg = df_logo_agg.replace([np.inf, -np.inf], np.nan)
df_logo_agg = df_logo_agg.dropna(subset=['Enrichment_Ratio'])

df_filtered = df_logo_agg.drop_duplicates(
    subset=['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode']
)

# Filter positions > 365
df_filtered = df_filtered[df_filtered['Spike_AS_Position'] > 365]

print(df_filtered['Enrichment_Ratio'].min())  # Should be > 1

duplicates = df_filtered[df_filtered.duplicated(
    subset=['Spike_AS_Position', 'Amino_Acid'], keep=False)]
print(duplicates)

# Add a new column with log10 of Enrichment_Ratio (safe transformation)
def safe_log10(enrichment, epsilon=1e-8, max_cap=1e8):
    enrichment = np.clip(enrichment, epsilon, max_cap)
    return np.log10(enrichment)

df_filtered['Enrichment_Ratio_log'] = df_filtered['Enrichment_Ratio'].apply(safe_log10)
df_filtered = df_filtered.dropna(subset=['Enrichment_Ratio_log'])

# Color assignment
aa_list = sorted(df_filtered['Amino_Acid'].unique())
colors = plt.cm.tab20.colors
aa_colors = {aa: colors[i % len(colors)] for i, aa in enumerate(aa_list)}
df_filtered['color'] = df_filtered['Amino_Acid'].map(aa_colors)

# Output directory
output_dir = "barcode_logoplots_panels"
os.makedirs(output_dir, exist_ok=True)

# Helper function for plain integer log ticks
def log_ticks_plain(ax):
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda y, _: f"{int(y)}"))

# Improved stacking function — letters sorted by height before drawing
def improved_draw_logo(df, **kwargs):
    # Sort within each x position so letters with greater height are on top
    df_sorted = df.sort_values(
        by=[kwargs['x_col'], kwargs['letter_height_col']],
        ascending=[True, True]
    )
    return dmslogo.draw_logo(df_sorted, **kwargs)

# Plot loop
for immunization in df_filtered['immunization'].unique():
    df_imm = df_filtered[df_filtered['immunization'] == immunization]
    barcodes = df_imm['barcode'].unique()[:4]

    if len(barcodes) == 0:
        continue

    print(f"Generating panel for immunization: {immunization}")

    fig, axs = plt.subplots(
        len(barcodes), 1,
        figsize=(45, 12 * len(barcodes)),
        gridspec_kw={'hspace': 2.5}
    )

    if len(barcodes) == 1:
        axs = [axs]

    for i, barcode in enumerate(barcodes):
        df_barcode = df_imm[
            (df_imm['barcode'] == barcode) &
            (df_imm['Enrichment_Ratio_log'].notna())
        ]

        if df_barcode.empty:
            continue

        # 🔑 Aggregate to avoid overlap
        df_barcode_agg = (
            df_barcode.groupby(["Spike_AS_Position", "Amino_Acid"], as_index=False)
            .agg({"Enrichment_Ratio_log": "mean"})
        )
        df_barcode_agg["color"] = df_barcode_agg["Amino_Acid"].map(aa_colors)

        fig_sub, ax = improved_draw_logo(
            df_barcode_agg,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log",
            color_col="color",
            title=f"{immunization} - {barcode} logoplot (log10 Enrichment Ratio)",
            addbreaks=True,
            ax=axs[i]
        )

        axs[i].set_title(
            f"{immunization} - {barcode} logoplot (log10 Enrichment Ratio)", pad=25
        )
        axs[i].set_ylabel('Log10 AB binding (median) \u2190 Enrichment \u2192', fontsize=10)
        axs[i].yaxis.set_label_coords(-0.01, 0.5)
        axs[i].set_ylim(-10, 6)
        axs[i].yaxis.set_major_locator(ticker.MultipleLocator(4))
        log_ticks_plain(axs[i])

    legend_elements = [Patch(facecolor=col, label=aa) for aa, col in aa_colors.items()]
    fig.subplots_adjust(right=0.85)
    fig.legend(handles=legend_elements, title='Amino Acid',
               loc='center right', borderaxespad=0.1)

    plt.tight_layout(rect=[0, 0, 0.85, 2])
    safe_immunization = immunization.replace("/", "_").replace("\\", "_")
    plot_filename = os.path.join(
        output_dir, f"{safe_immunization}_top4_logoplots.png"
    )
    plt.savefig(plot_filename, format='png', bbox_inches='tight')
    plt.show()
    print(f"Panel saved as {plot_filename}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import dmslogo

# Assuming df_bootstrap is your DataFrame with the necessary data
# and contains columns: 'immunization', 'Spike_AS_Position', 'Amino_Acid', 'freq_mean'

output_dir = "dmslogo_immunization_plots"
os.makedirs(output_dir, exist_ok=True)

for imm in df_bootstrap['immunization'].unique():
    df_subset = df_bootstrap[df_bootstrap['immunization'] == imm].copy()
    df_subset = df_subset.rename(columns={
        'Spike_AS_Position': 'site',
        'Amino_Acid': 'letter',
        'freq_mean': 'height'
    })

    # Ensure 'site' is integer
    df_subset['site'] = df_subset['site'].astype(int)

    fig, ax = dmslogo.logo.draw_logo(
        data=df_subset,
        x_col='site',
        letter_col='letter',
        letter_height_col='height',
        title=f"Logo Plot — {imm}"
    )

    plt.tight_layout()
    filename = f"{imm.replace('/', '_')}_logo_dms.png"
    fig.savefig(os.path.join(output_dir, filename), dpi=300)
    plt.close(fig)
plt.tight_layout()
plt.show()

In [None]:
#Colored by Enrichment ratio, Y axis is how many samples support it

import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Function to assign enrichment-based colors (black = low, bright red = high) using log scale
def enrichment_color(enrichment):
    # Set the max enrichment value for normalization
    max_enrichment = 3000  # Adjust to the maximum value you want for color scaling
    # Apply a logarithmic scale to the enrichment values
    log_enrichment = np.log10(enrichment + 1)  # log(x+1) to avoid issues with zero or small values
    max_log_enrichment = np.log10(max_enrichment + 1)
    norm_value = log_enrichment / max_log_enrichment  # Normalize to range [0, 1]
    norm_value = np.clip(norm_value, 0, 1)  # Ensure values stay in range
    return plt.cm.Reds(norm_value)  # Use 'Reds' colormap

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

df_logo_agg = df_logo_agg[df_logo_agg['Amino_Acid'] != "*"]

# Filter: Only keep amino acids with Enrichment_Ratio > 3
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 3]

# Count how often an amino acid appears across barcodes at each position
df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()  # Count occurrences across barcodes
    .reset_index(name='Sample_Count')  # Rename count column
)

# Assign colors based on max enrichment ratio per amino acid at each site
df_combined['Max_Enrichment'] = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['Enrichment_Ratio'].max().reset_index(drop=True)
df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)

df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [416,417,418, 439, 440,441, 452,453, 476, 477,478,482,483, 484,485, 493,494,495,496,497,498,499,500, 501, 502, 503,504,505]

# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots
for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')

    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Sample_Count",  # Now using count instead of enrichment
        color_col="color",  
        title=f"{immunization} logoplot (Occurrence-based)",
        addbreaks=True
    )

    # Add colorbar to the plot
    sm = plt.cm.ScalarMappable(cmap="Reds", norm=plt.Normalize(vmin=np.log10(3 + 1), vmax=np.log10(3000 + 1)))  # Log scale for max
    sm.set_array([])  # Required for the colorbar to work
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')  # Attach colorbar to the same axes as the plot
    cbar.set_label('Log(Enrichment Ratio)', rotation=270, labelpad=15)  # Label the colorbar

    # Ensure the output directory exists
    output_dir = r"/Users/lucaschlotheuber/Desktop/immunization_csv_files"
    os.makedirs(output_dir, exist_ok=True)  # Creates the directory if it doesn't exist
    
    # Set y-axis label
    ax.set_ylabel('IgG Secreting cell [n]')
    
    # Define file paths
    file_path = os.path.join(output_dir, f"{barcode}_logoplots2.png")
    plot_filename = os.path.join(output_dir, f"{immunization}_logoplot2.png")
    
    # Save the plot before showing it
    plt.savefig(file_path, dpi=300, bbox_inches='tight')  # Save with high resolution
    print(f"Plot saved as {file_path}")
    
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")
    
    # Show the plot in the notebook
    plt.show()

In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Function to assign enrichment-based colors (black = low, bright red = high) using log scale
def enrichment_color(enrichment):
    # Set the max enrichment value for normalization
    max_enrichment = 3000  # Adjust to the maximum value you want for color scaling
    # Apply a logarithmic scale to the enrichment values
    log_enrichment = np.log10(enrichment + 1)  # log(x+1) to avoid issues with zero or small values
    max_log_enrichment = np.log10(max_enrichment + 1)
    norm_value = log_enrichment / max_log_enrichment  # Normalize to range [0, 1]
    norm_value = np.clip(norm_value, 0, 1)  # Ensure values stay in range
    return plt.cm.Reds(norm_value)  # Use 'Reds' colormap

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

# Filter: Only keep amino acids with Enrichment_Ratio > 3
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 3]

# Count how often an amino acid appears across barcodes at each position (i.e., Barcode Count)
df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()  # Count occurrences across barcodes (this is the barcode count)
    .reset_index(name='Barcode_Count')  # Rename count column
)

# Assign colors based on max enrichment ratio per amino acid at each site
df_combined['Max_Enrichment'] = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['Enrichment_Ratio'].max().reset_index(drop=True)
df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)

df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]

# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots
for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')

    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Barcode_Count",  # Now using 'Barcode_Count' for the y-axis
        color_col="color",  
        title=f"{immunization} logoplot (Barcode Count-based)",
        addbreaks=True
    )

    # Add colorbar to the plot
    sm = plt.cm.ScalarMappable(cmap="Reds", norm=plt.Normalize(vmin=np.log10(3 + 1), vmax=np.log10(3000 + 1)))  # Log scale for max
    sm.set_array([])  # Required for the colorbar to work
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')  # Attach colorbar to the same axes as the plot
    cbar.set_label('Log(Enrichment Ratio)', rotation=270, labelpad=15)  # Label the colorbar

    # Set the y-axis label to 'Number of Barcodes Supporting Amino Acid'
    ax.set_ylabel('Number of Barcodes')

    # Show the plot in the notebook
    plt.show()

    # Save the plot as a PNG file
    plot_filename = os.path.join(output_dir, f"{immunization}_logoplot.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")


In [None]:
import random
import matplotlib.pyplot as plt
import os
import numpy as np

# Aggregate enrichment ratio per barcode
df_logo_agg = df_total.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization', 'barcode'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Function to assign enrichment-based colors (black = low, bright red = high) using log scale
def enrichment_color(enrichment):
    # Set the max enrichment value for normalization
    max_enrichment = 3000  # Adjust to the maximum value you want for color scaling
    # Apply a logarithmic scale to the enrichment values
    log_enrichment = np.log10(enrichment + 1)  # log(x+1) to avoid issues with zero or small values
    max_log_enrichment = np.log10(max_enrichment + 1)
    norm_value = log_enrichment / max_log_enrichment  # Normalize to range [0, 1]
    norm_value = np.clip(norm_value, 0, 1)  # Ensure values stay in range
    return plt.cm.Reds(norm_value)  # Use 'Reds' colormap

df_logo_agg['Amino_Acid'] = df_logo_agg['Amino_Acid'].str.upper()

# Filter: Only keep amino acids with Enrichment_Ratio > 3
df_filtered = df_logo_agg[df_logo_agg['Enrichment_Ratio'] > 3]

# Count how often an amino acid appears across barcodes at each position (i.e., Barcode Count)
df_combined = (
    df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])
    .size()  # Count occurrences across barcodes (this is the barcode count)
    .reset_index(name='Barcode_Count')  # Rename count column
)

# Assign colors based on max enrichment ratio per amino acid at each site
df_combined['Max_Enrichment'] = df_filtered.groupby(['Spike_AS_Position', 'Amino_Acid', 'immunization'])['Enrichment_Ratio'].max().reset_index(drop=True)
df_combined['color'] = df_combined['Max_Enrichment'].apply(enrichment_color)

df_combined = df_combined.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str)
)

# Select specific sites from Code 1 to plot (for example: positions 417, 439, 440, etc.)
sites_to_show = [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]

# Filter the df_combined to include only these sites
df_combined = df_combined[df_combined['Spike_AS_Position'].isin(sites_to_show)]

# Create a directory to save the PNG files (if it doesn't already exist)
output_dir = "logo_plots"
os.makedirs(output_dir, exist_ok=True)

# Generate and save logo plots
for immunization in df_combined['immunization'].unique():
    print(f"Generating plot for {immunization}...")

    # Create a filtered DataFrame for the current immunization group and specific sites
    df_immunization = df_combined.query(f'immunization == "{immunization}"')

    # Create the logo plot
    fig, ax = dmslogo.draw_logo(
        df_immunization,  # Pass only the filtered data for specific sites
        x_col="Spike_AS_Position",
        letter_col="Amino_Acid",
        letter_height_col="Barcode_Count",  # Now using 'Barcode_Count' for the y-axis
        color_col="color",  
        title=f"{immunization} logoplot (Barcode Count-based)",
        addbreaks=True
    )

    # Add colorbar to the plot
    sm = plt.cm.ScalarMappable(cmap="Reds", norm=plt.Normalize(vmin=np.log10(3 + 1), vmax=np.log10(3000 + 1)))  # Log scale for max
    sm.set_array([])  # Required for the colorbar to work
    cbar = fig.colorbar(sm, ax=ax, orientation='vertical')  # Attach colorbar to the same axes as the plot
    cbar.set_label('Log(Enrichment Ratio)', rotation=270, labelpad=15)  # Label the colorbar

    # Set the y-axis label to 'Number of Barcodes Supporting Amino Acid'
    ax.set_ylabel('Single Droplet')

    # Show the plot in the notebook
    plt.show()

    # Save the plot as a PNG file
    plot_filename = os.path.join(output_dir, f"{immunization}_logoplot.png")
    plt.savefig(plot_filename, format='png', bbox_inches='tight')  # Save as PNG
    print(f"Plot saved as {plot_filename}")


Generating logoplots


The following code will focus on escape mutations (i.e Enrichment ratios > 1). To visualize the most de-enriched positions, we need to invert the data, so that the lowest enrichment ratios are highlighted. 

In [None]:
#Only non-synonymous mutations are considered, due to the sensitivity of the analysis, and synomymus are unlikely to be of interest in an escape setting
df_escape = df_total[(df_total['Enrichment_Ratio'] < 1) & (df_total['Type_of_Mutation'] == 'NON-SYNOM')]

df_escape = df_escape[df_escape['Spike_AS_Position'] > 34+331] #Removes the first 34 positions due to bad read quality

df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

#df_escape['Enrichment_Ratio_log2'] = df_escape['Enrichment_Ratio'].apply(lambda x: np.log2(x) if x > 0 else x)

df_escape_agg = df_escape.groupby(['Spike_AS_Position', 'Amino_Acid','immunization'], as_index=False).agg({
    'Enrichment_Ratio_inverted': 'sum'
})

#df_escape_agg['Enrichment_Ratio_log2'] = df_escape_agg['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

#The positions that have showed high enrichment ratios in the library droplets are discarded from the analysis (pos 33, 72, 81 and 151)
sites_to_show_escape = map(
    str,
    #[(i+336) for i in range(107, 114) if i not in [33, 72, 81, 151]] +
    [455, 456, 472, 473, 484, 485, 486, 490, 496, 499] + # RBD-ACE2 interface according to article
    list(range(394,414)) + # R21 peptide sequence with high affinity
    list(range(484, 503)) # R13 peptide sequence with high affinity
)
df_escape_agg = df_escape_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show_escape)
)

output_dir = "immunization_csv_files"
os.makedirs(output_dir, exist_ok=True)


for immunization in df_escape_agg['immunization'].unique():
    print(immunization)
    df_filtered_escape = df_escape_agg.query(f'immunization == "{immunization}"')#.query("show_site")
    # Aggregate the data to ensure unique Spike_AS_Position values
    df_filtered_escape = df_filtered_escape.groupby('Spike_AS_Position', as_index=False).agg({
        'Enrichment_Ratio_inverted': 'sum',
    }).drop_duplicates(subset=['Spike_AS_Position'])

    csv_file_path = os.path.join(output_dir, f"{immunization}_escape_data.csv")
    df_filtered_escape.to_csv(csv_file_path, index=False)

    # Merge the show_site column back into df_filtered
    df_filtered_escape = df_filtered_escape.merge(df_escape_agg[['Spike_AS_Position', 'show_site']], on='Spike_AS_Position', how='left')
    
    # Ensure unique Spike_AS_Position values before reindexing
    df_filtered_escape = df_filtered_escape.drop_duplicates(subset=['Spike_AS_Position'])
    
    # Reindex to ensure sequential unbroken integers in Spike_AS_Position
    df_filtered_escape = df_filtered_escape.set_index('Spike_AS_Position').reindex(range(df_filtered_escape['Spike_AS_Position'].min(), df_filtered_escape['Spike_AS_Position'].max() + 1)).reset_index()
    df_filtered_escape['Enrichment_Ratio_invertted'] = df_filtered_escape['Enrichment_Ratio_inverted'].fillna(0)
    df_filtered_escape['show_site'] = df_filtered_escape['show_site'].fillna(False)
    
    fig, ax = dmslogo.line.draw_line(
        df_filtered_escape,
        x_col="Spike_AS_Position",
        height_col="Enrichment_Ratio_inverted",
        title=immunization + ' lineplot',
        xlabel="Spike AA Position",
        ylabel="Enrichment Ratio log2",
        show_col="show_site"
    )
    
    # Save the figure
    #file_path = os.path.join(r"C:\Users\au649453\OneDrive - Aarhus universitet\PhD\Luca\DMS_plots\Escape", f"{immunization}_lineplots.png")
    #plt.savefig(file_path, dpi = 300, bbox_inches = 'tight')
    #plt.close(fig)
    #plt.show()






In [None]:
df_escape = df_total[df_total['Enrichment_Ratio'] < 1]

df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

#df_escape['Enrichment_Ratio_log2'] = df_escape['Enrichment_Ratio'].apply(lambda x: np.log2(x) if x > 0 else x)

df_escape_agg = df_escape.groupby(['Spike_AS_Position', 'Amino_Acid','barcode','immunization'], as_index=False).agg({
    'Enrichment_Ratio_inverted': 'sum'
})

df_escape_agg['Enrichment_Ratio_log2'] = df_escape_agg['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

# Aggregate to ensure unique values for each combination of Spike_AS_Position and Amino_Acid
#df_escape_agg = df_escape_agg.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode'], as_index=False).agg({
#    'Enrichment_Ratio_log2': 'mean'
#})

sites_to_show = map(
    str,
    #[(i+336) for i in range(107, 114) if i not in [33, 72, 81, 151]] +  
    [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 50] # RBD-ACE2 interface according to article
)
df_escape_agg = df_escape_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

#Add the following before the "barcode.unique" to select various conditions
# .query('immunization == "Mutant_RBD"')
for barcode in df_escape_agg['barcode'].unique():
    print(barcode)
    df_filtered = df_escape_agg.query(f'barcode == "{barcode}"').query("show_site")
    if not df_filtered.empty:
        fig, ax = dmslogo.draw_logo(
            df_filtered,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log2",
            title=barcode + 'Escape' + immunization,
            addbreaks=True,
            heightscale= 0.8
        )
        # Save the figure
        #file_path = os.path.join(r"C:\Users\au649453\OneDrive - Aarhus universitet\PhD\Luca\DMS_plots\Escape\NON-SYNOM changes\Targeted", f"{barcode}_logoplot.png", )
        #plt.savefig(file_path, dpi = 300, bbox_inches = 'tight')
        #plt.close(fig)

        ax.set_title(f"")  # Change the title
        ax.set_ylabel("Escape [Log2]")  # Change the y-axis label



In [None]:
df_escape = df_total[df_total['Enrichment_Ratio'] < 1]

df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

df_escape_agg = df_escape.groupby(['Spike_AS_Position', 'Amino_Acid','barcode','immunization'], as_index=False).agg({
    'Enrichment_Ratio_inverted': 'sum'
})

df_escape_agg['Enrichment_Ratio_log2'] = df_escape_agg['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

sites_to_show = map(
    str,
    [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 50]  # RBD-ACE2 interface according to article
)

df_escape_agg = df_escape_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Add the following before the "barcode.unique" to select various conditions
# .query('immunization == "Mutant_RBD"')

for barcode in df_escape_agg['barcode'].unique():
    print(barcode)
    
    # Filter based on barcode and selected sites
    df_filtered = df_escape_agg.query(f'barcode == "{barcode}"').query("show_site")
    
    if not df_filtered.empty:
        # Exclude stop codons before plotting
        df_filtered = df_filtered[~df_filtered['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]

        fig, ax = dmslogo.draw_logo(
            df_filtered,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log2",
            title=barcode + '',
            addbreaks=True,
            heightscale=0.8
        )
        
        ax.set_title(f"")  # Change the title
        ax.set_ylabel("Escape [Log2]")  # Change the y-axis label


In [None]:
plt.ion()
df_escape = df_total[df_total['Enrichment_Ratio'] < 1]

df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

df_escape_agg = df_escape.groupby(['Spike_AS_Position', 'Amino_Acid','barcode','immunization'], as_index=False).agg({
    'Enrichment_Ratio_inverted': 'first'
})

df_escape_agg['Enrichment_Ratio_log2'] = df_escape_agg['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

sites_to_show = map(
    str,
    [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 50]  # RBD-ACE2 interface according to article
)

df_escape_agg = df_escape_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Add the following before the "barcode.unique" to select various conditions
# .query('immunization == "Mutant_RBD"')

for barcode in df_escape_agg['barcode'].unique():
    print(barcode)
    
    # Filter based on barcode and selected sites
    df_filtered = df_escape_agg.query(f'barcode == "{barcode}"').query("show_site")
    
    if not df_filtered.empty:
        # Exclude stop codons before plotting
        df_filtered = df_filtered[~df_filtered['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]
        
        # Extract immunization condition from the filtered DataFrame
        immunization_condition = df_filtered['immunization'].iloc[0]  # Assuming all rows for the barcode have the same immunization

        fig, ax = dmslogo.draw_logo(
            df_filtered,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log2",
            title=barcode + 'n \n n',
            addbreaks=True,
            heightscale=0.8
        )
        
        ax.set_title(f"")  # Change the title
        ax.set_ylabel("Escape [Log2]")  # Change the y-axis label
        

        
        # Add immunization info to the figure text
        fig.text(1.05, 0.5, f"Barcode: {barcode} and Immunization: {immunization_condition}", ha='left', va='center', fontsize=14, rotation=90)
        immunization_condition = df_filtered['immunization'].iloc[0]  # Assuming the immunization condition is the same for all rows in this barcode group
        filename = f"C:/Users/lschlotheube/Desktop/Thesis/LogoEscape/{barcode}_{immunization_condition}.png"
        plt.savefig(filename, bbox_inches='tight')
        plt.draw()


In [None]:
import matplotlib.pyplot as plt

plt.ion()
df_escape = df_total[df_total['Enrichment_Ratio'] < 1]

df_escape['Enrichment_Ratio_inverted'] = df_escape['Enrichment_Ratio'].apply(lambda x: 1 / x if x != 0 else x)

df_escape_agg = df_escape.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio_inverted': 'sum'
})

df_escape_agg['Enrichment_Ratio_log2'] = df_escape_agg['Enrichment_Ratio_inverted'].apply(lambda x: np.log2(x) if x > 0 else x)

sites_to_show = map(
    str,
    [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 50]  # RBD-ACE2 interface according to article
)

df_escape_agg = df_escape_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Add the following before the "barcode.unique" to select various conditions
# .query('immunization == "Mutant_RBD"')

for barcode in df_escape_agg['barcode'].unique():
    print(barcode)
    
    # Filter based on barcode and selected sites
    df_filtered = df_escape_agg.query(f'barcode == "{barcode}"').query("show_site")
    
    if not df_filtered.empty:
        # Exclude stop codons before plotting
        df_filtered = df_filtered[~df_filtered['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]
        
        # Create the plot
        fig, ax = dmslogo.draw_logo(
            df_filtered,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio_log2",
            title='',  # Empty title for now
            addbreaks=True,
            heightscale=0.8
        )
        
        ax.set_ylabel("Escape [Log2]")  # Set the y-axis label
        ax.set_xlabel("SARS-Cov-2 Spike AA Position")  
        
        # Add title to the right side of the plot using fig.text()
        fig.text(1.05, 0.5, f"Barcode: {barcode}", ha='left', va='center', fontsize=14, rotation=90)
        
        # Save the figure
        immunization_condition = df_filtered['immunization'].iloc[0]  # Assuming the immunization condition is the same for all rows in this barcode group
        filename = f"C:/Users/lschlotheube/Desktop/Thesis/LogoEscape/{barcode}_{immunization_condition}.png"
        plt.savefig(filename, bbox_inches='tight')  # Save with tight bounding box
        
        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memory
        plt.close(fig)


In [None]:
import matplotlib.pyplot as plt

plt.ion()


# No need for inversion or log2 transformation anymore
df_binding_agg = df_binding.groupby(['Spike_AS_Position', 'Amino_Acid', 'barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'sum'
})

# Filter for sites to show
sites_to_show = map(
    str,
    [417, 439, 440, 452, 476, 477, 484, 493, 501, 502, 505]  # RBD-ACE2 interface according to article
)

df_binding_agg = df_binding_agg.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["Spike_AS_Position"].astype(str),
    show_site=lambda x: x["Spike_AS_Position"].astype(str).isin(sites_to_show)
)

# Add the following before the "barcode.unique" to select various conditions
# .query('immunization == "Mutant_RBD"')

for barcode in df_binding_agg['barcode'].unique():
    print(barcode)
    
    # Filter based on barcode and selected sites
    df_filtered = df_binding_agg.query(f'barcode == "{barcode}"').query("show_site")
    df_filtered = df_filtered[~df_filtered['Amino_Acid'].str.contains(r'[\*\-]', regex=True)]
    if not df_filtered.empty:
        # Exclude stop codons before plotting
        df_filtered['Amino_Acid'] = df_filtered['Amino_Acid'].str.upper()
        # Create the plot
        print(f"Enrichment_Ratio Range for Barcode {barcode}: {df_filtered['Enrichment_Ratio'].min()} to {df_filtered['Enrichment_Ratio'].max()}")
        print(f"Enrichment_Ratio Distribution for Barcode {barcode} (first 10 values): {df_filtered['Enrichment_Ratio'].head(10).values}")
        
        fig, ax = dmslogo.draw_logo(
            df_filtered,
            x_col="Spike_AS_Position",
            letter_col="Amino_Acid",
            letter_height_col="Enrichment_Ratio",  # Plot the actual ratio (no log or inversion)
            title='',  # Empty title for now
            addbreaks=True,
            heightscale=0.8
        )
        
        ax.set_ylabel("Binding Ratio")  # Set the y-axis label to Binding Ratio
        
        # Add title to the right side of the plot using fig.text()
        fig.text(1.05, 0.5, f"Barcode: {barcode}", ha='left', va='center', fontsize=14, rotation=90)
        
        # Save the figure
        immunization_condition = df_filtered['immunization'].iloc[0]  # Assuming the immunization condition is the same for all rows in this barcode group
        filename = f"C:/Users/lschlotheube/Desktop/Thesis/LogoBinding/{barcode}_{immunization_condition}.png"
        plt.savefig(filename, bbox_inches='tight')  # Save with tight bounding box
        
        plt.draw()
        plt.pause(0.1)  # Give the plot time to render
        plt.show()

        # Optionally close the figure to free memory
        plt.close(fig)


The following cell can generate logoplots for sections which corresponds to the produced peptides

In [None]:
df_logo_agg_test = df_total.groupby(['DMS_RBD_AS_position', 'Amino_Acid','barcode', 'immunization'], as_index=False).agg({
    'Enrichment_Ratio': 'max'
})


sites_to_show_test = list(
    map(
        str,
        list(range(73, 79)) +  # RBD-ACE2 interface according to article
        list(range(87, 91)) +  # R21 peptide sequence with high affinity
        list(range(125, 131)) +  # R13 peptide sequence with high affinity
        list(range(143, 146)) +
        list(range(157, 164)) +
        list(range(174, 176))
    )
)
df_logo_agg_test = df_logo_agg_test.assign(
    site_label=lambda x: x["Amino_Acid"] + "_" + x["DMS_RBD_AS_position"].astype(str),
    show_site=lambda x: x["DMS_RBD_AS_position"].astype(str).isin(sites_to_show_test)
)

#The query can be changed to filter for specific barcodes or removed to get all barcodes
#('immunization == "wildtype_RBD"')
#
for barcode in df_logo_agg_test.query('immunization == "wildtype_RBD"')['barcode'].unique():
    print(barcode)
    fig, ax = dmslogo.draw_logo(
        df_logo_agg_test.query(f'barcode == "{barcode}"').query("show_site"),
        x_col="DMS_RBD_AS_position",
        letter_col="Amino_Acid",
        letter_height_col="Enrichment_Ratio",
        title=barcode + ' logoplot',
        addbreaks=True
    )





Extract the most enriched amino acid for each position for a given barcode. To use for AlphaFold structure prediction