# **AlphaProteo Target Protein Specification**

We recognize that for non-structural biologists, accurately defining the target
region of a protein can be challenging, particularly due to the confusion
introduced by AUTH numbering
([see FAQ](https://docs.google.com/document/d/1wYLQve9gp59UL_8jFKrdZoZC7kS9SfimTEw5SIwjJP8/preview?tab=t.0#heading=h.3gm45nxykaln)
for more details). This Colab aims to provide an interface for specifying target
protein regions, performing basic validation, and visualising the selected
target.

To begin, please type or paste your target PDB ID into the corresponding field
of Cell 3, then select "Runtime" -> "Run all" from the menu. In Cell 4, use the
drop-down lists to select your desired target residue ranges and hotspots.
Buttons are available to add more residue ranges or hotspots as needed. To
visualise your selection and hotspots on the structure, click the "Display
Specification" button.

Once satisfied with your selection, please copy the PDB ID (from Cell 3), along
with the Target Protein Specification and Hotspots (if any) from Cell 4, into
the "Submit Design Request" form. Additionally, use the maximal binder size
information provided in Cell 4 to populate the "Desired binder length range"
section of the form.

In [None]:
#@title 1. Install dependencies

# --- Global Settings ---
MINIMAL_SEGMENT_LENGTH = 5 # Minimum allowed length for a single specified segment
MINIMAL_SPEC_TARGET_LENGTH = 40 # Minimum allowed length for the whole protein specification
# --- End Global Settings ---

print(f"Installing py3Dmol and biopython... ")
!uv pip install --quiet py3Dmol biopython

from dataclasses import dataclass
from typing import Any
import gzip
import io
import os
import re

import requests

import py3Dmol
from Bio.PDB.MMCIF2Dict import MMCIF2Dict

from google.colab import output
from IPython.display import display, update_display, HTML
import ipywidgets as widgets


In [None]:
# @title 2. Helper and Parsing/Validation Functions


@dataclass(frozen=True, kw_only=True, slots=True)
class SpecSegment:
  """Represents a single parsed segment of the target specification.

  Attributes:
    chain_id: The chain identifier for the segment (e.g., 'A').
    start_res: The starting residue number of the segment.
    end_res: The ending residue number of the segment.
    original_segment_str: The original, unparsed string for this segment (e.g.,
      'A10-20').
    raw_residue_part: The part of the original string representing the residue
      range (e.g., '10-20').
  """

  chain_id: str
  start_res: int
  end_res: int
  original_segment_str: str
  raw_residue_part: str


def _split_hotspot_string(hotspot: str) -> tuple[str, int] | None:
  """Splits a hotspot string like 'A10' into chain and residue number.

  Args:
    hotspot: The hotspot string.

  Returns:
    A tuple (chain_id, residue_number) or None if parsing fails.
  """
  match = re.match(
      r'([A-Za-z]+)([0-9-]+)', hotspot
  )  # Allow for negative numbers if any PDB uses them
  if match:
    chain_id = match.group(1)
    try:
      residue_number = int(match.group(2))
      return chain_id, residue_number
    except ValueError:
      return None  # Residue part is not a valid integer
  return None


def _split_spec_segment_string(spec: str) -> tuple[str, str] | None:
  """Splits a specification segment string like 'A10-15' or 'A20'.

  Args:
    spec: The specification segment string.

  Returns:
    A tuple (chain_id, residue_range_str) or None if parsing fails.
    residue_range_str can be '10-15' or '20'.
  """
  match = re.match(r'([A-Za-z]+)([0-9-]+)', spec)
  if match:
    letters = match.group(1)
    number_part = match.group(2)
    return letters, number_part
  return None


def parse_hotspots_for_validation(
    hotspot_input: str,
) -> tuple[list[tuple[str, int]], list[str]]:
  """Parses a comma-separated hotspot string into a structured list for validation.

  Args:
    hotspot_input: e.g., 'A10,B20,C30'.

  Returns:
    A tuple: (parsed_hotspots, error_messages).
    parsed_hotspots: List of (chain_id, residue_number) tuples.
    error_messages: List of strings for critical parsing errors.
  """
  parsed_hotspots: list[tuple[str, int]] = []
  error_messages: list[str] = []
  if not hotspot_input.strip():
    return parsed_hotspots, error_messages

  hs_residue_list = hotspot_input.split(',')
  for hs_residue_str in hs_residue_list:
    hs_residue_str = hs_residue_str.strip()
    if not hs_residue_str:
      continue  # Skip empty parts from, e.g., trailing commas
    parsed_hs = _split_hotspot_string(hs_residue_str)
    if parsed_hs:
      parsed_hotspots.append(parsed_hs)
    else:
      error_messages.append(
          'CRITICAL ERROR (Hotspot Parsing): Hotspot definition'
          f" '{hs_residue_str}' is malformed."
      )
  return parsed_hotspots, error_messages


def parse_spec_for_validation(
    spec_input: str,
) -> tuple[list[SpecSegment], list[str]]:
  """Parses a full specification string into a structured list of segments for validation.

  Args:
    spec_input: e.g., 'A10-25,A26/B30-40'.

  Returns:
    A tuple: (parsed_segments, error_messages).
    parsed_segments: A list of SpecSegment objects.
    error_messages: List of strings for critical parsing errors.
  """
  parsed_segments: list[SpecSegment] = []
  error_messages: list[str] = []
  if not spec_input.strip():
    return parsed_segments, error_messages

  spec_chain_list = spec_input.split('/')
  for chain_spec_str in spec_chain_list:
    chain_spec_str = chain_spec_str.strip()
    if not chain_spec_str:
      continue
    spec_segment_list = chain_spec_str.split(',')
    for segment_str in spec_segment_list:
      segment_str = segment_str.strip()
      if not segment_str:
        continue
      parsed_split = _split_spec_segment_string(segment_str)
      if parsed_split:
        chain_id, raw_residue_part = parsed_split
        try:
          start_res: int
          end_res: int
          if '-' in raw_residue_part:
            start_str, end_str = raw_residue_part.split('-')
            start_res = int(start_str)
            end_res = int(end_str)
          else:
            start_res = end_res = int(raw_residue_part)

          parsed_segments.append(
              SpecSegment(
                  chain_id=chain_id,
                  start_res=start_res,
                  end_res=end_res,
                  original_segment_str=segment_str,
                  raw_residue_part=raw_residue_part,
              )
          )
        except ValueError:
          error_messages.append(
              'CRITICAL ERROR (Spec Parsing): Residue numbers in'
              f" '{segment_str}' are not valid integers."
          )
      else:
        error_messages.append(
            f"CRITICAL ERROR (Spec Parsing): Segment '{segment_str}' is"
            ' malformed.'
        )
  return parsed_segments, error_messages


def validate_target_specification(
    parsed_segments: list[SpecSegment], min_segment_len: int
) -> list[str]:
  """Validates parsed target specification segments.

  Args:
    parsed_segments: Output from parse_spec_for_validation.
    min_segment_len: Minimum allowed length for any single segment.

  Returns:
    A list of error message strings. Empty if valid.
  """
  error_messages: list[str] = []
  if not parsed_segments:
    return error_messages

  segments_by_chain: dict[str, list[SpecSegment]] = {}

  for segment in parsed_segments:
    # Check 1: Range direction
    if segment.start_res > segment.end_res:
      error_messages.append(
          f'ERROR (Target Spec): Invalid range {segment.original_segment_str}'
          ' (start > end).'
      )
      continue  # Skip further checks for this malformed segment

    # Check 2: Minimum Segment Length
    segment_length = segment.end_res - segment.start_res + 1
    if segment_length < min_segment_len:
      error_messages.append(
          f'ERROR (Target Spec): Segment {segment.original_segment_str} is too'
          f' short (length {segment_length}, minimum is {min_segment_len}).'
      )

    if segment.chain_id not in segments_by_chain:
      segments_by_chain[segment.chain_id] = []
    segments_by_chain[segment.chain_id].append(segment)

  # Check 3: Overlaps within each chain
  for chain_id, chain_segments in segments_by_chain.items():
    if len(chain_segments) < 2:
      continue

    # Sort segments by start residue for overlap detection
    sorted_chain_segments = sorted(chain_segments, key=lambda s: s.start_res)

    for i in range(len(sorted_chain_segments) - 1):
      seg_i = sorted_chain_segments[i]
      seg_j = sorted_chain_segments[i + 1]
      # Overlap if end of seg_i is greater than or equal to start of seg_j
      if seg_i.end_res >= seg_j.start_res:
        error_messages.append(
            f'ERROR (Target Spec): Overlap on chain {chain_id}. '
            f'Segment {seg_i.original_segment_str} (residues'
            f' {seg_i.start_res}-{seg_i.end_res}) overlaps with '
            f'segment {seg_j.original_segment_str} (residues'
            f' {seg_j.start_res}-{seg_j.end_res}).'
        )
  return error_messages


def _prepare_target_ranges_for_hotspot_validation(
    parsed_spec_segments: list[SpecSegment],
) -> tuple[dict[str, list[tuple[int, int]]] | None, list[str]]:
  """Transforms parsed spec segments into a dict for hotspot containment checks.

     Assumes segments have already passed basic validation (e.g., start <= end).
  Args:
    parsed_spec_segments: Output from parse_spec_for_validation.

  Returns:
    A tuple: (target_ranges_dict, error_messages_list).
    target_ranges_dict: {'ChainID': [(start, end), ...]} or None if errors.
  """
  target_ranges: dict[str, list[tuple[int, int]]] = {}
  error_messages: list[str] = []

  for segment in parsed_spec_segments:
    if (
        segment.start_res > segment.end_res
    ):  # Should have been caught by validate_target_specification
      error_messages.append(
          f'Internal Error: Segment {segment.original_segment_str} has start'
          ' > end during hotspot prep.'
      )
      continue

    if segment.chain_id not in target_ranges:
      target_ranges[segment.chain_id] = []
    target_ranges[segment.chain_id].append((segment.start_res, segment.end_res))

  # Sort ranges for consistency (optional, but good practice)
  for chain_id_key in target_ranges:
    target_ranges[chain_id_key].sort(key=lambda r: r[0])

  if error_messages:
    return None, error_messages
  return target_ranges, error_messages


def validate_hotspots(
    parsed_hotspots: list[tuple[str, int]],
    target_spec_ranges: dict[str, list[tuple[int, int]]] | None,
) -> list[str]:
  """Validates parsed hotspots for uniqueness and containment.

  Args:
    parsed_hotspots: Output from parse_hotspots_for_validation.
    target_spec_ranges: Dict from _prepare_target_ranges_for_hotspot_validation.
      None if target spec had critical parsing errors.

  Returns:
    A list of error message strings. Empty if valid.
  """
  error_messages: list[str] = []
  if not parsed_hotspots:
    return error_messages

  if target_spec_ranges is None:
    error_messages.append(
        'INFO (Hotspots): Cannot validate hotspots; target specification has'
        ' critical parsing errors or is empty.'
    )
    return error_messages

  seen_hotspots: set[tuple[str, int]] = set()
  reported_duplicates: set[tuple[str, int]] = set()

  # Check 1: Uniqueness
  for hs_chain, hs_res_num in parsed_hotspots:
    hotspot_tuple = (hs_chain, hs_res_num)
    if hotspot_tuple in seen_hotspots:
      if hotspot_tuple not in reported_duplicates:
        error_messages.append(
            f'ERROR (Hotspots): Duplicate hotspot: {hs_chain}{hs_res_num}.'
        )
        reported_duplicates.add(hotspot_tuple)
    else:
      seen_hotspots.add(hotspot_tuple)

  # Check 2: Containment (using the unique hotspots from seen_hotspots)
  for hs_chain, hs_res_num in seen_hotspots:  # Iterate over unique ones
    hotspot_str_repr = f'{hs_chain}{hs_res_num}'
    if hs_chain not in target_spec_ranges:
      error_messages.append(
          f'ERROR (Hotspots): Hotspot {hotspot_str_repr} is on chain'
          f" '{hs_chain}', which is not in the target specification."
      )
      continue

    is_contained = any(
        spec_start <= hs_res_num <= spec_end
        for spec_start, spec_end in target_spec_ranges[hs_chain]
    )
    if not is_contained:
      error_messages.append(
          f'ERROR (Hotspots): Hotspot {hotspot_str_repr} (residue'
          f' {hs_res_num}) '
          f"is not within any specified segment on chain '{hs_chain}'."
      )
  return error_messages


def get_structure_file(pdb_id: str) -> str | None:
  """Downloads an mmCIF file from the PDB.

  Args:
      pdb_id: The PDB ID of the protein structure (case-insensitive).

  Returns:
      String containing the contents of the mmCIF file from the PDB.
      None if the download fails.
  """
  pdb_id_lower = pdb_id.lower()
  url = f'https://files.wwpdb.org/pub/pdb/data/structures/all/mmCIF/{pdb_id_lower}.cif.gz'

  try:
    with requests.get(url) as res:
      return gzip.decompress(res.content).decode('utf-8')

  except Exception as e:
    print(f'Error downloading or processing PDB file {pdb_id}: {e}')
    return None


def get_chain_residue_ids_from_mmcif(
    mmcif_contents: str,
) -> dict[str, list[str]]:
  """Parses mmCIF file for chain IDs and their residue IDs (label_seq_id).

  Args:
      mmcif_contents: Contents of the mmCIF file.

  Returns:
      A dictionary mapping chain IDs (label_asym_id) to a sorted list of unique
      residue IDs (label_seq_id as strings).
  """
  chain_residue_dict: dict[str, set[str]] = {}
  final_dict: dict[str, list[str]] = {}

  try:
    with io.StringIO(mmcif_contents) as f:
      mmcif_dict = MMCIF2Dict(f)
  except Exception as e:
    print(f'Error parsing MMCIF file: {e}')
    return final_dict

  # Standard residue types (uppercase)
  standard_residues = {
      'ALA',
      'ARG',
      'ASN',
      'ASP',
      'CYS',
      'GLN',
      'GLU',
      'GLY',
      'HIS',
      'ILE',
      'LEU',
      'LYS',
      'MET',
      'PHE',
      'PRO',
      'SER',
      'THR',
      'TRP',
      'TYR',
      'VAL',
  }

  try:
    label_asym_ids = mmcif_dict.get('_atom_site.label_asym_id', [])
    label_seq_ids = mmcif_dict.get('_atom_site.label_seq_id', [])
    label_comp_ids = mmcif_dict.get('_atom_site.label_comp_id', [])

    for chain_id, res_id_str, res_name in zip(
        label_asym_ids, label_seq_ids, label_comp_ids
    ):
      if (
          res_name.upper() in standard_residues and res_id_str != '.'
      ):  # '.' indicates unknown seq id
        if chain_id not in chain_residue_dict:
          chain_residue_dict[chain_id] = set()
        chain_residue_dict[chain_id].add(res_id_str)
  except KeyError as e:
    print(
        f'MMCIF parsing warning: Missing expected key {e}. Chain/residue data'
        ' might be incomplete.'
    )
  except Exception as e:
    print(f'Unexpected error during MMCIF data extraction: {e}')
    return final_dict

  for chain_id_key, res_id_set in chain_residue_dict.items():
    # Convert to int for sorting, then back to str for dropdowns
    try:
      # Filter for strings that can be converted to int
      numeric_res_ids = [r for r in res_id_set if r.lstrip('-').isdigit()]
      non_numeric_res_ids = [
          r for r in res_id_set if not r.lstrip('-').isdigit()
      ]

      sorted_numeric_res_ids = sorted([int(r) for r in numeric_res_ids])
      # Combine sorted numeric with alphanumerically sorted non-numeric
      final_dict[chain_id_key] = [
          str(r) for r in sorted_numeric_res_ids
      ] + sorted(non_numeric_res_ids)

    except ValueError:  # Should be less likely with pre-filtering
      # Fallback for unexpected non-integer values if filtering fails
      print(
          'Warning: Non-integer residue IDs encountered in chain'
          f' {chain_id_key} despite filtering. Using alphanumeric sort for all.'
      )
      final_dict[chain_id_key] = sorted(list(res_id_set))

  return final_dict


# --- Lightweight Parsers for Visualization ---
def _parse_spec_for_visualization(spec_str: str) -> list[tuple[str, str]]:
  """Lightweight parser for spec string, for py3Dmol.

  Args:
      spec_str: Specification string.

  Returns:
      list of (chain_id, residue_range_str) e.g. ('A', '10-15').
  """
  parsed = []
  if not spec_str.strip():
    return parsed
  for chain_part in spec_str.split('/'):
    for segment_part in chain_part.split(','):
      segment_part = segment_part.strip()
      if not segment_part:
        continue
      match = _split_spec_segment_string(segment_part)
      if match:
        parsed.append(match)
  return parsed


def _parse_hotspot_for_visualization(hotspot_str: str) -> list[tuple[str, str]]:
  """Lightweight parser for hotspot string, for py3Dmol.

  Args:
      hotspot_str: Hotspot string.

  Returns:
      list of (chain_id, residue_num_str) e.g. ('A', '10').
  """
  parsed = []
  if not hotspot_str.strip():
    return parsed
  for hs_part in hotspot_str.split(','):
    hs_part = hs_part.strip()
    if not hs_part:
      continue
    match = _split_spec_segment_string(
        hs_part
    )  # Re-use as it gives (chain, res_str)
    if match:
      parsed.append(match)  # (chain_id, res_num_as_string)
  return parsed


# --- Functions to Create Molecular Graphics ---
def display_structure(
    pdb_id: str,
    spec_str: str,
    hotspot_str: str,
    use_auth_numbering: bool,
    mol_widget_output_area: widgets.Output,
) -> py3Dmol.view | None:
  """Displays a protein structure using py3Dmol."""
  with mol_widget_output_area:
    mol_widget_output_area.clear_output(wait=True)
    try:
      viewer = py3Dmol.view(
          query=f'pdb:{pdb_id.lower()}', width=800, height=800
      )  # Use pdb_id.lower()
      viewer.pdbid = (
          pdb_id.lower()
      )  # Store the actual PDB ID used for the query
      viewer.setStyle({'cartoon': {'colorscheme': 'chain'}})

      # Display user's selection
      if spec_str:
        spec_parsed_for_viz = _parse_spec_for_visualization(spec_str)
        viewer.setStyle({'cartoon': {'colorscheme': 'chain', 'opacity': 0.5}})
        for chain_id, res_range_str in spec_parsed_for_viz:
          sel_dict = (
              {'lchain': chain_id, 'lresi': res_range_str}
              if not use_auth_numbering
              else {'chain': chain_id, 'resi': res_range_str}
          )
          viewer.setStyle(
              sel_dict, {'cartoon': {'colorscheme': 'chain', 'opacity': 1.0}}
          )

      # Display hotspots
      if hotspot_str:
        hs_parsed_for_viz = _parse_hotspot_for_visualization(hotspot_str)
        for chain_id, res_num_str in hs_parsed_for_viz:
          sel_dict = (
              {'lchain': chain_id, 'lresi': res_num_str}
              if not use_auth_numbering
              else {'chain': chain_id, 'resi': res_num_str}
          )
          highlight_style = {
              'cartoon': {'color': 'yellow'},
              'stick': {'opacity': 0.9, 'color': 'yellow'},
          }
          viewer.setStyle(sel_dict, highlight_style)

      # Hover labels (always use label_asym_id and label_seq_id for hover)
      viewer.setHoverable(
          {},
          True,
          """function(atom,viewer,event,container) {
                if(!atom.label) {
                    atom.label = viewer.addLabel(atom.lchain+atom.lresi+"/"+atom.resn,
                                              {position: atom, backgroundColor: 'mintcream', fontColor:'black'});
                }
            }""",
          """function(atom,viewer) {
                if(atom.label) {
                    viewer.removeLabel(atom.label);
                    delete atom.label;
                }
            }""",
      )
      viewer.zoomTo()
      viewer.show()
      return viewer
    except Exception as e:
      print(f'Error displaying structure for {pdb_id}: {e}')
      return None


def update_structure_visualization(
    viewer: py3Dmol.view | None,
    spec_str: str,
    hotspot_str: str,
    use_auth_numbering: bool,
    mol_widget_output_area: widgets.Output,
) -> py3Dmol.view | None:
  """Updates an existing py3Dmol viewer with new selections."""
  with mol_widget_output_area:
    mol_widget_output_area.clear_output(wait=True)
    if not viewer:
      print('ERROR: No viewer provided to update.')
      return None

    viewer.setStyle({'cartoon': {'colorscheme': 'chain'}})

    if spec_str:
      spec_parsed_for_viz = _parse_spec_for_visualization(spec_str)
      viewer.setStyle({'cartoon': {'colorscheme': 'chain', 'opacity': 0.5}})
      for chain_id, res_range_str in spec_parsed_for_viz:
        sel_dict = (
            {'lchain': chain_id, 'lresi': res_range_str}
            if not use_auth_numbering
            else {'chain': chain_id, 'resi': res_range_str}
        )
        viewer.setStyle(
            sel_dict, {'cartoon': {'colorscheme': 'chain', 'opacity': 1.0}}
        )
    else:  # If spec_str is empty, ensure full opacity for the whole cartoon
      viewer.setStyle({'cartoon': {'colorscheme': 'chain', 'opacity': 1.0}})

    if hotspot_str:
      hs_parsed_for_viz = _parse_hotspot_for_visualization(hotspot_str)
      for chain_id, res_num_str in hs_parsed_for_viz:
        sel_dict = (
            {'lchain': chain_id, 'lresi': res_num_str}
            if not use_auth_numbering
            else {'chain': chain_id, 'resi': res_num_str}
        )
        # viewer.setStyle(sel_dict, {'sphere': {'colorscheme': 'whiteCarbon', 'radius': 1.0}})
        highlight_style = {
            'cartoon': {'color': 'yellow'},
            'stick': {'opacity': 0.9, 'color': 'yellow'},
        }
        viewer.setStyle(sel_dict, highlight_style)

    # Re-apply hoverable as styles might clear it
    viewer.setHoverable(
        {},
        True,
        """function(atom,viewer,event,container) {
                if(!atom.label) {
                    atom.label = viewer.addLabel(atom.lchain+atom.lresi+"/"+atom.resn,
                                              {position: atom, backgroundColor: 'mintcream', fontColor:'black'});
                }
            }""",
        """function(atom,viewer) {
                if(atom.label) {
                    viewer.removeLabel(atom.label);
                    delete atom.label;
                }
            }""",
    )

    viewer.zoomTo()
    viewer.show()
    return viewer

In [None]:
# @title 3. Get PDB structure

pdb_id_input = '5vli'  # @param {type:"string"}
# @markdown - e.g. 5vli

current_pdb_id: str = pdb_id_input.strip().lower()
mmcif_file: str | None = None
chain_residue_data: dict[str, list[str]] = {}
available_chains: list[str] = []

if current_pdb_id:
  print(f'Downloading and processing {current_pdb_id}...')
  try:
    mmcif_file = get_structure_file(current_pdb_id)
    chain_residue_data = get_chain_residue_ids_from_mmcif(mmcif_file)
  except Exception as e:
    print(f'Error processing PDB file for {current_pdb_id}: {e}')
    chain_residue_data = {}
    available_chains = []

  available_chains = sorted(chain_residue_data.keys())
  if available_chains:
    print(f'Done! Found chains: {", ".join(available_chains)}')
  else:
    print(
        'No standard protein chains with parsable residue IDs'
        f' found for {current_pdb_id}.'
    )
    # Reset for safety
    chain_residue_data = {}
    available_chains = []
else:
  print('No PDB ID entered.')

In [None]:
# @title 4. Visual selection of the target protein specification and hotspots

# Switching scrolling off
output.no_vertical_scroll()

# --- Widget Lists and Global Variables ---
target_protein_spec_str: str = ''
hotspot_spec_str: str = ''
current_binder_length: int | str = 0

range_chain_dropdowns: list[widgets.Dropdown] = []
range_residue_dropdowns_start: list[widgets.Dropdown] = []
range_residue_dropdowns_end: list[widgets.Dropdown] = []

hotspot_chain_dropdowns: list[widgets.Dropdown] = []
hotspot_residue_dropdowns: list[widgets.Dropdown] = []


# Initialize VBoxes early and ensure they are always objects
selection_widgets_ranges_vbox: widgets.VBox = widgets.VBox([])
selection_widgets_hotspots_vbox: widgets.VBox = widgets.VBox([])
mol_viewer_instance: py3Dmol.view | None = None
output_area: widgets.Output = widgets.Output()  # Define output_area early
mol_widget_area: widgets.Output = (
    widgets.Output()
)  # Define mol_widget_area early


# --- Functions to Create Interactive Dropdown Selectors ---
def create_chain_residue_range_dropdowns() -> None:
  """Creates and displays a new set of dropdowns for a residue range."""

  if not available_chains:
    with output_area:  # output_area is now guaranteed to be defined
      output_area.clear_output(wait=True)
      display(
          HTML(
              "<font color='orange'>No chains available from PDB to select"
              ' ranges. Please load a valid PDB ID first.</font>'
          )
      )
    return

  chain_label = widgets.Label(
      value=f'Selection {len(range_chain_dropdowns) + 1}: Chain '
  )
  chain_dd = widgets.Dropdown(
      options=available_chains,
      disabled=False,
      layout=widgets.Layout(width='10%'),
  )
  initial_res_options = chain_residue_data.get(available_chains[0], [])
  res_label_start = widgets.Label(value=' from ')
  res_dd_start = widgets.Dropdown(
      options=initial_res_options,
      disabled=False,
      layout=widgets.Layout(width='10%'),
  )
  res_label_end = widgets.Label(value=' to ')
  initial_end_options = (
      initial_res_options[1:]
      if len(initial_res_options) > 1
      else initial_res_options
  )
  res_dd_end = widgets.Dropdown(
      options=initial_end_options,
      disabled=False,
      layout=widgets.Layout(width='10%'),
  )

  # --- Callbacks for range dropdowns ---
  def _update_residue_options_for_range(
      change: dict[str, Any],
      start_dd: widgets.Dropdown,
      end_dd: widgets.Dropdown,
  ) -> None:
    selected_chain = change.new
    new_res_options = chain_residue_data.get(selected_chain, [])
    current_start_val_str = str(start_dd.value)
    current_end_val_str = str(end_dd.value)

    start_dd.options = new_res_options
    if current_start_val_str in new_res_options:
      start_dd.value = current_start_val_str
    elif new_res_options:
      start_dd.value = new_res_options[0]

    if new_res_options:
      actual_start_val_for_end_options = str(start_dd.value)
      try:
        start_idx = new_res_options.index(actual_start_val_for_end_options)
        valid_end_opts = new_res_options[start_idx:]
        end_dd.options = valid_end_opts
        if current_end_val_str in valid_end_opts:
          end_dd.value = current_end_val_str
        elif valid_end_opts:
          end_dd.value = valid_end_opts[0]
      except (ValueError, IndexError):
        end_dd.options = new_res_options
        if new_res_options:
          end_dd.value = new_res_options[0]
    else:
      end_dd.options = []
    _collect_and_validate_selections(None)

  def _update_end_residue_options_for_range(
      change: dict[str, Any],
      start_dd: widgets.Dropdown,
      end_dd: widgets.Dropdown,
  ) -> None:
    selected_start_res_str = str(change.new)
    all_options_for_chain = list(start_dd.options)
    current_end_val = str(end_dd.value)
    if selected_start_res_str and all_options_for_chain:
      try:
        selected_start_res_int = int(selected_start_res_str)
        valid_end_options_str = [
            r_str
            for r_str in all_options_for_chain
            if int(r_str) >= selected_start_res_int
        ]
      except ValueError:
        try:
          start_idx = all_options_for_chain.index(selected_start_res_str)
          valid_end_options_str = all_options_for_chain[start_idx:]
        except ValueError:
          valid_end_options_str = all_options_for_chain
      end_dd.options = valid_end_options_str
      if current_end_val in valid_end_options_str:
        end_dd.value = current_end_val
      elif valid_end_options_str:
        end_dd.value = valid_end_options_str[0]
    else:
      end_dd.options = all_options_for_chain
      if all_options_for_chain:
        end_dd.value = all_options_for_chain[0]
    _collect_and_validate_selections(None)

  # --- End Callbacks for range dropdowns ---

  chain_dd.observe(
      lambda c: _update_residue_options_for_range(c, res_dd_start, res_dd_end),
      names='value',
  )
  res_dd_start.observe(
      lambda c: _update_end_residue_options_for_range(
          c, res_dd_start, res_dd_end
      ),
      names='value',
  )
  res_dd_end.observe(
      lambda c: _collect_and_validate_selections(None), names='value'
  )

  range_chain_dropdowns.append(chain_dd)
  range_residue_dropdowns_start.append(res_dd_start)
  range_residue_dropdowns_end.append(res_dd_end)

  new_hbox = widgets.HBox([
      chain_label,
      chain_dd,
      res_label_start,
      res_dd_start,
      res_label_end,
      res_dd_end,
  ])
  # selection_widgets_ranges_vbox is guaranteed to be a VBox object here
  selection_widgets_ranges_vbox.children += (new_hbox,)
  _collect_and_validate_selections(None)


def create_hotspot_dropdowns() -> None:
  """Creates and displays a new set of dropdowns for a hotspot residue."""

  if not available_chains:
    with output_area:  # output_area is now guaranteed to be defined
      output_area.clear_output(wait=True)
      display(
          HTML(
              "<font color='orange'>No chains available from PDB to select"
              ' hotspots. Please load a valid PDB ID first.</font>'
          )
      )
    return

  chain_label = widgets.Label(
      value=f'Hotspot {len(hotspot_chain_dropdowns) + 1}: Chain '
  )
  chain_dd = widgets.Dropdown(
      options=available_chains,
      disabled=False,
      layout=widgets.Layout(width='10%'),
  )
  initial_res_options = chain_residue_data.get(available_chains[0], [])
  res_label = widgets.Label(value=' Residue ')
  res_dd = widgets.Dropdown(
      options=initial_res_options,
      disabled=False,
      layout=widgets.Layout(width='10%'),
  )

  # --- Callback for hotspot dropdowns ---
  def _update_hotspot_residue_options_for_hotspot(
      change: dict[str, Any], res_dropdown: widgets.Dropdown
  ) -> None:
    selected_chain = change.new
    res_dropdown.options = chain_residue_data.get(selected_chain, [])
    if res_dropdown.options:
      res_dropdown.value = res_dropdown.options[0]
    _collect_and_validate_selections(None)

  # --- End Callback for hotspot dropdowns ---

  chain_dd.observe(
      lambda c: _update_hotspot_residue_options_for_hotspot(c, res_dd),
      names='value',
  )
  res_dd.observe(
      lambda c: _collect_and_validate_selections(None), names='value'
  )

  hotspot_chain_dropdowns.append(chain_dd)
  hotspot_residue_dropdowns.append(res_dd)

  new_hbox = widgets.HBox([chain_label, chain_dd, res_label, res_dd])
  # selection_widgets_hotspots_vbox is guaranteed to be a VBox object
  selection_widgets_hotspots_vbox.children += (new_hbox,)
  _collect_and_validate_selections(None)


def _display_specification_button_clicked(b: widgets.Button | None) -> None:
  """Callback for 'Display Specification' button. Updates molecular viewer."""
  global mol_viewer_instance
  _collect_and_validate_selections(None)  # Ensure spec strings are up-to-date

  if not current_pdb_id and mol_widget_area:
    with mol_widget_area:  # mol_widget_area is now guaranteed to be defined
      mol_widget_area.clear_output(wait=True)
      print(
          'Please enter a PDB ID in Cell 3 and run it to display the structure.'
      )
    return

  if current_pdb_id and mol_widget_area:
    pdb_id_for_viewer = current_pdb_id.lower()
    if (
        not mol_viewer_instance
        or not hasattr(mol_viewer_instance, 'pdbid')
        or mol_viewer_instance.pdbid != pdb_id_for_viewer
    ):
      mol_viewer_instance = display_structure(
          pdb_id_for_viewer,
          target_protein_spec_str,
          hotspot_spec_str,
          False,
          mol_widget_area,
      )
    else:
      mol_viewer_instance = update_structure_visualization(
          mol_viewer_instance,
          target_protein_spec_str,
          hotspot_spec_str,
          False,
          mol_widget_area,
      )


def _collect_and_validate_selections(b: widgets.Button | None) -> None:
  """Collects selections, validates, and updates UI elements."""
  global target_protein_spec_str, hotspot_spec_str, current_binder_length
  # --- Construct spec string logic ---
  final_chain_spec_parts: list[str] = []
  temp_chain_segments: list[str] = []
  last_chain_processed: str | None = None

  for i in range(len(range_chain_dropdowns)):
    chain_val = range_chain_dropdowns[i].value
    start_res_val = range_residue_dropdowns_start[i].value
    end_res_val = range_residue_dropdowns_end[i].value

    if not (chain_val and start_res_val and end_res_val):
      continue
    segment_text = f'{chain_val}{start_res_val}-{end_res_val}'
    if last_chain_processed is None:
      last_chain_processed = chain_val
      temp_chain_segments.append(segment_text)
    elif chain_val == last_chain_processed:
      temp_chain_segments.append(segment_text)
    else:
      final_chain_spec_parts.append(','.join(temp_chain_segments))
      last_chain_processed = chain_val
      temp_chain_segments = [segment_text]
  if temp_chain_segments:
    final_chain_spec_parts.append(','.join(temp_chain_segments))
  target_protein_spec_str = '/'.join(final_chain_spec_parts)

  hotspot_parts: list[str] = []
  for i in range(len(hotspot_chain_dropdowns)):
    chain = hotspot_chain_dropdowns[i].value
    res = hotspot_residue_dropdowns[i].value
    if chain and res:
      hotspot_parts.append(f'{chain}{res}')
  hotspot_spec_str = ','.join(sorted(list(set(hotspot_parts))))

  target_spec_text_widget.value = target_protein_spec_str
  hotspot_spec_text_widget.value = hotspot_spec_str
  # --- Validation logic ---
  all_errors: list[str] = []
  parsed_spec_segments, spec_parsing_errors = parse_spec_for_validation(
      target_protein_spec_str
  )
  all_errors.extend(spec_parsing_errors)
  target_ranges_for_hs_val: dict[str, list[tuple[int, int]]] | None = None
  hs_prep_errors: list[str] = []
  if not spec_parsing_errors:
    spec_validation_errors = validate_target_specification(
        parsed_spec_segments, MINIMAL_SEGMENT_LENGTH
    )
    all_errors.extend(spec_validation_errors)
    target_ranges_for_hs_val, hs_prep_errors = (
        _prepare_target_ranges_for_hotspot_validation(parsed_spec_segments)
    )
    all_errors.extend(hs_prep_errors)
  if hotspot_spec_str:
    parsed_hotspots, hs_parsing_errors = parse_hotspots_for_validation(
        hotspot_spec_str
    )
    all_errors.extend(hs_parsing_errors)
    if not hs_parsing_errors and not hs_prep_errors:
      hotspot_validation_errors = validate_hotspots(
          parsed_hotspots, target_ranges_for_hs_val
      )
      all_errors.extend(hotspot_validation_errors)
  # --- Update Output Area ---
  with output_area:  # output_area is guaranteed to be defined
    output_area.clear_output(wait=True)
    if all_errors:
      unique_errors = sorted(list(set(all_errors)))
      error_html = '<br>'.join([f'- {e}' for e in unique_errors])
      display(
          HTML(
              "<font color='red'><b>Validation"
              f' Issues:</b><br>{error_html}</font>'
          )
      )
      current_binder_length = 'N/A (validation errors)'
    else:
      total_spec_length = 0
      if not spec_parsing_errors:
        for segment in parsed_spec_segments:
          if segment.start_res <= segment.end_res:
            total_spec_length += segment.end_res - segment.start_res + 1
      current_binder_length = 512 - total_spec_length
      display(
          HTML(
              "<font color='green'>Target specification and hotspots validated"
              ' successfully.</font>'
          )
      )
      print(f'Specified target length: {total_spec_length} residues.')
      print(f'Maximal binder length: {current_binder_length} residues.')
      if isinstance(current_binder_length, int):
        if current_binder_length < 60 and current_binder_length >= 40:
          display(
              HTML(
                  "<font color='orange'><b>WARNING:</b>"
                  f' {current_binder_length} residues for binder may be too'
                  ' short.</font>'
              )
          )
        elif current_binder_length < 40:
          display(
              HTML(
                  "<font color='red'><b>ERROR:</b>"
                  f' {current_binder_length} residues left for binder is below'
                  ' the minimal binder length limit (40 residues).</font>'
              )
          )
          display(
              HTML(
                  "<font color='red'><b>ERROR:</b> Reduce size of the target"
                  ' protein selection to less than 472 residue (current size'
                  f' is {total_spec_length} residues).</font>'
              )
          )
      if isinstance(total_spec_length, int):
        if total_spec_length < MINIMAL_SPEC_TARGET_LENGTH:
          display(
              HTML(
                  f"<font color='red'><b>ERROR:</b> {total_spec_length} is"
                  ' below the minimal protein specification limit'
                  f' ({MINIMAL_SPEC_TARGET_LENGTH} residues).</font>'
              )
          )


# --- Button Click Handlers ---
def _add_range_selection_clicked(b: widgets.Button) -> None:
  create_chain_residue_range_dropdowns()


def _add_hotspot_selection_clicked(b: widgets.Button) -> None:
  create_hotspot_dropdowns()


def _remove_last_range_selection(b: widgets.Button) -> None:
  """Removes last set of selection widgets for the protein specs."""
  # global selection_widgets_ranges_vbox
  if (
      selection_widgets_ranges_vbox and selection_widgets_ranges_vbox.children
  ):  # Check if VBox itself is not None
    selection_widgets_ranges_vbox.children = (
        selection_widgets_ranges_vbox.children[:-1]
    )
    if range_chain_dropdowns:
      range_chain_dropdowns.pop()
    if range_residue_dropdowns_start:
      range_residue_dropdowns_start.pop()
    if range_residue_dropdowns_end:
      range_residue_dropdowns_end.pop()
    _collect_and_validate_selections(None)


def _remove_last_hotspot_selection(b: widgets.Button) -> None:
  """Removes last set of selection widgets for the hotspots."""
  # global selection_widgets_hotspots_vbox
  if (
      selection_widgets_hotspots_vbox
      and selection_widgets_hotspots_vbox.children
  ):
    selection_widgets_hotspots_vbox.children = (
        selection_widgets_hotspots_vbox.children[:-1]
    )
    if hotspot_chain_dropdowns:
      hotspot_chain_dropdowns.pop()
    if hotspot_residue_dropdowns:
      hotspot_residue_dropdowns.pop()
    _collect_and_validate_selections(None)


def _create_separator_html(
    thin: bool = False, transparent: bool = False
) -> widgets.HTML:
  border_style = (
      '1px solid transparent'
      if transparent
      else ('1px solid #ccc' if thin else '2px solid #ccc')
  )
  margin = '5px 0' if thin else '20px 0'
  return widgets.HTML(
      f'<hr style="border-top: {border_style}; margin: {margin};">'
  )


# --- Initialize UI Elements ---
display(
    HTML(
        '<link rel="stylesheet"'
        ' href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">'
    )
)

target_spec_label = widgets.Label(
    value='Target Specification:', layout=widgets.Layout(width='15%')
)
target_spec_text_widget = widgets.Text(
    value=target_protein_spec_str,
    disabled=True,
    layout=widgets.Layout(width='80%'),
)
display(widgets.HBox([target_spec_label, target_spec_text_widget]))

hotspot_spec_label = widgets.Label(
    value='Hotspots:', layout=widgets.Layout(width='15%')
)
hotspot_spec_text_widget = widgets.Text(
    value=hotspot_spec_str, disabled=True, layout=widgets.Layout(width='80%')
)
display(widgets.HBox([hotspot_spec_label, hotspot_spec_text_widget]))

display(_create_separator_html(transparent=True))

# output_area is already defined globally
display(output_area)
print(
    '\nCopy text from "Target Specification" and "Hotspots" into the submission'
    ' form.'
)
print(
    'Use "Add Target Segment" button to add more segments to the protein'
    ' specification'
)
print('Use "Add Hotspot" button to add more hotspot residues')
print(
    'Use "Remove Last Segment" and "Remove Last Hotspot" buttons to remove'
    ' corresponding elements'
)

display(_create_separator_html())

add_range_btn = widgets.Button(
    description='Add Target Segment', icon='plus-square'
)
add_range_btn.layout.width = '400px'
add_range_btn.on_click(_add_range_selection_clicked)
remove_range_btn = widgets.Button(
    description='Remove Last Segment', icon='trash'
)
remove_range_btn.layout.width = '400px'
remove_range_btn.on_click(_remove_last_range_selection)
display(widgets.HBox([add_range_btn, remove_range_btn]))

# selection_widgets_ranges_vbox is already defined globally
display(selection_widgets_ranges_vbox)

display(_create_separator_html(thin=True))

add_hotspot_btn = widgets.Button(description='Add Hotspot', icon='plus-square')
add_hotspot_btn.layout.width = '400px'
add_hotspot_btn.on_click(_add_hotspot_selection_clicked)
remove_hotspot_btn = widgets.Button(
    description='Remove Last Hotspot', icon='trash'
)
remove_hotspot_btn.layout.width = '400px'
remove_hotspot_btn.on_click(_remove_last_hotspot_selection)
display(widgets.HBox([add_hotspot_btn, remove_hotspot_btn]))

# selection_widgets_hotspots_vbox is already defined globally
display(selection_widgets_hotspots_vbox)

display(_create_separator_html())

display_spec_btn = widgets.Button(
    description='Display Specification on Structure', icon='eye'
)
display_spec_btn.layout.width = '400px'
display_spec_btn.on_click(_display_specification_button_clicked)
display(widgets.HBox([display_spec_btn]))

display(_create_separator_html(transparent=True))
print(
    'Use "Display Specification on Structure" to visualize the target and'
    ' hotspots.'
)
print(
    'Regions that are not specified are displayed semi-transparently, while'
    ' specified segments are solid'
)
print('Hotspot residues are displayed as yellow sticks\\n\\n')

# mol_widget_area is already defined globally
display(mol_widget_area)

# --- Initial UI Setup ---
if available_chains:
  create_chain_residue_range_dropdowns()
  _display_specification_button_clicked(None)
else:
  with output_area:  # output_area is guaranteed defined
    output_area.clear_output(wait=True)
    display(
        HTML(
            "<font color='orange'>Please provide a valid PDB ID in Cell 3 and"
            ' run it.</font>'
        )
    )
  with mol_widget_area:  # mol_widget_area is guaranteed defined
    mol_widget_area.clear_output(wait=True)
    print(
        'Viewer will appear here once a PDB is loaded and specification is'
        ' displayed.'
    )

_collect_and_validate_selections(None)

In [None]:
# @title 5. (Optional) Run this cell to print final target specification and hotspots

print(f'PDB id: {current_pdb_id if current_pdb_id else "Not set"}')
print(f'Target Protein Specification: {target_spec_text_widget.value}')
print(f'Hotspots: {hotspot_spec_text_widget.value}')
print(f'Calculated Maximal Binder Length: {current_binder_length} residues')