In [None]:
# --- Import libraries ---
# Core functionality & visualization
import ipywidgets as widgets
import subprocess
import re
import os
import uuid
import math
import time
import io
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import base64
# Jupyter display helpers
from IPython.display import display, Markdown, HTML

# BioPython substitution matrices
from Bio.Align import substitution_matrices

# File handling & threading
from tempfile import NamedTemporaryFile
from concurrent.futures import ThreadPoolExecutor, as_completed

# Advanced layout for Matplotlib
from mpl_toolkits.axes_grid1 import make_axes_locatable

#EDIT HERE!
#We gotta know where you are keeping your files:
mchI_alleles_file = '' # Path to the file containing MHC Class I alleles
mchII_alleles_file = '' # Path to the file containing MHC Class II alleles



# --- Widget declarations (GUI) ---
# Input sequence (protein)
sequence_input = widgets.Textarea(
    value='',
    placeholder='Paste your protein sequence here (single letter code)...',
    description='Sequence:',
    layout=widgets.Layout(width='100%', height='100px'),
    style={'description_width': 'initial'}
)

# Species selection
species_dropdown = widgets.Dropdown(
    options=['human', 'mouse'],
    value='human',
    description='Species:',
    style={'description_width': 'initial'}
)

# MHC class selector
mhc_class_dropdown = widgets.Dropdown(
    options=['I', 'II', 'Both'],
    value='I',
    description='MHC Class:',
    style={'description_width': 'initial'}
)

# Allele selector (multi-select)
allele_multiselect = widgets.SelectMultiple(
    options=['ALL'], 
    value=['ALL'],
    description='Alleles:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%', height='100px' )
)

# K-mer settings for MHC-I
kmer_i_choice = widgets.Dropdown(
    options=['Default (9)', 'All (8-11)', 'Custom'],
    value='Default (9)',
    description='k-mer (MHC I):',
    style={'description_width': 'initial'}
)

kmer_i_custom = widgets.BoundedIntText(
    value=9,
    min=8,
    max=11,
    step = 1,
    description='Custom k-mer (MHC I):',
    style={'description_width': 'initial'},
    layout = widgets.Layout(visibility='hidden')
)

# K-mer settings for MHC-II
kmer_ii_choice = widgets.Dropdown(
    options=['Default (15)', 'All (13-25)', 'Custom'],
    value='Default (15)',
    description='k-mer (MHC II):',
    style={'description_width': 'initial'}
)

kmer_ii_custom = widgets.BoundedIntText(
    value=15,
    min=13,
    max=25,
    step = 1,
    description='Custom k-mer (MHC II):',
    style={'description_width': 'initial'},
    layout = widgets.Layout(visibility='hidden')
)

# Stride settings
stride_i = widgets.IntText(
    value=3,
    description='Stride (MHC I):',
    style={'description_width': 'initial'}
)

stride_ii = widgets.IntText(
    value=4,
    description='Stride (MHC II):',
    style={'description_width': 'initial'}
)

# Binding strength filtering
binding_strength = widgets.Dropdown(
    options=['Strong only', 'Strong + Weak', 'Custom range'],
    value='Strong + Weak',
    description='Binding filter:',
    style={'description_width': 'initial'}
)

custom_rank_strong = widgets.BoundedFloatText(
    value=1.0,
    min=0.0,
    max=100.0,
    step=0.1,
    description='Strong ≤ %Rank:',
    layout=widgets.Layout(visibility='hidden', width='50%'),
    style={'description_width': 'initial'}
)

custom_rank_weak = widgets.BoundedFloatText(
    value=5.0,
    min=0.0,
    max=100.0,
    step=0.1,
    description='Weak ≤ %Rank:',
    layout=widgets.Layout(visibility='hidden', width='50%'),
    style={'description_width': 'initial'}
)

# Substitution matrix selection for scoring
blosum_dropdown = widgets.Dropdown(
    options= [a for a in substitution_matrices.load() if "BLOSUM" in a],
    value='BLOSUM62',
    description='Scoring matrix:',
    style={'description_width': 'initial'}
)

# Heatmap scoring mode and view
heatmap_mode = widgets.ToggleButtons(
    options=['Additive', 'Max'],
    value='Additive',
    description='Scoring Mode:',
    style={'description_width': 'initial'}
)

heatmap_view = widgets.ToggleButtons(
    options=['Composite', 'Separate'],
    value='Composite',
    description='View:',
    style={'description_width': 'initial'}
)

# Run button and output panel
run_button = widgets.Button(
    description='Run ImmunoMap',
    button_style='success',
    icon='play'
)

output = widgets.Output()

# --- Display UI ---
display(Markdown("### 🧬 ImmunoMap: Input Configuration"))
display(sequence_input)
display(widgets.HBox([species_dropdown, mhc_class_dropdown]))
display(allele_multiselect)
display(widgets.HBox([kmer_i_choice, kmer_i_custom]))
display(widgets.HBox([kmer_ii_choice, kmer_ii_custom]))
display(widgets.HBox([stride_i, stride_ii]))
display(binding_strength)
display(widgets.HBox([custom_rank_strong, custom_rank_weak]))
display(blosum_dropdown)
display(heatmap_mode)
display(heatmap_view)
display(run_button)
display(output)

# --- Helper Functions ---

# Converts allele names into a safe file-friendly format (e.g., HLA-A*02:01 → HLA-A0201)
def normalize_allele_format(allele):
    safe_name = allele.replace("*", "").replace(":", "").replace("/", "-")
    return safe_name


# Loads a list of MHC alleles from a plain-text file and normalizes them
def load_alleles_from_file(path):
    try:
        with open(path, 'r') as f:
            raw_alleles = [
                line.strip() for line in f 
                if line.strip() and not line.startswith('#')  # skip empty and comment lines
            ]
        # Normalize and deduplicate
        normalized = set()
        for allele in raw_alleles:
            norm = normalize_allele_format(allele)
            normalized.add(norm)
        return sorted(normalized)
    except FileNotFoundError:
        print(f"File not found: {path}")
        return []
    


# Updates the allele_multiselect options based on the current species and MHC class
def update_allele_list(*args):
    species = species_dropdown.value
    mhc_class = mhc_class_dropdown.value
    alleles = set()

    base_dir = os.path.join("hostdb", species)

    # Checks if the database file for a given allele and MHC class exists
    def has_database_file(allele, mhc_class):
        safe_name = allele.replace("*", "").replace(":", "").replace("/", "-")
        db_file = os.path.join(base_dir, f"mhc{mhc_class}", f"{safe_name}_hits.csv")
        return os.path.exists(db_file)

    # Filter alleles with available databases for Class I
    if mhc_class in ['Both', 'I']:
        class_i = allele_db[species].get('I', [])
        filtered_i = [a for a in class_i if has_database_file(a, 'I')]
        filtered_i = [normalize_allele_format(a) for a in filtered_i]
        filtered_i = sorted(set(filtered_i))
        alleles.update(filtered_i)

    # Filter alleles with available databases for Class II
    if mhc_class in ['Both', 'II']:
        class_ii = allele_db[species].get('II', [])
        filtered_ii = [a for a in class_ii if has_database_file(a, 'II')]
        filtered_ii = [normalize_allele_format(a) for a in filtered_ii]
        filtered_ii = sorted(set(filtered_ii))
        alleles.update(filtered_ii)

    # Update widget options and reset selection
    allele_multiselect.options = ['ALL'] + sorted(alleles)
    allele_multiselect.value = ('ALL',)


# Toggles visibility of custom k-mer input fields based on dropdown selection
def toggle_kmer_custom(change):
    if change['owner'] == kmer_i_choice:
        kmer_i_custom.layout.visibility = 'visible' if change['new'] == 'Custom' else 'hidden'
    if change['owner'] == kmer_ii_choice:
        kmer_ii_custom.layout.visibility = 'visible' if change['new'] == 'Custom' else 'hidden'


# Toggles visibility of custom rank input fields when "Custom range" is selected
def toggle_custom_binding(change):
    visible = 'visible' if change['new'] == 'Custom range' else 'hidden'
    custom_rank_strong.layout.visibility = visible
    custom_rank_weak.layout.visibility = visible


# Generates overlapping k-mers of specified lengths and stride from a protein sequence
def generate_kmers(sequence, k_lengths, stride):
    kmers = []
    for k in k_lengths:
        if k > len(sequence):
            continue
        for i in range(0, len(sequence) - k + 1, stride):
            kmers.append({
                'peptide': sequence[i:i+k],
                'start': i,
                'length': k
            })
    return kmers


# Writes k-mer peptides to a temporary file (one per line)
def write_peptides_to_file(peptides, filepath):
    with open(filepath, 'w') as f:
        for p in peptides:
            f.write(p['peptide'] + '\n')


# Runs NetMHCpan or NetMHCIIpan using subprocess and checks for errors
def run_netmhcpan(peptide_file, alleles, output_file, mhc_class='I'):
    if mhc_class == 'I':
        cmd = [
            "netMHCpan", "-p", peptide_file,
            "-a", ",".join(alleles), "-BA", "-xls", "-xlsfile", output_file
        ]
    elif mhc_class == 'II':
        cmd = [
            "netMHCIIpan", "-f", peptide_file,
            "-a", ",".join(alleles), "-inptype", "1", "-xls", "-xlsfile", output_file
        ]
    else:
        raise ValueError("Unsupported MHC class")

    result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    
    # Raise error if NetMHC failed
    if result.returncode != 0:
        raise RuntimeError(
            f"NetMHC command failed.\n"
            f"Command: {' '.join(cmd)}\n"
            f"Stdout:\n{result.stdout.decode()}\n"
            f"Stderr:\n{result.stderr.decode()}"
        )

    # Check that output file was actually created
    if not os.path.exists(output_file):
        raise FileNotFoundError(
            f"Expected output file not found: {output_file}\n"
            f"Command: {' '.join(cmd)}\n"
            f"Stdout:\n{result.stdout.decode()}\n"
            f"Stderr:\n{result.stderr.decode()}"
        )


# Parses NetMHC output (in XLS format) and returns peptides grouped by allele/ID
def parse_mhc_output(xls_file, strong_thresh=1.0, weak_thresh=5.0):
    df = pd.read_csv(xls_file, sep='\t', skiprows=1)

    # Identify appropriate column for rank values
    rank_col = None
    if 'BA_Rank' in df.columns:
        rank_col = 'BA_Rank'
    elif 'Rank' in df.columns:
        rank_col = 'Rank'

    if not rank_col:
        print("❌ Could not find a suitable rank column. Columns are:", df.columns.tolist())
        return {}

    # Classify peptides based on binding strength
    def classify(rank):
        if rank <= strong_thresh:
            return 'strong'
        elif rank <= weak_thresh:
            return 'weak'
        else:
            return 'non'

    df['binding_level'] = df[rank_col].apply(classify)

    # Group by 'ID' or 'Allele', fallback to first column if necessary
    if 'ID' in df.columns:
        group_key = 'ID'
    elif 'Allele' in df.columns:
        group_key = 'Allele'
    else:
        group_key = df.columns[0]

    grouped = df.groupby(group_key)
    return {key: group.to_dict('records') for key, group in grouped}



# --- Compute BLOSUM similarity between two peptides ---
def blosum_score(peptide1, peptide2, matrix='BLOSUM62'):
    matrix_dict = substitution_matrices.load(matrix)  # Load scoring matrix
    score = 0
    for a, b in zip(peptide1, peptide2):
        pair = tuple(sorted((a, b)))  # Use sorted pairs since BLOSUM matrices are symmetric
        score += matrix_dict.get(pair, -4)  # Default mismatch penalty if not found
    return score

# --- Identify similar peptides from host based on allowed substitutions ---
def find_similar_peptides(query, host_peptides, max_subs=3):
    results = []
    for host in host_peptides:
        if len(host) != len(query):  # Only compare peptides of equal length
            continue
        # Count number of mismatches (substitutions)
        mismatches = sum(1 for a, b in zip(query, host) if a != b)
        if mismatches <= max_subs:
            results.append(host)
    return results

# --- Final immunogenicity score calculation ---
def score_peptide(peptide, rank, host_peptides, matrix='BLOSUM62', cutoff="0.5"):
    cutoff = float(cutoff)

    # Case 1: Exact match to host → completely non-immunogenic
    if peptide in host_peptides:
        return 0.0

    # Search for similar host peptides within increasing mismatch limits
    max_subs = 3  
    similar = []
    while not similar and max_subs < len(peptide):
        similar = find_similar_peptides(peptide, host_peptides, max_subs)
        max_subs += 1

    if not similar:
        similarity_score = -100.0  # No similar peptides → very immunogenic
    else:
        # Compute normalized BLOSUM similarity
        max_raw = max(blosum_score(peptide, p, matrix) for p in similar)
        max_possible = blosum_score(peptide, peptide, matrix)
        similarity_score = max_raw / max_possible

    # Final score penalizes strong binders that are highly similar to host
    final_score = -1.0 * ((similarity_score + (rank / 100)) - (1.0 + cutoff/100))
    return round(final_score, 4)

# --- Load MHC allele lists from disk and populate lookup dictionary ---

mhc_i_alleles = load_alleles_from_file(mchI_alleles_file)
mhc_ii_alleles = load_alleles_from_file(mchII_alleles_file)

allele_db = {
    'human': {
        'I': [a for a in mhc_i_alleles if a.startswith('HLA-')],
        'II': [a for a in mhc_ii_alleles if a.startswith('HLA-D')]
    },
    'mouse': {
        'I': [a for a in mhc_i_alleles if a.startswith('H-2-')],
        'II': [a for a in mhc_ii_alleles if a.startswith('H-2-I')]
    }
}

# --- Set up reactive behavior for dropdowns and input elements ---
species_dropdown.observe(update_allele_list, names='value')
mhc_class_dropdown.observe(update_allele_list, names='value')
kmer_i_choice.observe(toggle_kmer_custom, names='value')
kmer_ii_choice.observe(toggle_kmer_custom, names='value')
binding_strength.observe(toggle_custom_binding, names='value')

# --- Main Execution Callback: Triggered when "Run ImmunoMap" button is clicked ---
def run_immunomap_callback(btn):
    start_time = time.time()
    with output:
        output.clear_output()
        print("Running ImmunoMap...")

        # --- Validate input sequence ---
        sequence = sequence_input.value.replace('\n', '').replace(' ', '').strip().upper()

        if not sequence:
            print("Error: No sequence provided.")
            return
        if not re.match(r'^[ACDEFGHIKLMNPQRSTVWY]+$', sequence):
            print("Error: Sequence contains invalid characters. Only single-letter amino acid codes are allowed.")
            return
        if len(sequence) < 8:
            print("Error: Sequence is too short. Minimum length is 8 amino acids.")
            return

        # --- Gather settings from user inputs ---
        species = species_dropdown.value
        mhc_class = mhc_class_dropdown.value
        alleles = allele_multiselect.value
        selected = blosum_dropdown.value
        base_dir = os.path.join("hostdb", species)

        # Check if a database CSV exists for a given allele and class
        def has_database_file(allele, mhc_class):
            safe_name = allele.replace("*", "").replace(":", "").replace("/", "-")
            db_file = os.path.join(base_dir, f"mhc{mhc_class}", f"{safe_name}_hits.csv")
            return os.path.exists(db_file)

        # --- Expand 'ALL' into all available alleles for selected species and class ---
        if 'ALL' in alleles:
            if mhc_class == 'Both':
                proto_alleles = allele_db[species]['I'] + allele_db[species]['II']
                alleles = []
                for a in proto_alleles:
                    if a in allele_db[species]['I'] and has_database_file(a, 'I'):
                        alleles.append(a)
                    if a in allele_db[species]['II'] and has_database_file(a, 'II'):
                        alleles.append(a)
            else:
                proto_alleles = allele_db[species]['I' if mhc_class == 'I' else 'II']
                alleles = [a for a in proto_alleles if has_database_file(a, mhc_class)]

        # --- Define k-mer lengths based on widget selection ---
        kmer_i = (
            [9] if kmer_i_choice.value == 'Default (9)' else
            list(range(8, 12)) if kmer_i_choice.value == 'All (8-11)' else
            [kmer_i_custom.value]
        )
        kmer_ii = (
            [15] if kmer_ii_choice.value == 'Default (15)' else
            list(range(13, 26)) if kmer_ii_choice.value == 'All (13-25)' else
            [kmer_ii_custom.value]
        )
        stride = {'I': stride_i.value, 'II': stride_ii.value}

        # --- Define %rank thresholds for strong/weak binding ---
        if binding_strength.value == 'Strong only':
            rank_strong = [0.5, 1.0]
            rank_weak = [0.5, 1.0]
        elif binding_strength.value == 'Strong + Weak':
            rank_strong = [0.5, 1.0]
            rank_weak = [2.0, 5.0]
        else:
            rank_strong = custom_rank_strong.value
            rank_weak = custom_rank_weak.value

        # --- Store user input as a config dictionary ---
        run_config = {
            'sequence': sequence,
            'species': species,
            'mhc_class': mhc_class,
            'alleles': alleles,
            'kmer_lengths': {'I': kmer_i, 'II': kmer_ii},
            'stride': stride,
            'ic50_range': (rank_weak, rank_strong)
        }

        eta = (.5 + 5 * len(run_config['alleles']) + .001 * len(run_config['sequence']))*2
        print(f"🕒 Estimated runtime: ~{math.ceil(eta)} seconds.")

        # --- Determine which MHC classes to evaluate ---
        mhc_classes = ['I', 'II'] if run_config['mhc_class'] == 'Both' else [run_config['mhc_class']]
        sequence = run_config['sequence']

        # Separate alleles by class
        alleles_by_class = {
            'I': [a for a in run_config['alleles'] if a in allele_db[run_config['species']]['I']],
            'II': [a for a in run_config['alleles'] if a in allele_db[run_config['species']]['II']]
        }

        all_results = {}

        # --- Loop over MHC Class I and/or II ---
        for mhc_class in mhc_classes:
            class_alleles = alleles_by_class[mhc_class]

            # Determine filtering threshold
            if mhc_class == "I":
                rank_cutoff = run_config['ic50_range'][0][0] if not isinstance(run_config['ic50_range'][0], float) else run_config['ic50_range'][0]
            elif mhc_class == "II":
                rank_cutoff = run_config['ic50_range'][0][1] if not isinstance(run_config['ic50_range'][0], float) else run_config['ic50_range'][0]

            # Load host peptide binders for this species/class
            host_binder_dir = os.path.join("hostdb", species, f"mhc{mhc_class}")
            host_peptides_by_allele = {}

            for allele in class_alleles:
                safe_name = normalize_allele_format(allele)
                host_path = os.path.join(host_binder_dir, f"{safe_name}_hits.csv")

                if os.path.exists(host_path):
                    host_df = pd.read_csv(host_path)
                    if 'rank' in host_df.columns:
                        filtered_df = host_df[host_df['rank'] <= rank_cutoff]
                    else:
                        print(f"⚠️ Warning: 'rank' column missing for {allele}. Skipping.")
                        filtered_df = pd.DataFrame()
                else:
                    filtered_df = pd.DataFrame()

                # Store filtered peptides as a set for similarity checks
                host_peptides_by_allele[allele] = set(filtered_df['peptide']) if not filtered_df.empty else set()

            # --- Generate peptides from the user sequence ---
            peptides = generate_kmers(
                sequence,
                run_config['kmer_lengths'][mhc_class],
                run_config['stride'][mhc_class]
            )

            # --- Create temp file to store peptides ---
            with NamedTemporaryFile(delete=False, mode='w', suffix=".pep") as pep_file:
                write_peptides_to_file(peptides, pep_file.name)

                # Run NetMHC on a single allele and return binders
                def run_single_allele(allele, mhc_class, pep_path):
                    out_file = f"immunomap_output_{uuid.uuid4().hex}.xls"
                    try:
                        run_netmhcpan(pep_path, [allele], out_file, mhc_class)
                        # Resolve thresholds if list-like or float
                        if not isinstance(run_config['ic50_range'][1], float):
                            if mhc_class == 'I':
                                rank_strong = run_config['ic50_range'][1][0]
                                rank_weak = run_config['ic50_range'][0][0]
                            elif mhc_class == 'II':
                                rank_strong = run_config['ic50_range'][1][1]
                                rank_weak = run_config['ic50_range'][0][1]
                        else:
                            rank_strong = run_config['ic50_range'][1]
                            rank_weak = run_config['ic50_range'][0]

                        binders_by_group = parse_mhc_output(out_file, rank_strong, rank_weak)

                        # Flatten binders into a single list
                        all_binders = []
                        for group in binders_by_group.values():
                            all_binders.extend([b for b in group if b['binding_level'] in ('strong', 'weak')])

                        return allele, all_binders

                    except Exception as e:
                        return allele, f"⚠️ Error: {e}"

                    finally:
                        if os.path.exists(out_file):
                            os.remove(out_file)

                # --- Parallelize NetMHC calls using threads ---
                try:
                    with ThreadPoolExecutor() as executor:
                        futures = [
                            executor.submit(run_single_allele, allele, mhc_class, pep_file.name)
                            for allele in class_alleles
                        ]
                        for future in as_completed(futures):
                            allele, result = future.result()
                            if isinstance(result, str):  # Error message
                                print(f"⚠️ Error running NetMHC for {allele}: {result}")
                            else:
                                all_results[allele] = result

                finally:
                    if os.path.exists(pep_file.name):
                        os.remove(pep_file.name)

        # --- Score and deduplicate peptides for each allele ---
        for allele, binders in all_results.items():
            seen = set()
            deduped = []
            host_set = host_peptides_by_allele.get(allele, set())
            for b in binders:
                pep = b.get('Peptide') or b.get('peptide')
                if pep in seen:
                    continue  # Avoid duplicates
                rank = float(b.get('BA_Rank', b.get('Rank', 100)))

                # Score using selected BLOSUM matrix and host similarity
                if allele in alleles_by_class['I']:
                    b['score'] = score_peptide(
                        pep, rank, host_set, selected,
                        rank_weak[0] if not isinstance(rank_weak, float) else rank_weak
                    )
                elif allele in alleles_by_class['II']:
                    b['score'] = score_peptide(
                        pep, rank, host_set, selected,
                        rank_weak[1] if not isinstance(rank_weak, float) else rank_weak
                    )
                seen.add(pep)
                deduped.append(b)
            all_results[allele] = deduped

        # --- Display summary and top binders per allele ---
        for allele, binders in all_results.items():
            strong_count = sum(1 for b in binders if b['binding_level'] == 'strong')
            weak_count = sum(1 for b in binders if b['binding_level'] == 'weak')
            print(f"\n  {allele}: {strong_count} strong binders, {weak_count} weak binders")

            print(f"🔬 Scored peptides for {allele}:")
            for b in sorted(binders, key=lambda x: x.get('score', 0)):
                print(f"  {b['Peptide']} - Score: {b['score']}, Rank: {b.get('BA_Rank', b.get('Rank'))}")

        # --- Render immunogenicity heatmap (composite or separate) ---
        def render_heatmap(sequence, scored_peptides_by_allele, mode='additive', composite=True):
            seq_len = len(sequence)
            alleles = list(scored_peptides_by_allele.keys())

            # Create a matrix to hold scores (1D for composite, 2D for separate)
            score_map = np.zeros((len(alleles), seq_len)) if not composite else np.zeros(seq_len)

            for a_idx, allele in enumerate(alleles):
                row = np.zeros(seq_len)
                for b in scored_peptides_by_allele[allele]:
                    pep = b['Peptide']
                    score = b['score']
                    start_idxs = [m.start() for m in re.finditer(f'(?={pep})', sequence)]
                    for i in start_idxs:
                        if mode == 'additive':
                            row[i:i+len(pep)] += score
                        elif mode == 'max':
                            row[i:i+len(pep)] = np.maximum(row[i:i+len(pep)], score)

                if composite:
                    if mode == 'additive':
                        score_map += row
                    elif mode == 'max':
                        score_map = np.maximum(score_map, row)
                else:
                    score_map[a_idx] = row

            # --- Plot the heatmap ---
            fig, ax = plt.subplots(
                figsize=(max(20, len(sequence)//10), 1.5 if composite else 1.5 * len(alleles))
            )
            cmap = plt.get_cmap("Reds")
            norm = mcolors.Normalize(vmin=0, vmax=1 if score_map.max() < 1 else score_map.max())

            # Create a side colorbar using axes divider
            divider = make_axes_locatable(ax)
            cax = divider.append_axes("left", size=.3, pad=1)

            if composite:
                im = ax.imshow(score_map[np.newaxis, :], aspect='auto', cmap=cmap, norm=norm)
                ax.set_yticks([0])
                ax.set_yticklabels(['Composite'])
            else:
                im = ax.imshow(score_map, aspect='auto', cmap=cmap, norm=norm)
                ax.set_yticks(range(len(alleles)))
                ax.set_yticklabels(alleles)

            ax.set_xticks(range(len(sequence)))
            ax.set_xticklabels(list(sequence), rotation=0, fontsize=8)
            ax.tick_params(axis='x', pad=5)
            ax.set_xlabel("Protein Sequence Position")

            # Add colorbar on left side
            cbar = fig.colorbar(im, cax=cax, orientation='vertical')
            cbar.ax.yaxis.set_ticks_position('left')
            cbar.ax.yaxis.set_label_position('left')
            cbar.set_label("Immunogenicity Score")
            plt.tight_layout()

            # --- Render as scrollable PNG ---
            buf = io.BytesIO()
            fig.savefig(buf, format='png', bbox_inches='tight')
            plt.close(fig)
            buf.seek(0)

            #svg = buf.getvalue().decode('utf-8')
            image_data = base64.b64encode(buf.getvalue()).decode('utf-8')
            scroll_html = f'''
            <div style="width: 100%; height: 100%; overflow-y: auto; overflow-x: auto; border: 1px solid #ccc; padding: 10px;">
                <img src="data:image/png;base64,{image_data}" style="width: auto; height: auto; max-width: none; max-height: none; display: block; margin: 0 auto;"/>
            </div>
            '''
            # Uncomment below to use SVG instead of PNG (requires matplotlib >= 3.4)
            #scroll_html = f'''
            #<div style="overflow-x:auto; border:1px solid #ccc;">
            #    {svg}
            #</div>
            #'''
            display(HTML(scroll_html))

        # --- Hook for UI toggle buttons to rerender heatmap ---
        def update_heatmap(*args):
            mode = heatmap_mode.value.lower()
            view = heatmap_view.value
            render_heatmap(
                run_config['sequence'].upper(),
                all_results,
                mode=mode,
                composite=(view == 'Composite')
            )

        print("✅ Finished MHC predictions.")
        update_heatmap()  # Initial render
        print(f"🕒 Total runtime: {int(time.time() - start_time)} seconds.")
        print(len(alleles), "alleles processed.")
        print(len(sequence), "amino acids in sequence.")
# Clear previous handlers and register the callback function
run_button._click_handlers.callbacks.clear()
run_button.on_click(run_immunomap_callback)
