In [None]:
# === Imports ===
import os
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from ast import literal_eval
from scipy.stats import mannwhitneyu
from statannotations.Annotator import Annotator
from statsmodels.stats.multitest import multipletests

# === Setup ===
curr_wd = os.path.abspath(os.getcwd())
print(f"Current working directory: {curr_wd}")

# === Utility Functions ===
def return_chop_seq(seq, list_of_slices):
    """Return a new string composed of specific slices from the input sequence."""
    return ''.join(seq[start:end] for start, end in list_of_slices)

def combine_dicts(dicts_list):
    """Merge list of dictionaries by appending values of same keys."""
    combined_dict = {}
    for dictionary in dicts_list:
        for key, value in dictionary.items():
            combined_dict.setdefault(key, []).append(value)
    return combined_dict

def count_consecutive_stretches_of_1(lst):
    """Find stretches of consecutive 1's in a list."""
    start_stop_list = []
    current_stretch = False
    for index, num in enumerate(lst):
        if num == 1:
            if not current_stretch:
                start = index
                current_stretch = True
        elif current_stretch:
            stop = index
            start_stop_list.append((start, stop, "IDR"))
            current_stretch = False
    if current_stretch:
        start_stop_list.append((start, len(lst), "IDR"))
    return start_stop_list

def is_either_between(low_range, high_range, a, b):
    """Check if either a or b is within a given range."""
    return (low_range <= a <= high_range) or (low_range <= b <= high_range)

def assign_group(protein, group1_name, group1_string, group2_name, group2_string):
    """Assign protein to one of two groups."""
    if protein in group1_name:
        return group1_string
    elif protein in group2_name:
        return group2_string
    else:
        return 'Not in any group'

def assign_group_4(protein, group1_name, group1_string, group2_name, group2_string, group3_name, group3_string, group4_name, group4_string):
    """Assign protein to one of four groups."""
    if protein in group1_name:
        return group1_string
    elif protein in group2_name:
        return group2_string
    elif protein in group3_name:
        return group3_string
    elif protein in group4_name:
        return group4_string
    else:
        return 'Not in any group'

def create_boxplot_with_dots(data, hue_order, scatter=True, **kwargs):
    """Create boxplot with optional scatter overlay and statistical annotations."""
    fig = plt.figure(figsize=(12, 6))
    ax = sns.boxplot(data=data, y='data', x='Group', hue='metric', width=0.8,
                     order=["pos", "neg"], hue_order=hue_order,
                     showfliers=False, ax=plt.gca(), zorder=5, **kwargs)
    
    if scatter:
        sns.stripplot(data=data, y='data', x='Group', hue='metric', color='black',
                      order=["pos", "neg"], hue_order=hue_order,
                      size=5, jitter=0.1, alpha=0.3, dodge=True, zorder=10)

    pairs = [[('pos', metric), ('neg', metric)] for metric in hue_order]
    annotator = Annotator(plt.gca(), pairs, data=data, y='data', x='Group', hue='metric')
    annotator.configure(test='Mann-Whitney', text_format='star', loc='outside',
                        hide_non_significant=False, verbose=0)
    annotator.apply_and_annotate()

    plt.xlabel("GAR positive group vs negative group")
    plt.xlim(-0.5, 1.5)
    plt.legend().remove()
    ax.grid(axis='y', zorder=0)
    ax.set_axisbelow(True)
    plt.close()

    return fig, annotator

def mult_count(seq, list_of_letters):
    """Count total occurrences of specified letters in a sequence."""
    return sum(seq.count(letter) for letter in list_of_letters)

def pval_to_asterisk(value):
    """Convert p-value to significance string."""
    if value > 0.05 or np.isnan(value):
        return 'ns'
    elif value > 0.01:
        return '*'
    elif value > 0.001:
        return '**'
    elif value > 0.0001:
        return '***'
    else:
        return "****"

# === Constants ===
aa_type_to_aa_dict = {
    "aliphatic": "GAVLMI",
    "aromatic": "FYW",
    "pos_charged": "KRH",
    "neg_charged": "DE",
    "uncharged": "STCNPQ",
    "aliphatic_noG": "AVLMI",
    "pos_charged_noR": "KH"
}
aa_types_all = list(aa_type_to_aa_dict.keys())
aa_all = list("ACDEFGHIKLMNPQRSTVWY")

# === Load Data ===
motif_info_set_df = pd.read_parquet(
    os.path.join(curr_wd, 'data/processed/GAR_motif_Wang_set_human_cleaned_annot_filtered.parquet')
)
annotated_IDR_df = pd.read_parquet(
    os.path.join(curr_wd, 'data/processed/annotation_datasets/all_IDR_human.parquet')
)
annotated_domain_df = pd.read_parquet(
    os.path.join(curr_wd, 'data/processed/annotation_datasets/all_domains_human.parquet')
)
annotated_PTMs_df = pd.read_parquet(
    os.path.join(curr_wd, 'data/processed/annotation_datasets/all_PTMs_human.parquet')
)

# Prepare PTM to AA mapping
df2 = annotated_PTMs_df.dropna(subset=['AA', 'ptm']).drop_duplicates(subset=['AA', 'ptm'])
ptm_to_aa_dict = {ptm: df2[df2["ptm"] == ptm]["AA"].iloc[0] for ptm in df2["ptm"]}
print(ptm_to_aa_dict)


In [None]:
ver = "v3"

# Define named sets and their associated file names
set_definitions = {
    "GAR_full": ["GAR_subset_full"],
    "GAR_LLPS_pos": [
        "4_LLPS_positive_set_and_GAR_subset",
        "5_LLPS_positive_set_and_NA_positive_set_and_GAR_subset"
    ],
    "GAR_LLPS_pos_NA_neg": ["4_LLPS_positive_set_and_GAR_subset"],
    "GAR_LLPS_neg": [
        "6_NA_positive_set_and_GAR_subset",
        "7_GAR_subset_only"
    ],
    "GAR_LLPS_neg_NA_pos": ["6_NA_positive_set_and_GAR_subset"],
    "GAR_NA_pos": [
        "5_LLPS_positive_set_and_NA_positive_set_and_GAR_subset",
        "6_NA_positive_set_and_GAR_subset"
    ],
    "GAR_NA_neg": [
        "4_LLPS_positive_set_and_GAR_subset",
        "7_GAR_subset_only"
    ],
    "GAR_pos": ["5_LLPS_positive_set_and_NA_positive_set_and_GAR_subset"],
    "GAR_neg": ["7_GAR_subset_only"],
}

# Initialize dictionaries and containers
set_dict = {}
set_list = []
proteins_sets_dict = {}

# Load each set of proteins from its respective files
for set_name, file_names in set_definitions.items():
    proteins = []
    for fname in file_names:
        file_path = f"{curr_wd}/data/processed/final_set_lists/{fname}.txt"
        with open(file_path, "r") as f:
            proteins.extend(line.strip() for line in f)
    set_dict[set_name] = proteins
    set_list.append(proteins)

# Load full proteome
full_proteome_path = f"{curr_wd}/data/processed/list_of_human_proteins.csv"
with open(full_proteome_path, "r") as f:
    full_proteome = [line.strip() for line in f]
set_dict["full_proteome"] = full_proteome
set_list.append(full_proteome)

# Store sets in versioned dictionary
set_names = list(set_dict.keys())
proteins_sets_dict[ver] = set_dict

my_pos_group = "GAR_pos"
my_neg_group = "GAR_neg"

GAR_LLPS_pos_NA_neg_group = "GAR_LLPS_pos_NA_neg"
GAR_LLPS_neg_NA_pos_group = "GAR_LLPS_neg_NA_pos"


motif_info_set_df['Group'] = motif_info_set_df['UniqueID'].apply(
                                                                    assign_group_4,
                                                                    args= (
                                                                            set_list[set_names.index(my_pos_group)], "pos",
                                                                            set_list[set_names.index(my_neg_group)], "neg",
                                                                            set_list[set_names.index(GAR_LLPS_pos_NA_neg_group)], "LLPS_pos_NA_neg",
                                                                            set_list[set_names.index(GAR_LLPS_neg_NA_pos_group)], "LLPS_neg_NA_pos",
                                                                            
                                                                             ))

In [None]:
protein_dict_with_aa_metrics = {}

for i, curr_protein in enumerate(proteins_sets_dict['v3']['full_proteome']):
    # Check for missing motif info
    curr_motif_df = motif_info_set_df[motif_info_set_df["UniqueID"] == curr_protein]
    if curr_motif_df.empty:
        continue

    full_sequence = curr_motif_df['full_seq'].iloc[0]
    prot_length = len(full_sequence)

    # Check for missing or invalid IDR info
    curr_idr_df = annotated_IDR_df[annotated_IDR_df["protein_name"] == curr_protein]
    if curr_idr_df.empty:
        continue

    IDR_info = curr_idr_df["prediction-disorder-mobidb_lite"].iloc[0]
    if isinstance(IDR_info, list) and IDR_info.count(-1) > 0:
        continue

    print(i, " out of ", len(proteins_sets_dict['v3']['full_proteome']))
    print(curr_protein)

    protein_dict_with_aa_metrics[curr_protein] = {}
    IDR_info = IDR_info.tolist() if isinstance(IDR_info, np.ndarray) else IDR_info  # Make sure it's a list

    # Count IDR types
    count_0 = IDR_info.count(0)
    count_1 = IDR_info.count(1)
    motif_1s = 0
    motif_IDR_bounds = []

    # === Identify motif-associated IDRs ===
    for _, m in curr_motif_df[["start", "end"]].iterrows():
        if m['start'] >= len(IDR_info):
            break

        # Skip overlapping motifs
        if any(is_either_between(lr, hr, m['start'], m['end']) for lr, hr in motif_IDR_bounds):
            continue

        if IDR_info[m['start']] == 1:
            # Expand left
            low_range = m['start']
            while low_range > 0 and IDR_info[low_range - 1] == 1:
                low_range -= 1
                motif_1s += 1
            # Expand right
            high_range = m['start']
            while high_range < len(IDR_info) and IDR_info[high_range] == 1:
                high_range += 1
                motif_1s += 1

        elif IDR_info[m['end'] - 1] == 1:
            low_range = m['end'] - 1
            while low_range > 0 and IDR_info[low_range] == 1:
                low_range -= 1
                motif_1s += 1
            high_range = m['end']
            while high_range < len(IDR_info) and IDR_info[high_range] == 1:
                high_range += 1
                motif_1s += 1
        else:
            continue

        motif_IDR_bounds.append((low_range, high_range))

    # === Determine IDR stats ===
    len_non_IDR = count_0
    len_all_IDRs = count_1
    len_motif_IDR = motif_1s
    len_other_IDR = count_1 - motif_1s

    all_IDR_bounds = count_consecutive_stretches_of_1(IDR_info)

    protein_dict_with_aa_metrics[curr_protein].update({
        'num_of_IDR_regions': len(all_IDR_bounds),
        'num_of_IDR_regions_w_motif': len(motif_IDR_bounds),
        'num_of_IDR_regions_wo_motif': len(all_IDR_bounds) - len(motif_IDR_bounds),
        'motif_IDR_range': motif_IDR_bounds,
        'all_IDR_ranges': all_IDR_bounds,
    })

    # === Identify sequence regions ===
    non_IDR_bounds = []
    other_IDR_bounds = []
    last_end = 0

    for bnd in all_IDR_bounds:
        if bnd[0] > last_end:
            non_IDR_bounds.append((last_end, bnd[0]))
        last_end = bnd[1]

        if not any(is_either_between(bnd[0], bnd[1], mb[0], mb[1]) for mb in motif_IDR_bounds):
            other_IDR_bounds.append((bnd[0], bnd[1]))

    if last_end < prot_length:
        non_IDR_bounds.append((last_end, prot_length))

    seq_outside_IDR = return_chop_seq(full_sequence, non_IDR_bounds)
    seq_other_IDR = return_chop_seq(full_sequence, other_IDR_bounds)

    # === Per-motif and per-AA analysis ===
    for i2, row2 in curr_motif_df.iterrows():
        protein_dict_with_aa_metrics[curr_protein][i2] = {}
        motif_length = len(row2['motif'])
        start = row2["start"] - 1
        end = row2["end"]

        # Find matching motif_IDR_bounds
        for mIb in motif_IDR_bounds:
            if is_either_between(mIb[0], mIb[1], row2["start"], row2["end"]):
                curr_motif_IDR_bounds = mIb
                break

        for aa in aa_all + aa_types_all:
            aa_str = aa_type_to_aa_dict[aa] if aa in aa_types_all else aa

            aa_metrics = {
                "counts_overall": mult_count(full_sequence, aa_str),
                "density_overall": mult_count(full_sequence, aa_str) / prot_length,
                "counts_outside_IDR": np.nan,
                "density_outside_IDR": np.nan,
                "counts_in_other_IDR": np.nan,
                "density_in_other_IDRs": np.nan,
                "counts_in_motif_IDR": np.nan,
                "density_in_motif_IDR": np.nan,
                "counts_in_motif": mult_count(row2['motif'], aa_str),
                "density_in_motif": mult_count(row2['motif'], aa_str) / motif_length,
            }

            # +/- window metrics
            for step in range(1, 11):
                plus_seq = full_sequence[end + (step - 1) - 1 : end + step]
                minus_seq = full_sequence[start - step - 1 : start - (step - 1)]
                aa_metrics[f"counts_in_motif+{step}"] = mult_count(plus_seq, aa_str) if len(plus_seq) else np.nan
                aa_metrics[f"counts_in_motif-{step}"] = mult_count(minus_seq, aa_str) if len(minus_seq) else np.nan

            # Region-specific counts
            if seq_outside_IDR:
                aa_metrics["counts_outside_IDR"] = mult_count(seq_outside_IDR, aa_str)
                aa_metrics["density_outside_IDR"] = aa_metrics["counts_outside_IDR"] / len_non_IDR

            if seq_other_IDR:
                aa_metrics["counts_in_other_IDR"] = mult_count(seq_other_IDR, aa_str)
                aa_metrics["density_in_other_IDRs"] = aa_metrics["counts_in_other_IDR"] / len_other_IDR

            motif_seq = full_sequence[curr_motif_IDR_bounds[0]:curr_motif_IDR_bounds[1]]
            aa_metrics["counts_in_motif_IDR"] = mult_count(motif_seq, aa_str)
            aa_metrics["density_in_motif_IDR"] = aa_metrics["counts_in_motif_IDR"] / (curr_motif_IDR_bounds[1] - curr_motif_IDR_bounds[0])

            protein_dict_with_aa_metrics[curr_protein][i2][aa] = aa_metrics


In [None]:
import pandas as pd
import numpy as np
from scipy.stats import mannwhitneyu
from statsmodels.stats.multitest import multipletests

# Flatten nested dictionary into lists
records = []
for prot, motifs in protein_dict_with_aa_metrics.items():
    for motif, aas in motifs.items():
        if isinstance(aas, dict):
            for aa, metrics in aas.items():
                for metric, data in metrics.items():
                    records.append((prot, motif, aa, metric, data))

# Create DataFrame
protein_df = pd.DataFrame(records, columns=['prot', 'motif', 'aa', 'metric', 'data'])
print(len(protein_df))

# Drop NaN data entries
protein_df = protein_df.dropna(subset=['data'])
print(len(protein_df))

# Distinguish between individual amino acids and types
aa_all = [x for x in protein_df['aa'].unique() if len(x) == 1]
aa_types_all = [x for x in protein_df['aa'].unique() if len(x) > 1]
print(aa_all)
print(aa_types_all)

# Collect all metric types
metrics_all = list(protein_df['metric'].unique())
print(metrics_all)

# Define group assignment
def assign_groups_advanced(protein, g1, g1_label, g2, g2_label, g3, g3_label, g4, g4_label):
    if protein in g1:
        return g1_label
    elif protein in g2:
        return g2_label
    elif protein in g3:
        return g3_label
    elif protein in g4:
        return g4_label
    return 'Not in any group'

# Apply group labels (uses predefined variables: set_list, set_names, my_pos_group, my_neg_group)
protein_df['Group'] = protein_df['prot'].apply(
    assign_group, args=(
        set_list[set_names.index(my_pos_group)], "pos",
        set_list[set_names.index(my_neg_group)], "neg"
    )
)

# Infer data types
protein_df = protein_df.infer_objects()

# Compute statistics
statistics_dict = {}
for metric in metrics_all:
    statistics_dict[metric] = {}

    for aa in protein_df['aa'].unique():
        aa_mask = protein_df['aa'] == aa
        metric_mask = protein_df['metric'] == metric
        pos_data = protein_df[aa_mask & metric_mask & (protein_df['Group'] == "pos")]['data'].tolist()
        neg_data = protein_df[aa_mask & metric_mask & (protein_df['Group'] == "neg")]['data'].tolist()

        mwu_stat, mwu_p = mannwhitneyu(pos_data, neg_data)
        mean_pos = np.mean(pos_data) + 1 / len(neg_data)
        mean_neg = np.mean(neg_data) + 1 / len(pos_data)

        statistics_dict[metric][aa] = [mean_pos, mean_neg, mwu_stat, mwu_p]

    # Multiple testing correction (BH)
    p_vals = [v[3] for v in statistics_dict[metric].values()]
    bh_corrected = multipletests(p_vals, method='fdr_bh')[1]
    for aa, corrected_p in zip(statistics_dict[metric], bh_corrected):
        statistics_dict[metric][aa].append(corrected_p)

# Output
print(list(statistics_dict.keys()))
print(statistics_dict)


In [None]:
output_path = f"{curr_wd}/data/results/proteins_with_aa_metrics_df.pkl"
with open(output_path, "wb") as fp:
    pickle.dump(protein_df, fp)
    print("DF saved successfully to file.")

output_path = f"{curr_wd}/data/results/proteins_with_aa_metrics_statistics_dict.pkl"
with open(output_path, "wb") as fp:
    pickle.dump(statistics_dict, fp)
    print("Dict saved successfully to file.")


In [None]:
import math
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import pandas as pd

with open(f"{curr_wd}/data/results/proteins_with_aa_metrics_statistics_dict.pkl", 'rb') as fp:
    statistics_dict = pickle.load(fp)

print("missing are: I, H, M, T, Q ,S")

def create_heatmap(
    data,
    metrics_to_portray,
    aa_to_portray=list(),
    change_label_dict_y=dict(),
    change_label_dict_x=dict(),
    vmax=None,
    cut_ns=True,
    sort_map=True,
    cell_size=10,
    output_data=False
):
    # Determine amino acids to portray
    if len(aa_to_portray) == 0:
        all_aa = list(next(iter(data.values())).keys())
    else:
        all_aa = aa_to_portray

    num_rows = len(all_aa)
    num_cols = len(metrics_to_portray)

    custom_xticks = metrics_to_portray
    custom_yticks = all_aa

    # Collect fold changes, significance labels, and raw p-values
    fold_changes, sign_labels, num_labels = [], [], []
    for aa in custom_yticks:
        row_fc, row_sign, row_num = [], [], []
        for met in custom_xticks:
            if data[met][aa][1] == 0 and data[met][aa][0] == 0:
                fold_change = np.log2(1)
            else:
                fold_change = np.log2(data[met][aa][0] / data[met][aa][1])
            pval_fdr = data[met][aa][4]
            row_fc.append(fold_change)
            row_sign.append(pval_to_asterisk(pval_fdr))
            row_num.append(pval_fdr)
        fold_changes.append(row_fc)
        sign_labels.append(row_sign)
        num_labels.append(row_num)

    # Replace infinities with vmax if given
    fold_changes = [[vmax if math.isinf(v) else v for v in row] for row in fold_changes]

    # Remove rows where all significance labels are "ns"
    if cut_ns:
        filtered_fc, filtered_sign, filtered_labels, filtered_num = [], [], [], []
        for i, row in enumerate(sign_labels):
            if not all(lbl == "ns" for lbl in row):
                filtered_sign.append(sign_labels[i])
                filtered_fc.append(fold_changes[i])
                filtered_labels.append(custom_yticks[i])
                filtered_num.append(num_labels[i])
        fold_changes, sign_labels, custom_yticks, num_labels = filtered_fc, filtered_sign, filtered_labels, filtered_num

    # Convert lists to NumPy arrays
    fold_changes = np.array(fold_changes)
    sign_labels = np.array(sign_labels)

    num_cols = fold_changes.shape[1]
    num_rows = fold_changes.shape[0]

    # Optionally sort rows by total fold change
    if sort_map:
        indices = sorted(range(len(fold_changes)), key=lambda i: sum(fold_changes[i]), reverse=True)
        fold_changes = fold_changes[indices]
        sign_labels = sign_labels[indices]
        custom_yticks = [custom_yticks[i] for i in indices]
        num_labels = [num_labels[i] for i in indices]

    # Define color scale
    if vmax is None:
        vmax = np.nanmax(fold_changes)
        vmin = np.nanmin(fold_changes)
        print(vmax)
        print(vmin)

    cmap = mcolors.LinearSegmentedColormap.from_list(
        'custom_colormap',
        [(0, '#FF4040'), (0.5, 'white'), (1, '#8DB600')],
        N=256
    )

    # Plot setup
    fig = plt.figure(figsize=(num_cols * 1.305, num_rows * 0.495))
    plt.imshow(
        fold_changes,
        cmap=cmap,
        vmin=-vmax,
        vmax=vmax,
        extent=[0, num_cols, num_rows, 0],
        aspect=0.66
    )

    # Cell annotation
    x_positions = np.arange(num_cols) + 0.5
    y_positions = np.arange(num_rows) + 0.5
    for i, y in enumerate(y_positions):
        for j, x in enumerate(x_positions):
            plt.text(x, y, sign_labels[i][j], ha='center', va='center', fontsize=10, color='black')

    # Tick labels
    plt.xticks(x_positions, custom_xticks, rotation=0, ha='right')
    plt.yticks(y_positions, custom_yticks, va='center', size=11)
    plt.gca().tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)

    # Apply y-axis label replacements
    y_labels = [item.get_text() for item in plt.gca().get_yticklabels()]
    for k, v in change_label_dict_y.items():
        try:
            y_labels[y_labels.index(k)] = v
        except ValueError:
            print(f"Label {k} could not be found!")
    plt.gca().set_yticklabels(y_labels)

    # Apply x-axis label replacements
    x_labels = [item.get_text() for item in plt.gca().get_xticklabels()]
    plt.gca().set_xticklabels([change_label_dict_x.get(lbl, lbl) for lbl in x_labels])

    # Assemble output data
    data_dict = {}
    for i, row in enumerate(fold_changes):
        data_dict[custom_yticks[i]] = {}
        for j, val in enumerate(row):
            data_dict[custom_yticks[i]][custom_xticks[j]] = (val, num_labels[i][j])
    data_df = pd.DataFrame.from_dict(data_dict, orient='index')

    # Axis labels and colorbar
    plt.setp(plt.gca().get_xticklabels(), rotation=0, ha="center", rotation_mode="anchor")
    plt.colorbar(
        fraction=0.25,
        location='right',
        pad=0.001 * len(y_labels)
    ).set_label("log2(fold change of means)")
    plt.xlabel("")
    plt.ylabel("amino acids", size=12, labelpad=20)

    return (fig, data_df) if output_data else fig

# Optional label mapping for x-axis
xlabel_dict = {
    "density_overall": "full\nprotein",
    "density_outside_IDR": "ordered\nregions",
    "density_in_other_IDRs": "oIDRs",
    "density_in_motif_IDR": "mIDRs",
    "density_in_motif": "RG-\nmotif"
}

this_fig, this_data = create_heatmap(statistics_dict, ["density_outside_IDR", "density_in_other_IDRs","density_in_motif_IDR", "density_in_motif"],
                        ['K', 'Y', 'N', 'D', 'E',  'F','R', 'G','S', 'H', "I", "M", "Q", "T",  'P', 'V', 'W', 'A', 'C','L'], change_label_dict_x = xlabel_dict, vmax=None, sort_map =True, output_data=True
               )

os.makedirs(os.path.join(curr_wd, "data/results/subfigures/"), exist_ok=True)
this_fig.savefig(os.path.join(curr_wd, "data/results/subfigures/fig4_A.svg"), transparent=True)


In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

# Chosen colors and labels
colors = ['#73A5EB', '#C487ED', '#E6E05C', '#E69454']
labels = ['oIDRs', 'ordered regions',  'mIDRs', 'RG-motif']

# Create custom legend handles
handles = [Patch(color=color, label=label) for color, label in zip(colors, labels)]

# Create a dummy figure for the legend
fig, ax = plt.subplots()
ax.axis('off')

legend = ax.legend(handles=handles, ncol=4, loc='center', frameon=True)


os.makedirs(os.path.join(curr_wd, "data/results/subfigures/"), exist_ok=True)
this_fig.savefig(os.path.join(curr_wd, "data/results/subfigures/fig4_B.svg"), transparent=True)

plt.show()


In [None]:
import os
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from typing import List

# === Constants ===
SAVE_PATH = "data/results/subfigures/suppl_fig_S5.svg"

# === Load cleaned AA-metric data ===
with open(f"{curr_wd}/data/results/proteins_with_aa_metrics_df.pkl", 'rb') as fp:
    protein_df = pickle.load(fp)

# === Split positive/negative groups ===
cut_data_pos = protein_df[protein_df['Group'] == "pos"]
cut_data_neg = protein_df[protein_df['Group'] == "neg"]

# === Heatmap plotting function ===
def plot_corr_heatmap(data: pd.DataFrame, aa_list: List[str], group_label: str):
    """
    Plots a correlation heatmap of metric values across amino acids for a protein group.
    """

    # Filter for only the relevant AAs
    filtered = data[data['aa'].isin(aa_list)]
    filtered = filtered[filtered['metric'] == 'counts_in_motif'].copy()

    # Combine AA and metric into a single column
    filtered['merged_metric'] = filtered['aa'] + "_" + filtered['metric']

    # Pivot to get protein x merged_metric matrix
    reshaped = filtered.pivot_table(index='prot', columns='merged_metric', values='data')

    # Sort columns by AA and relative motif position if present
    def sort_key(col):
        prefix, suffix = col.split('_', 1)
        offset = suffix.split('motif')[-1]
        return (prefix, int(offset) if offset else 0)

    reshaped = reshaped[sorted(reshaped.columns, key=sort_key)]

    # Compute correlation matrix
    corr_matrix = reshaped.corr(method='pearson')

    # Plot heatmap
    plt.figure(figsize=(6, 6))
    cmap = sns.light_palette("#3b4c1f", as_cmap=True)

    sns.heatmap(corr_matrix, annot=False, cbar=True, cmap=cmap)

    # Optional: adapt ticks if desired

    plt.xticks(np.arange(0.5, len(aa_all), 1.0), ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'], rotation=0)
    plt.yticks(np.arange(0.5, len(aa_all), 1.0), ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y'], rotation=0)


    plt.xlabel("AA + metric position")
    plt.ylabel("")

    # Save figure
    os.makedirs(os.path.join(curr_wd, "data/results/subfigures/"), exist_ok=True)
    save_file = os.path.join(curr_wd, SAVE_PATH.replace("suppl_fig_S5", f"suppl_fig_S5_{group_label}"))
    plt.savefig(save_file, transparent=True)
    plt.show()


# === Plot for both groups ===
plot_corr_heatmap(cut_data_pos, aa_all, group_label="pos")
plot_corr_heatmap(cut_data_neg, aa_all, group_label="neg")


In [None]:
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# Filter data by group
cut_data_pos = protein_df[protein_df['Group'] == "pos"]
cut_data_neg = protein_df[protein_df['Group'] == "neg"]

def plot_corr_heat_map_detailed(cut_data, chosen_aa, group_label, outdir):
    # Filter by selected amino acids and metric
    cut_data = cut_data[cut_data['aa'].isin(chosen_aa)].copy()
    cut_data['merged_metric'] = cut_data['aa'] + "_" + cut_data['metric']

    # Keep only selected metrics
    selected_metrics = ['counts_in_motif']
    cut_data = cut_data[cut_data['metric'].isin(selected_metrics)]

    # Pivot table: proteins x (AA_position)
    matrix = cut_data.pivot_table(index='prot', columns='merged_metric', values='data')

    # Sort columns by AA and numeric suffix in the metric name
    def sort_key(col):
        aa, rest = col.split('_', 1)
        suffix = rest.split('motif')[1]
        offset = int(suffix) if suffix else 0
        return aa, offset

    matrix = matrix[sorted(matrix.columns, key=sort_key)]

    # Compute correlation
    corr_matrix = matrix.corr(method='pearson')

    # Create plot
    plt.figure(figsize=(5, 5))
    cmap = sns.light_palette("#3b4c1f", as_cmap=True)
    heatmap = sns.heatmap(
        corr_matrix, annot=True, cbar=True, cmap=cmap,
        annot_kws={"fontsize": 12}, square=True
    )

    # Adjust tick labels
    tick_labels = chosen_aa
    plt.xticks(np.arange(0.5, len(tick_labels), 1.0), tick_labels, rotation=0, fontsize=12)
    plt.yticks(np.arange(0.5, len(tick_labels), 1.0), tick_labels, rotation=0, fontsize=12)
    plt.xlabel("amino acids", fontsize=14)
    plt.ylabel("")

    # Save figure
    os.makedirs(outdir, exist_ok=True)
    filename = f"fig5_E" if group_label == "pos" else "fig5_F"
    plt.savefig(os.path.join(outdir, filename + ".svg"), transparent=True)
    # plt.close()

# Run for both groups
chosen_aa = list("FGRWY")
output_dir = os.path.join(curr_wd, "data/results/subfigures/")
plot_corr_heat_map_detailed(cut_data_pos, chosen_aa, "pos", output_dir)
plot_corr_heat_map_detailed(cut_data_neg, chosen_aa, "neg", output_dir)


In [None]:
import os
import matplotlib.pyplot as plt
from matplotlib_venn import venn2
import pandas as pd

# === Parameters ===
letter1 = 'Y'
letter2 = 'F'
threshold = 5
window = 30
group_used = ["pos"]

# Output path
output_path = os.path.join(curr_wd, "data/results/subfigures/")
os.makedirs(output_path, exist_ok=True)

# === Filter motif dataframe ===
used_motif_df = motif_info_set_df[motif_info_set_df["Group"].isin(group_used)]

# === Count motifs with at least 'threshold' of letter1/letter2 in the ±window region ===
def motif_has_letter(row, letter):
    seq = str(row["full_seq"][row["start"] - 1 - window : row["end"] + window])
    return seq.count(letter) >= threshold

mask_A = used_motif_df.apply(lambda row: motif_has_letter(row, letter1), axis=1)
mask_B = used_motif_df.apply(lambda row: motif_has_letter(row, letter2), axis=1)
mask_AB = mask_A & mask_B

count_A = mask_A.sum()
count_B = mask_B.sum()
count_AB = mask_AB.sum()

print(f"Motifs with ≥{threshold} '{letter1}': {count_A}")
print(f"Motifs with ≥{threshold} '{letter2}': {count_B}")
print(f"Motifs with both ≥{threshold} '{letter1}' and '{letter2}': {count_AB}")

# === Venn diagram ===
plt.figure(figsize=(4, 4))
venn = venn2(
    subsets=(count_A - count_AB, count_B - count_AB, count_AB),
    set_labels=("", ""),
    set_colors=("mediumvioletred", "lightseagreen"),
    alpha=0.6
)

# Aesthetic tweaks
for i in [0, 1, 2]:
    patch = venn.patches[i]
    if patch:
        patch.set_edgecolor("grey")
        patch.set_linewidth(2)

for label in venn.subset_labels:
    if label:
        label.set_fontsize(14)

# Save
plt.tight_layout()
plt.savefig(os.path.join(output_path, "fig5_A.svg"), transparent=True)
plt.show()


In [None]:
import gseapy as gp
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import mygene
import numpy as np
import time

def convert_uniprot_to_gene_symbols(uniprot_ids):
    """Converts a list of UniProt IDs to Entrez Gene Symbols using MyGene.info."""
    mg = mygene.MyGeneInfo()
    query_result = mg.querymany(uniprot_ids, scopes='uniprot', fields='symbol', species='human')
    
    symbol_counts = {}
    gene_symbols = []
    
    duplicates_checked = {}
    
    for entry in query_result:
        if "symbol" in entry:
            symbol = entry["symbol"]
            if symbol not in duplicates_checked:
                duplicates_checked[symbol] = []
            duplicates_checked[symbol].append(entry)
            
    for symbol, entries in duplicates_checked.items():
        unique_symbols = set(e["symbol"] for e in entries if "symbol" in e)
        if len(unique_symbols) == 1:
            gene_symbols.append(symbol)
        else:
            print(f"Potential ambiguous mapping for {symbol}: {unique_symbols}")
    
    return gene_symbols

def run_enrichment(protein_list, library="GO_Biological_Process_2021", max_retries=3):
    """Runs Enrichr enrichment analysis for a given protein list and GO library."""
    gene_symbols = convert_uniprot_to_gene_symbols(protein_list)
    
    for attempt in range(max_retries):
        try:
            enr = gp.enrichr(gene_list=gene_symbols, gene_sets=library, organism='human', outdir=None)
            if enr.results is None or enr.results.empty:
                print(f"No significant enrichment found for {library}.")
                return None
            return enr.results.sort_values(by="Adjusted P-value").head(10)  # Top 10 terms
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            time.sleep(2)  # Wait before retrying
    
    print("Enrichr request failed after multiple attempts.")
    return None

def plot_enrichment(enrichment_results, title):
    """Plots enrichment results as a horizontal bar plot."""
    if enrichment_results is None:
        print(f"No significant enrichment found for {title}.")
        return
    
    enrichment_results["-log10(Adjusted P-value)"] = -enrichment_results["Adjusted P-value"].apply(lambda x: np.log10(x))
    
    plt.figure(figsize=(8, 5))
    sns.barplot(
        data=enrichment_results, 
        y="Term", 
        x="-log10(Adjusted P-value)", 
        palette="viridis"
    )
    plt.xlabel("-log10(Adjusted P-value)")
    plt.ylabel("GO Term")
    plt.title(title)
    plt.gca().invert_yaxis()  # Highest enrichment at the top
    plt.tight_layout()
    plt.show()

# Example usage

proteins_A = used_motif_df.loc[mask_A, "UniqueID"].tolist()
# print(proteins_A)

proteins_B = used_motif_df.loc[mask_B, "UniqueID"].tolist()
# print(proteins_B)

proteins_AB = list(set(proteins_A) & set(proteins_B))

protein_list1 = list(set([item for item in proteins_A if item not in proteins_AB])) # ["P04637", "P00533", "P38398", "P42345", "P06400"]  # UniProt IDs
protein_list2 = list(set([item for item in proteins_B if item not in proteins_AB]))# ["P01106", "P24941", "P11802", "P24385", "Q15796"]  # UniProt IDs

go_terms = ["GO_Biological_Process_2021", "GO_Molecular_Function_2021", "GO_Cellular_Component_2021"]
results_1 = []
results_2 = []
for go in go_terms:
    results_1.append(run_enrichment(protein_list1, go))
    results_2.append(run_enrichment(protein_list2, go))


# Save to file
with open(curr_wd + '/data/results/' + 'protein_for_Y.pkl', 'wb') as f:
    pickle.dump(results_1, f)

with open(curr_wd + '/data/results/' + 'protein_for_F.pkl', 'wb') as f:
    pickle.dump(results_2, f)



In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap

# Load from file
with open(curr_wd + '/data/results/' + 'protein_for_Y.pkl', 'rb') as f:
    results_1 = pickle.load(f)

# Load from file
with open(curr_wd + '/data/results/' + 'protein_for_F.pkl', 'rb') as f:
    results_2 = pickle.load(f)


def shorten_label(label, max_length=50):
    """Shortens labels while preserving the GO term in parentheses."""
    if len(label) <= max_length:
        return label
    if "(" in label and ")" in label:
        term = label[label.rfind("("):]  # Extract GO term in parentheses
        base = label[:max_length - len(term) - 1].rstrip()  # Truncate before GO term
        return f"{base}… {term}" if base else term  # Avoid empty base
    return label[:max_length - 1] + "…"  # Fallback truncation

def plot_enrichment(enrichment_results1, enrichment_results2, title):
    """Plots enrichment results as a horizontal bar plot with labels on top of the bars."""

    # Filter for significant terms (Adjusted P-value < 0.01)
    enrichment_results1 = enrichment_results1[enrichment_results1["Adjusted P-value"] < 0.01].copy()
    enrichment_results2 = enrichment_results2[enrichment_results2["Adjusted P-value"] < 0.01].copy()

    # If no terms remain after filtering, return
    if enrichment_results1.empty and enrichment_results2.empty:
        print(f"No significant enrichment found for {title}.")
        return

    enrichment_results1["-log10(Adjusted P-value)"] = -np.log10(enrichment_results1["Adjusted P-value"])
    enrichment_results2["-log10(Adjusted P-value)"] = -np.log10(enrichment_results2["Adjusted P-value"])

    num_terms1 = len(enrichment_results1)
    num_terms2 = len(enrichment_results2)
    
    # Height ratios to maintain equal bar widths
    height_ratios = [num_terms1, num_terms2] if num_terms1 and num_terms2 else [1, 1]

    fig, axs = plt.subplots(
        2, 1, sharex=True, figsize=(5.2, 1.5+ 0.25*(num_terms1 + num_terms2)), gridspec_kw={"height_ratios": height_ratios}
    )

    # Define custom colormap
    custom_colors_Y = ["#feb4e3", "#ca549e"]  # Light to dark pink shades
    custom_colors_F = ["#c6fbf9", "#71c6c2"]
    custom_cmap_Y = LinearSegmentedColormap.from_list("custom_rdpu", custom_colors_Y, N=256)
    custom_cmap_F = LinearSegmentedColormap.from_list("custom_rdpu", custom_colors_F, N=256)
    
    # Identify shared terms
    terms1 = set(enrichment_results1["Term"])
    terms2 = set(enrichment_results2["Term"])
    shared_terms = terms1 & terms2  # Terms appearing in both group
    for ax, enrichment_results, group_name in zip(
        axs, 
        [enrichment_results1, enrichment_results2], 
        ["tyrosine protein group", "phenylalanine protein group"]
    ):
        enrichment_results = enrichment_results.copy()
        enrichment_results["Short Term"] = enrichment_results["Term"].apply(shorten_label)

        # Normalize values for color mapping
        values = -enrichment_results["-log10(Adjusted P-value)"]
        norm_values = (values - values.min()) / (values.max() - values.min()) if len(values) > 1 else [0.5] * len(values)
        if group_name == "tyrosine protein group":
            colors = [custom_cmap_Y(val) for val in norm_values]
        else:
            colors = [custom_cmap_F(val) for val in norm_values]

        sns.barplot(
            data=enrichment_results, 
            y="Short Term", 
            x="-log10(Adjusted P-value)", 
            palette=colors, ax=ax
        )

        # Remove y-axis labels and ticks
        ax.set_ylabel("")
        ax.set_yticklabels([])
        ax.tick_params(axis='y', left=False)
        ax.set_xlabel("")
        
        # Ensure x-axis tick labels are visible
        # ax.tick_params(axis='x', labelbottom=True)

        # Set subplot title
        # ax.set_title(group_name, fontsize=12, fontweight="bold", rotation=90)
        
        # Add text labels on top of bars
        for bar, term, label in zip(ax.patches, enrichment_results["Term"], enrichment_results["Short Term"]):
            fontweight = "normal" if term in shared_terms else "bold"
            ax.text(
                0.1,  # Position at the end of the bar
                bar.get_y() + bar.get_height() / 2,  # Centered vertically
                label,
                ha="left", va="center", fontsize=10, fontweight=fontweight, color="black"
            )

    axs[-1].set_xlabel("-log10(Adjusted P-value)")
    plt.suptitle(title, fontsize=14, fontweight="bold")  # Main title
    plt.tight_layout()
    os.makedirs(os.path.join(curr_wd, "data/results/subfigures/"), exist_ok=True)
    if title == "GO Biological Process":
        plt.savefig(os.path.join(curr_wd, "data/results/subfigures/fig5_B.svg"), transparent=True)
    elif title == "GO Molecular Function":
        plt.savefig(os.path.join(curr_wd, "data/results/subfigures/fig5_C.svg"), transparent=True)
    elif title == "GO Cellular Component":
        plt.savefig(os.path.join(curr_wd, "data/results/subfigures/fig5_D.svg"), transparent=True)
    plt.show()



for i, go in enumerate(["GO Biological Process", "GO Molecular Function", "GO Cellular Component"]):
    plot_enrichment(results_1[i], results_2[i], f"{go}")


In [None]:
from collections import Counter
from Bio import SeqIO
import numpy as np

def find_all_occurrences(s, char, add=0):
    """Find all positions of `char` in `s`, offset by `add`."""
    return [i + add for i, c in enumerate(s) if c == char]

def signed_distances_to_region_rel(positions, region, length=1):
    """Compute signed normalized distances of positions to a region."""
    start, end = region
    distances = []
    for p in positions:
        if p < start:
            distances.append((p - start) / length)  # Negative: left of region
        elif p > end:
            distances.append((p - end) / length)    # Positive: right of region
    return distances

# Define the amino acids to scan for
chars_to_find = aa_all  # Previously: ["Y", "N", "F", "M", "D", "K", "A"]

# Split data into groups
pos_motif_df = motif_info_set_df[motif_info_set_df["Group"] == "pos"]
neg_motif_df = motif_info_set_df[motif_info_set_df["Group"] == "neg"]

pos_new_dict = {}
neg_new_dict = {}

# Compute distances for each amino acid of interest
for let in chars_to_find:
    temp_pos_new_data = []
    for _, r in pos_motif_df.iterrows():
        motif_IDR_ranges = protein_dict_with_aa_metrics[r["UniqueID"]]["motif_IDR_range"]
        corr_rng = next((rng for rng in motif_IDR_ranges if r["start"] > rng[0] and r["end"] < rng[1]), None)
        if corr_rng is None:
            continue  # Skip if no matching range is found
        pre_seq = r["full_seq"][corr_rng[0]:r["start"]]
        post_seq = r["full_seq"][r["start"]:corr_rng[1]]
        positions = find_all_occurrences(r["full_seq"][corr_rng[0]:corr_rng[1]], let, corr_rng[0])
        distances = signed_distances_to_region_rel(positions, (r["start"], r["end"]))
        temp_pos_new_data.append(distances)
    pos_new_dict[let] = temp_pos_new_data

    temp_neg_new_data = []
    for _, r in neg_motif_df.iterrows():
        motif_IDR_ranges = protein_dict_with_aa_metrics[r["UniqueID"]]["motif_IDR_range"]
        corr_rng = next((rng for rng in motif_IDR_ranges if r["start"] > rng[0] and r["end"] < rng[1]), None)
        if corr_rng is None:
            continue
        positions = find_all_occurrences(r["full_seq"][corr_rng[0]:corr_rng[1]], let, corr_rng[0])
        distances = signed_distances_to_region_rel(positions, (r["start"], r["end"]))
        temp_neg_new_data.append(distances)
    neg_new_dict[let] = temp_neg_new_data

# Collect subsequences around motifs for positive group
all_pos_pre_seqs = []
all_pos_post_seqs = []
all_pos_motifs = []
all_pos_full_length = []

for _, r in pos_motif_df.iterrows():
    motif_IDR_ranges = protein_dict_with_aa_metrics[r["UniqueID"]]["motif_IDR_range"]
    corr_rng = next((rng for rng in motif_IDR_ranges if r["start"] > rng[0] and r["end"] < rng[1]), None)
    if corr_rng is None:
        continue
    all_pos_pre_seqs.append(r["full_seq"][corr_rng[0]:r["start"] - 1])
    all_pos_motifs.append(r["motif"])
    all_pos_post_seqs.append(r["full_seq"][r["end"]:corr_rng[1]])
    all_pos_full_length.append(r["full_seq"][corr_rng[0]:corr_rng[1]])

# Collect subsequences around motifs for negative group
all_neg_pre_seqs = []
all_neg_post_seqs = []

for _, r in neg_motif_df.iterrows():
    motif_IDR_ranges = protein_dict_with_aa_metrics[r["UniqueID"]]["motif_IDR_range"]
    corr_rng = next((rng for rng in motif_IDR_ranges if r["start"] > rng[0] and r["end"] < rng[1]), None)
    if corr_rng is None:
        continue
    all_neg_pre_seqs.append(r["full_seq"][corr_rng[0]:r["start"]])
    all_neg_post_seqs.append(r["full_seq"][r["end"] + 1:corr_rng[1]])

def extract_sequences_from_fasta(fasta_file):
    """Extract sequences from FASTA, keyed by UniProt ID."""
    sequences = {}
    for record in SeqIO.parse(fasta_file, "fasta"):
        record_id = record.id.split("|")[1]
        sequences[record_id] = str(record.seq)
    return sequences

def filter_sequences(sequences, masks):
    """Mask sequences using boolean mask from disorder annotations."""
    filtered_sequences = {}
    for identifier, seq in sequences.items():
        if identifier in masks:
            mask = masks[identifier]
            if len(seq) == len(mask):
                filtered_seq = ''.join([res for res, flag in zip(seq, mask) if flag])
                filtered_sequences[identifier] = filtered_seq
    return filtered_sequences

def compute_aa_proportions(sequences):
    """Compute mean and SEM of amino acid proportions across sequences."""
    amino_acids = "ACDEFGHIKLMNPQRSTVWY"
    proportions = {aa: [] for aa in amino_acids}
    
    for seq in sequences.values():
        seq_len = len(seq)
        if seq_len == 0:
            continue
        counts = Counter(seq)
        for aa in amino_acids:
            proportions[aa].append(counts.get(aa, 0) / seq_len)
    
    avg_proportions = {aa: 100 * np.mean(proportions[aa]) for aa in amino_acids}
    sem_proportions = {
        aa: 100 * (np.std(proportions[aa], ddof=1) / np.sqrt(len(proportions[aa])))
        for aa in amino_acids
    }
    
    return avg_proportions, sem_proportions

# Load sequences and apply IDR masking
sequences = extract_sequences_from_fasta('/mnt/d/phd/scripts/raw_data/proteomes/UP000005640_9606.fasta')
masks = annotated_IDR_df.set_index("protein_name")["prediction-disorder-mobidb_lite"].to_dict()
filtered = filter_sequences(sequences, masks)

# Compute AA proportions in filtered IDRs
avg, sem = compute_aa_proportions(filtered)
print("Average Proportions:", avg)
print("Sum of averages:", sum(avg.values()))
print("Standard Error:", sem)


In [None]:

import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
from itertools import chain

def calculate_aa_percentage(sequence, aa_list):
    """Compute the percentage of each amino acid in a sequence."""
    counts = Counter(sequence)
    total = len(sequence) if sequence else 1  # Avoid division by zero
    return {aa: (counts[aa] / total) * 100 for aa in aa_list}

def compute_proportions(sequences, amino_acids, scope, window_size=1):
    """Compute amino acid proportions and standard deviation at each position."""
    aa_percentages = {aa: [] for aa in amino_acids}
    aa_stddev = {aa: [] for aa in amino_acids}
    aa_counts = {aa: [] for aa in amino_acids}

    for i in range(min(scope, 1), max(scope, -1) + 1):
        if scope > 0:
            merged_seqs = [seq[i-1:i+window_size] for seq in sequences if len(seq) > i+2]
        else:
            merged_seqs = [seq[i - 1 - window_size:i] for seq in sequences if len(seq) > abs(i)]

        proportions_list = [calculate_aa_percentage(seq, amino_acids) for seq in merged_seqs]

        for aa in amino_acids:
            values = [p.get(aa, 0) for p in proportions_list]
            aa_percentages[aa].append(np.mean(values))
            aa_stddev[aa].append(np.std(values))
            aa_counts[aa].append(max(1, len(values)))  # Avoid division by zero

    return aa_percentages, aa_stddev, aa_counts

def plot_aa_proportion(pre_pos, post_pos, pre_neg, post_neg, amino_acids,
                       ref_props, ref_sems, ylim_values=None, window_size=1, scope=100, gap=1):
    """Plot amino acid proportions for positions flanking motifs."""
    pre_pos_props, pre_pos_stddev, pre_pos_counts = compute_proportions(pre_pos, amino_acids, -scope, window_size)
    post_pos_props, post_pos_stddev, post_pos_counts = compute_proportions(post_pos, amino_acids, scope, window_size)
    pre_neg_props, pre_neg_stddev, pre_neg_counts = compute_proportions(pre_neg, amino_acids, -scope, window_size)
    post_neg_props, post_neg_stddev, post_neg_counts = compute_proportions(post_neg, amino_acids, scope, window_size)

    fig, axes = plt.subplots(len(amino_acids) // 4, 4, figsize=(5.5 * 4, 1 + 2/4 * len(amino_acids)), sharex=False)
    axes = axes.flatten()

    for i, (ax, aa) in enumerate(zip(axes, amino_acids), start=1):
        pre_len = len(pre_pos_props[aa])
        post_len = len(post_pos_props[aa])

        xticks_pre = np.arange(-pre_len, 0)
        xticks_post = np.arange(1, post_len + 1) + gap

        # Compute SEMs
        sem_pre_pos = np.array(pre_pos_stddev[aa]) / np.sqrt(pre_pos_counts[aa])
        sem_post_pos = np.array(post_pos_stddev[aa]) / np.sqrt(post_pos_counts[aa])
        sem_pre_neg = np.array(pre_neg_stddev[aa]) / np.sqrt(pre_neg_counts[aa])
        sem_post_neg = np.array(post_neg_stddev[aa]) / np.sqrt(post_neg_counts[aa])

        # Reference line
        if aa in ref_props:
            ax.axhline(ref_props[aa], color="black", linestyle="dashed", label=f"Avg {aa} in Disordered Regions")

        # Plot shaded error regions
        ax.fill_between(xticks_pre, pre_pos_props[aa] - sem_pre_pos, pre_pos_props[aa] + sem_pre_pos, color='#8DB600', alpha=0.2)
        ax.fill_between(xticks_post, post_pos_props[aa] - sem_post_pos, post_pos_props[aa] + sem_post_pos, color='#8DB600', alpha=0.2)
        ax.fill_between(xticks_pre, pre_neg_props[aa] - sem_pre_neg, pre_neg_props[aa] + sem_pre_neg, color='#FF4040', alpha=0.2)
        ax.fill_between(xticks_post, post_neg_props[aa] - sem_post_neg, post_neg_props[aa] + sem_post_neg, color='#FF4040', alpha=0.2)

        # Plot lines
        ax.plot(xticks_pre, pre_pos_props[aa], marker="o", markersize=4, linewidth=2, color='#8DB600', label="Pre-Motif (Pos)")
        ax.plot(xticks_post, post_pos_props[aa], marker="o", markersize=4, linewidth=2, color='#8DB600')
        ax.plot(xticks_pre, pre_neg_props[aa], marker="o", markersize=4, linewidth=2, color='#FF4040', label="Pre-Motif (Neg)")
        ax.plot(xticks_post, post_neg_props[aa], marker="o", markersize=4, linewidth=2, color='#FF4040')

        ax.set_ylabel(f"Proportion of {aa} (%)", size=10)
        ax.grid(True)

        # Highlight motif region
        y_min, y_max = ax.get_ylim()
        ax.fill_between([-1, 2], y_min - 5, y_max + 10, color="gray", alpha=0.3)
        ax.set_ylim(y_min, y_max)
        if ylim_values and aa in ylim_values:
            ax.set_ylim(ylim_values[aa])

        # Set custom x-ticks
        xticks_major = np.concatenate([
            range(xticks_pre[-1], xticks_pre[0]-1, -scope//5),
            range(xticks_post[0], xticks_post[-1]+1, scope//5)
        ])
        xticks_minor = np.concatenate([
            range(xticks_pre[-1], xticks_pre[0]-1, -1),
            range(xticks_post[0], xticks_post[-1]+1)
        ])
        ax.set_xticks(xticks_major)
        ax.set_xticks(xticks_minor, minor=True)
        ax.set_xticklabels(xticks_major, size=9)
        ax.set_yticklabels(ax.get_yticklabels(), size=9)
        ax.set_xlim(-scope - 1, scope + gap + 1)

        if i > 16:
            ax.set_xlabel("Amino acid position relative to RG-motif", size=10)

    plt.tight_layout()
    os.makedirs(os.path.join(curr_wd, "data/results/subfigures/"), exist_ok=True)
    plt.savefig(os.path.join(curr_wd, "data/results/subfigures/suppl_fig_S4.svg"), transparent=True)

    plt.show()

# Example usage:
selected_aas = aa_all
ref_y_props = avg
ref_y_sem = sem
ylim_manual = {"A": (0, 25), "F": (0, 5), "R": (0, 25), "G": (0, 25)}

plot_aa_proportion(
    all_pos_pre_seqs, all_pos_post_seqs,
    all_neg_pre_seqs, all_neg_post_seqs,
    selected_aas,
    ref_y_props, ref_y_sem,
    ylim_values=None,
    window_size=4,
    scope=40
)

