<a href="https://colab.research.google.com/github/eoinleen/Protein-design-random/blob/main/WIP_v2_RFdiff_MSA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
RF_Diffusion Sequence Analysis Pipeline
-------------------------------------
Created by: Claude AI (Anthropic) with input from user
Version: 2.0
Date: January 27, 2025

This script analyzes sequences from RF_diffusion output files (.fasta format).
It processes sequences that are in the format:
>design:0 n:0|mpnn:1.589|plddt:0.460|i_ptm:0.173|i_pae:20.678|rmsd:8.341 SEQUENCE1/SEQUENCE2

Features:
- Extracts sequences after the '/' delimiter
- Performs multiple sequence alignment using MUSCLE
- Generates phylogenetic trees (rectangular and circular)
- Calculates sequence conservation scores
- Creates position-specific scoring matrices (PSSM)
- Performs cluster analysis
- Colors sequences by i_pae scores
- Provides statistical analysis of tree structure

Usage:
1. Upload this script to Google Colab
2. Modify the fasta_path variable at the bottom to point to your input file
3. Run the entire script

Output:
All files are saved in the same directory as the input file:
- conservation_plot.png: Shows conservation across sequence positions
- phylogenetic_tree_rectangular.png: Traditional tree visualization
- phylogenetic_tree_circular.png: Circular tree visualization
- cluster_X_tree.png: Individual cluster trees
- pssm_heatmap.png: Shows amino acid frequencies at each position
- tree_statistics.txt: Statistical analysis of tree structure
- aligned.fasta: Multiple sequence alignment output

Credits:
Primary Developer: Claude AI (Anthropic)
Contributing User: [Your input helped shape the specific application]
"""

# Required installations
import sys
!{sys.executable} -m pip install bio
!{sys.executable} -m pip install matplotlib
!{sys.executable} -m pip install seaborn
!{sys.executable} -m pip install scipy
!apt-get install muscle
!apt-get install -y hmmer

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Import required libraries
from Bio import SeqIO, AlignIO, Phylo
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Align import MultipleSeqAlignment
from Bio.Phylo.TreeConstruction import DistanceCalculator, DistanceTreeConstructor
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from io import StringIO
import os
from scipy.cluster.hierarchy import fcluster, linkage
import pandas as pd

class RFDiffusionAnalyzer:
    def __init__(self, fasta_path):
        self.fasta_path = fasta_path
        self.output_dir = os.path.dirname(os.path.abspath(fasta_path))
        self.sequences = []
        self.alignment = None
        self.tree = None
        self.conservation_scores = None
        self.pssm = None
        self.clusters = None

        print(f"Analysis will use file: {self.fasta_path}")
        print(f"Outputs will be saved to: {self.output_dir}")

    def extract_sequences(self):
        """Extract sequences after the '/' from RF_diffusion output"""
        sequences = []
        for record in SeqIO.parse(self.fasta_path, "fasta"):
            split_seq = str(record.seq).split('/')
            if len(split_seq) > 1:
                new_record = SeqRecord(
                    Seq(split_seq[1].strip()),
                    id=f"design_{len(sequences)}",
                    description=record.description
                )
                sequences.append(new_record)
        self.sequences = sequences
        return sequences

    def create_alignment(self):
        """Create MSA using MUSCLE"""
        temp_fasta = "temp_sequences.fasta"
        SeqIO.write(self.sequences, temp_fasta, "fasta")
        !muscle -in {temp_fasta} -out aligned.fasta
        self.alignment = AlignIO.read("aligned.fasta", "fasta")
        return self.alignment

    def calculate_conservation(self):
        """Calculate conservation scores for each position"""
        if self.alignment is None:
            self.create_alignment()

        conservation_scores = []
        for i in range(self.alignment.get_alignment_length()):
            column = self.alignment[:, i]
            unique, counts = np.unique(list(column), return_counts=True)
            conservation = max(counts) / len(self.alignment)
            conservation_scores.append(conservation)

        self.conservation_scores = conservation_scores
        return conservation_scores

    def analyze_clusters(self):
        """Analyze clusters in the phylogenetic tree"""
        if self.tree is None:
            self.generate_tree()

        # Convert tree distances to matrix
        distances = []
        names = []
        for leaf in self.tree.get_terminals():
            names.append(leaf.name)
            distances.append([self.tree.distance(leaf, x) for x in self.tree.get_terminals()])

        # Create distance matrix
        dist_matrix = pd.DataFrame(distances, index=names, columns=names)

        # Perform clustering
        Z = linkage(dist_matrix, method='ward')
        clusters = fcluster(Z, t=5, criterion='maxclust')  # Adjust t for number of clusters

        # Store cluster information
        self.clusters = {name: cluster for name, cluster in zip(names, clusters)}
        return self.clusters

    def extract_scores(self, sequence_description):
        """Extract scores from sequence description"""
        scores = {}
        parts = sequence_description.split('|')
        for part in parts:
            if ':' in part:
                key, value = part.split(':')
                try:
                    scores[key] = float(value)
                except ValueError:
                    continue
        return scores

    def generate_tree(self):
        """Generate phylogenetic tree"""
        if self.alignment is None:
            self.create_alignment()

        calculator = DistanceCalculator('identity')
        dm = calculator.get_distance(self.alignment)
        constructor = DistanceTreeConstructor(calculator, 'upgma')
        self.tree = constructor.build_tree(self.alignment)
        return self.tree

    def calculate_pssm(self):
        """Calculate position-specific scoring matrix"""
        if self.alignment is None:
            self.create_alignment()

        amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
        length = self.alignment.get_alignment_length()
        pssm = np.zeros((len(amino_acids), length))

        for i in range(length):
            column = self.alignment[:, i]
            for j, aa in enumerate(amino_acids):
                pssm[j, i] = list(column).count(aa) / len(self.alignment)

        self.pssm = pssm
        return pssm

    def plot_conservation(self):
        """Plot conservation scores"""
        if self.conservation_scores is None:
            self.calculate_conservation()

        plt.figure(figsize=(15, 5))
        plt.plot(self.conservation_scores)
        plt.title('Sequence Conservation by Position')
        plt.xlabel('Position')
        plt.ylabel('Conservation Score')
        plt.savefig(os.path.join(self.output_dir, 'conservation_plot.png'))
        plt.show()
        plt.close()

    def plot_tree_with_options(self, style='rectangular', color_by_pae=True, cluster_analysis=True):
        """Plot phylogenetic tree with multiple visualization options"""
        if self.tree is None:
            self.generate_tree()

        # Analyze clusters if requested
        if cluster_analysis:
            clusters = self.analyze_clusters()

        # Create different tree layouts
        layouts = ['rectangular', 'circular'] if style == 'both' else [style]

        for layout in layouts:
            plt.figure(figsize=(20, 20))

            # Set up colors based on i_pae scores
            colors = {}
            if color_by_pae:
                for record in SeqIO.parse(self.fasta_path, "fasta"):
                    scores = self.extract_scores(record.description)
                    design_id = record.id.split()[0]
                    colors[design_id] = 'blue' if scores.get('i_pae', float('inf')) < 8 else 'black'

            # Customize label format
            def custom_label(leaf):
                original_name = leaf.name
                parts = original_name.split('_')
                if len(parts) >= 2:
                    return f"{parts[0]}-{parts[1]}"
                return original_name

            # Draw tree
            if layout == 'rectangular':
                Phylo.draw(self.tree,
                          do_show=False,
                          label_func=custom_label,
                          show_confidence=False,
                          branch_labels=None,
                          label_colors=colors)
            else:  # circular layout
                fig = plt.figure(figsize=(20, 20))
                ax = fig.add_subplot(111, projection='polar')
                Phylo.draw_graphviz(self.tree, prog='twopi')

            plt.title(f'Phylogenetic Tree of Designs ({layout} layout)')

            # Save plot
            plt.savefig(os.path.join(self.output_dir, f'phylogenetic_tree_{layout}.png'),
                       dpi=300, bbox_inches='tight')
            plt.close()

            # Generate statistical analysis
            self.analyze_tree_statistics()

            # If clustering was performed, create separate plots for each cluster
            if cluster_analysis:
                self.plot_clusters()

    def analyze_tree_statistics(self):
        """Perform statistical analysis of tree structure"""
        stats = {
            'total_branch_length': self.tree.total_branch_length(),
            'max_depth': max(node.depth for node in self.tree.get_nonterminals()),
            'num_terminals': len(self.tree.get_terminals()),
            'balance': self.calculate_tree_balance(),
            'avg_branch_length': np.mean([node.branch_length for node in self.tree.get_nonterminals()
                                        if node.branch_length is not None])
        }

        # Save statistics to file
        with open(os.path.join(self.output_dir, 'tree_statistics.txt'), 'w') as f:
            for key, value in stats.items():
                f.write(f"{key}: {value}\n")

        return stats

    def calculate_tree_balance(self):
        """Calculate Colless's tree balance index"""
        def count_descendants(clade):
            if clade.is_terminal():
                return 1
            return sum(count_descendants(c) for c in clade.clades)

        def colless_index(clade):
            if clade.is_terminal():
                return 0
            if len(clade.clades) == 2:
                left, right = clade.clades
                return abs(count_descendants(left) - count_descendants(right)) + \
                       colless_index(left) + colless_index(right)
            return sum(colless_index(c) for c in clade.clades)

        return colless_index(self.tree.root)

    def plot_clusters(self):
        """Plot separate trees for each cluster"""
        if not hasattr(self, 'clusters'):
            return

        # Group sequences by cluster
        cluster_groups = {}
        for name, cluster in self.clusters.items():
            if cluster not in cluster_groups:
                cluster_groups[cluster] = []
            cluster_groups[cluster].append(name)

        # Create and save separate plots for each cluster
        for cluster_id, members in cluster_groups.items():
            # Create subtree
            subtree = self.tree.common_ancestor(members)

            plt.figure(figsize=(15, 15))
            Phylo.draw(subtree,
                      do_show=False,
                      label_func=lambda x: x.name[:10],
                      show_confidence=False)
            plt.title(f'Cluster {cluster_id} (n={len(members)})')
            plt.savefig(os.path.join(self.output_dir, f'cluster_{cluster_id}_tree.png'),
                       dpi=300, bbox_inches='tight')
            plt.close()

    def plot_pssm_heatmap(self):
        """Plot PSSM heatmap"""
        if self.pssm is None:
            self.calculate_pssm()

        plt.figure(figsize=(20, 10))
        sns.heatmap(self.pssm,
                   yticklabels=list('ACDEFGHIKLMNPQRSTVWY'),
                   cmap='YlOrRd')
        plt.title('Position-Specific Scoring Matrix')
        plt.xlabel('Position')
        plt.ylabel('Amino Acid')
        plt.savefig(os.path.join(self.output_dir, 'pssm_heatmap.png'))
        plt.show()
        plt.close()

    def run_complete_analysis(self):
        """Run all analyses and generate plots"""
        print("Analysis starting...")
        print("Output will be saved to:", self.output_dir)

        print("\n1. Extracting sequences...")
        self.extract_sequences()

        print("2. Creating alignment...")
        self.create_alignment()

        print("3. Calculating conservation...")
        self.calculate_conservation()

        print("4. Generating tree...")
        self.generate_tree()

        print("5. Calculating PSSM...")
        self.calculate_pssm()

        print("\n6. Generating and saving plots...")
        self.plot_conservation()
        self.plot_tree_with_options(style='both', color_by_pae=True, cluster_analysis=True)
        self.plot_pssm_heatmap()

        print("\nAnalysis complete! Files saved in:", self.output_dir)

def run_analysis(fasta_path):
    """Run analysis on RF_diffusion output file"""
    try:
        analyzer = RFDiffusionAnalyzer(fasta_path)
        analyzer.run_complete_analysis()
        return analyzer
    except Exception as e:
        print(f"Error occurred: {str(e)}")
        return None

# Specify the full path to your file
fasta_path = "/content/drive/MyDrive/path/to/your/design.fasta"  # Modify this path

# Run analysis
analyzer = run_analysis(fasta_path)