<a href="https://colab.research.google.com/github/eoinleen/basic_plotting/blob/main/basic-x-y_scatter_with_batch_mode.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Interactive Scatter Plot for Google Colab with File Upload
# Upload .txt file and customize plotting parameters
# NOW WITH BATCH MODE!

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from google.colab import files
import io
import os
from pathlib import Path

# ========================================
# MODE SELECTION
# ========================================
BATCH_MODE = False  # Set to True for batch processing, False for single file

# For BATCH_MODE = True, specify the directory containing your .txt files
# You can mount Google Drive and point to a folder, or upload multiple files
BATCH_DIRECTORY = "/content/scatter_data"  # Change this to your directory path

# ========================================
# PLOTTING PARAMETERS - MODIFY HERE
# ========================================

# Data parameters
DATA_START_ROW = 4  # Row where data starts (0-indexed, so row 5 in your file)
X_COLUMN = 3        # Column index for X-axis (0-indexed, so column 4 = pTM)
Y_COLUMN = 4        # Column index for Y-axis (0-indexed, so column 5 = ipTM)

# Axis labels (for single mode - in batch mode, these are derived from filenames)
X_LABEL = "pTM"
Y_LABEL = "ipTM"
PLOT_TITLE = "ColabFold Results: pTM vs ipTM"

# Font settings
FONT_SIZE = 12
TITLE_FONT_SIZE = 16
LABEL_FONT_SIZE = 14
TICK_FONT_SIZE = 10

# Grid settings
SHOW_GRID = True
GRID_ALPHA = 0.3
MAJOR_GRID_LINEWIDTH = 0.8
MINOR_GRID_LINEWIDTH = 0.4

# Marker settings
MARKER_STYLE = 'o'      # Options: 'o', 's', '^', 'v', '<', '>', 'D', 'p', '*', etc.
MARKER_SIZE = 3         # Size of markers
MARKER_COLOR = 'blue'   # Color name or hex code (e.g., '#FF5733')
MARKER_ALPHA = 0.6      # Transparency (0.0 to 1.0)
MARKER_EDGE_COLOR = 'black'
MARKER_EDGE_WIDTH = 0.3

# Plot appearance
FIGURE_SIZE = (10, 8)   # Width, Height in inches
DPI = 100
BACKGROUND_TRANSPARENT = True

# Tick settings
MAJOR_TICK_SPACING_X = None  # Set to None for automatic, or specify value (e.g., 0.1)
MINOR_TICK_SPACING_X = None  # Set to None for automatic, or specify value (e.g., 0.05)
MAJOR_TICK_SPACING_Y = None  # Set to None for automatic, or specify value (e.g., 0.1)
MINOR_TICK_SPACING_Y = None  # Set to None for automatic, or specify value (e.g., 0.05)

# Axis limits (set to None for automatic)
X_MIN = None
X_MAX = None
Y_MIN = None
Y_MAX = None

# Batch mode specific settings
BATCH_PLOTS_PER_ROW = 2  # Number of plots per row in batch mode
BATCH_SAVE_INDIVIDUAL = True  # Save each plot as individual PNG file
BATCH_SAVE_COMBINED = True   # Save all plots in one combined figure

# ========================================
# UTILITY FUNCTIONS
# ========================================

def parse_filename_for_labels(filename):
    """Extract X and Y labels from filename like 'pae_interaction_vs_binder_aligned_rmsd.txt'"""
    basename = Path(filename).stem  # Remove .txt extension

    if '_vs_' in basename:
        parts = basename.split('_vs_')
        x_label = parts[0].replace('_', ' ').title()
        y_label = parts[1].replace('_', ' ').title()
        return x_label, y_label
    else:
        # Fallback if filename doesn't follow convention
        return "X Value", "Y Value"

def load_and_parse_file(filepath_or_content, is_content=False):
    """Load and parse data from file or content string"""

    if is_content:
        data_string = filepath_or_content
        filename = "uploaded_file"
    else:
        with open(filepath_or_content, 'r') as f:
            data_string = f.read()
        filename = Path(filepath_or_content).name

    lines = data_string.strip().split('\n')

    # Skip comment lines and header lines
    data_lines = []
    for i, line in enumerate(lines):
        if i >= DATA_START_ROW and not line.strip().startswith('#'):
            data_lines.append(line.strip())

    # Parse data into lists
    x_values = []
    y_values = []
    failed_rows = 0

    for line in data_lines:
        if line.strip():  # Skip empty lines
            parts = line.replace('\t', ' ').split()
            if len(parts) >= max(X_COLUMN, Y_COLUMN) + 1:
                try:
                    x_val = float(parts[X_COLUMN])
                    y_val = float(parts[Y_COLUMN])
                    x_values.append(x_val)
                    y_values.append(y_val)
                except (ValueError, IndexError):
                    failed_rows += 1
                    continue

    return {
        'filename': filename,
        'x_values': x_values,
        'y_values': y_values,
        'failed_rows': failed_rows,
        'total_rows': len(data_lines)
    }

def create_single_plot(data_dict, ax=None, x_label=None, y_label=None, title=None):
    """Create a single scatter plot - can be standalone or on a subplot axis"""

    x_values = data_dict['x_values']
    y_values = data_dict['y_values']

    if len(x_values) == 0:
        print(f"❌ No valid data points in {data_dict['filename']}")
        return None

    # Use provided axis or create new figure
    if ax is None:
        fig, ax = plt.subplots(figsize=FIGURE_SIZE, dpi=DPI)
        standalone = True
    else:
        standalone = False

    # Set font sizes
    if standalone:
        plt.rcParams.update({
            'font.size': FONT_SIZE,
            'axes.titlesize': TITLE_FONT_SIZE,
            'axes.labelsize': LABEL_FONT_SIZE,
            'xtick.labelsize': TICK_FONT_SIZE,
            'ytick.labelsize': TICK_FONT_SIZE
        })

    # Create scatter plot
    ax.scatter(x_values, y_values,
               s=MARKER_SIZE**2,
               c=MARKER_COLOR,
               marker=MARKER_STYLE,
               alpha=MARKER_ALPHA,
               edgecolors=MARKER_EDGE_COLOR,
               linewidths=MARKER_EDGE_WIDTH)

    # Set labels and title
    ax.set_xlabel(x_label or X_LABEL, fontsize=LABEL_FONT_SIZE if standalone else FONT_SIZE)
    ax.set_ylabel(y_label or Y_LABEL, fontsize=LABEL_FONT_SIZE if standalone else FONT_SIZE)
    ax.set_title(title or PLOT_TITLE, fontsize=TITLE_FONT_SIZE if standalone else FONT_SIZE+2)

    # Set axis limits if specified
    if X_MIN is not None or X_MAX is not None:
        ax.set_xlim(X_MIN, X_MAX)
    if Y_MIN is not None or Y_MAX is not None:
        ax.set_ylim(Y_MIN, Y_MAX)

    # Configure grid
    if SHOW_GRID:
        ax.grid(True, alpha=GRID_ALPHA, linewidth=MAJOR_GRID_LINEWIDTH)
        ax.grid(True, which='minor', alpha=GRID_ALPHA/2, linewidth=MINOR_GRID_LINEWIDTH)

    # Set tick spacing if specified
    current_xlim = ax.get_xlim()
    current_ylim = ax.get_ylim()

    if MAJOR_TICK_SPACING_X is not None:
        start_x = np.ceil(current_xlim[0] / MAJOR_TICK_SPACING_X) * MAJOR_TICK_SPACING_X
        end_x = np.floor(current_xlim[1] / MAJOR_TICK_SPACING_X) * MAJOR_TICK_SPACING_X
        ax.set_xticks(np.arange(start_x, end_x + MAJOR_TICK_SPACING_X, MAJOR_TICK_SPACING_X))

    if MINOR_TICK_SPACING_X is not None:
        start_x = np.ceil(current_xlim[0] / MINOR_TICK_SPACING_X) * MINOR_TICK_SPACING_X
        end_x = np.floor(current_xlim[1] / MINOR_TICK_SPACING_X) * MINOR_TICK_SPACING_X
        ax.set_xticks(np.arange(start_x, end_x + MINOR_TICK_SPACING_X, MINOR_TICK_SPACING_X), minor=True)

    if MAJOR_TICK_SPACING_Y is not None:
        start_y = np.ceil(current_ylim[0] / MAJOR_TICK_SPACING_Y) * MAJOR_TICK_SPACING_Y
        end_y = np.floor(current_ylim[1] / MAJOR_TICK_SPACING_Y) * MAJOR_TICK_SPACING_Y
        ax.set_yticks(np.arange(start_y, end_y + MAJOR_TICK_SPACING_Y, MAJOR_TICK_SPACING_Y))

    if MINOR_TICK_SPACING_Y is not None:
        start_y = np.ceil(current_ylim[0] / MINOR_TICK_SPACING_Y) * MINOR_TICK_SPACING_Y
        end_y = np.floor(current_ylim[1] / MINOR_TICK_SPACING_Y) * MINOR_TICK_SPACING_Y
        ax.set_yticks(np.arange(start_y, end_y + MINOR_TICK_SPACING_Y, MINOR_TICK_SPACING_Y), minor=True)

    # Enable minor ticks
    ax.minorticks_on()

    # Set transparent background if requested
    if BACKGROUND_TRANSPARENT and standalone:
        ax.patch.set_alpha(0.0)
        plt.gcf().patch.set_alpha(0.0)

    if standalone:
        plt.tight_layout()

    return ax

def print_statistics(data_dict):
    """Print statistics for a dataset"""
    x_vals = data_dict['x_values']
    y_vals = data_dict['y_values']

    if x_vals and y_vals:
        # Derive labels from filename if possible
        x_label, y_label = parse_filename_for_labels(data_dict['filename'])

        print(f"\n📈 STATISTICS: {data_dict['filename']}")
        print("=" * 50)
        print(f"✅ Data points: {len(x_vals)}")
        if data_dict['failed_rows'] > 0:
            print(f"⚠️  Skipped rows: {data_dict['failed_rows']}")
        print(f"🔢 {x_label}: min={min(x_vals):.3f}, max={max(x_vals):.3f}, mean={np.mean(x_vals):.3f}, std={np.std(x_vals):.3f}")
        print(f"🔢 {y_label}: min={min(y_vals):.3f}, max={max(y_vals):.3f}, mean={np.mean(y_vals):.3f}, std={np.std(y_vals):.3f}")

        # Correlation coefficient
        correlation = np.corrcoef(x_vals, y_vals)[0, 1]
        print(f"📊 Correlation coefficient: {correlation:.3f}")

# ========================================
# MAIN EXECUTION LOGIC
# ========================================

if not BATCH_MODE:
    # ========================================
    # SINGLE FILE MODE (ORIGINAL BEHAVIOR)
    # ========================================

    print("📁 SINGLE FILE MODE - UPLOAD YOUR .TXT DATA FILE")
    print("=" * 60)
    print("Click the button below to upload your .txt file containing the data.")
    print("Expected format: Space or tab-separated columns")
    print("Comments (lines starting with #) will be automatically skipped.")
    print()

    # Upload file
    uploaded = files.upload()

    # Process the uploaded file
    uploaded_data = None
    filename = None

    for filename, content in uploaded.items():
        print(f"✅ Successfully uploaded: {filename}")

        try:
            uploaded_data = content.decode('utf-8')
            print(f"📊 File size: {len(content)} bytes")

            # Count lines for user info
            lines = uploaded_data.strip().split('\n')
            total_lines = len(lines)
            comment_lines = sum(1 for line in lines if line.strip().startswith('#'))
            data_lines = total_lines - comment_lines

            print(f"📋 Total lines: {total_lines}")
            print(f"💬 Comment lines: {comment_lines}")
            print(f"📈 Potential data lines: {data_lines}")

            # Show first few lines as preview
            print("\n📖 File preview (first 10 lines):")
            print("-" * 40)
            for i, line in enumerate(lines[:10]):
                print(f"{i+1:2d}: {line}")
            if len(lines) > 10:
                print(f"... and {len(lines)-10} more lines")
            print("-" * 40)

        except UnicodeDecodeError:
            print(f"❌ Error: Could not decode {filename}. Please ensure it's a text file.")
            uploaded_data = None

        break

    if uploaded_data is not None:
        print(f"\n🎨 CREATING PLOT...")
        print("=" * 30)

        data_dict = load_and_parse_file(uploaded_data, is_content=True)

        if len(data_dict['x_values']) > 0:
            create_single_plot(data_dict)
            plt.show()

            print_statistics(data_dict)
            print(f"\n✅ Plot created successfully!")
        else:
            print("❌ No valid data points found. Please check your parameters.")
    else:
        print("\n⏸️  No file uploaded.")

else:
    # ========================================
    # BATCH MODE - PROCESS MULTIPLE FILES
    # ========================================

    print("📁 BATCH MODE - PROCESSING MULTIPLE FILES")
    print("=" * 60)

    # Check if directory exists, if not, offer to upload files
    if not os.path.exists(BATCH_DIRECTORY):
        print(f"⚠️  Directory '{BATCH_DIRECTORY}' not found.")
        print("\nOptions:")
        print("1. Mount Google Drive and update BATCH_DIRECTORY path")
        print("2. Upload multiple files now")
        print()

        response = input("Upload files now? (y/n): ").lower()

        if response == 'y':
            print("\n📤 Upload all .txt files you want to process:")
            uploaded = files.upload()

            # Create temporary directory for uploaded files
            BATCH_DIRECTORY = "/content/uploaded_batch"
            os.makedirs(BATCH_DIRECTORY, exist_ok=True)

            # Save uploaded files
            for filename, content in uploaded.items():
                filepath = os.path.join(BATCH_DIRECTORY, filename)
                with open(filepath, 'wb') as f:
                    f.write(content)
                print(f"✅ Saved: {filename}")
        else:
            print("\n💡 To use batch mode:")
            print("   1. Mount Google Drive: from google.colab import drive; drive.mount('/content/drive')")
            print("   2. Set BATCH_DIRECTORY to your folder path")
            print("   3. Run this cell again")
            raise SystemExit("Batch mode cancelled")

    # Find all .txt files in directory
    txt_files = sorted(Path(BATCH_DIRECTORY).glob("*.txt"))

    if len(txt_files) == 0:
        print(f"❌ No .txt files found in {BATCH_DIRECTORY}")
        raise SystemExit("No files to process")

    print(f"\n✅ Found {len(txt_files)} .txt files:")
    for f in txt_files:
        print(f"   📄 {f.name}")
    print()

    # Load all data
    all_data = []
    for filepath in txt_files:
        print(f"Loading: {filepath.name}...")
        data_dict = load_and_parse_file(str(filepath))
        if len(data_dict['x_values']) > 0:
            all_data.append(data_dict)
            print(f"   ✅ {len(data_dict['x_values'])} data points loaded")
        else:
            print(f"   ⚠️  No valid data found, skipping")

    print(f"\n✅ Successfully loaded {len(all_data)} files")

    if len(all_data) == 0:
        print("❌ No valid data to plot!")
        raise SystemExit("No valid data")

    # Calculate grid dimensions
    n_plots = len(all_data)
    n_cols = BATCH_PLOTS_PER_ROW
    n_rows = int(np.ceil(n_plots / n_cols))

    # Create combined figure
    if BATCH_SAVE_COMBINED:
        print(f"\n🎨 Creating combined plot grid ({n_rows}x{n_cols})...")
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(FIGURE_SIZE[0]*n_cols, FIGURE_SIZE[1]*n_rows), dpi=DPI)

        # Flatten axes array for easy iteration
        if n_plots == 1:
            axes = [axes]
        else:
            axes = axes.flatten()

        # Create each subplot
        for idx, data_dict in enumerate(all_data):
            x_label, y_label = parse_filename_for_labels(data_dict['filename'])
            title = f"{x_label} vs {y_label}"

            create_single_plot(data_dict, ax=axes[idx], x_label=x_label, y_label=y_label, title=title)
            print_statistics(data_dict)

        # Hide any unused subplots
        for idx in range(n_plots, len(axes)):
            axes[idx].axis('off')

        plt.tight_layout()
        plt.show()

        # Save combined figure
        combined_filename = "batch_scatter_plots_combined.png"
        fig.savefig(combined_filename, dpi=300, bbox_inches='tight', transparent=BACKGROUND_TRANSPARENT)
        print(f"\n💾 Saved combined plot: {combined_filename}")

    # Create and save individual plots
    if BATCH_SAVE_INDIVIDUAL:
        print(f"\n🎨 Creating individual plots...")

        for data_dict in all_data:
            x_label, y_label = parse_filename_for_labels(data_dict['filename'])
            title = f"{x_label} vs {y_label}"

            plt.figure(figsize=FIGURE_SIZE, dpi=DPI)
            create_single_plot(data_dict, x_label=x_label, y_label=y_label, title=title)

            # Save individual plot
            output_filename = f"plot_{Path(data_dict['filename']).stem}.png"
            plt.savefig(output_filename, dpi=300, bbox_inches='tight', transparent=BACKGROUND_TRANSPARENT)
            print(f"   💾 Saved: {output_filename}")
            plt.close()

    print("\n" + "="*60)
    print("✅ BATCH PROCESSING COMPLETE!")
    print("="*60)
    print(f"📊 Processed {len(all_data)} files successfully")
    if BATCH_SAVE_COMBINED:
        print(f"📁 Combined plot saved")
    if BATCH_SAVE_INDIVIDUAL:
        print(f"📁 {len(all_data)} individual plots saved")

# ========================================
# QUICK REFERENCE
# ========================================

print("\n" + "="*60)
print("🎛️  QUICK CUSTOMIZATION REFERENCE")
print("="*60)
print("🔄 MODE: Set BATCH_MODE = True for multiple files, False for single file")
print("📍 Column indices (0-based): Column 1=0, Column 2=1, Column 3=2, etc.")
print("🎯 Common marker styles: 'o' (circle), 's' (square), '^' (triangle up),")
print("                        'v' (triangle down), 'D' (diamond), '*' (star)")
print("🎨 Common colors: 'red', 'blue', 'green', 'orange', 'purple', 'brown',")
print("                  'pink', 'gray', 'olive', 'cyan', or hex codes like '#FF5733'")
print("📏 Tick spacing: Set to specific values (e.g., 0.1) or None for automatic")
print("🔄 To modify: Change values in the 'PLOTTING PARAMETERS' section above")
print("="*60)