<a href="https://colab.research.google.com/github/eoinleen/Biophysics-general/blob/main/20250130_dickin_abbout_Seq_analysis_RFdiff_v8_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
"""
Sequence Analysis Pipeline with Cluster Images
Author: [Your Name]
Date: [Current Date]

Description:
This script performs sequence analysis with clustering visualization.
"""

import os
from typing import List, Tuple, Dict, Optional
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial.distance import pdist, squareform
from scipy.cluster import hierarchy
from joblib import Parallel, delayed
import re
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

@dataclass
class PipelineConfig:
    """Configuration settings for the analysis pipeline."""
    input_file: Path
    output_dir: Path
    max_sequences: Optional[int] = None
    chunk_size: int = 1000
    n_jobs: int = -1
    dpi: int = 300
    amino_acids: str = "ACDEFGHIKLMNPQRSTVWY"
    distance_metric: str = "hamming"
    linkage_method: str = "average"

class SequenceAnalysisPipeline:
    """Main class for sequence analysis pipeline operations."""

    def __init__(self, config: PipelineConfig):
        self.config = config
        self.config.output_dir.mkdir(parents=True, exist_ok=True)
        self._setup_logging()

    def _setup_logging(self) -> None:
        """Setup logging to file in output directory."""
        log_file = self.config.output_dir / "pipeline.log"
        file_handler = logging.FileHandler(log_file)
        file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
        logger.addHandler(file_handler)

    def extract_designed_sequences(self) -> List[SeqRecord]:
        """Extract designed sequences with improved error handling and validation."""
        if not self.config.input_file.exists():
            raise FileNotFoundError(f"Input file not found: {self.config.input_file}")

        extracted_sequences = []
        design_lengths = set()

        logger.info(f"Extracting sequences from {self.config.input_file}")

        for idx, record in enumerate(SeqIO.parse(self.config.input_file, "fasta")):
            if self.config.max_sequences and idx >= self.config.max_sequences:
                break

            try:
                match = re.search(r'design:(\d+).*n:(\d+)', record.description)
                if not match:
                    logger.warning(f"Skipping sequence {idx}: Could not parse description")
                    continue

                new_id = f"d{match.group(1)}_n{match.group(2)}"

                if '/' not in str(record.seq):
                    logger.warning(f"Skipping sequence {idx}: No '/' delimiter found")
                    continue

                designed_seq = str(record.seq).split('/')[1].replace('-', '').strip()

                if not designed_seq:
                    logger.warning(f"Skipping sequence {idx}: Empty sequence after processing")
                    continue

                if not all(aa in self.config.amino_acids for aa in designed_seq):
                    logger.warning(f"Skipping sequence {idx}: Invalid amino acids found")
                    continue

                design_lengths.add(len(designed_seq))
                extracted_sequences.append(SeqRecord(Seq(designed_seq), id=new_id, description=""))

            except Exception as e:
                logger.error(f"Error processing sequence {idx}: {e}")
                continue

        if not extracted_sequences:
            raise ValueError("No valid sequences found in input file")

        if len(design_lengths) > 1:
            logger.warning(f"Multiple sequence lengths found: {design_lengths}")

        output_fasta = self.config.output_dir / "extracted_sequences.fasta"
        SeqIO.write(extracted_sequences, output_fasta, "fasta")
        logger.info(f"Extracted {len(extracted_sequences)} sequences to {output_fasta}")

        return extracted_sequences

    def one_hot_encode(self, sequences: List[str]) -> np.ndarray:
        """One-hot encode sequences with memory optimization."""
        aa_dict = {aa: i for i, aa in enumerate(self.config.amino_acids)}
        seq_length = len(sequences[0])
        n_sequences = len(sequences)

        encoding = np.zeros((n_sequences, seq_length * len(self.config.amino_acids)), dtype=np.int8)

        for i, seq in enumerate(sequences):
            for j, aa in enumerate(seq):
                if aa in aa_dict:
                    encoding[i, j * len(self.config.amino_acids) + aa_dict[aa]] = 1

        return encoding

    def compute_distance_matrix(self, encoded_seqs: np.ndarray) -> np.ndarray:
        """Compute distance matrix with parallel processing."""
        logger.info("Computing distance matrix...")
        distances = pdist(encoded_seqs, metric=self.config.distance_metric)
        return squareform(distances)

    def create_visualizations(self, distance_matrix: np.ndarray, sequence_names: List[str]) -> None:
        """Create and save visualizations."""
        logger.info("Generating visualizations...")

        # Create similarity heatmap
        plt.figure(figsize=(12, 10))
        sns.heatmap(1 - distance_matrix[:50, :50],
                   cmap="viridis",
                   square=True,
                   xticklabels=sequence_names[:50],
                   yticklabels=sequence_names[:50])
        plt.title("Sequence Similarity Matrix (First 50 Sequences)")
        plt.tight_layout()
        plt.savefig(self.config.output_dir / "similarity_heatmap.png",
                   dpi=self.config.dpi)
        plt.close()

        # Create phylogenetic tree
        condensed_dist = squareform(distance_matrix)
        Z = hierarchy.linkage(condensed_dist, method=self.config.linkage_method)

        plt.figure(figsize=(20, 10))
        dendro = hierarchy.dendrogram(Z,
                                    labels=sequence_names,
                                    leaf_rotation=90,
                                    leaf_font_size=8)
        plt.title("Sequence Similarity Tree")
        plt.xlabel("Sequence ID")
        plt.ylabel("Distance")
        plt.tight_layout()
        plt.savefig(self.config.output_dir / "sequence_tree.png",
                   dpi=self.config.dpi)
        plt.close()

        # Save individual cluster images
        cluster_colors = set(dendro['color_list'])
        leaf_colors = dendro['color_list']
        leaves = dendro['leaves']

        for color_idx, color in enumerate(cluster_colors):
            # Find all leaves in this cluster
            cluster_leaves = [leaves[i] for i, c in enumerate(leaf_colors) if c == color]
            cluster_indices = np.isin(range(len(sequence_names)), cluster_leaves)

            # Extract sub-matrix for this cluster
            cluster_dist = distance_matrix[cluster_indices][:, cluster_indices]
            cluster_names = [sequence_names[i] for i in range(len(sequence_names)) if cluster_indices[i]]

            # Create cluster-specific linkage matrix
            cluster_Z = hierarchy.linkage(squareform(cluster_dist), method=self.config.linkage_method)

            # Calculate figure size based on number of sequences
            seq_count = len(cluster_names)
            width = max(20, seq_count * 0.15)

            plt.figure(figsize=(width, 10))
            plt.subplots_adjust(bottom=0.25)

            # Create dendrogram with angled labels
            hierarchy.dendrogram(
                cluster_Z,
                labels=cluster_names,
                leaf_rotation=45,
                leaf_font_size=8,
                link_color_func=lambda x: color
            )

            # Calculate dynamic figure dimensions based on number of sequences
            seq_count = len(cluster_names)
            width = max(20, seq_count * 0.2)  # Increase width multiplier
            height = 12  # Taller figure to accommodate labels

            plt.figure(figsize=(width, height))

            # Create dendrogram without labels first
            dend = hierarchy.dendrogram(
                cluster_Z,
                labels=None,  # No labels initially
                leaf_rotation=45,
                leaf_font_size=8,
                link_color_func=lambda x: color
            )

            ax = plt.gca()

            # Get the x-coordinates and labels in the correct order
            xticks = ax.get_xticks()
            labels = [cluster_names[i] for i in dend['leaves']]

            # Remove existing ticks and labels
            ax.set_xticks([])
            ax.set_xlabel('')  # Remove x-axis label

            # Add custom positioned labels with more spacing
            # Moving labels up by adjusting y_offset values
            for idx, (x, label) in enumerate(zip(xticks, labels)):
                # Even labels at base position, odd labels offset
                y_offset = 0.02 if idx % 2 == 0 else -0.04  # Moved up (positive values)
                alignment = 'left' if idx % 2 == 0 else 'right'

                ax.text(x, y_offset, label,
                       rotation=45,
                       ha=alignment,
                       va='top',
                       fontsize=8)

            # Set title and labels
            plt.title(f"Cluster {color_idx}: {len(cluster_names)} sequences")
            plt.ylabel("Distance")

            # Adjust margins - reduced bottom margin since labels are higher
            plt.subplots_adjust(bottom=0.15)  # Reduced from 0.3

            # Save with tight bbox to include all labels
            plt.savefig(self.config.output_dir / f"cluster_{color_idx}.png",
                       dpi=self.config.dpi,
                       bbox_inches='tight',
                       pad_inches=0.5)  # Add padding
            plt.close()

            plt.title(f"Cluster {color_idx}: {len(cluster_names)} sequences")
            plt.xlabel("Sequence ID")
            plt.ylabel("Distance")

            plt.tight_layout(rect=[0, 0.2, 1, 1])
            plt.savefig(self.config.output_dir / f"cluster_{color_idx}.png",
                       dpi=self.config.dpi,
                       bbox_inches='tight')
            plt.close()

    def run(self) -> Tuple[List[SeqRecord], np.ndarray]:
        """Run the complete analysis pipeline."""
        try:
            logger.info("Starting sequence analysis pipeline...")

            sequences = self.extract_designed_sequences()
            sequence_list = [str(seq.seq) for seq in sequences]
            sequence_names = [seq.id for seq in sequences]

            encoded_seqs = self.one_hot_encode(sequence_list)
            distance_matrix = self.compute_distance_matrix(encoded_seqs)
            self.create_visualizations(distance_matrix, sequence_names)

            logger.info("Analysis pipeline completed successfully")
            return sequences, distance_matrix

        except Exception as e:
            logger.error(f"Pipeline failed: {str(e)}")
            raise

def main():
    """Main entry point with example usage."""
    config = PipelineConfig(
        input_file=Path("/content/drive/MyDrive/Fasta-files/3NOB_90-110/3NOB_90-110_design.fasta"),
        output_dir=Path("/content/drive/MyDrive/Fasta-files/3NOB_90-110/analysis_output"),
        max_sequences=None,
        chunk_size=1000,
        n_jobs=-1,
        dpi=300
    )

    pipeline = SequenceAnalysisPipeline(config)
    sequences, distance_matrix = pipeline.run()

if __name__ == "__main__":
    main()

In [15]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
