<a href="https://colab.research.google.com/github/eoinleen/Biophysics-general/blob/main/Dickin-about-Seq_analysis_RFdiff_v6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
import pandas as pd
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.cluster import AgglomerativeClustering
from scipy.spatial.distance import pdist, squareform
from scipy.cluster import hierarchy
from joblib import Parallel, delayed
import re

def extract_designed_sequences(input_file, output_file):
    """Extract designed sequences from RF_diffusion output and save to a new FASTA file."""
    if not os.path.exists(input_file):
        raise FileNotFoundError(f"Input file not found: {input_file}")
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    extracted_sequences = []
    design_lengths = set()

    for idx, record in enumerate(SeqIO.parse(input_file, "fasta")):
        try:
            match = re.search(r'design:(\d+).*n:(\d+)', record.description)
            if match:
                new_id = f"d{match.group(1)}_n{match.group(2)}"

                if '/' in str(record.seq):
                    designed_seq = str(record.seq).split('/')[1].replace('-', '').strip()
                    design_lengths.add(len(designed_seq))
                    extracted_sequences.append(SeqRecord(Seq(designed_seq), id=new_id, description=""))
                else:
                    print(f"Warning: No '/' found in sequence {record.id}")
            else:
                print(f"Warning: Could not parse sequence {idx}")
        except Exception as e:
            print(f"Error processing sequence {idx}: {e}")

    if not extracted_sequences:
        raise ValueError("No valid sequences found in input file")

    SeqIO.write(extracted_sequences, output_file, "fasta")
    print(f"Extracted {len(extracted_sequences)} sequences, saved to: {output_file}")
    return extracted_sequences

def one_hot_encode(sequences):
    """One-hot encode amino acid sequences."""
    amino_acids = "ACDEFGHIKLMNPQRSTVWY"
    aa_dict = {aa: i for i, aa in enumerate(amino_acids)}
    encoding = np.zeros((len(sequences), len(sequences[0]), len(amino_acids)))

    for i, seq in enumerate(sequences):
        for j, char in enumerate(seq):
            if char in aa_dict:
                encoding[i, j, aa_dict[char]] = 1
    return encoding.reshape(len(sequences), -1)

def compute_distance_matrix(encoded_seqs):
    """Compute pairwise Hamming distances using parallel processing."""
    return squareform(pdist(encoded_seqs, metric="hamming"))

def create_phylogenetic_tree(distance_matrix, sequence_names, output_dir):
    """Create and save a phylogenetic tree visualization."""
    os.makedirs(output_dir, exist_ok=True)
    subset_matrix = distance_matrix[:100, :100]
    subset_names = sequence_names[:100]
    Z = hierarchy.linkage(squareform(subset_matrix), method='average')

    plt.figure(figsize=(20, 5))
    hierarchy.dendrogram(Z, labels=subset_names, leaf_rotation=90, leaf_font_size=10)
    plt.title("Sequence Similarity Tree")
    plt.xlabel("Sequence ID")
    plt.ylabel("Distance")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'sequence_tree.png'), dpi=300)
    plt.close()
    print("Tree visualization saved.")

def analyze_sequences(input_file, output_dir):
    """Run complete sequence analysis pipeline."""
    os.makedirs(output_dir, exist_ok=True)
    extracted_file = os.path.join(output_dir, "extracted_sequences.fasta")
    sequences = extract_designed_sequences(input_file, extracted_file)
    sequence_list = [str(seq.seq) for seq in sequences]
    sequence_names = [seq.id for seq in sequences]

    print("Encoding sequences...")
    encoded_seqs = one_hot_encode(sequence_list)
    print("Computing distance matrix...")
    distance_matrix = compute_distance_matrix(encoded_seqs)

    print("Generating similarity heatmap...")
    plt.figure(figsize=(10, 10))
    sns.heatmap(1 - distance_matrix[:50, :50], cmap="viridis", square=True)
    plt.title("Sequence Similarity Matrix")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "similarity_heatmap.png"), dpi=300)
    plt.close()

    print("Creating phylogenetic tree...")
    create_phylogenetic_tree(distance_matrix, sequence_names, output_dir)

    print("Analysis complete. Results saved in", output_dir)
    return sequences, distance_matrix

if __name__ == "__main__":
    input_file = "/content/drive/MyDrive/Fasta-files/3NOB_90-110/3NOB_90-110_design.fasta"
    output_dir = "/content/drive/MyDrive/Fasta-files/3NOB_90-110/analysis_output"
    analyze_sequences(input_file, output_dir)


Extracted 2048 sequences, saved to: /content/drive/MyDrive/Fasta-files/3NOB_90-110/analysis_output/extracted_sequences.fasta
Encoding sequences...
Computing distance matrix...
Generating similarity heatmap...
Creating phylogenetic tree...
Tree visualization saved.
Analysis complete. Results saved in /content/drive/MyDrive/Fasta-files/3NOB_90-110/analysis_output
