<a href="https://colab.research.google.com/github/eoinleen/Protein-design-random/blob/main/WIP_v3_RFdiff_MSA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
RF_Diffusion Sequence Analysis Pipeline
-------------------------------------
Created by: Claude AI (Anthropic) with input from user
Version: 2.1
Date: January 27, 2025

This script analyzes sequences from RF_diffusion output files (.fasta format).
It processes sequences that are in the format:
>design:0 n:0|mpnn:1.589|plddt:0.460|i_ptm:0.173|i_pae:20.678|rmsd:8.341 SEQUENCE1/SEQUENCE2

IMPORTANT UPDATE: V2.1 fixes sequence extraction to analyze only the designed sequences
after the '/' delimiter, ignoring the scaffold sequence before the '/'.

Features:
- Extracts sequences after the '/' delimiter
- Performs multiple sequence alignment using MUSCLE
- Generates phylogenetic trees (rectangular and circular)
- Calculates sequence conservation scores
- Creates position-specific scoring matrices (PSSM)
- Performs cluster analysis
- Colors sequences by i_pae scores
- Provides statistical analysis of tree structure

Usage:
1. Upload this script to Google Colab
2. Modify the fasta_path variable at the bottom to point to your input file
3. Run the entire script

Output:
All files are saved in the same directory as the input file:
- conservation_plot.png: Shows conservation across sequence positions
- phylogenetic_tree_rectangular.png: Traditional tree visualization
- phylogenetic_tree_circular.png: Circular tree visualization
- cluster_X_tree.png: Individual cluster trees
- pssm_heatmap.png: Shows amino acid frequencies at each position
- tree_statistics.txt: Statistical analysis of tree structure
- aligned.fasta: Multiple sequence alignment output

Credits:
Primary Developer: Claude AI (Anthropic)
Contributing Developer: Dr Eoin Leen, University of Leeds
"""

# [Previous imports and installations remain the same]

class RFDiffusionAnalyzer:
    def __init__(self, fasta_path):
        self.fasta_path = fasta_path
        self.output_dir = os.path.dirname(os.path.abspath(fasta_path))
        self.sequences = []
        self.alignment = None
        self.tree = None
        self.conservation_scores = None
        self.pssm = None
        self.clusters = None

        # Create output directory if it doesn't exist
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        print(f"Analysis will use file: {self.fasta_path}")
        print(f"Outputs will be saved to: {self.output_dir}")

    def extract_sequences(self):
        """Extract only the designed sequences after the '/' from RF_diffusion output"""
        sequences = []
        design_lengths = set()  # To track sequence lengths

        for record in SeqIO.parse(self.fasta_path, "fasta"):
            split_seq = str(record.seq).split('/')
            if len(split_seq) > 1:
                designed_seq = split_seq[1].strip()
                design_lengths.add(len(designed_seq))

                # Create new record with only the designed sequence
                new_record = SeqRecord(
                    Seq(designed_seq),
                    id=f"design_{len(sequences)}",
                    description=record.description
                )
                sequences.append(new_record)

        # Validation checks
        if not sequences:
            raise ValueError("No valid sequences found. Check if file contains '/' delimiters.")

        if len(design_lengths) > 1:
            print("WARNING: Designed sequences have different lengths:", design_lengths)

        print(f"Extracted {len(sequences)} designed sequences of length {list(design_lengths)[0]}")

        self.sequences = sequences
        return sequences

    def validate_sequence(self, seq):
        """Validate sequence contains only valid amino acid characters"""
        valid_aa = set('ACDEFGHIKLMNPQRSTVWY')
        seq_chars = set(seq.upper())
        invalid_chars = seq_chars - valid_aa
        if invalid_chars:
            print(f"WARNING: Found invalid amino acid characters: {invalid_chars}")
            return False
        return True

    def create_alignment(self):
        """Create MSA using MUSCLE with added validation"""
        if not self.sequences:
            self.extract_sequences()

        # Validate sequences before alignment
        for record in self.sequences:
            self.validate_sequence(str(record.seq))

        temp_fasta = "temp_sequences.fasta"
        SeqIO.write(self.sequences, temp_fasta, "fasta")

        try:
            !muscle -in {temp_fasta} -out aligned.fasta
            self.alignment = AlignIO.read("aligned.fasta", "fasta")
            print(f"Successfully aligned {len(self.alignment)} sequences of length {self.alignment.get_alignment_length()}")
            return self.alignment
        except Exception as e:
            print(f"Error during alignment: {str(e)}")
            raise

    # [Previous methods remain the same until plot_conservation]

    def plot_conservation(self):
        """Plot conservation scores with improved visualization"""
        if self.conservation_scores is None:
            self.calculate_conservation()

        plt.figure(figsize=(15, 5))
        plt.plot(self.conservation_scores, 'b-', linewidth=2)
        plt.fill_between(range(len(self.conservation_scores)),
                        self.conservation_scores,
                        alpha=0.2)

        plt.title('Sequence Conservation by Position')
        plt.xlabel('Position in Designed Sequence')
        plt.ylabel('Conservation Score')
        plt.grid(True, alpha=0.3)

        # Add mean conservation line
        mean_conservation = np.mean(self.conservation_scores)
        plt.axhline(y=mean_conservation, color='r', linestyle='--',
                   label=f'Mean Conservation: {mean_conservation:.2f}')
        plt.legend()

        plt.ylim(0, 1.05)  # Set y-axis limits with small padding
        plt.tight_layout()

        save_path = os.path.join(self.output_dir, 'conservation_plot.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Conservation plot saved to: {save_path}")
        plt.show()
        plt.close()

    # [Rest of the methods remain the same]

    def run_complete_analysis(self):
        """Run all analyses and generate plots with improved error handling"""
        try:
            print("Analysis starting...")
            print("Output will be saved to:", self.output_dir)

            print("\n1. Extracting sequences...")
            self.extract_sequences()

            print("\n2. Creating alignment...")
            self.create_alignment()

            print("\n3. Calculating conservation...")
            self.calculate_conservation()

            print("\n4. Generating tree...")
            self.generate_tree()

            print("\n5. Calculating PSSM...")
            self.calculate_pssm()

            print("\n6. Generating and saving plots...")
            self.plot_conservation()
            self.plot_tree_with_options(style='both', color_by_pae=True, cluster_analysis=True)
            self.plot_pssm_heatmap()

            print("\nAnalysis complete! Files saved in:", self.output_dir)

        except Exception as e:
            print(f"\nERROR: Analysis failed - {str(e)}")
            raise

# [Previous helper functions remain the same]

# Specify the full path to your file
fasta_path = "/content/drive/MyDrive/path/to/your/design.fasta"  # Modify this path

# Run analysis
try:
    analyzer = run_analysis(fasta_path)
except Exception as e:
    print(f"Failed to complete analysis: {str(e)}")