In [None]:
# Standard libraries
import os
import pickle

# Third-party libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import requests
from statannotations.Annotator import Annotator
from scipy.stats import chi2_contingency, ttest_ind

# Print current working directory
curr_wd = os.path.abspath(os.getcwd())
print(f"Current working directory: {curr_wd}")


def get_go_term_details(go_id):
    """
    Retrieve GO term details from EBI QuickGO API.

    Parameters:
        go_id (str): The GO term identifier.

    Returns:
        dict or None: Dictionary of GO term details or None if not found.
    """
    url = f"https://www.ebi.ac.uk/QuickGO/services/ontology/go/terms/{go_id}"
    response = requests.get(url)
    if response.status_code == 200:
        data = response.json()
        if "results" in data and data["results"]:
            term = data["results"][0]
            return {
                "id": term["id"],
                "name": term["name"],
                "namespace": term["aspect"],
                "definition": term["definition"]["text"]
            }
    return None


def assign_group(protein, group1_list, group1_label, group2_list, group2_label):
    """
    Assign a group label to a protein.

    Parameters:
        protein (str): Protein ID.
        group1_list (list): Proteins in group 1.
        group1_label (str): Label for group 1.
        group2_list (list): Proteins in group 2.
        group2_label (str): Label for group 2.

    Returns:
        str: Assigned group label.
    """
    if protein in group1_list:
        return group1_label
    elif protein in group2_list:
        return group2_label
    else:
        return 'Not in any group'


def count_consecutive_stretches_of_1(lst, label="IDR"):
    """
    Find stretches of consecutive 1s in a list.

    Parameters:
        lst (list): List of binary values.
        label (str): Label to assign to each stretch.

    Returns:
        list of tuples: Start, end, and label of each stretch.
    """
    stretches = []
    in_stretch = False
    for i, val in enumerate(lst):
        if val == 1 and not in_stretch:
            start = i
            in_stretch = True
        elif val == 0 and in_stretch:
            stretches.append((start, i, label))
            in_stretch = False
    if in_stretch:
        stretches.append((start, len(lst), label))
    return stretches


def is_either_between(low, high, start, end):
    """
    Check if either endpoint of a motif is within a given range.

    Parameters:
        low (int): Lower bound of range.
        high (int): Upper bound of range.
        start (int): Start of motif.
        end (int): End of motif.

    Returns:
        bool: True if either motif endpoint is within range.
    """
    return (low <= start <= high) or (low <= end <= high)


def merge_duplicates(input_string):
    """
    Remove consecutive duplicate characters from a string.

    Parameters:
        input_string (str): The input string.

    Returns:
        str: String with consecutive duplicates removed.
    """
    result = ""
    prev_char = None
    for char in input_string:
        if char != prev_char:
            result += char
            prev_char = char
    return result


def cat_string(input_string):
    """
    Categorize input string pattern based on motif patterns.

    Parameters:
        input_string (str): The motif pattern string.

    Returns:
        str: Category label for the input string.
    """
    if "DIDI" in input_string and "IDID" in input_string:
        return "more\ncomplex"
    elif "IDID" in input_string:
        return "IDID"
    elif "DIDI" in input_string:
        return "DIDI"
    elif "IDI" in input_string:
        return "IDI"
    elif "DID" in input_string:
        return "DID"
    elif "ID" in input_string:
        return "ID"
    elif "DI" in input_string:
        return "DI"
    elif "I" in input_string:
        return "I"
    elif "D" in input_string:
        return "D"
    else:
        return "None"

def make_struct_string(region_list):
    """
    Convert a list of annotated regions to a structural string representation.

    Parameters:
        region_list (list of tuples): Each tuple ends with a region type 
                                      (e.g., "IDR", "MOTIF", "NABIND", "DOMAIN").

    Returns:
        str: Structural string where each region is represented by a single character.
    """
    struct_string = ""
    mapping = {
        "IDR": "I",
        "MOTIF": "M",
        "NABIND": "R",
        "DOMAIN": "D"
    }

    for _, _, label in region_list:
        struct_string += mapping.get(label, "")
    return struct_string


def assign_groups_advanced(
    protein,
    group1_list, group1_label,
    group2_list, group2_label,
    group3_list, group3_label,
    group4_list, group4_label
):
    """
    Assign a protein to one of four groups based on membership.

    Parameters:
        protein (str): Protein ID.
        groupX_list (list): List of proteins in group X.
        groupX_label (str): Label for group X.

    Returns:
        str: Assigned group label or 'Not in any group'.
    """
    if protein in group1_list:
        return group1_label
    elif protein in group2_list:
        return group2_label
    elif protein in group3_list:
        return group3_label
    elif protein in group4_list:
        return group4_label
    else:
        return 'Not in any group'


def calculate_domain_motif_distance(domain, motif, mode):
    """
    Calculate the distance between a domain and a motif in a protein.

    Parameters:
        domain (tuple): (start, end) of the domain.
        motif (tuple): (start, end) of the motif.
        mode (str): 'e' for edge-based distance, 'c' for center-based distance.

    Returns:
        int or float: Distance between domain and motif. If overlapping, returns 0.
    """
    start1, end1 = domain
    start2, end2 = motif

    if mode == "e":
        if max(start1, start2) <= min(end1, end2):  # overlapping
            return 0
        elif start1 < start2:  # domain left of motif
            return start2 - end1
        else:  # domain right of motif
            return end2 - start1
    elif mode == "c":
        domain_center = start1 + (end1 - start1) / 2
        motif_center = start2 + (end2 - start2) / 2
        if max(start1, start2) <= min(end1, end2):  # overlapping
            return 0
        else:
            return motif_center - domain_center


# Load annotated datasets
motif_info_set_df = pd.read_parquet(
    os.path.join(curr_wd, 'data/processed/GAR_motif_Wang_set_human_cleaned_annot_filtered.parquet')
)
annotated_IDR_df = pd.read_parquet(
    os.path.join(curr_wd, 'data/processed/annotation_datasets/all_IDR_human.parquet')
)
annotated_domain_df = pd.read_parquet(
    os.path.join(curr_wd, 'data/processed/annotation_datasets/all_domains_human.parquet')
)


In [None]:
ver = "v3"

# Define named sets and their associated file names
set_definitions = {
    "GAR_full": ["GAR_subset_full"],
    "GAR_LLPS_pos": [
        "4_LLPS_positive_set_and_GAR_subset",
        "5_LLPS_positive_set_and_NA_positive_set_and_GAR_subset"
    ],
    "GAR_LLPS_pos_NA_neg": ["4_LLPS_positive_set_and_GAR_subset"],
    "GAR_LLPS_neg": [
        "6_NA_positive_set_and_GAR_subset",
        "7_GAR_subset_only"
    ],
    "GAR_LLPS_neg_NA_pos": ["6_NA_positive_set_and_GAR_subset"],
    "GAR_NA_pos": [
        "5_LLPS_positive_set_and_NA_positive_set_and_GAR_subset",
        "6_NA_positive_set_and_GAR_subset"
    ],
    "GAR_NA_neg": [
        "4_LLPS_positive_set_and_GAR_subset",
        "7_GAR_subset_only"
    ],
    "GAR_pos": ["5_LLPS_positive_set_and_NA_positive_set_and_GAR_subset"],
    "GAR_neg": ["7_GAR_subset_only"],
}

# Initialize dictionaries and containers
set_dict = {}
set_list = []
proteins_sets_dict = {}

# Load each set of proteins from its respective files
for set_name, file_names in set_definitions.items():
    proteins = []
    for fname in file_names:
        file_path = f"{curr_wd}/data/processed/final_set_lists/{fname}.txt"
        with open(file_path, "r") as f:
            proteins.extend(line.strip() for line in f)
    set_dict[set_name] = proteins
    set_list.append(proteins)

# Load full proteome
full_proteome_path = f"{curr_wd}/data/processed/list_of_human_proteins.csv"
with open(full_proteome_path, "r") as f:
    full_proteome = [line.strip() for line in f]
set_dict["full_proteome"] = full_proteome
set_list.append(full_proteome)

# Store sets in versioned dictionary
set_names = list(set_dict.keys())
proteins_sets_dict[ver] = set_dict

# Extract positive and negative protein lists
pos_prot_list = set_dict["GAR_pos"]
neg_prot_list = set_dict["GAR_neg"]
set_prot_list = list(set(pos_prot_list + neg_prot_list))

# Filter for Pfam-annotated domains
pfam_annotated_domain_df = annotated_domain_df[
    annotated_domain_df['databases'].apply(lambda x: "pfam" in x)
]

# Build dictionary of domain → GO ID list (non-empty only)
domain_GO_dict = {
    row['name']: row['GO_identifiers']
    for _, row in pfam_annotated_domain_df.iterrows()
    if len(row['GO_identifiers']) > 0
}

print(f"Domains with GO annotations: {len(domain_GO_dict)}")

# Collect and deduplicate all GO terms
all_GO_terms = []
for value in domain_GO_dict.values():
    if isinstance(value, (list, np.ndarray)):
        all_GO_terms.extend(map(str, value))
    else:
        all_GO_terms.append(str(value))

print(f"Total GO term entries (before deduplication): {len(all_GO_terms)}")
all_GO_terms = list(set(all_GO_terms))
print(f"Unique GO terms: {len(all_GO_terms)}")
# print(all_GO_terms)

# Build GO ID → GO name dictionary from annotated domains
GO_ID_dict = {}
for _, row in annotated_domain_df.iterrows():
    ids = row.get("GO_identifiers", [])
    names = row.get("GO_names", [])
    GO_ID_dict.update(dict(zip(ids, names)))

print(f"GO term name mappings: {len(GO_ID_dict)}")
# print("Resulting Dictionary:", GO_ID_dict)


In [None]:
# Count occurrences of domain names in full and Pfam-annotated domain dataframes
occurrences_all = annotated_domain_df['name'].value_counts().to_dict()
print(f"Total unique domains (all): {len(occurrences_all)}")

occurrences_pfam = pfam_annotated_domain_df['name'].value_counts().to_dict()
print(f"Total unique Pfam domains: {len(occurrences_pfam)}")

# Filter domains occurring at least 100 times
occurrences_pfam_filtered = {k: v for k, v in occurrences_pfam.items() if v >= 100}
print(f"Domains with >=100 occurrences: {len(occurrences_pfam_filtered)}")

# List of filtered domain names
domains_to_check = list(occurrences_pfam_filtered.keys())
print(f"Filtered domains to check: {domains_to_check}")
print(f"Number of filtered domains: {len(domains_to_check)}")

# Sort filtered domains by occurrences in descending order
sorted_domains = dict(sorted(occurrences_pfam_filtered.items(), key=lambda item: item[1], reverse=True))

# Separate keys and values for plotting
domain_names = list(sorted_domains.keys())
domain_counts = list(sorted_domains.values())

# Plot bar chart of occurrences
plt.figure(figsize=(12, 6))
plt.bar(domain_names, domain_counts, color='skyblue')
plt.xlabel('Domain Names')
plt.ylabel('Occurrences')
plt.title('Occurrences of Pfam Domains with >=100 Counts')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()


In [None]:
column_name = "GO_identifiers"  # Replace with your column name
rows_with_empty_list_or_array = pfam_annotated_domain_df[column_name].apply(
    lambda x: isinstance(x, (list, np.ndarray)) and len(x) == 0
).sum()
print(f"Rows with empty lists in column '{column_name}':", rows_with_empty_list_or_array)

# Load GO terms related to nucleic acid binding
with open(curr_wd + '/data/external/InterPro/' + 'GO_terms_NAbinding.txt', "r") as f:
    list_of_GO_terms_withNAbinding = f.read().strip().split("\n")

# Initialize output dictionary
protein_dict_with_domain_metrics = {}

# Iterate over proteins in the full proteome
full_proteome = proteins_sets_dict['v3']['full_proteome']
for i, curr_protein in enumerate(full_proteome):
    # Skip if no motif info available
    if motif_info_set_df[motif_info_set_df["UniqueID"] == curr_protein].empty:
        continue

    print(f"{i} out of {len(full_proteome)} — {curr_protein}")

    # Initialize default metrics
    protein_dict_with_domain_metrics[curr_protein] = {
        "num_of_IDR": 0,
        "structure_string": "",
        "num_domains": 0
    }

    # Get IDR annotation if available
    if annotated_IDR_df[annotated_IDR_df["protein_name"] == curr_protein].empty:
        list_of_IDRs = []
        list_of_IDRs_unchanged = []
    else:
        curr_IDR_info = annotated_IDR_df.loc[
            annotated_IDR_df["protein_name"] == curr_protein,
            "prediction-disorder-mobidb_lite"
        ].tolist()[0].tolist()

        list_of_IDRs = count_consecutive_stretches_of_1(curr_IDR_info)
        list_of_IDRs_unchanged = count_consecutive_stretches_of_1(curr_IDR_info)

        # Label IDRs overlapping with motifs
        motif_rows = motif_info_set_df[motif_info_set_df["UniqueID"] == curr_protein]
        for idx, idr_region in enumerate(list_of_IDRs):
            for _, motif in motif_rows[["start", "end"]].iterrows():
                if is_either_between(idr_region[0], idr_region[1], motif['start'], motif['end']):
                    list_of_IDRs[idx] = (idr_region[0], idr_region[1], "MOTIF")
                    break

    # Get domain annotations
    curr_domain_info = pfam_annotated_domain_df[
        pfam_annotated_domain_df['protein_name'] == curr_protein
    ]

    list_of_domains = []
    list_of_domains_R = []
    for _, row in curr_domain_info.iterrows():
        domain_tuple = (row['start'], row['end'], row['name'])
        list_of_domains.append(domain_tuple)

        # Annotate based on NA-binding GO terms
        if set(row['GO_identifiers']).intersection(list_of_GO_terms_withNAbinding):
            list_of_domains_R.append((row['start'], row['end'], "NABIND"))
        else:
            list_of_domains_R.append((row['start'], row['end'], "DOMAIN"))

    # Combine and sort domains and IDRs
    combined_features = sorted(list_of_domains_R + list_of_IDRs, key=lambda x: x[0])
    combined_features_old = sorted(list_of_domains_R + list_of_IDRs_unchanged, key=lambda x: x[0])

    # Store results
    protein_dict_with_domain_metrics[curr_protein] = {
        "num_of_IDR": len(list_of_IDRs),
        "structure_string": make_struct_string(combined_features),
        "structure_string_old": make_struct_string(combined_features_old),
        "num_domains": len(list_of_domains),
        "domains": list_of_domains,
        "IDR_bounds": list_of_IDRs
    }

# Convert to DataFrame
proteins_with_domain_metrics_df = pd.DataFrame(protein_dict_with_domain_metrics).transpose()
proteins_with_domain_metrics_df.reset_index(inplace=True)
proteins_with_domain_metrics_df.rename(columns={'index': 'proteins'}, inplace=True)

# Annotate group (pos/neg)
proteins_with_domain_metrics_df['Group'] = proteins_with_domain_metrics_df['proteins'].apply(
    assign_group, args=(pos_prot_list, "pos", neg_prot_list, "neg")
)

# Create reduced structure strings
proteins_with_domain_metrics_df['reduced_struct_string'] = proteins_with_domain_metrics_df['structure_string'].apply(merge_duplicates)
proteins_with_domain_metrics_df['reduced_struct_string_old'] = proteins_with_domain_metrics_df['structure_string_old'].apply(merge_duplicates)
proteins_with_domain_metrics_df['categ_reduced_struct_string_old'] = proteins_with_domain_metrics_df['reduced_struct_string_old'].apply(cat_string)

# Convert dtypes
proteins_with_domain_metrics_df = proteins_with_domain_metrics_df.infer_objects()
print(proteins_with_domain_metrics_df.dtypes)

# Final output
proteins_with_domain_metrics_df

output_path = f"{curr_wd}/data/results/proteins_with_domain_metrics_df.pkl"
with open(output_path, "wb") as fp:
    pickle.dump(proteins_with_domain_metrics_df, fp)
    print("DataFrame saved successfully to file.")


In [None]:
with open(f"{curr_wd}/data/results/proteins_with_domain_metrics_df.pkl", 'rb') as fp:
    proteins_with_domain_metrics_df = pickle.load(fp)
proteins_with_domain_metrics_df

proteins_with_domain_metrics_df["domains_number"] = proteins_with_domain_metrics_df['domains'].apply(lambda x: len(x) if isinstance(x, list) else 0)
print(proteins_with_domain_metrics_df[proteins_with_domain_metrics_df["Group"] == 'pos'].domains_number.value_counts())
print(proteins_with_domain_metrics_df[proteins_with_domain_metrics_df["Group"] == 'neg'].domains_number.value_counts())

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import colors

# Function to conditionally display absolute numbers
def conditional_autopct(pct, all_values):
    absolute = int(round(pct * sum(all_values) / 100.0))
    return f"{absolute}" if absolute > 1 else ""

# Label mapping for domain counts
custom_label_dict = {
    0: "no domains",
    1: "1 domain",
    2: "2 domains",
    3: "3 domains",
    4: "4 domains",
    "more than 4": "≥ 5 domains"
}

# --- Prepare data ---
def preprocess_group_data(df, group_name):
    group_df = df[df["Group"] == group_name]
    domain_counts = group_df['domains_number'].value_counts().sort_index()
    grouped_counts = domain_counts.groupby(lambda x: "more than 4" if x > 4 else x).sum()
    custom_labels = [custom_label_dict.get(label, label) for label in grouped_counts.index]
    return grouped_counts, custom_labels, len(group_df)

data_pos, labels_pos, pos_len = preprocess_group_data(proteins_with_domain_metrics_df, "pos")
data_neg, labels_neg, neg_len = preprocess_group_data(proteins_with_domain_metrics_df, "neg")

# --- Plotting ---
fig, axes = plt.subplots(1, 2, figsize=(8, 5))
palette = sns.color_palette("OrRd", 7)
palette[0] = 'lightgrey'

def plot_donut(ax, data, labels, title_text):
    wedges, _, autotexts = ax.pie(
        data.values,
        colors=palette,
        autopct=lambda pct: conditional_autopct(pct, data.values),
        startangle=90,
        wedgeprops=dict(width=0.4)
    )
    ax.text(0, 0, title_text, ha='center', va='center', fontsize=12)
    for autotext in autotexts:
        pos = autotext.get_position()
        autotext.set_position((pos[0] * 1.3, pos[1] * 1.3))
        autotext.set_fontsize(12)
    return wedges

# Positive group plot
wedges_pos = plot_donut(axes[0], data_pos, labels_pos, f"positive group\nn={pos_len}")

# Negative group plot
wedges_neg = plot_donut(axes[1], data_neg, labels_neg, f"negative group\nn={neg_len}")

# Legend
axes[1].legend(wedges_pos, labels_pos, title="# of domains", loc="center left", bbox_to_anchor=(1, 0.5),
               fontsize=11, title_fontsize=12)

plt.tight_layout()
os.makedirs(os.path.join(curr_wd, "data/results/subfigures/"), exist_ok=True)
plt.savefig(os.path.join(curr_wd, "data/results/subfigures/fig2_A_B.svg"), transparent=True)

plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# --- Utility functions ---
def extract_domains(df, group_name, domain_GO_dict):
    """Prepare domain data for a specific group (pos/neg)."""
    group_df = df[df["Group"] == group_name].explode("domains").reset_index(drop=True)
    group_df["domain_name"] = group_df["domains"].apply(lambda x: x[2] if isinstance(x, tuple) else "no domain")
    group_df["GO_terms"] = group_df["domain_name"].apply(lambda x: domain_GO_dict.get(x, "no GO term"))
    return group_df[group_df["domain_name"] != "no domain"]


def simple_barplot(data_main, data_ref, mode, min_apps=10, label_dict=None):
    if label_dict is None:
        label_dict = {}

    # Color settings
    if mode == 'positive':
        main_color, ref_color = "#8DB600", "#FF4040"
    elif mode == 'negative':
        main_color, ref_color = "#FF4040", "#8DB600"
    else:
        raise ValueError("Mode must be 'positive' or 'negative'.")

    # Filter by minimum count
    filtered = {k: v for k, v in data_main.items() if v >= min_apps}
    keys = list(filtered.keys())
    values = list(filtered.values())
    ref_values = [data_ref.get(k, 0) for k in keys]

    fig = plt.figure(figsize=(3.5, len(keys) * 0.35))

    y_positions = np.arange(len(keys))
    bars = plt.barh(y_positions, values, color=main_color, edgecolor='black', height=0.75, alpha=0.75)

    # Add text labels on the left of bars
    for y, key in enumerate(keys):
        label = label_dict.get(key, key)
        plt.text(plt.gca().get_xlim()[1] * 0.01, y, label, ha='left', va='center', fontsize=10)

    # Aesthetic settings
    plt.xlabel('# of domain appearances')
    plt.yticks([])  # Hide y-axis ticks
    if mode == "negative":
        plt.xticks(ticks=range(0, 15, 2))

    plt.grid(axis='x', linestyle='--', zorder=0)
    plt.gca().set_axisbelow(True)
    plt.gca().invert_yaxis()
    plt.ylim(len(keys) + 1.5, -0.5)

    plt.tight_layout()
    os.makedirs(os.path.join(curr_wd, "data/results/subfigures/"), exist_ok=True)
    if mode == "positive":
        plt.savefig(os.path.join(curr_wd, "data/results/subfigures/fig2_C.svg"), transparent=True)
    else:
        plt.savefig(os.path.join(curr_wd, "data/results/subfigures/fig2_D.svg"), transparent=True)
    plt.show()

# --- Main processing ---
# Define custom label(s) if needed
label_dict = {
    "Serine-threonine/tyrosine-protein kinase, catalytic domain": "Ser-Thr/Tyr protein kinase"
}

# Prepare domain data for both groups
data_pos_nod = extract_domains(proteins_with_domain_metrics_df, 'pos', domain_GO_dict)
data_neg_nod = extract_domains(proteins_with_domain_metrics_df, 'neg', domain_GO_dict)

# Generate value count dictionaries
counts_pos = data_pos_nod["domain_name"].value_counts().to_dict()
counts_neg = data_neg_nod["domain_name"].value_counts().to_dict()

# --- Plotting ---
simple_barplot(counts_pos, counts_neg, mode="positive", min_apps=5, label_dict=label_dict)
simple_barplot(counts_neg, counts_pos,  mode="negative", min_apps=5, label_dict=label_dict)


In [None]:
#### GO terms preparation
data_pos_nog = data_pos_nod.explode("GO_terms")
print(len(data_pos_nog.GO_terms.value_counts()))

data_neg_nog = data_neg_nod.explode("GO_terms")
print(len(data_neg_nog.GO_terms.value_counts()))


def simple_barplot_GO(cdp, cdn, mode, min_apps=10, label_dict=None):
    if label_dict is None:
        label_dict = {}

    if mode == 'positive':
        data_dict = cdp
        side_dict = cdn
        main_color = "#8DB600"
        side_color = '#FF4040'
    elif mode == "negative":
        data_dict = cdn
        side_dict = cdp
        main_color = '#FF4040'
        side_color = "#8DB600"
    else:
        raise ValueError("Mode must be 'positive' or 'negative'")

    # Filter GO terms by minimum count
    filtered_data = {key: value for key, value in data_dict.items() if value >= min_apps}
    keys = list(filtered_data.keys())
    values = list(filtered_data.values())

    fig = plt.figure(figsize=(3.5, len(keys) * 0.35))
    y_positions = np.arange(len(keys))

    # Plot main bars
    bars = plt.barh(
        keys,
        values,
        color=main_color,
        edgecolor='black',
        label='Main Values',
        height=0.75,
        alpha=0.75
    )

    # Annotate each bar
    for idx, key in enumerate(keys):
        label = GO_ID_dict.get(key, key)
        if label == "regulation of DNA-templated transcription":
            label = "reg. of DNA-templ. transcription"

        category = get_go_term_details(key)
        if category is None:
            label = "no functional annotation"
        elif category["namespace"] == "biological_process":
            label += " (BP)"
        elif category["namespace"] == "cellular_component":
            label += " (CC)"
        elif category["namespace"] == "molecular_function":
            label += " (MF)"

        plt.text(
            plt.gca().get_xlim()[1] * 0.01,
            idx,
            label,
            ha='left',
            va='center',
            rotation=0
        )

    # Plot formatting
    plt.xlabel('# of domain appearances')
    plt.yticks(ticks=[], labels=[])
    plt.grid(axis='x', linestyle='--', zorder=0)
    plt.gca().set_axisbelow(True)
    plt.gca().invert_yaxis()

    os.makedirs(os.path.join(curr_wd, "data/results/subfigures/"), exist_ok=True)
    if mode == "positive":
        plt.savefig(os.path.join(curr_wd, "data/results/subfigures/fig2_E.svg"), transparent=True)
    else:
        plt.savefig(os.path.join(curr_wd, "data/results/subfigures/fig2_F.svg"), transparent=True)
    plt.show()


# Run plots
simple_barplot_GO(
    data_pos_nog.GO_terms.value_counts().to_dict(),
    data_neg_nog.GO_terms.value_counts().to_dict(),
    mode="positive",
    min_apps=10,
    label_dict=label_dict
)

simple_barplot_GO(
    data_pos_nog.GO_terms.value_counts().to_dict(),
    data_neg_nog.GO_terms.value_counts().to_dict(),
    mode="negative",
    min_apps=10,
    label_dict=label_dict
)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_motif_distances(df, motif_info_df, domain_GO_dict, view_limit, label_dict={}, show_singular_domains=False):
    def flatten(lst):
        return [item for sublist in lst for item in sublist]

    data_by_group = {}
    y_tick_counts = []

    for group_label in ["pos", "neg"]:
        filtered_df = df[df["Group"] == group_label].explode("domains").dropna(subset=['domains']).reset_index(drop=True)
        domains = set(filtered_df['domains'].apply(lambda x: x[2]))

        domain_motif_distances = {domain: ([], []) for domain in domains}

        for domain in domains:
            domain_df = filtered_df[filtered_df["domains"].apply(lambda x: x[2] == domain)]
            for _, row in domain_df.iterrows():
                protein_id = row["proteins"]
                motifs = motif_info_df[motif_info_df["UniqueID"] == protein_id]
                motif_count = len(motifs)
                domain_count = len(filtered_df[filtered_df["proteins"] == protein_id])

                for _, motif_row in motifs.iterrows():
                    distance = calculate_domain_motif_distance(row['domains'][:2], (motif_row["start"], motif_row["end"]), "c")
                    if domain_count == 1 and motif_count == 1:
                        domain_motif_distances[domain][1].append(distance)
                    else:
                        domain_motif_distances[domain][0].append(distance)

        def get_GO_filtered_lists(go_id):
            return [domain_motif_distances[dom][i] for dom in domain_motif_distances if go_id in domain_GO_dict.get(dom, [])]

        all_combined = (
            flatten([domain_motif_distances[dom][0] for dom in domain_motif_distances]),
            flatten([domain_motif_distances[dom][1] for dom in domain_motif_distances]),
        )
        domain_motif_distances["All domains combined"] = all_combined

        # Trim by view_limit
        domain_motif_distances_trimmed = {
            dom: ([x for x in lst0 if abs(x) <= view_limit], [x for x in lst1 if abs(x) <= view_limit])
            for dom, (lst0, lst1) in domain_motif_distances.items()
        }

        # Sort by total number of data points
        sorted_domains = sorted(domain_motif_distances_trimmed.items(), key=lambda x: len(x[1][0]) + len(x[1][1]), reverse=True)
        selected_domains = [dom for dom, (lst0, lst1) in sorted_domains if len(lst0) + len(lst1) > 9]

        y_tick_counts.append(len(selected_domains))
        data_by_group[group_label] = {
            "domain_motif_distances": domain_motif_distances,
            "selected_domains": selected_domains
        }

    # Plot setup
    height_ratios = [count / sum(y_tick_counts) for count in y_tick_counts]
    fig, axes = plt.subplots(2, 1, figsize=(10, sum(y_tick_counts)/3.25), sharex=True, gridspec_kw={'height_ratios': height_ratios})

    for ax, group_label in zip(axes, ["pos", "neg"]):
        group_data = data_by_group[group_label]
        domains = group_data["selected_domains"]
        distances = group_data["domain_motif_distances"]

        domain_idx_map = {domain: idx for idx, domain in enumerate(domains)}
        colors = plt.get_cmap("tab10").colors

        for i, domain in enumerate(domains):
            if i % 2 == 0:
                ax.axhspan(i - 0.5, i + 0.5, color='lightgrey', alpha=0.5)

        for idx, domain in enumerate(domains):
            color = "black" if domain == "All domains combined" else colors[idx % len(colors)]
            dotalpha = 0.25 if domain == "All domains combined" else 0.5
            dotsize = 50 if domain == "All domains combined" else 175
            crosssize = 25 if domain == "All domains combined" else 75

            x_vals, special_x_vals = distances[domain]
            y_vals = [domain_idx_map[domain] + np.random.uniform(-0.30, 0.30) for _ in x_vals]
            special_y_vals = [domain_idx_map[domain] + np.random.uniform(-0.30, 0.30) for _ in special_x_vals]

            ax.scatter(x_vals, y_vals, alpha=dotalpha, label=domain, color=color, s=dotsize, marker=".", edgecolor="black")

            if show_singular_domains:
                ax.scatter(special_x_vals, special_y_vals, marker='x', color=color, alpha=1.0, s=crosssize)
            else:
                ax.scatter(special_x_vals, special_y_vals, alpha=dotalpha, label=domain, color=color, s=dotsize, marker=".", edgecolor="black")

        ax.set_yticks(list(domain_idx_map.values()))
        ax.set_yticklabels([label_dict.get(dom, dom) for dom in domain_idx_map], fontsize=11)
        ax.grid(True, axis='x', linestyle='--', alpha=0.5)
        ax.set_xlim(-view_limit, view_limit)
        ax.set_ylim(-0.5, len(domains) - 0.5)

    axes[1].set_xlabel('# of residues from the domain to the RG-motif', fontsize=11)
    axes[0].text(view_limit * 1.03, y_tick_counts[0] // 2, "positive", rotation=90, fontsize=11, fontweight="bold", va="center")
    axes[1].text(view_limit * 1.03, y_tick_counts[1] // 2, "negative", rotation=90, fontsize=11, fontweight="bold", va="center")
    
    fig.text(-0.01, 0.5, "Domains", va='center', rotation='vertical', fontsize=12, fontweight="bold")
    plt.tight_layout()
    os.makedirs(os.path.join(curr_wd, "data/results/subfigures/"), exist_ok=True)
    plt.savefig(os.path.join(curr_wd, "data/results/subfigures/fig2_G.svg"), transparent=True)
    plt.show()

# Example call
y_label_dict = {
    'Serine-threonine/tyrosine-protein kinase, catalytic domain': "Ser-Thr/Tyr-protein kinase, catalytic"
}

plot_motif_distances(
    proteins_with_domain_metrics_df,
    motif_info_set_df,
    domain_GO_dict,
    view_limit=1250,
    label_dict=y_label_dict,
    show_singular_domains=True
)
