# Instructions

1. **Set file paths and options** in the **Setup** cell:
   - `iptm_file_path`: Path to the IPTM vs. PEAK file (required).
   - `spoc_file_path`: Path to the SPOC score file (optional).
   - `SPOC_analysis`: Set to `True` if you want to do SPOC-based analysis (requires a valid SPOC file), otherwise `False`.
   - `output_dir`: Where to save charts and selected data (defaults to creating an "analysis" folder next to your IPTM file).

2. **Run the notebook cells in order**:
   - The second cell loads the IPTM data and checks whether to proceed with SPOC or basic analysis.
   - If SPOC analysis is enabled and the file is provided, the subsequent cells will merge data and show the SPOC-based chart.
   - Otherwise, you'll see the basic IPTM vs. PEAK chart.

3. **Interact with the charts**:
   - Use **Lasso/Box select** to label points persistently.
   - Use the **Search** widget to highlight points by partial name.
   - **Clear** labels or search highlights as needed.
   - **Save** the plot as HTML/PDF or **export** selected data as a CSV.

4. **Check the output directory** for your saved files.

In [6]:
# === STEP 1: BASIC SETUP ===

import os
import pandas as pd
import plotly.express as px
import ipywidgets as widgets
from IPython.display import display, Markdown
from plotly.graph_objs import FigureWidget

# ---------------- USER INPUTS ----------------
# Required: path to the IPTM vs. PEAK file
iptm_file_path = "IPTM_vs_PTM.txt"


# Optional: path to the SPOC file
#Set None if not available
spoc_file_path = "spoc_dir_SPOC_analysis.csv"

# Boolean flag indicating whether you want to do SPOC analysis
SPOC_analysis = True  # or False

# Output directory (default is a subfolder 'analysis' next to the IPTM file)
# If you want to override, set output_dir = "/your/desired/output"
default_base = os.path.dirname(iptm_file_path)  # Folder of the IPTM file
default_out = os.path.join(default_base, "analysis")
output_dir = default_out


### Set hoverfields and fetch data about Protein complex association

In [7]:
# ==========================
# A) USER CONFIGURATION FOR THRESHOLD-BASED FETCH
# ==========================
# Dictionary mapping columns to the threshold above which complex info will be fetched.
# e.g. If "IPTM_max" >= 0.5 OR "scaled_PEAKavg" >= 1.2, fetch complex info. 
# If the column doesn't exist or is NaN, it won't trigger a fetch.
FETCH_COLUMNS = {
    "IPTM_max": 0.5,
    "scaled_PEAKavg": 0.75
}

print(
    f"Will fetch complex info if ANY of these conditions are met:\n"
    + "\n".join([f"  {col} >= {thresh}" for col, thresh in FETCH_COLUMNS.items()])
)

import ipywidgets as widgets
from IPython.display import display

fetch_toggle = widgets.ToggleButtons(
    options=['Fetch Complex Data', 'Skip Complex Data'],
    description='Complex Info:'
)

display(fetch_toggle)


Will fetch complex info if ANY of these conditions are met:
  IPTM_max >= 0.5
  scaled_PEAKavg >= 0.75


ToggleButtons(description='Complex Info:', options=('Fetch Complex Data', 'Skip Complex Data'), value='Fetch C…

In [8]:
# Cell 2: IPTM & SPOC Data Loading, Complex Portal Retrieval, and Hover Widget Setup
# ==========================

import os
import requests
import json
import numpy as np
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, Markdown
import concurrent.futures

# 1) Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)

# 2) Print diagnostic info
print("IPTM file path   :", iptm_file_path)
print("SPOC file path   :", spoc_file_path)
print("SPOC_analysis    :", SPOC_analysis)
print("Output directory :", output_dir)


# We will check fetch_toggle.value later to see if the user selected "Skip".

# ==========================
# B) Class & Function Definitions
# ==========================

# --------------------------
# B1) ProteinComplexInfo class
# --------------------------
class ProteinComplexInfo:
    def __init__(self, uniprot_id, uniprot_cache, complex_cache):
        """
        Initialize with a UniProt ID and retrieve its JSON record
        using the provided caches.
        """
        self.uniprot_id = uniprot_id
        self.uniprot_cache = uniprot_cache
        self.complex_cache = complex_cache
        
        # Use a cached fetch
        self.uniprot_data = self.get_uniprot_data(uniprot_id)
        self.complex_ids = self.extract_complex_ids(self.uniprot_data) if self.uniprot_data else []

    def get_uniprot_data(self, uniprot_id):
        if uniprot_id in self.uniprot_cache:
            return self.uniprot_cache[uniprot_id]
        
        url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json"
        try:
            response = requests.get(url)
            response.raise_for_status()
            data = response.json()
            # Store in cache
            self.uniprot_cache[uniprot_id] = data
            return data
        except requests.exceptions.RequestException as e:
            print(f"Error retrieving data for {uniprot_id}: {e}")
            self.uniprot_cache[uniprot_id] = None
            return None

    def extract_complex_ids(self, uniprot_data):
        complexes = []
        for xref in uniprot_data.get("uniProtKBCrossReferences", []):
            if xref.get("database") == "ComplexPortal":
                comp_id = xref.get("id", "N/A")
                complexes.append(comp_id)
        return complexes

    def get_complex_details_json(self, complex_id):
        if complex_id in self.complex_cache:
            return self.complex_cache[complex_id]

        url = f"https://www.ebi.ac.uk/intact/complex-ws/complex/{complex_id}"
        headers = {"accept": "application/json"}
        try:
            response = requests.get(url, headers=headers)
            response.raise_for_status()
            data = response.json()
            self.complex_cache[complex_id] = data
            return data
        except requests.exceptions.RequestException as e:
            print(f"Error retrieving details for complex {complex_id}: {e}")
            self.complex_cache[complex_id] = None
            return None
        except json.decoder.JSONDecodeError as e:
            print(f"JSON decode error for complex {complex_id}: {e}")
            self.complex_cache[complex_id] = None
            return None

    def get_all_complex_info(self):
        complex_info_list = []
        for comp_id in self.complex_ids:
            details = self.get_complex_details_json(comp_id)
            if details:
                name = details.get("name", "No name provided")
                functions = details.get("functions", [])
                complex_info_list.append({
                    "complex_id": comp_id,
                    "name": name,
                    "functions": functions
                })
        return complex_info_list
    
# --------------------------
# B2) Helper to retrieve complex info from a UniProt ID
# --------------------------
def get_complex_info_from_target(target_uniprot_id, uniprot_cache, complex_cache):
    pci = ProteinComplexInfo(target_uniprot_id, uniprot_cache, complex_cache)
    complex_info_list = pci.get_all_complex_info()
    if complex_info_list:
        # For simplicity, take the first complex if multiple exist
        first_complex = complex_info_list[0]
        complex_name = first_complex.get("name", "")
        functions = first_complex.get("functions", [])
        functions_str = "; ".join(functions) if functions else ""
        return (complex_name, functions_str)
    else:
        return (None, None)

# --------------------------
# B3) Parse target UniProt ID and short name from the "NAME" field
# --------------------------
def parse_target_info(full_name):
    """
    Given a string like:
      "76_sp-Q92610-ZN592_HUMAN_vs_sp-Q13889-TF2H3_HUMAN"
    return (target_uniprot_id, name).

    Example: ("Q13889", "TF2H3")
    """
    if pd.isnull(full_name):
        return (None, None)
    try:
        parts = full_name.split("_vs_")
        if len(parts) < 2:
            return (None, None)
        target_part = parts[1]  # e.g., "sp-Q13889-TF2H3_HUMAN"
        chunks = target_part.split("-")
        if len(chunks) < 3:
            return (None, target_part)
        target_uniprot_id = chunks[1]  # e.g., Q13889
        name = chunks[2].split("_")[0]  # e.g., TF2H3
        return (target_uniprot_id, name)
    except Exception as e:
        print(f"Error parsing '{full_name}': {e}")
        return (None, None)


# --------------------------
# B4) Extract the max IPTM score from the "IPTM" column
# --------------------------
def extract_max_from_iptm(value):
    try:
        if pd.isna(value):
            return np.nan
        parts = str(value).split(":")
        nums = []
        for part in parts:
            try:
                nums.append(float(part))
            except:
                pass
        return max(nums) if nums else np.nan
    except Exception:
        return np.nan


# --------------------------
# B5) Logic to decide if complex info should be fetched, based on user-defined thresholds
# --------------------------
def should_fetch_complex_info(row, fetch_config):
    """
    Returns True if ANY of the user-specified columns in `fetch_config`
    meets or exceeds its threshold. Otherwise returns False.

    - `fetch_config` is a dict: {column_name: threshold_value, ...}
    - If the column is missing or NaN, it won't trigger a fetch.
    """
    for col, threshold in fetch_config.items():
        if col in row and pd.notnull(row[col]):
            if row[col] >= threshold:
                return True
    return False


def fetch_complex_info_threshold(row):
    """
    Return (complex_name, complex_function) only if should_fetch_complex_info
    returns True. Otherwise, (None, None).
    """
    target_id = row["target_uniprot_id"]
    if pd.isnull(target_id):
        return pd.Series([None, None])

    # Check the user-defined thresholds
    if not should_fetch_complex_info(row, FETCH_COLUMNS):
        return pd.Series([None, None])

    return pd.Series(get_complex_info_from_target(target_id))

# ==========================
# C) Main Data Processing
# ==========================

# 1) Load the IPTM data
df_iptm = pd.read_csv(iptm_file_path, sep="\t")
print("Loaded IPTM DataFrame with shape:", df_iptm.shape)

# 2) Create an IPTM_max column
df_iptm["IPTM_max"] = df_iptm["IPTM"].apply(extract_max_from_iptm)
print("Created 'IPTM_max' column with the maximum IPTM score for each row.")

# 3) Merge with SPOC data if provided
if SPOC_analysis and spoc_file_path is not None:
    print("SPOC analysis is True, and a SPOC file is provided. Proceeding with SPOC-based code...")
    df_spoc = pd.read_csv(spoc_file_path)
    print("SPOC DataFrame shape:", df_spoc.shape)
    
    # Merge
    merged_df = pd.merge(
        df_iptm,
        df_spoc,
        left_on="NAME",
        right_on="complex_name",
        how="left"
    )
    print("Merged DataFrame shape:", merged_df.shape)
else:
    print("Either SPOC_analysis is False or no SPOC file provided.")
    print("Proceeding without SPOC merge (basic bubble chart).")
    merged_df = df_iptm.copy()


# Cell 2
if fetch_toggle.value == 'Skip Complex Data':
    print("Skipping fetch...")
    merged_df["complex"] = None
    merged_df["complex_info"] = None
else:
    print("Fetching data...")
    # 4) Let the user know we will fetch complex info
    msg = (
        f"Fetching complex associations if any of these conditions are met:\n" +
        "\n".join([f"  - {col} >= {thresh}" for col, thresh in FETCH_COLUMNS.items()]) +
        "\nPlease be patient or increase thresholds for fetchhing Complex info..."
    )
    print(msg)



# 5) Parse (target_uniprot_id, name) from "NAME"
merged_df[["target_uniprot_id", "name"]] = merged_df["NAME"].apply(
    lambda x: pd.Series(parse_target_info(x))
)

#parallel downloading
import concurrent.futures

UNIPROT_CACHE = {}
COMPLEX_DETAILS_CACHE = {}

# 1) Figure out which rows pass your thresholds
rows_to_fetch = merged_df.apply(lambda r: should_fetch_complex_info(r, FETCH_COLUMNS), axis=1)
sub_df = merged_df[rows_to_fetch].copy()

# 2) Extract unique UniProt IDs
unique_uniprot_ids = sub_df["target_uniprot_id"].dropna().unique()

# 3) Parallel fetch UniProt JSON data
def fetch_uniprot_and_cache(uniprot_id):
    # Re-use or slightly refactor your ProteinComplexInfo code:
    url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.json"
    try:
        response = requests.get(url)
        response.raise_for_status()
        data = response.json()
        UNIPROT_CACHE[uniprot_id] = data
    except Exception as e:
        print(f"Error fetching {uniprot_id}: {e}")
        UNIPROT_CACHE[uniprot_id] = None

# Actually run the concurrent fetch
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
    future_to_uid = {
        executor.submit(fetch_uniprot_and_cache, uid): uid
        for uid in unique_uniprot_ids
    }
    for future in concurrent.futures.as_completed(future_to_uid):
        uid = future_to_uid[future]
        try:
            future.result()  # triggers fetch
        except Exception as e:
            print(f"Error in future for {uid}: {e}")

# 4) Collect the ComplexPortal IDs from each UniProt record
all_complex_ids = set()
for uid in unique_uniprot_ids:
    record = UNIPROT_CACHE.get(uid)
    if record and "uniProtKBCrossReferences" in record:
        for xref in record["uniProtKBCrossReferences"]:
            if xref.get("database") == "ComplexPortal":
                comp_id = xref.get("id")
                if comp_id:
                    all_complex_ids.add(comp_id)

# 5) Parallel fetch ComplexPortal details
def fetch_complexportal_and_cache(comp_id):
    url = f"https://www.ebi.ac.uk/intact/complex-ws/complex/{comp_id}"
    headers = {"accept": "application/json"}
    try:
        r = requests.get(url, headers=headers)
        r.raise_for_status()
        COMPLEX_DETAILS_CACHE[comp_id] = r.json()
    except Exception as e:
        print(f"Error fetching ComplexPortal {comp_id}: {e}")
        COMPLEX_DETAILS_CACHE[comp_id] = None

with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
    future_to_compid = {
        executor.submit(fetch_complexportal_and_cache, cid): cid
        for cid in all_complex_ids
    }
    for future in concurrent.futures.as_completed(future_to_compid):
        cid = future_to_compid[future]
        try:
            future.result()
        except Exception as e:
            print(f"Error in future for complex {cid}: {e}")

print("Parallel fetching of UniProt and ComplexPortal data complete.")

def final_complex_info(row):
    """
    Return (complex_name, functions_str) for the row's target_uniprot_id,
    but only if it passes thresholds. We use UNIPROT_CACHE and COMPLEX_DETAILS_CACHE
    for quick lookups.
    """
    if pd.isnull(row["target_uniprot_id"]):
        return (None, None)
    if not should_fetch_complex_info(row, FETCH_COLUMNS):
        return (None, None)

    # We do a mini-version of get_complex_info_from_target here, 
    # but purely with the caches:
    uid = row["target_uniprot_id"]
    uniprot_record = UNIPROT_CACHE.get(uid, {})
    if not uniprot_record:
        return (None, None)

    # Extract complex IDs:
    complex_ids = []
    for xref in uniprot_record.get("uniProtKBCrossReferences", []):
        if xref.get("database") == "ComplexPortal":
            comp_id = xref.get("id")
            if comp_id:
                complex_ids.append(comp_id)
    if not complex_ids:
        return (None, None)

    # For simplicity, just take the first one
    first_comp_id = complex_ids[0]
    complex_record = COMPLEX_DETAILS_CACHE.get(first_comp_id, {})
    if not complex_record:
        return (None, None)

    name = complex_record.get("name", "")
    functions = complex_record.get("functions", [])
    functions_str = "; ".join(functions) if functions else ""
    return (name, functions_str)

# Finally, assign the columns:
merged_df[["complex", "complex_info"]] = merged_df.apply(
    final_complex_info, axis=1, result_type="expand"
)

# ==========================
# D) (Optional) SPOC-Specific Enhancements
# ==========================
if SPOC_analysis and spoc_file_path is not None:
    # Create an "opacity" column based on spoc_score (if present)
    if "spoc_score" in merged_df.columns and merged_df["spoc_score"].notnull().any():
        min_score = merged_df["spoc_score"].min()
        max_score = 1.0  # forcing maximum to 1.0
        def compute_opacity(score):
            if pd.isnull(score):
                return 0.1
            if max_score == min_score:
                return 1.0
            return 0.1 + (score - min_score) / (max_score - min_score) * (1.0 - 0.1)
        merged_df["opacity"] = merged_df["spoc_score"].apply(compute_opacity)
    else:
        merged_df["opacity"] = 1.0


IPTM file path   : IPTM_vs_PTM.txt
SPOC file path   : spoc_dir_SPOC_analysis.csv
SPOC_analysis    : True
Output directory : analysis
Loaded IPTM DataFrame with shape: (831, 10)
Created 'IPTM_max' column with the maximum IPTM score for each row.
SPOC analysis is True, and a SPOC file is provided. Proceeding with SPOC-based code...
SPOC DataFrame shape: (210, 30)
Merged DataFrame shape: (831, 41)
Skipping fetch...
Parallel fetching of UniProt and ComplexPortal data complete.


In [9]:
#####Merged with annotation file (optional)

import sys
sys.path.append("/Users/matthias.vorlaender/Library/CloudStorage/OneDrive-VBC/scripts/python/Tools/DataMerger")

from DataMerger import DataFrameMerger
#merged_df.head()

# Instantiate the merger
merger = DataFrameMerger(merged_df)

# Set the file path (the file can be CSV, TSV, XLS, or XLSX)
merger.set_file_path("/Volumes/plaschka/shared/data/mass-spec/MS_analysis/protein_annotations/POIs_PolII_Spliceosome_250225_MV.xlsx")

#Optional
#merged_df.head()

# Preview the external file
merger.preview_file(n=5)

##Merge dataframe with external file (i.e protein annotations from MS analysis)
merger.set_merge_columns(
    df_col='target_uniprot_id',
    external_col='Other UniProt Accessions',
    multiple_id_delimiter=','
)

# Perform the merge (e.g., as a left merge)
merged_df = merger.merge_data(how='left')



Successfully read file: /Volumes/plaschka/shared/data/mass-spec/MS_analysis/protein_annotations/POIs_PolII_Spliceosome_250225_MV.xlsx
External DataFrame shape: (534, 16)
Showing first 5 rows of external DataFrame:


Unnamed: 0,Gene Symbol,Complex,Class / family,Other UniProt Accessions,More Aliases from Entrez,Organism,Ensemble Gene ID,Molecular weight from SwissProt,Domains from SwissProt,Motifs from SwissProt,Entrez Gene ID,Full Gene Name from Entrez,UniProt Entry Name,Unnamed: 13,Pdbs,Compositional AA Bias from SwissProt
0,TAF1,,Transcription initiation,P21675,,,,,,,,Transcription initiation factor TFIID subunit ...,TAF1_HUMAN,,,
1,TAF2,,Transcription initiation,Q6P1X5,,,,,,,,Transcription initiation factor TFIID subunit ...,TAF2_HUMAN,,,
2,TAF3,,Transcription initiation,Q5VWG9,,,,,,,,Transcription initiation factor TFIID subunit ...,TAF3_HUMAN,,,
3,TAF4,,Transcription initiation,O00268,,,,,,,,Transcription initiation factor TFIID subunit ...,TAF4_HUMAN,,,
4,TAF5,,Transcription initiation,Q15542,,,,,,,,Transcription initiation factor TFIID subunit ...,TAF5_HUMAN,,,


Merging on:
- Current DF column: target_uniprot_id
- External DF column: Other UniProt Accessions
Handling multiple IDs in the external file with delimiter: ','
Merged DataFrame shape: (831, 62) using how='left'


In [10]:
# ========================== E) Build default hover text & interactive widget# ==========================
# 1) Default columns for hover
default_hover_columns = ["NAME", "IPTM", "PEAK", "complex"]
# Remove any that don't exist in merged_df
default_hover_columns = [c for c in default_hover_columns if c in merged_df.columns]

merged_df["hover_text"] = merged_df.apply(
    lambda row: "<br>".join([f"{col}: {row[col]}" for col in default_hover_columns]),
    axis=1
)

# 2) Build widget to allow user to update hover columns
available_hover_columns = list(merged_df.columns)
if default_hover_columns:
    preselected = tuple(default_hover_columns)
else:
    preselected = (available_hover_columns[0],)  # fallback

hover_columns_selector = widgets.SelectMultiple(
    options=available_hover_columns,
    value=preselected,
    description="Hover Columns:",
    disabled=False,
    layout={'width': '400px'}
)
update_hover_button = widgets.Button(
    description="Update Hover Info",
    button_style="primary"
)

def update_hover_info(b):
    selected_columns = list(hover_columns_selector.value)
    if not selected_columns:
        print("Please select at least one column for hover info.")
        return
    merged_df["hover_text"] = merged_df.apply(
        lambda row: "<br>".join([f"{col}: {row[col]}" for col in selected_columns]),
        axis=1
    )
    print("Hover info updated using columns:", selected_columns)

update_hover_button.on_click(update_hover_info)

display(Markdown("### SPOC Hover-Column Selection"))
display(widgets.HBox([hover_columns_selector, update_hover_button]))

# Your 'merged_df' is now ready for further analysis or plotting.

### SPOC Hover-Column Selection

HBox(children=(SelectMultiple(description='Hover Columns:', index=(0, 1, 6, 41), layout=Layout(width='400px'),…

### Plot with or without SPOC values

In [11]:
# === STEP 4a: SPOC-BASED BUBBLE CHART ===
import re
from plotly.colors import sample_colorscale

# Global variable to store the current hover column selection.
# Initialize with the default hover columns.
current_hover_columns = default_hover_columns

def update_hover_info(b):
    global current_hover_columns
    selected_columns = list(hover_columns_selector.value)
    if not selected_columns:
        print("Please select at least one column for hover info.")
        return
    current_hover_columns = selected_columns

    merged_df["hover_text"] = merged_df.apply(
        lambda row: "<br>".join([f"{col}: {row[col]}" for col in  selected_columns]),
        axis=1
    )
    print("Hover info updated using columns:", selected_columns)

update_hover_button.on_click(update_hover_info)

def parse_name_field(name_str):
    """
    Given a string of form:
       "76_sp-Q92610-ZN592_HUMAN_vs_sp-Q13889-TF2H3_HUMAN"
    return a dict with:
       {
         'index': '76',
         'protein1': 'sp-Q92610-ZN592_HUMAN',
         'protein2': 'sp-Q13889-TF2H3_HUMAN'
       }
    If parsing fails, returns something fallback with empty strings.
    """
    try:
        # Split around '_vs_'
        parts = name_str.split("_vs_")
        left_part = parts[0]  # e.g. "76_sp-Q92610-ZN592_HUMAN"
        right_part = parts[1] # e.g. "sp-Q13889-TF2H3_HUMAN"

        # Now split the left_part on the first underscore, to separate index from protein1
        left_sub = left_part.split("_", 1)
        idx = left_sub[0]  # "76"
        prot1 = left_sub[1]  # "sp-Q92610-ZN592_HUMAN"

        return {
            "index": idx,
            "protein1": prot1,
            "protein2": right_part
        }
    except Exception:
        # If something goes wrong, return placeholders
        return {
            "index": "",
            "protein1": "",
            "protein2": ""
        }

def parse_color(color_str):
    """Converts a hex or rgb(a) color string to (r, g, b)."""
    # This helper is used if you need numeric r,g,b from a string.
    # If you only need to pass e.g. "red" or "#ff0000" to Plotly,
    # you can skip converting to (r,g,b). Plotly can handle them directly.
    if color_str.startswith("#"):
        hex_color = color_str.lstrip("#")
        r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
        return r, g, b
    elif color_str.startswith("rgb"):
        nums = re.findall(r'\d+', color_str)
        r, g, b = tuple(int(n) for n in nums[:3])
        return r, g, b
    else:
        # Try named CSS color (e.g. "red", "blue") - Plotly accepts those directly
        return color_str

# Before building the figure, let's ensure 'index' is in merged_df
if SPOC_analysis and spoc_file_path is not None:
    if "index" not in merged_df.columns:
        # Parse once for all rows
        parsed_info = merged_df["NAME"].apply(parse_name_field).apply(pd.Series)
        merged_df = pd.concat([merged_df, parsed_info], axis=1)
        # Ensure 'index' column is treated as a string
        merged_df["index"] = merged_df["index"].astype(str)

    # Build the scatter figure
    fig = px.scatter(
        merged_df,
        x="scaled_PEAKavg",
        y="IPTMavg",
        size="IPTM_max",    
        color="scaled_PEAKavg",
        color_continuous_scale="viridis_r",
        title="(SPOC) IPTM vs. Scaled PEAKavg",
        labels={"IPTMavg": "IPTMavg", "scaled_PEAKavg": "Scaled PEAKavg"}
    )
    # Then insert this snippet next:

    # Build color array with custom opacity
    min_val = merged_df["scaled_PEAKavg"].min()
    max_val = merged_df["scaled_PEAKavg"].max()
    if max_val != min_val:
        norm = (merged_df["scaled_PEAKavg"] - min_val) / (max_val - min_val)
    else:
        norm = merged_df["scaled_PEAKavg"] * 0  # or just 0

    base_colors = px.colors.sequential.Viridis_r
    rgba_colors = []
    for val, opa in zip(norm, merged_df["opacity"]):
        color_str = sample_colorscale(base_colors, val)[0]  # returns a color string
        try:
            r, g, b = parse_color(color_str)
            if isinstance(r, str):
                # color was a named CSS color
                rgba_colors.append(r) 
            else:
                rgba_colors.append(f"rgba({r},{g},{b},{opa})")
        except:
            rgba_colors.append(f"rgba(0,0,0,{opa})")

    fig.update_traces(
        marker=dict(color=rgba_colors),
        customdata=merged_df[["hover_text"]].values,
        hovertemplate="%{customdata[0]}<extra></extra>"
    )
        
    fig.update_layout(
        # Update x and y axes: no grid, with black axis lines.
        xaxis=dict(
            showgrid=False,
            showline=True,
            linewidth=2,
            linecolor='black'
        ),
        yaxis=dict(
            showgrid=False,
            showline=True,
            linewidth=2,
            linecolor='black'
        ),
        # Add a rectangle shape as an outer border.
        shapes=[
            dict(
                type="rect",
                xref="paper", yref="paper",
                x0=0, y0=0, x1=1, y1=1,
                line=dict(color="black", width=2)
            )
        ],
        # Optionally, set the template and margins.
        template="plotly_white",
        margin=dict(l=50, r=50, t=50, b=50)
    )
        
    figw = FigureWidget(fig)

    # --- GLOBAL STORAGE FOR SELECTIONS ---
    global_persisted_indices_spoc = set()

    def handle_selection(trace, points, selector):
        global global_persisted_indices_spoc
        global_persisted_indices_spoc.update(points.point_inds)
        if not global_persisted_indices_spoc:
            print("No points selected.")
            return
        print("Accumulated selected indices:", global_persisted_indices_spoc)
        selected_df = merged_df.iloc[list(global_persisted_indices_spoc)]
        
        # Example label: extract a short uniprot name from 'NAME' or just show the "index"
        def process_name(name_str):
            try:
                parts = name_str.split("_vs_")
                if len(parts) < 2:
                    return name_str
                # E.g. "sp-Q13889-TF2H3_HUMAN"
                hit = parts[1]
                hit_parts = hit.split("-")
                if len(hit_parts) < 3:
                    return hit
                return hit_parts[2].split("_")[0]
            except:
                return name_str
        
        labels = selected_df["NAME"].apply(process_name)
        
        # Check if we already have a "Persistent Labels" trace
        persistent_trace = None
        for t in figw.data:
            if t.name == "Persistent Labels":
                persistent_trace = t
                break
        
        if persistent_trace is None:
            figw.add_scatter(
                x=selected_df["scaled_PEAKavg"],
                y=selected_df["IPTMavg"],
                mode="text",
                text=labels,
                textposition="top center",
                name="Labels (double click to hide)",
                hoverinfo="skip",
                textfont=dict(color="black", size=8)
            )
        else:
            persistent_trace.x = selected_df["scaled_PEAKavg"]
            persistent_trace.y = selected_df["IPTMavg"]
            persistent_trace.text = labels

    # Attach selection callback
    for trace in figw.data:
        trace.on_selection(handle_selection)

    # --- STANDARD SEARCH (by substring) WIDGETS ---
    search_input_spoc = widgets.Text(
        value="",
        placeholder="Enter partial name to search",
        description="Search NAME:",
        style={'description_width': '120px'},
        layout={'width': '400px'}
    )
    search_button_spoc = widgets.Button(
        description="Search",
        tooltip="Search partial matches",
        button_style="primary"
    )
    clear_search_button_spoc = widgets.Button(
        description="Clear Search",
        tooltip="Remove search highlights",
        button_style="warning"
    )

    def on_search_button_click_spoc(b):
        query = search_input_spoc.value.strip()
        if not query:
            print("Please enter a search query.")
            return
        mask = merged_df["NAME"].str.contains(query, case=False, na=False)
        matched = merged_df[mask]
        if matched.empty:
            print("No matches found.")
            return
        
        # Add highlight scatter
        figw.add_scatter(
            x=matched["scaled_PEAKavg"],
            y=matched["IPTMavg"],
            mode="markers+text",
            marker=dict(symbol="circle-open", size=8, line=dict(width=2, color="red")),
            text=[query]*len(matched),
            textposition="top center",
            name="Search Highlight",
            hoverinfo="skip"
        )
        print(f"Found {len(matched)} match(es). Highlights added.")

    def on_clear_search_button_click_spoc(b):
        indices_to_remove = [i for i, t in enumerate(figw.data) if t.name == "Search Highlight"]
        if not indices_to_remove:
            print("No search highlights to clear.")
            return
        for idx in sorted(indices_to_remove, reverse=True):
            figw.data = figw.data[:idx] + figw.data[idx+1:]
        print("Search highlights cleared.")
    
    search_button_spoc.on_click(on_search_button_click_spoc)
    clear_search_button_spoc.on_click(on_clear_search_button_click_spoc)

    # --- NEW MULTI-GROUP INDEX + COLOR HIGHLIGHT ---
    # Example input: (1,5,12,19=green) (2,9,200=red)
    group_highlight_input_spoc = widgets.Text(
        ##TRICK THE COPY PASTE BIUG HERE!!###
        value="(245,22,250,743,690,261,233,229,479,464,107,1,203,660,659,648,363,462,474,475,492,271,192,97=green) (33=grey) (591,425,761,771,286,385,233,479,464,203,660,659,648,462,474,492,192,508,181,190,64,579,40,708,364,416,35,151=red) (436,117,630,573,3,60,687=black)(400,411,423=blue)",
        placeholder="(192,97=green) (35,151=red)",
        description="Multi-Groups:",
        style={'description_width': '100px'},
        layout={'width': '600px'}
    )
    group_highlight_button_spoc = widgets.Button(
        description="Highlight Groups",
        tooltip="Highlight multiple index groups, each with a color",
        button_style="info"
    )

    def on_group_highlight_button_click_spoc(b):
        """
        Example input: (743=green) (385,23,151=red) (20,21,423=blue)
        Each group is parsed, and for each group we use the name value for labeling.
        """
        input_str = group_highlight_input_spoc.value.strip()
        if not input_str:
            print("No group spec given. Format: (1,5,12=red) (2,9=green)")
            return

        # Split by closing parenthesis, filtering out empties.
        group_specs = [chunk.strip() for chunk in input_str.split(")") if chunk.strip()]

        for spec in group_specs:
            # Remove any leading "(" if present.
            if spec.startswith("("):
                spec = spec[1:].strip()

            # Split on "=" to separate indices from the color.
            if "=" in spec:
                left_part, color_part = spec.split("=", 1)
                indices_str = left_part.strip()
                color_str = color_part.strip()
            else:
                indices_str = spec
                color_str = "red"  # default

            # Split indices (comma-separated).
            idx_list = [x.strip() for x in indices_str.split(",") if x.strip()]
            if not idx_list:
                print(f"No valid indices found in '{spec}'")
                continue

            # Find all matching rows in merged_df.
            matched = merged_df[merged_df["index"].isin(idx_list)]
            if matched.empty:
                print(f"No match for indices {idx_list}")
                continue

            # Use the name column for labeling.
            group_label = matched["name"]

            # Add one scatter trace for the group.
            figw.add_scatter(
                x=matched["scaled_PEAKavg"],
                y=matched["IPTMavg"],
                mode="markers+text",
                showlegend=False,             # <--- Hide from legend
                marker=dict(symbol="circle", color=color_str, size=5),
                text=group_label,
                textposition="top center",
                name=f"Highlight",
                hoverinfo="skip"
            )
            print(f"Highlighted indices {idx_list} in color '{color_str}'")

        print("Group highlight done.")

    group_highlight_button_spoc.on_click(on_group_highlight_button_click_spoc)

    # --- CLEAR LABELS & SAVE PLOT ---
    clear_labels_button_spoc = widgets.Button(
        description="Clear Labels",
        button_style="warning"
    )
    def on_clear_labels_click_spoc(b):
        global_persisted_indices_spoc
        global_persisted_indices_spoc.clear()

        # Remove the "Persistent Labels" or highlight traces if needed
        names_to_remove = ["Persistent Labels"]
        # Also remove any highlight traces we might want to clear
        # If you only want to remove the "Persistent Labels", 
        # leave out highlight traces from the list above.
        indices_to_remove = [
            i for i, t in enumerate(figw.data) 
            if t.name in names_to_remove or t.name.startswith("Highlight ")
        ]
        for idx in sorted(indices_to_remove, reverse=True):
            figw.data = figw.data[:idx] + figw.data[idx+1:]
        print("Persistent labels and highlight traces cleared.")

    clear_labels_button_spoc.on_click(on_clear_labels_click_spoc)

    save_plot_button_spoc = widgets.Button(
        description="Save Plot (HTML & PDF)",
        tooltip="Save the current plot",
        button_style="info"
    )


    file_name_widget_spoc = widgets.Text(
        value="selected_data.csv",
        placeholder="Enter file name",
        description="Save CSV as:",
        disabled=False
    )
    save_data_button_spoc = widgets.Button(
        description="Save Data",
        button_style="success"
    )
    save_data_output_spoc = widgets.Output()
    
    
    import datetime

    custom_suffix_save = widgets.Text(
        value="",
        placeholder="Add file name suffix",
        description="Filename Suffix:",
        layout={'width': '400px'}
    )
    
    def on_save_plot_click_spoc(b):
        try:
            # Retrieve custom suffix from widget and current timestamp.
            suffix = custom_suffix_save.value.strip()
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            if suffix:
                suffix_str = f"_{suffix}_{timestamp}"
            else:
                suffix_str = f"_{timestamp}"
            
            # Build file paths with the custom suffix and timestamp.
            html_filename = f"spoc_bubble_chart{suffix_str}.html"
            pdf_filename = f"spoc_bubble_chart{suffix_str}.pdf"
            html_path = os.path.join(output_dir, html_filename)
            pdf_path = os.path.join(output_dir, pdf_filename)
            
            # Save the figure.
            figw.write_html(html_path)
            figw.write_image(pdf_path, format="pdf")
            print(f"Plot saved:\n  HTML: {html_path}\n  PDF: {pdf_path}")
        except Exception as e:
            print("Error saving plot:", e)

    save_plot_button_spoc.on_click(on_save_plot_click_spoc)

    # --- DISPLAY ---
    instructions_text_spoc = """
    **SPOC-Based Plot Instructions:**
    1. **Use Lasso/Box select to pick points and persist labels**.
    2. (Optional) Search by partial `NAME` using the first box, then clear highlights if needed.
    3. **Highlight by index** (the digits before the underscore) using the second box, 
       e.g. (1,5,12,19=green) (2,9,200=red) (20=blue)
    4. Clear persistent labels and/or highlight traces if needed.
    5. Save the plot (HTML & PDF) or selected data (CSV).
    """
    display(Markdown(instructions_text_spoc))
    display(widgets.HBox([search_input_spoc, search_button_spoc, clear_search_button_spoc]))
    
    # The new multi-index highlight input
    # Now just display the new widgets:
    display(widgets.HBox([group_highlight_input_spoc, group_highlight_button_spoc]))
    display(widgets.HBox([clear_labels_button_spoc]))

    display(widgets.HBox([custom_suffix_save, save_plot_button_spoc]))


    display(figw)
    display(widgets.HBox([file_name_widget_spoc, save_data_button_spoc]))
    display(save_data_output_spoc)

else:
    print("Skipping SPOC-based bubble chart...")


    **SPOC-Based Plot Instructions:**
    1. **Use Lasso/Box select to pick points and persist labels**.
    2. (Optional) Search by partial `NAME` using the first box, then clear highlights if needed.
    3. **Highlight by index** (the digits before the underscore) using the second box, 
       e.g. (1,5,12,19=green) (2,9,200=red) (20=blue)
    4. Clear persistent labels and/or highlight traces if needed.
    5. Save the plot (HTML & PDF) or selected data (CSV).
    

HBox(children=(Text(value='', description='Search NAME:', layout=Layout(width='400px'), placeholder='Enter par…

HBox(children=(Text(value='(245,22,250,743,690,261,233,229,479,464,107,1,203,660,659,648,363,462,474,475,492,2…



HBox(children=(Text(value='', description='Filename Suffix:', layout=Layout(width='400px'), placeholder='Add f…

FigureWidget({
    'data': [{'customdata': array([['Complex: nan'],
                                   ['Complex: nan'],
                                   ['Complex: nan'],
                                   ...,
                                   ['Complex: nan'],
                                   ['Complex: nan'],
                                   ['Complex: nan']], dtype=object),
              'hovertemplate': '%{customdata[0]}<extra></extra>',
              'legendgroup': '',
              'marker': {'color': [rgba(67,58,129,0.1), rgba(42,120,142,0.1),
                                   rgba(70,20,101,0.1), ..., rgba(178,221,45,0.1),
                                   rgba(153,216,60,0.1), rgba(175,221,47,0.1)],
                         'coloraxis': 'coloraxis',
                         'size': array([0.89 , 0.813, 0.647, ..., 0.141, 0.147, 0.127]),
                         'sizemode': 'area',
                         'sizeref': np.float64(0.002225),
                         'sy

HBox(children=(Text(value='selected_data.csv', description='Save CSV as:', placeholder='Enter file name'), But…

Output()

## Merge with external data
TO tricky the copy apste button, search the code for "##TRICK THE COPY PASTE BIUG HERE!!###" and enter your hiohglits there!
 


In [12]:
# === STEP 1: MERGE MS DATA ===
ms_file_path = "/Volumes/plaschka/shared/data/mass-spec/MS_analysis/analysis/MV_RPB3_FLAG_pretty/exports/merged_data_20250319_182512_with_nuc_vs_chrom_with_fraction_dependency.csv"

#df_ms = pd.read_csv(ms_file_path, sep=",")
df_ms = pd.read_csv(ms_file_path)

print("Loaded MS data with shape:", df_ms.shape)

#def extract_target_uniprot(name_str):
#    try:
#        parts = name_str.split("_vs_")
#        if len(parts) < 2:
#            return None
#        target = parts[1]  # e.g., "sp-Q9Y3X0-CCDC9_HUMAN"
#        target_parts = target.split("-")
#        if len(target_parts) < 2:
#            return None
#        return target_parts[1]  # e.g. "Q9Y3X0"
#    except Exception:
#        return None
#
## Create a new column in merged_df with the target uniprot IDs.
#merged_df["target_uniprot"] = merged_df["NAME"].apply(extract_target_uniprot)
#print("Extracted target_uniprot in merged_df.")

# Merge the MS data with merged_df on "Accession" from MS data and "target_uniprot" in merged_df.
merged_df = pd.merge(merged_df, df_ms, left_on="target_uniprot_id", right_on="Accession", how="left")
print("Merged DataFrame shape after merging MS data:", merged_df.shape)

Loaded MS data with shape: (1681, 5)
Merged DataFrame shape after merging MS data: (831, 71)


### Chose columns for mapping color, size and opacity

In [13]:
# === CELL 1: Column & Transform Selection, Save to JSON Config ===
import os
import json
import ipywidgets as widgets
from IPython.display import display, Markdown

config_file = "config.json"

numeric_cols = merged_df.select_dtypes(include=[float, int]).columns.tolist()
transform_options = ["None", "log2", "log10"]

if os.path.exists(config_file):
    with open(config_file, "r") as f:
        prev_cfg = json.load(f)
    print(f"Loaded existing config from {config_file}: {prev_cfg}")
else:
    prev_cfg = {}
    print("No config file found; using empty defaults.")

def dict_get(d, key, fallback):
    return d[key] if key in d else fallback

color_selector = widgets.Dropdown(
    options=numeric_cols,
    value=dict_get(prev_cfg, "color_column", "scaled_PEAKavg"),
    description="Color Col:",
    layout={'width': '220px'}
)
size_selector = widgets.Dropdown(
    options=numeric_cols,
    value=dict_get(prev_cfg, "size_column", "spoc_score"),
    description="Size Col:",
    layout={'width': '220px'}
)
opacity_selector = widgets.Dropdown(
    options=numeric_cols,
    value=dict_get(prev_cfg, "opacity_column", "spoc_score"),
    description="Opacity Col:",
    layout={'width': '220px'}
)

color_transform_selector = widgets.Dropdown(
    options=transform_options,
    value=dict_get(prev_cfg, "color_transform", "None"),
    description="Color Transform:",
    layout={'width': '220px'}
)
size_transform_selector = widgets.Dropdown(
    options=transform_options,
    value=dict_get(prev_cfg, "size_transform", "None"),
    description="Size Transform:",
    layout={'width': '220px'}
)
opacity_transform_selector = widgets.Dropdown(
    options=transform_options,
    value=dict_get(prev_cfg, "opacity_transform", "None"),
    description="Opac Transform:",
    layout={'width': '220px'}
)

display(Markdown("### Select columns for mapping color, size and transparency and their transformations:"))
display(widgets.HBox([color_selector, size_selector, opacity_selector]))
display(widgets.HBox([
    color_transform_selector, size_transform_selector, opacity_transform_selector
]))

def on_save_config(b):
    cfg = {
        "color_column": color_selector.value,
        "size_column": size_selector.value,
        "opacity_column": opacity_selector.value,
        "color_transform": color_transform_selector.value,
        "size_transform": size_transform_selector.value,
        "opacity_transform": opacity_transform_selector.value
    }
    with open(config_file, "w") as f:
        json.dump(cfg, f)
    print("[Cell1] Configuration saved to", config_file, ":", cfg)

save_button = widgets.Button(description="Save Config", button_style="success")
save_button.on_click(on_save_config)
display(save_button)

print("Adjust columns/transforms as desired, then click 'Save Config'. Next, run Cell 2.")

Loaded existing config from config.json: {'color_column': 'FractionDependent', 'size_column': 'DNaseI digest chromatin FLAG-mCh_RPB3 vs DNaseI digest chromatin WT', 'opacity_column': 'IPTM_max', 'color_transform': 'None', 'size_transform': 'None', 'opacity_transform': 'None'}


### Select columns for mapping color, size and transparency and their transformations:

HBox(children=(Dropdown(description='Color Col:', index=46, layout=Layout(width='220px'), options=('IPTMavg', …

HBox(children=(Dropdown(description='Color Transform:', layout=Layout(width='220px'), options=('None', 'log2',…

Button(button_style='success', description='Save Config', style=ButtonStyle())

Adjust columns/transforms as desired, then click 'Save Config'. Next, run Cell 2.


### Optional config file overwrite

In [14]:
### Config file overwrite
#Load a different config file if required
config_file = "config.json"

### Select dynamic range and generate plots 

In [None]:
# === CELL 2: Load & Apply + Ranges + Color Scale => Generate Advanced Plot ===
import os
import json
import numpy as np
import ipywidgets as widgets
from IPython.display import display, Markdown
import plotly.express as px
from plotly.graph_objs import FigureWidget
import datetime

###############################################################################
# 1) Transform function: zero -> 5% quantile, optional log
###############################################################################
def apply_transform_with_5pct(series, transform_kind="None"):
    s = series.fillna(0).copy()
    pos_mask = (s > 0)
    if pos_mask.any():
        q5 = np.quantile(s[pos_mask], 0.05)
        if q5 <= 0:
            q5 = 1e-6
    else:
        q5 = 1e-6
    s[s==0] = q5
    if transform_kind == "None":
        return s
    if (s < 0).any():
        raise ValueError(f"Negative data found for {transform_kind} transform.")
    if transform_kind == "log2":
        return np.log2(s)
    elif transform_kind == "log10":
        return np.log10(s)
    else:
        raise ValueError("Invalid transform option")


###############################################################################
# 2) Load & Apply => create *processed columns => display default min/max
###############################################################################
load_output = widgets.Output()
color_min_box = widgets.FloatText(description="Color Min:", layout={'width': '200px'})
color_max_box = widgets.FloatText(description="Color Max:", layout={'width': '200px'})
size_min_box  = widgets.FloatText(description="Size Min:", layout={'width': '200px'})
size_max_box  = widgets.FloatText(description="Size Max:", layout={'width': '200px'})
opac_min_box  = widgets.FloatText(description="Opac Min:", layout={'width': '200px'})
opac_max_box  = widgets.FloatText(description="Opac Max:", layout={'width': '200px'})

# Let user pick color scale & reverse
color_scales = ["viridis","magma","inferno","plasma","Blues","Reds","RdBu","cividis","PuOr"]
color_scale_selector = widgets.Dropdown(
    options=color_scales,
    value="viridis",
    description="Color Scale:",
    layout={'width': '220px'}
)
reverse_scale_checkbox = widgets.Checkbox(value=False, description="Reverse Scale")

# Store the chosen columns/transforms in these globals:
global_color_col = None
global_size_col = None
global_opacity_col = None
global_color_xform = None
global_size_xform = None
global_opacity_xform = None

def on_load_apply_config(b):
    global global_color_col, global_size_col, global_opacity_col
    global global_color_xform, global_size_xform, global_opacity_xform
    with load_output:
        load_output.clear_output()
        try:
            # 1) read config
            config_file = "config.json"
            if not os.path.exists(config_file):
                print("No config.json found. Please run Cell 1 and save config.")
                return
            with open(config_file,"r") as f:
                final_cfg = json.load(f)

            # 2) Store them in the global variables
            global_color_col   = final_cfg.get("color_column","scaled_PEAKavg")
            global_size_col    = final_cfg.get("size_column","spoc_score")
            global_opacity_col = final_cfg.get("opacity_column","spoc_score")
            global_color_xform = final_cfg.get("color_transform","None")
            global_size_xform  = final_cfg.get("size_transform","None")
            global_opacity_xform = final_cfg.get("opacity_transform","None")

            # 3) transform the columns in merged_df
            merged_df["color_processed"]   = apply_transform_with_5pct(merged_df[global_color_col],   global_color_xform)
            merged_df["size_processed"]    = apply_transform_with_5pct(merged_df[global_size_col],    global_size_xform)
            merged_df["opacity_processed"] = apply_transform_with_5pct(merged_df[global_opacity_col], global_opacity_xform)

            # Show min/max
            cmin, cmax = merged_df["color_processed"].min(), merged_df["color_processed"].max()
            smin, smax = merged_df["size_processed"].min(),  merged_df["size_processed"].max()
            omin, omax = merged_df["opacity_processed"].min(), merged_df["opacity_processed"].max()

            color_min_box.value, color_max_box.value = cmin, cmax
            size_min_box.value, size_max_box.value   = smin, smax
            opac_min_box.value, opac_max_box.value   = omin, omax

            # 4) if config had color_scale & reverse, load them
            cscale = final_cfg.get("color_scale","viridis")
            color_scale_selector.value = cscale
            rev = final_cfg.get("reverse_scale",False)
            reverse_scale_checkbox.value = rev
            
            print("[Load] Config:", final_cfg)
            print(f" color_processed => [{cmin:.3f}..{cmax:.3f}]")
            print(f" size_processed => [{smin:.3f}..{smax:.3f}]")
            print(f" opacity_processed => [{omin:.3f}..{omax:.3f}]")
        except ValueError as e:
            print("[Load Error]", e)

load_button = widgets.Button(description="Load & Apply Config", button_style="primary")
load_button.on_click(on_load_apply_config)



###############################################################################
# 3) Generate Plot (with advanced features)
###############################################################################
plot_output = widgets.Output()
global_persisted_indices_dynamic = set()  # for selection

def on_generate_plot_click(b):
    with plot_output:
        plot_output.clear_output()
        
        # 1) read range values from the widgets
        cmin = color_min_box.value
        cmax = color_max_box.value
        smin = size_min_box.value
        smax = size_max_box.value
        omin = opac_min_box.value
        omax = opac_max_box.value

        # color scale
        chosen_cscale = color_scale_selector.value
        if reverse_scale_checkbox.value:
            chosen_cscale += "_r"

        # 2) clamp color & size
        def clamp(arr, low, high):
            return arr.clip(low,high)

        merged_df["color_clamped"] = clamp(merged_df["color_processed"], cmin, cmax)
        merged_df["size_clamped"]  = clamp(merged_df["size_processed"], smin, smax)

        # 3) map opacity => [0..1]
        arr_op = merged_df["opacity_processed"].values
        if np.isclose(omax, omin):
            arr_map = np.full_like(arr_op, 0.5)
        else:
            arr_map = (arr_op - omin)/(omax - omin)
        arr_map = np.clip(arr_map, 0, 1)

        # 4) build figure
        fig = px.scatter(
            merged_df,
            x="scaled_PEAKavg",
            y="IPTMavg",
            color="color_clamped",
            size="size_clamped",
            hover_data=["NAME","color_processed","size_processed","opacity_processed","hover_text"],
            color_continuous_scale=chosen_cscale,
            range_color=[cmin, cmax],
            labels={"color_clamped": ""},  # blank legend label
            title="Alphafold PEAKavg vs. IPTMavg plot"
        )
        
        # 5) example text annotation describing the mappings
        mapping_text = (
            f"Color mapped to {color_selector.value} (transform: {color_transform_selector.value})<br>"
            f"Size mapped to {size_selector.value} (transform: {size_transform_selector.value})<br>"
            f"Opacity mapped to {opacity_selector.value} (transform: {opacity_transform_selector.value})"
        )
        fig.add_annotation(
            xref="paper",
            yref="paper",
            x=0,
            y=0.92,
            xanchor="left",
            yanchor="bottom",
            showarrow=False,
            text=mapping_text,
            font=dict(size=12, color="black")
        )

        fig.update_layout(
            clickmode="event+select",
            autosize=False,
            width=800,   # pick a fixed width
            height=600,
            legend=dict(
                x=0.05,       # shift legend to the right of the main plot
                xanchor="left",
                y=1.0,
                yanchor="top"
            ),
            xaxis=dict(
                showgrid=False, 
                showline=True, 
                linewidth=2, 
                linecolor='black',
                range=[0,1]  # Force x-axis range
            ),
            yaxis=dict(
                showgrid=False, 
                showline=True, 
                linewidth=2, 
                linecolor='black',
                range=[0,1]  # Force y-axis range
            ),
            shapes=[
                # Outer border rectangle (in paper coordinates)
                dict(
                    type="rect",
                    xref="paper", yref="paper",
                    x0=0, y0=0, x1=1, y1=1,
                    line=dict(color="black", width=2)
                ),
                # Horizontal line at y=0.5 (in axis coordinates)
                dict(
                    type="line",
                    xref="x", yref="y",
                    x0=0, x1=1, y0=0.5, y1=0.5,
                    line=dict(color="black", width=2, dash="dash")
                ),
                # Vertical line at x=0.75 (in axis coordinates)
                dict(
                    type="line",
                    xref="x", yref="y",
                    x0=0.75, x1=0.75, y0=0, y1=1,
                    line=dict(color="black", width=2, dash="dash")
                )
            ],
            template="plotly_white",
            margin=dict(l=50, r=50, t=50, b=50)
)
        # 6) define selection callback for lasso/box
        def handle_selection(trace, points, selector):
            global global_persisted_indices_dynamic
            global_persisted_indices_dynamic.update(points.point_inds)
            if not global_persisted_indices_dynamic:
                print("[Plot] No points selected.")
                return
            print("[Plot] Accumulated indices:", global_persisted_indices_dynamic)
            sel_df = merged_df.iloc[list(global_persisted_indices_dynamic)]

            labels = sel_df["name"]  # or another label

            # see if "Persistent Labels" trace exists
            persist_tr = None
            for t in figw.data:
                if t.name == "Persistent Labels":
                    persist_tr = t
                    break
            if persist_tr is None:
                figw.add_scatter(
                    x=sel_df["scaled_PEAKavg"],
                    y=sel_df["IPTMavg"],
                    mode="text",
                    text=labels,
                    textposition="top center",
                    name="Persistent Labels",
                    showlegend=False,             # <--- Hide from legend
                    hoverinfo="skip",
                    textfont=dict(color="black", size=8)
                )
            else:
                persist_tr.x = sel_df["scaled_PEAKavg"]
                persist_tr.y = sel_df["IPTMavg"]
                persist_tr.text = labels
        
        
        # display final figure
        display(Markdown("### Select the lasso tool and click to label a single point or lasso-select to label a range. Double click on the backgorund to restore transparency settings"))


        figw = FigureWidget(fig)
        display(figw)

        # 5) attach click callback
        for tr in figw.data:
            tr.on_selection(handle_selection)  # already there
            tr.on_click(handle_click)         # new: attach the click callback
 
        # store for searches, highlights, saves
        global global_figw_for_search
        global_figw_for_search = figw

        # apply per-point marker opacity
        for trace in figw.data:
            trace.marker.opacity = arr_map
            trace.marker.line = dict(width=1, color="black")
            
        

generate_plot_button = widgets.Button(
    description="Generate Plot",
    button_style="info"
)
generate_plot_button.on_click(on_generate_plot_click)


###############################################################################
# SAVE FIGURE & SETTINGS BUTTON
###############################################################################
save_figure_button = widgets.Button(description="Save Figure & Settings", button_style="success")
save_figure_suffix = widgets.Text(
    value="",
    placeholder="File name suffix",
    description="Filename Suffix:",
    layout={'width': '300px'}
)
save_figure_output = widgets.Output()

def on_save_figure_click(b):
    with save_figure_output:
        save_figure_output.clear_output()
        global global_figw_for_search
        if global_figw_for_search is None:
            print("No figure to save! Please 'Generate Plot' first.")
            return
        
        suffix_input = save_figure_suffix.value.strip()
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        if suffix_input:
            suffix_str = f"_{suffix_input}_{timestamp}"
        else:
            suffix_str = f"_{timestamp}"

        # Build final file paths
        html_name = os.path.join(output_dir, f"AF_plot_{suffix_str}.html")
        pdf_name  = os.path.join(output_dir, f"AF_plot_{suffix_str}.pdf")
        json_name = os.path.join(output_dir, f"AF_plot_{suffix_str}_settings.json")

        # Gather user settings from global variables
        try:
            ccol   = global_color_col
            scol   = global_size_col
            ocol   = global_opacity_col
            cx     = global_color_xform
            sx     = global_size_xform
            ox     = global_opacity_xform
        except NameError:
            print("No config loaded, can't save settings!")
            return
        
        # Convert widget values to native Python types
        cmin = float(color_min_box.value)
        cmax = float(color_max_box.value)
        smin = float(size_min_box.value)
        smax = float(size_max_box.value)
        omin = float(opac_min_box.value)
        omax = float(opac_max_box.value)
        chosen_scale = color_scale_selector.value
        rev_scale = bool(reverse_scale_checkbox.value)

        final_settings = {
            "color_column": ccol,
            "size_column": scol,
            "opacity_column": ocol,
            "color_transform": cx,
            "size_transform": sx,
            "opacity_transform": ox,
            "color_min": cmin,
            "color_max": cmax,
            "size_min": smin,
            "size_max": smax,
            "opacity_min": omin,
            "opacity_max": omax,
            "color_scale": chosen_scale,
            "reverse_scale": rev_scale
        }

        # Save figure to HTML & PDF
        try:
            global_figw_for_search.write_html(html_name)
            global_figw_for_search.write_image(pdf_name, format="pdf")
        except Exception as e:
            print("Error saving figure:", e)
            return

        # Save settings to JSON
        with open(json_name, "w") as jf:
            json.dump(final_settings, jf, indent=2)
        # Overwrite config.json so it can be reloaded later
        with open("config.json", "w") as jf:
            json.dump(final_settings, jf, indent=2)

        print(f"Figure saved:\n  HTML: {html_name}\n  PDF: {pdf_name}")
        print(f"Settings saved:\n  JSON: {json_name}")
        print("You can re-load these exact settings (including min/max) later.")

save_figure_button.on_click(on_save_figure_click)


###############################################################################
# A) SEARCH WIDGETS (partial NAME => highlight)
###############################################################################
global_figw_for_search = None  # we set it when we generate the plot

search_input_dynamic = widgets.Text(
    value="",
    placeholder="Enter partial NAME to search (use ; for multiple)",
    description="Search NAME:",
    layout={'width': '300px'}
)
search_button_dynamic = widgets.Button(description="Search", button_style="primary")
clear_search_button_dynamic = widgets.Button(description="Clear Search", button_style="warning")

def on_search_button_click_dynamic(b):
    """
    Allows multiple queries separated by semicolons.
    Each query spawns its own highlight trace named after that query.
    """
    if global_figw_for_search is None:
        print("No figure built yet! Please 'Generate Plot' first.")
        return
    
    raw_value = search_input_dynamic.value.strip()
    if not raw_value:
        print("Please enter one or more partial NAME substrings, separated by semicolons.")
        return
    
    queries = [q.strip() for q in raw_value.split(";") if q.strip()]
    if not queries:
        print("No valid queries found (check input).")
        return
    
    for q in queries:
        mask = merged_df["NAME"].str.contains(q, case=False, na=False)
        matched = merged_df[mask]
        if matched.empty:
            print(f"[{q}] No matches found.")
            continue

        # highlight each query in a new scatter
        global_figw_for_search.add_scatter(
            x=matched["scaled_PEAKavg"],
            y=matched["IPTMavg"],
            mode="markers+text",
            showlegend=False,             # <--- Hide from legend
            marker=dict(symbol="circle-open", size=8, line=dict(width=2, color="red")),
            text=[q]*len(matched),
            textposition="top center",
            name=q,  # legend uses the query
            hoverinfo="skip"
        )
        print(f"[{q}] Found {len(matched)} match(es). Highlights added.")

def on_clear_search_button_click_dynamic(b):
    if global_figw_for_search is None:
        print("No figure built yet!")
        return
    # Remove all traces with name == "Search Highlight"
    # If you prefer to remove traces named after each query, you can adapt it accordingly.
    to_remove = [i for i,t in enumerate(global_figw_for_search.data) if t.name == "Search Highlight"]
    if not to_remove:
        print("No search highlights to clear.")
        return
    for idx in sorted(to_remove, reverse=True):
        global_figw_for_search.data = global_figw_for_search.data[:idx] + global_figw_for_search.data[idx+1:]
    print("Search highlights cleared.")

search_button_dynamic.on_click(on_search_button_click_dynamic)
clear_search_button_dynamic.on_click(on_clear_search_button_click_dynamic)
search_ui = widgets.HBox([search_input_dynamic, search_button_dynamic, clear_search_button_dynamic])

# 6) define selection callback for lasso/box
def handle_selection(trace, points, selector):
    global global_persisted_indices_dynamic
    global_persisted_indices_dynamic.update(points.point_inds)
    if not global_persisted_indices_dynamic:
        print("[Plot] No points selected.")
        return
    print("[Plot] Accumulated indices:", global_persisted_indices_dynamic)
    sel_df = merged_df.iloc[list(global_persisted_indices_dynamic)]

    labels = sel_df["name"]  # or any column
    # see if "Persistent Labels" trace exists
    persist_tr = None
    for t in figw.data:
        if t.name == "Persistent Labels":
            persist_tr = t
            break
    if persist_tr is None:
        figw.add_scatter(
            x=sel_df["scaled_PEAKavg"],
            y=sel_df["IPTMavg"],
            mode="text",
            text=labels,
            textposition="top center",
            name="Persistent Labels",
            hoverinfo="skip",
            textfont=dict(color="black", size=8)
        )
    else:
        persist_tr.x = sel_df["scaled_PEAKavg"]
        persist_tr.y = sel_df["IPTMavg"]
        persist_tr.text = labels


##################################
# NEW: handle_click for point click
##################################
def handle_click(trace, points, state):
    """
    For each clicked point, we label it with NAME + index.
    We gather all clicked points in a single "Clicked Labels" scatter trace.
    """
    if not points.point_inds:
        return

    # We'll allow multiple clicks, so let's accumulate them in a separate trace
    clicked_trace = None
    for t in figw.data:
        if t.name == "Clicked Labels":
            clicked_trace = t
            break

    # If no "Clicked Labels" trace, create one
    if clicked_trace is None:
        figw.add_scatter(
            x=[], 
            y=[], 
            mode="text",
            text=[],
            textposition="top center",
            name="Clicked Labels",
            showlegend=False,             # <--- Hide from legend

            hoverinfo="skip",
            textfont=dict(color="blue", size=8)
        )
        # re‑find it
        for t in figw.data:
            if t.name == "Clicked Labels":
                clicked_trace = t
                break

    # For each clicked point index, add the label
    for i in points.point_inds:
        row = merged_df.iloc[i]
        x_val = row["scaled_PEAKavg"]
        y_val = row["IPTMavg"]
        label_txt = f"{row['NAME']} (index={row['index']})"

        # Append to existing arrays
        clicked_trace.x = list(clicked_trace.x) + [x_val]
        clicked_trace.y = list(clicked_trace.y) + [y_val]
        clicked_trace.text = list(clicked_trace.text) + [label_txt]

    print("[Click] Labeled points:", points.point_inds)


###############################################################################
# B) GROUP HIGHLIGHT WIDGETS
###############################################################################
group_highlight_input_dynamic = widgets.Text(
    value="(245,22,250,743,690,261,233,229,479,464,107,1,203,660,659,648,363,462,474,475,492,271,192,97=blue)",
    placeholder="(1,5=green) (2,9=red)",
    description="Multi-Groups:",
    layout={'width': '600px'}
)
group_highlight_button_dynamic = widgets.Button(
    description="Highlight Groups",
    tooltip="Highlight multiple index groups",
    button_style="info"
)

def on_group_highlight_button_click_dynamic(b):
    if global_figw_for_search is None:
        print("No figure built yet! Please 'Generate Plot' first.")
        return
    
    input_str = group_highlight_input_dynamic.value.strip()
    if not input_str:
        print("No group spec. Format: (1,5=red) (2,9=blue)")
        return
    group_specs = [chunk.strip() for chunk in input_str.split(")") if chunk.strip()]

    for spec in group_specs:
        if spec.startswith("("):
            spec = spec[1:].strip()
        if "=" in spec:
            left_part, color_part = spec.split("=",1)
            idx_str = left_part.strip()
            color_str = color_part.strip()
        else:
            idx_str = spec
            color_str = "red"
        idx_list = [x.strip() for x in idx_str.split(",") if x.strip()]

        matched = merged_df[merged_df["index"].isin(idx_list)]
        if matched.empty:
            print(f"No match for indices {idx_list}")
            continue

        # label with "name"
        group_label = matched["name"]
        global_figw_for_search.add_scatter(
            x=matched["scaled_PEAKavg"],
            y=matched["IPTMavg"],
            mode="markers+text",
            showlegend=False,             # <--- Hide from legend
            marker=dict(symbol="circle", color=color_str, size=5, opacity=0.5),
            text=group_label,
            textfont=dict(size=6),
            textposition="top center",
            name="Highlight (Dynamic)",
            hoverinfo="skip"
        )
        print(f"Highlighted {len(matched)} points in color '{color_str}'.")
    print("Group highlight done.")

group_highlight_button_dynamic.on_click(on_group_highlight_button_click_dynamic)
group_ui = widgets.HBox([group_highlight_input_dynamic, group_highlight_button_dynamic])

# 1) Define the button
clear_labels_button = widgets.Button(
    description="Clear Labels",
    button_style="warning"
)

# 2) Define the callback
def on_clear_labels_click(b):
    if global_figw_for_search is None:
        print("No figure to clear labels from!")
        return

    # We can remove both "Persistent Labels" and "Clicked Labels" if they exist
    names_to_remove = ["Persistent Labels", "Clicked Labels"]

    # Gather indices in reverse to remove them from global_figw_for_search.data
    to_remove = []
    for i, trace in enumerate(global_figw_for_search.data):
        if trace.name in names_to_remove:
            to_remove.append(i)
    
    if not to_remove:
        print("No persistent/clicked labels to clear.")
        return

    for i in sorted(to_remove, reverse=True):
        global_figw_for_search.data = (
            global_figw_for_search.data[:i] + global_figw_for_search.data[i+1:]
        )
    print("Cleared persistent/clicked label traces.")

# 3) Bind the callback
clear_labels_button.on_click(on_clear_labels_click)


###############################################################################
# C) EXPORT SELECTED POINTS TO CSV
###############################################################################
file_name_widget_dynamic = widgets.Text(
    value="selected_data_dynamic.csv",
    placeholder="Enter file name",
    description="Save CSV as:",
    disabled=False,
    layout={'width': '300px'}
)
save_data_button_dynamic = widgets.Button(description="Save Data", button_style="success")
save_data_output_dynamic = widgets.Output()

def on_save_data_click_dynamic(b):
    with save_data_output_dynamic:
        save_data_output_dynamic.clear_output()
        if not global_persisted_indices_dynamic:
            print("No points selected. Nothing to save.")
            return
        
        # selected subseta
        selected_df = merged_df.iloc[list(global_persisted_indices_dynamic)]
        
        suffix_input = save_figure_suffix.value.strip()
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        if suffix_input:
            suffix_str = f"_{suffix_input}_{timestamp}"
        else:
            suffix_str = f"_{timestamp}"

        
        csv_name  = os.path.join(output_dir, f"AF_plot_selected_hits{suffix_input}.csv")
        if not csv_name:
            print("Please enter a valid file name first.")
            return
        
        try:
            out_path = csv_name  # or os.path.join(output_dir, csv_name)
            selected_df.to_csv(out_path, index=False)
            print(f"Saved {len(selected_df)} selected rows to '{out_path}'.")
        except Exception as e:
            print("Error saving CSV:", e)


save_data_button_dynamic.on_click(on_save_data_click_dynamic)
display(Markdown("### Click the load & Apply config button first!"))
display(load_button)
display(load_output)
display(Markdown("### Adjust Ranges & Color Scale"))
display(widgets.HBox([color_scale_selector, reverse_scale_checkbox]))
display(Markdown("**Color Range**"))
display(widgets.HBox([color_min_box, color_max_box]))
display(Markdown("**Size Range**"))
display(widgets.HBox([size_min_box, size_max_box]))
display(Markdown("**Opacity Range** (Decrease upper value to increase visibility)"))
display(widgets.HBox([opac_min_box, opac_max_box]))

display(Markdown("**Note**: You can save the path of the config json file in Cell 1 to load later."))


display(Markdown("### 4) Generate Plot with chosen ranges"))
display(generate_plot_button)

display(Markdown("### Searching & Group Highlight"))
display(Markdown("**Enter multiple semi-colon seperated values to retrieve partial matches in the search 'NAME box'.**"))

display(search_ui)
# Finally, display the button

display(Markdown("**Enter multiple comma seperated index value and the color you want to use for highlighing 'NAME box'.**"))

display(group_ui)
display(plot_output)

display(Markdown("### Save Figure & Settings & Export Selected Points to CSV"))
display(widgets.HBox([save_figure_suffix, save_figure_button]))
display([save_figure_output,save_data_output_dynamic])
display(save_data_output_dynamic)
display(widgets.HBox([file_name_widget_dynamic, save_data_button_dynamic]))


### Click the load & Apply config button first!

Button(button_style='primary', description='Load & Apply Config', style=ButtonStyle())

Output()

### Adjust Ranges & Color Scale

HBox(children=(Dropdown(description='Color Scale:', layout=Layout(width='220px'), options=('viridis', 'magma',…

**Color Range**

HBox(children=(FloatText(value=0.0, description='Color Min:', layout=Layout(width='200px')), FloatText(value=0…

**Size Range**

HBox(children=(FloatText(value=0.0, description='Size Min:', layout=Layout(width='200px')), FloatText(value=0.…

**Opacity Range** (Decrease upper value to increase visibility)

HBox(children=(FloatText(value=0.0, description='Opac Min:', layout=Layout(width='200px')), FloatText(value=0.…

**Note**: You can save the path of the config json file in Cell 1 to load later.

### 4) Generate Plot with chosen ranges

Button(button_style='info', description='Generate Plot', style=ButtonStyle())

### Searching & Group Highlight

**Enter multiple semi-colon seperated values to retrieve partial matches in the search 'NAME box'.**

HBox(children=(Text(value='', description='Search NAME:', layout=Layout(width='300px'), placeholder='Enter par…

**Enter multiple comma seperated index value and the color you want to use for highlighing 'NAME box'.**

HBox(children=(Text(value='(245,22,250,743,690,261,233,229,479,464,107,1,203,660,659,648,363,462,474,475,492,2…

Output()

### Save Figure & Settings & Export Selected Points to CSV

HBox(children=(Text(value='', description='Filename Suffix:', layout=Layout(width='300px'), placeholder='File …

[Output(), Output()]

Output()

HBox(children=(Text(value='selected_data_dynamic.csv', description='Save CSV as:', layout=Layout(width='300px'…