In [None]:
import torch # Still needed for torch_geometric utilities potentially
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx # Utility to convert if needed (though less direct here)
from torch_geometric.data import Data # For potential type hinting or future use
import traceback
import os
import random
from typing import Optional, Tuple, Dict, List

SEED = 42 # Use the same seed for potentially reproducible layouts
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
# No device setting needed unless using GPU-accelerated layout (unlikely here)

AMINO_ACIDS = 'ARNDCQEGHILKMFPSTWYV-' # Includes padding char
VALID_AA = 'ARNDCQEGHILKMFPSTWYV'  # Valid amino acids

EXPECTED_SEQ_LEN = 33
CENTRAL_K_POS_ABS = 16 # 0-based index

def visualize_graph_cutoffs(
    sample_row: pd.Series,
    node_size_feature: str = 'degree', # Options: 'degree', 'plddt', 'sasa'
    cutoffs: List[float] = [8.0, 10.0],
    save_path: str = "graph_cutoffs.pdf",
    fig_size: Tuple[int, int] = (12, 6),
    layout_seed: Optional[int] = SEED
    ):
    """
    Generates a side-by-side visualization of a graph with different distance cutoffs.

    Args:
        sample_row (pd.Series): A single row from the DataFrame. Must contain
                                'sequence', 'distance_map', 'label', 'pos', 'entry',
                                and potentially 'pLDDT', 'sasa' if used for sizing.
        node_size_feature (str): Feature to map to node size ('degree', 'plddt', 'sasa').
        cutoffs (List[float]): List of distance thresholds to visualize.
        save_path (str): Path to save the output PDF file.
        fig_size (Tuple[int, int]): Figure size for the plot.
        layout_seed (Optional[int]): Seed for the NetworkX layout algorithm.
    """
    print(f"\nGenerating visualization for sample: Entry={sample_row.get('entry', 'N/A')}, Pos={sample_row.get('pos', 'N/A')}")
    print(f"Node size based on: {node_size_feature}")
    print(f"Cutoffs: {cutoffs}")

    try:
        # --- 1. Extract Base Information ---
        sequence = sample_row['sequence']
        if pd.isna(sequence) or len(sequence) != EXPECTED_SEQ_LEN or sequence[CENTRAL_K_POS_ABS] != 'K':
             print("Error: Invalid sequence or central K in sample row.")
             return

        # Parse the full distance map
        distance_map_str = sample_row['distance_map']
        distance_map = np.array(eval(str(distance_map_str)), dtype=np.float32).reshape(EXPECTED_SEQ_LEN, EXPECTED_SEQ_LEN)
        # Replace stored -1 (padding/inf) with infinity for distance checks
        distance_map[distance_map == -1] = np.inf

        # Identify valid nodes and central K relative index
        valid_pos_indices = [i for i, aa in enumerate(sequence) if aa in VALID_AA]
        if not valid_pos_indices: print("Error: No valid residues found."); return
        num_nodes = len(valid_pos_indices)
        try:
            central_k_new_idx = valid_pos_indices.index(CENTRAL_K_POS_ABS)
        except ValueError:
            print("Error: Central K not found among valid residues."); return

        # --- 2. Extract Node Features needed for sizing/labels ---
        valid_sequence = ''.join([sequence[i] for i in valid_pos_indices])
        node_labels = {i: aa for i, aa in enumerate(valid_sequence)} # Labels for drawing

        node_sasa = None
        if node_size_feature == 'sasa':
            if 'sasa' in sample_row and not pd.isna(sample_row['sasa']):
                 sasa_full = np.array(eval(str(sample_row['sasa'])), dtype=np.float32)
                 # Ensure sasa_full has expected length before indexing
                 if len(sasa_full) == EXPECTED_SEQ_LEN:
                     node_sasa = np.nan_to_num(sasa_full[valid_pos_indices])
                     if len(node_sasa) != num_nodes: raise ValueError("SASA length mismatch after indexing.")
                 else: print(f"Warning: SASA array length ({len(sasa_full)}) doesn't match expected ({EXPECTED_SEQ_LEN}). Cannot use for sizing.")
            else: print("Warning: SASA node size requested but 'sasa' data missing/invalid.")

        node_plddt = None
        if node_size_feature == 'plddt':
            if 'plDDT' in sample_row and not pd.isna(sample_row['plDDT']):
                 plddt_full = np.array(eval(str(sample_row['plDDT'])), dtype=np.float32)
                 if len(plddt_full) == EXPECTED_SEQ_LEN:
                      node_plddt = np.nan_to_num(plddt_full[valid_pos_indices])
                      if len(node_plddt) != num_nodes: raise ValueError("pLDDT length mismatch after indexing.")
                 else: print(f"Warning: pLDDT array length ({len(plddt_full)}) doesn't match expected ({EXPECTED_SEQ_LEN}). Cannot use for sizing.")
            else: print("Warning: pLDDT node size requested but 'plDDT' data missing/invalid.")


        # --- 3. Prepare Plot ---
        num_plots = len(cutoffs)
        fig, axes = plt.subplots(1, num_plots, figsize=fig_size, squeeze=False) # Use subplots for better control
        base_node_size = 200
        k_node_size_bonus = 200 # How much bigger K is than others by default
        min_node_size = 100 # Min size when scaling by feature
        max_node_size = 600 # Max size when scaling by feature

        fixed_layout = None # To store layout from the most connected graph

        # --- 4. Generate Graph Structures for Each Cutoff ---
        graphs_nx = {}
        edge_lists = {}
        most_edges = -1

        for cutoff in cutoffs:
            valid_distance_map = distance_map[np.ix_(valid_pos_indices, valid_pos_indices)]
            adj = (valid_distance_map < cutoff) & (valid_distance_map > 0)
            np.fill_diagonal(adj, False)
            edge_list_valid = np.argwhere(adj) # Indices relative to valid nodes
            edges = edge_list_valid.tolist()

            # Optional: Add sequential fallback if NO distance edges found
            # if not edges and num_nodes > 1:
            #     for k_valid in range(num_nodes - 1): edges.extend([[k_valid, k_valid+1], [k_valid+1, k_valid]])

            nx_graph = nx.Graph()
            nx_graph.add_nodes_from(range(num_nodes))
            nx_graph.add_edges_from(edges)
            graphs_nx[cutoff] = nx_graph
            edge_lists[cutoff] = edges # Store edges if needed later

            # Keep track of which graph has most edges for layout calculation
            if len(edges) > most_edges:
                 most_edges = len(edges)
                 layout_graph = nx_graph # Use this graph to calculate layout

        # --- 5. Calculate Layout (only once using the most connected graph) ---
        if most_edges >= 0 and layout_graph is not None: # Check if any graph was created
            print(f"Calculating layout based on graph with {most_edges} edges...")
            try:
                 # Kamada-Kawai often gives good 'physical' type layouts
                 fixed_layout = nx.kamada_kawai_layout(layout_graph)
                 # Check if layout worked (can fail on disconnected graphs)
                 if fixed_layout is None or len(fixed_layout) != num_nodes:
                      print("Warning: kamada_kawai_layout failed or returned wrong size, using spring_layout.")
                      fixed_layout = nx.spring_layout(layout_graph, seed=layout_seed)
            except Exception as e_layout:
                 print(f"Layout calculation failed ({e_layout}), using spring_layout.")
                 fixed_layout = nx.spring_layout(layout_graph, seed=layout_seed)
        else:
             print("Error: No graph available to calculate layout.")
             return # Cannot proceed without layout


        # --- 6. Draw Subplots ---
        for i, cutoff in enumerate(cutoffs):
            ax = axes[0, i] # Get the subplot axis
            nx_graph = graphs_nx[cutoff]
            pos = fixed_layout # Use the pre-calculated fixed layout

            # --- Determine Node Styles ---
            # Colors
            node_colors = ['#1f77b4'] * num_nodes # Default blue
            if 0 <= central_k_new_idx < num_nodes:
                 node_colors[central_k_new_idx] = '#ff7f0e' # K color (orange)

            # Sizes
            node_sizes = np.full(num_nodes, base_node_size, dtype=float) # Start with base size
            if node_size_feature == 'degree':
                degrees = np.array([nx_graph.degree(n) for n in nx_graph.nodes()])
                max_deg = max(1, np.max(degrees)) # Avoid division by zero
                node_sizes = min_node_size + (max_node_size - min_node_size) * (degrees / max_deg)
            elif node_size_feature == 'plddt' and node_plddt is not None:
                plddt_scaled = node_plddt / 100.0 # Scale 0-1
                node_sizes = min_node_size + (max_node_size - min_node_size) * np.clip(plddt_scaled, 0, 1) # Clip just in case
            elif node_size_feature == 'sasa' and node_sasa is not None:
                max_sasa = max(1e-6, np.max(node_sasa)) # Avoid division by zero
                node_sizes = min_node_size + (max_node_size - min_node_size) * (node_sasa / max_sasa)
            # Ensure K node is large enough
            if 0 <= central_k_new_idx < num_nodes:
                 node_sizes[central_k_new_idx] = max(node_sizes[central_k_new_idx], base_node_size + k_node_size_bonus)

            # --- Determine Edge Styles ---
            edge_widths = []
            edge_colors = []
            k_edge_width = 2.0 # Slightly thinner bold
            default_edge_width = 0.6
            k_edge_color = '#d62728' # Red for K connections
            default_edge_color = 'darkgrey'

            # Create edge list directly from NetworkX graph to ensure order matches drawing
            current_edges = list(nx_graph.edges())
            for u, v in current_edges:
                is_k_connection = (u == central_k_new_idx or v == central_k_new_idx)
                edge_widths.append(k_edge_width if is_k_connection else default_edge_width)
                edge_colors.append(k_edge_color if is_k_connection else default_edge_color)

            # --- Draw ---
            nx.draw_networkx_nodes(nx_graph, pos, ax=ax, node_color=node_colors, node_size=node_sizes.tolist(), alpha=0.85)
            nx.draw_networkx_edges(nx_graph, pos, ax=ax, edgelist=current_edges, width=edge_widths, edge_color=edge_colors, alpha=0.6)
            nx.draw_networkx_labels(nx_graph, pos, ax=ax, labels=node_labels, font_size=7, font_weight='normal', font_color='black')

            ax.set_title(f"Cutoff = {cutoff} Å ({len(current_edges)} edges)")
            ax.axis('off')

        # --- Final Figure Adjustments and Save ---
        fig.suptitle(f"Graph Connectivity vs. Distance Cutoff\nEntry={sample_row.get('entry', 'N/A')}, Pos={sample_row.get('pos', 'N/A')}, Label={sample_row.get('label', 'N/A')}", fontsize=14, y=1.0) # Adjust y slightly
        plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout bounds

        try:
             plt.savefig(save_path, format="pdf", bbox_inches='tight', dpi=300)
             print(f"Visualization saved to: {save_path}")
        except Exception as e:
             print(f"Error saving figure to {save_path}: {e}")
        # plt.show() # Display the plot - comment out if running in non-interactive script
        plt.close(fig) # Close figure to free memory

    except KeyError as e:
         print(f"Error: Missing expected column in sample_row: {e}")
    except Exception as e:
         print(f"An unexpected error occurred during visualization: {e}")
         traceback.print_exc()

if __name__ == "__main__":
     # This block loads data and calls the visualization function
     # --- !!! UPDATE THIS PATH !!! ---
     data_csv_path = "../data/processed_features_fixed_train_contactmap.csv" # Path to CSV with sequence, distance_map, etc.

     try:
         print(f"Loading data from: {data_csv_path}")
         if not os.path.exists(data_csv_path):
             raise FileNotFoundError(f"Data CSV not found at: {data_csv_path}")

         df_full = pd.read_csv(data_csv_path)
         print(f"Loaded {len(df_full)} total samples.")

         if not df_full.empty:
              # --- SELECT SAMPLE TO VISUALIZE ---
              sample_index_to_plot = 69 # Choose the index (0-based) of the row you want

              if 0 <= sample_index_to_plot < len(df_full):
                   sample_row = df_full.iloc[sample_index_to_plot]
                   entry = sample_row.get('entry', f'Index_{sample_index_to_plot}')
                   pos = sample_row.get('pos', 'N/A')
                   output_dir = "visualizations" # Optional: Subdirectory for plots
                   os.makedirs(output_dir, exist_ok=True)

                   # --- CALL VISUALIZATION FUNCTION ---
                   print("\nGenerating plot with node size based on 'degree'...")
                   save_name_deg = os.path.join(output_dir, f"graph_cutoffs_degree_{entry}_{pos}.pdf")
                   visualize_graph_cutoffs(sample_row, node_size_feature='degree', save_path=save_name_deg)

                   # --- Uncomment to generate plots with other sizings ---
                   print("\nGenerating plot with node size based on 'plddt'...")
                   save_name_plddt = os.path.join(output_dir, f"graph_cutoffs_plddt_{entry}_{pos}.pdf")
                   visualize_graph_cutoffs(sample_row, node_size_feature='plddt', save_path=save_name_plddt)

                   print("\nGenerating plot with node size based on 'sasa'...")
                   save_name_sasa = os.path.join(output_dir, f"graph_cutoffs_sasa_{entry}_{pos}.pdf")
                   visualize_graph_cutoffs(sample_row, node_size_feature='sasa', save_path=save_name_sasa)

                   print("\nVisualization generation complete.")

              else:
                   print(f"Error: Sample index {sample_index_to_plot} is out of bounds for DataFrame length {len(df_full)}.")
         else:
              print("Error: Loaded DataFrame is empty.")

     except FileNotFoundError as e:
          print(e)
     except Exception as e:
          print(f"An error occurred in the example usage block: {e}")
          traceback.print_exc()

Loading data from: ../data/processed_features_fixed_train_contactmap.csv
Loaded 8853 total samples.

Generating plot with node size based on 'degree'...

Generating visualization for sample: Entry=P00805, Pos=184
Node size based on: degree
Cutoffs: [8.0, 10.0]
Calculating layout based on graph with 438 edges...
Visualization saved to: visualizations/graph_cutoffs_degree_P00805_184.pdf

Generating plot with node size based on 'plddt'...

Generating visualization for sample: Entry=P00805, Pos=184
Node size based on: plddt
Cutoffs: [8.0, 10.0]
Calculating layout based on graph with 438 edges...
Visualization saved to: visualizations/graph_cutoffs_plddt_P00805_184.pdf

Generating plot with node size based on 'sasa'...

Generating visualization for sample: Entry=P00805, Pos=184
Node size based on: sasa
Cutoffs: [8.0, 10.0]
Calculating layout based on graph with 438 edges...
Visualization saved to: visualizations/graph_cutoffs_sasa_P00805_184.pdf

Visualization generation complete.
