In [None]:
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import random
import matplotlib.patches as mpatches
import os

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

def visualize_protein_graph(df, sample_idx=None, threshold=8.0, figsize=(12, 10), seed=42):
    """
    Visualize a protein graph from the dataset.
    
    Args:
        df: DataFrame containing protein data
        sample_idx: Index of sample to visualize (if None, a random one is selected)
        threshold: Distance threshold for edge creation
        figsize: Size of the figure
        seed: Random seed for reproducibility
    
    Returns:
        fig, ax: The matplotlib figure and axes objects
    """
    # Set seeds for reproducibility
    random.seed(seed)
    np.random.seed(seed)
    
    # Select a random sample if not specified
    if sample_idx is None:
        sample_idx = random.randint(0, len(df) - 1)
    
    # Get the sample data
    sample = df.iloc[sample_idx]
    sequence = sample['sequence']
    
    # Use sequence for padding detection
    is_padded = [pos for pos, aa in enumerate(sequence) if aa == '-']
    
    # Get distance map and secondary structure
    dist_map = np.array(eval(sample['distance_map'])).reshape(33, 33)
    ss_string = sample['ss']
    
    # Create the graph
    G = nx.Graph()
    
    # Add nodes (amino acids)
    K_POS = 16  # Position of the K residue
    node_colors = []
    node_sizes = []
    node_labels = {}
    
    for i in range(33):
        if i in is_padded:
            continue
        
        # Add node
        G.add_node(i)
        
        # Set node attributes based on secondary structure
        ss_type = ss_string[i]
        if ss_type == 'H':  # Helix
            color = 'red'
        elif ss_type == 'E':  # Sheet
            color = 'blue'
        else:  # Loop
            color = 'green'
            
        # Highlight K position
        if i == K_POS:
            # Make it a special color with higher opacity
            color = to_rgba('purple', alpha=0.9)
            node_sizes.append(500)  # Make K larger
        else:
            # Regular amino acids
            color = to_rgba(color, alpha=0.6)
            node_sizes.append(300)
            
        node_colors.append(color)
        node_labels[i] = f"{sequence[i]}{i+1}"
    
    # Add edges based on distance threshold
    for i in range(33):
        if i in is_padded:
            continue
            
        for j in range(i+1, 33):  # Avoid duplicate edges
            if j in is_padded:
                continue
                
            if dist_map[i,j] != -1 and dist_map[i,j] < threshold:
                G.add_edge(i, j, weight=dist_map[i,j])
    
    # Compute a good layout
    # For protein visualization, a circular or spring layout often works well
    # pos = nx.circular_layout(G)
    pos = nx.spring_layout(G, seed=seed, k=0.8)
    
    # Create the figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Draw nodes with attributes
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes, ax=ax)
    
    # Draw edges with width based on distance (inverse relationship)
    edge_widths = []
    for u, v in G.edges():
        # Thicker for closer amino acids
        edge_widths.append(2.0 * (1.0 / (dist_map[u,v] + 0.1)))
    
    nx.draw_networkx_edges(G, pos, width=edge_widths, alpha=0.6, edge_color='gray', ax=ax)
    
    # Draw labels
    nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10, font_weight='bold', ax=ax)
    
    # Add legend for secondary structure
    legend_elements = [
        mpatches.Patch(color='red', alpha=0.6, label='Helix (H)'),
        mpatches.Patch(color='blue', alpha=0.6, label='Sheet (E)'),
        mpatches.Patch(color='green', alpha=0.6, label='Loop (L)'),
        mpatches.Patch(color='purple', alpha=0.9, label='K position')
    ]
    ax.legend(handles=legend_elements, loc='upper right')
    
    # Add title and other details
    plt.title(f"Protein Graph Visualization (Sample {sample_idx}, Distance Threshold: {threshold}Å)")
    plt.tight_layout()
    
    # Some stats for the graph
    plt.figtext(0.02, 0.02, f"Nodes: {G.number_of_nodes()}, Edges: {G.number_of_edges()}", fontsize=10)
    plt.axis('off')
    
    return fig, ax, G

def visualize_multiple_samples(df, num_samples=3, thresholds=[8.0], save_dir='.'):
    """
    Create visualizations for multiple random samples with different thresholds.
    
    Args:
        df: DataFrame containing protein data
        num_samples: Number of samples to visualize
        thresholds: List of distance thresholds to use
        save_dir: Directory to save figures
    """
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Select random sample indices
    if num_samples > len(df):
        num_samples = len(df)
        print(f"Warning: Number of samples reduced to {num_samples} (size of dataset)")
        
    sample_indices = random.sample(range(len(df)), num_samples)
    
    for idx in sample_indices:
        for threshold in thresholds:
            print(f"Creating visualization for sample {idx} with threshold {threshold}Å")
            fig, ax, G = visualize_protein_graph(df, sample_idx=idx, threshold=threshold)
            
            # Save the figure
            filename = f"{save_dir}/protein_graph_sample_{idx}_threshold_{threshold:.1f}.png"
            plt.savefig(filename, dpi=300, bbox_inches='tight')
            print(f"  Saved to {filename}")
            print(f"  Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
            
            plt.close()
            
    print(f"\nAll visualizations saved to {save_dir}/")

def visualize_k_centered(df, sample_idx=None, max_samples=5, threshold=8.0, save_dir='.'):
    """
    Create K-centered visualizations that emphasize the local neighborhood around the K residue.
    
    Args:
        df: DataFrame containing protein data
        sample_idx: Specific sample to visualize (if None, random samples are chosen)
        max_samples: Maximum number of samples to visualize if sample_idx is None
        threshold: Distance threshold for edge creation
        save_dir: Directory to save figures
    """
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Determine which samples to visualize
    if sample_idx is not None:
        sample_indices = [sample_idx]
    else:
        if max_samples > len(df):
            max_samples = len(df)
        sample_indices = random.sample(range(len(df)), max_samples)
    
    K_POS = 16  # Position of the K residue
    
    for idx in sample_indices:
        print(f"Creating K-centered visualization for sample {idx}")
        sample = df.iloc[idx]
        sequence = sample['sequence']
        
        # Skip if this sample doesn't have K at position 16
        if sequence[K_POS] != 'K':
            print(f"  Warning: Sample {idx} doesn't have K at position 16, skipping")
            continue
        
        # Use sequence for padding detection
        is_padded = [pos for pos, aa in enumerate(sequence) if aa == '-']
        
        # Get distance map and secondary structure
        dist_map = np.array(eval(sample['distance_map'])).reshape(33, 33)
        ss_string = sample['ss']
        
        # Create the graph
        G = nx.Graph()
        
        # Add all non-padded nodes
        node_colors = []
        node_sizes = []
        node_labels = {}
        
        for i in range(33):
            if i in is_padded:
                continue
            
            # Add node
            G.add_node(i)
            
            # Set node attributes based on secondary structure
            ss_type = ss_string[i]
            if ss_type == 'H':  # Helix
                color = 'red'
            elif ss_type == 'E':  # Sheet
                color = 'blue'
            else:  # Loop
                color = 'green'
                
            # Node size based on distance from K
            distance_to_k = abs(i - K_POS)
            if i == K_POS:
                color = to_rgba('purple', alpha=0.9)
                node_sizes.append(700)  # Make K even larger
            else:
                color = to_rgba(color, alpha=0.6)
                # Larger nodes for amino acids closer to K
                size = 500 - min(200, distance_to_k * 25)
                node_sizes.append(size)
                
            node_colors.append(color)
            node_labels[i] = f"{sequence[i]}{i+1}"
        
        # Add edges based on distance threshold
        edge_weights = []
        for i in range(33):
            if i in is_padded:
                continue
                
            for j in range(i+1, 33):  # Avoid duplicate edges
                if j in is_padded:
                    continue
                    
                if dist_map[i,j] != -1 and dist_map[i,j] < threshold:
                    G.add_edge(i, j, weight=dist_map[i,j])
                    
                    # Assign edge weight for visualization
                    # Edges connected to K have higher weight
                    if i == K_POS or j == K_POS:
                        edge_weights.append(3.0)
                    else:
                        # Inverse relationship with distance
                        edge_weights.append(1.5 * (1.0 / (dist_map[i,j] + 0.1)))
        
        # Compute layout that emphasizes K
        # Position K at the center
        pos = nx.spring_layout(G, seed=SEED, k=0.9)
        # Adjust position of K to center
        k_pos = pos[K_POS]
        for key in pos:
            pos[key] = pos[key] - k_pos
        pos[K_POS] = np.array([0, 0])
        
        # Scale positions slightly to spread out
        for key in pos:
            pos[key] = pos[key] * 1.2
        
        # Create the figure
        fig, ax = plt.subplots(figsize=(12, 10))
        
        # Draw nodes
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes, ax=ax)
        
        # Draw edges with varying width
        nx.draw_networkx_edges(G, pos, width=edge_weights, alpha=0.6, edge_color='gray', ax=ax)
        
        # Draw node labels
        nx.draw_networkx_labels(G, pos, labels=node_labels, font_size=10, font_weight='bold', ax=ax)
        
        # Add legend for secondary structure
        legend_elements = [
            mpatches.Patch(color='red', alpha=0.6, label='Helix (H)'),
            mpatches.Patch(color='blue', alpha=0.6, label='Sheet (E)'),
            mpatches.Patch(color='green', alpha=0.6, label='Loop (L)'),
            mpatches.Patch(color='purple', alpha=0.9, label='K position')
        ]
        ax.legend(handles=legend_elements, loc='upper right')
        
        # Add title
        plt.title(f"K-Centered Protein Graph (Sample {idx}, Threshold: {threshold}Å)")
        plt.tight_layout()
        
        # Add graph stats
        plt.figtext(0.02, 0.02, f"Nodes: {G.number_of_nodes()}, Edges: {G.number_of_edges()}", fontsize=10)
        plt.axis('off')
        
        # Save the figure
        filename = f"{save_dir}/k_centered_graph_sample_{idx}.png"
        plt.savefig(filename, dpi=300, bbox_inches='tight')
        print(f"  Saved to {filename}")
        print(f"  Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")
        
        plt.close()

if __name__ == "__main__":
    # Define paths to data files
    train_path = "../../../data/train/structure/processed_features_train.csv"
    test_path = "../../../data/test/structure/processed_features_test.csv"
    
    # Try to load the data files
    data_loaded = False
    for path in [train_path, test_path]:
        if os.path.exists(path):
            print(f"Loading data from {path}")
            try:
                df = pd.read_csv(path)
                print(f"Successfully loaded {len(df)} samples")
                data_loaded = True
                break
            except Exception as e:
                print(f"Error loading {path}: {e}")
    
    if not data_loaded:
        print("Could not load data files. Please check file paths.")
        exit(1)
    
    # Print some basic information about the dataset
    print("\nDataset information:")
    print(f"Number of samples: {len(df)}")
    
    if 'label' in df.columns:
        label_counts = df['label'].value_counts()
        print(f"Label distribution: {label_counts.to_dict()}")
    
    # Create visualizations
    print("\nCreating standard graph visualizations...")
    visualize_multiple_samples(df, num_samples=3, thresholds=[6.0, 8.0, 10.0], save_dir='./protein_visualizations')
    
    print("\nCreating K-centered graph visualizations...")
    visualize_k_centered(df, max_samples=3, threshold=8.0, save_dir='./protein_visualizations')
    
    print("\nAll visualizations complete!")

Loading data from ../data/processed_features_fixed_train_contactmap.csv
Successfully loaded 8853 samples

Dataset information:
Number of samples: 8853
Label distribution: {1: 4592, 0: 4261}

Creating standard graph visualizations...
Creating visualization for sample 1824 with threshold 6.0Å
  Saved to ./protein_visualizations/protein_graph_sample_1824_threshold_6.0.png
  Graph has 27 nodes and 54 edges
Creating visualization for sample 1824 with threshold 8.0Å
  Saved to ./protein_visualizations/protein_graph_sample_1824_threshold_8.0.png
  Graph has 27 nodes and 91 edges
Creating visualization for sample 1824 with threshold 10.0Å
  Saved to ./protein_visualizations/protein_graph_sample_1824_threshold_10.0.png
  Graph has 27 nodes and 122 edges
Creating visualization for sample 409 with threshold 6.0Å
  Saved to ./protein_visualizations/protein_graph_sample_409_threshold_6.0.png
  Graph has 33 nodes and 73 edges
Creating visualization for sample 409 with threshold 8.0Å
  Saved to ./pro