In [18]:
import numpy as np
from typing import List, Dict
import RNA  # ViennaRNA Python bindings
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import matplotlib.ticker as ticker

color_palette = sns.color_palette("muted")

class RNASequenceEvaluator:
    def __init__(self, training_set: List[str]):
        self.training_set = set(training_set)

    def calculate_mfe_structure(self, sequence: str) -> Dict:
        """
        Calculate the minimum free energy (MFE) and structure using ViennaRNA.
        """
        fc = RNA.fold_compound(sequence)
        structure, mfe = fc.mfe()
        return {"structure": structure, "mfe": mfe}

    def calculate_partition_function(self, sequence: str) -> Dict:
        """
        Calculate base pair probabilities using ViennaRNA partition function.
        """
        fc = RNA.fold_compound(sequence)
        fc.pf()  # Calculate partition function
        probabilities = fc.bpp() 
        return probabilities

        return probabilities

    def analyze_structural_ensemble(self, sequence: str) -> Dict:
        """
        Analyze structural ensemble diversity using ViennaRNA partition function.
        """
        fc = RNA.fold_compound(sequence)
        fc.pf()  # Calculate partition function
        ensemble_energy = fc.mean_bp_distance()
        return {"ensemble_diversity": ensemble_energy}

    def predict_tertiary_structure(self, sequence: str) -> Dict:
        """
        Predict tertiary structure using ViennaRNA secondary structure prediction
        as a basis (mock implementation for tertiary structure).
        """
        fc = RNA.fold_compound(sequence)
        structure, mfe = fc.mfe()
        tertiary_structure_energy = mfe - 2.0  # Approximation for demonstration
        return {"tertiary_structure_energy": tertiary_structure_energy}

    def assess_validity(self, sequence: str) -> bool:
        """
        Check if the sequence is chemically valid.
        """
        return all(base in "ACGU" for base in sequence)

    def calculate_novelty(self, sequences: List[str]) -> float:
        """
        Calculate the novelty of the generated sequences.
        """
        novel_count = sum(1 for seq in sequences if seq not in self.training_set)
        return novel_count / len(sequences)

    def calculate_external_diversity(self, generated_sequences: List[str]) -> float:
        """
        Calculate external diversity based on the Tanimoto distance.
        """
        def tanimoto_distance(x: str, y: str) -> float:
            set_x = set(x)
            set_y = set(y)
            intersection = len(set_x & set_y)
            union = len(set_x | set_y)
            return 1 - intersection / union if union != 0 else 0

        total_distance = 0.0
        comparisons = 0

        for x in generated_sequences:
            for y in self.training_set:
                total_distance += tanimoto_distance(x, y)
                comparisons += 1

        return total_distance / comparisons if comparisons > 0 else 0.0

    def calculate_perplexity(self, probabilities: List[float]) -> float:
        """
        Calculate perplexity given probabilities.
        """
        entropy = -sum(p * np.log2(p) for p in probabilities if p > 0)
        return 2 ** entropy

    def calculate_foldability(self, plDDT_scores: List[float]) -> float:
        """
        Calculate foldability based on pLDDT scores.
        """
        return np.mean(plDDT_scores)

    def evaluate_sequence(self, sequence: str) -> Dict:
        """
        Perform a comprehensive evaluation of a single RNA sequence.
        """
        results = {
            "mfe_structure": self.calculate_mfe_structure(sequence),
            "partition_function": self.calculate_partition_function(sequence),
            "structural_ensemble": self.analyze_structural_ensemble(sequence),
            "tertiary_structure": self.predict_tertiary_structure(sequence),
            "validity": self.assess_validity(sequence),
        }
        return results

    def evaluate_sequences(self, sequences: List[str]) -> Dict:
        """
        Perform a comprehensive evaluation on a set of RNA sequences.
        """
        evaluations = [self.evaluate_sequence(seq) for seq in sequences]
        novelty = self.calculate_novelty(sequences)
        external_diversity = self.calculate_external_diversity(sequences)
        return {
            "evaluations": evaluations,
            "novelty": novelty,
            "external_diversity": external_diversity
        }

    def compare_metrics_distributions(self, rna_lists: List[List[str]], labels: List[str]) -> None:
        """
        Compare metrics distributions for multiple RNA lists using various plots.
        """
        metrics = ["mfe", "ensemble_diversity"]
        results = {label: {metric: [] for metric in metrics} for label in labels}

        for label, sequences in zip(labels, rna_lists):
            for seq in sequences:
                eval_result = self.evaluate_sequence(seq)
                results[label]["mfe"].append(eval_result["mfe_structure"]["mfe"])
                results[label]["ensemble_diversity"].append(eval_result["structural_ensemble"]["ensemble_diversity"])

        for metric in metrics:
            data_arrays = [results[label][metric] for label in labels]

            # Violin Plot
            filename = f"violin_{metric}.png"
            self.plot_violin_compare(data_arrays, labels, metric, filename)

            # Bar Plot
            filename = f"bar_{metric}.png"
            self.plot_lost_data_bar(data_arrays, labels, metric, filename)

            # Ridge Plot
            filename = f"ridge_{metric}.png"
            self.plot_ridge_compare(data_arrays, labels, metric, filename)

            # Density Plot
            filename = f"density_{metric}.png"
            self.plot_density_compare(data_arrays, labels, metric, filename)

            # Box Plot
            filename = f"box_{metric}.png"
            self.plot_box_compare(data_arrays, labels, metric, filename)

    def plot_lost_data_bar(self, data_arrays, labels, ylabel, filename):
        interactions = {label: sum(data) for label, data in zip(labels, data_arrays)}
        plt.figure()
        plt.bar(interactions.keys(), interactions.values())
        plt.title("Lost Data Comparison")
        plt.xlabel("RNA Groups")
        plt.ylabel(ylabel)
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()

    def plot_violin_compare(self, data_arrays, labels, ylabel, filename):
        plt.figure()
        sns.violinplot(
            data=data_arrays,
            inner='box', 
            palette=color_palette,
            width=0.6,  
            linewidth=1.5,  
            fliersize=4,  
            whis=1.5  
        )
        plt.xticks(range(len(data_arrays)), labels)
        plt.ylabel(ylabel)
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()

    def plot_ridge_compare(self, data_arrays, labels, xlabel, filename):
        data_dict = {label: data for label, data in zip(labels, data_arrays)}
        data = pd.DataFrame(data_dict)
        df_melted = data.melt(var_name='Category', value_name='Value')
        g = sns.FacetGrid(df_melted, row="Category", hue="Category", aspect=3, height=1.5, 
                          palette=color_palette,)
        g.map(sns.kdeplot, "Value", fill=True, alpha=0.6)
        g.set_axis_labels(xlabel, "Density")
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()

    def plot_density_compare(self, data_arrays, labels, xlabel, filename):
        plt.figure()
        for data, label in zip(data_arrays, labels):
            sns.kdeplot(data, fill=True, label=label, alpha=0.3)
        plt.xlabel(xlabel, labelpad=20)
        plt.legend()
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()

    def plot_box_compare(self, data_arrays, labels, ylabel, filename):
        plt.figure()
        sns.boxplot(
            data=data_arrays,
            palette=color_palette,
            width=0.6,  
            linewidth=1.5,  
            fliersize=4,  
            whis=1.5  
        )
        plt.xticks(range(len(data_arrays)), labels)
        plt.ylabel(ylabel)
        plt.grid(True, axis='y', linestyle='--', linewidth=0.6, alpha=0.7)
        plt.gca().yaxis.set_major_locator(ticker.MaxNLocator(integer=False))
        plt.tight_layout()
        plt.savefig(filename)
        plt.close()

if __name__ == "__main__":
    training_set = ["ACGUGA", "GCUAUGC"]
    rna_sequences_group_1 = ["ACGUGA", "UCGGAC", "GGCCAU"]
    rna_sequences_group_2 = ["UGCAGU", "CCAUUG", "GAUCGU"]

    evaluator = RNASequenceEvaluator(training_set)
    evaluator.compare_metrics_distributions(
        [rna_sequences_group_1, rna_sequences_group_2],
        labels=["Group 1", "Group 2"]
    )
    evaluate_results = evaluator.evaluate_sequences(rna_sequences_group_1)
    print(evaluate_results)


  with pd.option_context('mode.use_inf_as_na', True):
  func(*plot_args, **plot_kwargs)
  with pd.option_context('mode.use_inf_as_na', True):
  func(*plot_args, **plot_kwargs)
  with pd.option_context('mode.use_inf_as_na', True):
  sns.kdeplot(data, fill=True, label=label, alpha=0.3)
  with pd.option_context('mode.use_inf_as_na', True):
  sns.kdeplot(data, fill=True, label=label, alpha=0.3)
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):


{'evaluations': [{'mfe_structure': {'structure': '......', 'mfe': 0.0}, 'partition_function': ((0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)), 'structural_ensemble': {'ensemble_diversity': 0.0}, 'tertiary_structure': {'tertiary_structure_energy': -2.0}, 'validity': True}, {'mfe_structure': {'structure': '......', 'mfe': 0.0}, 'partition_function': ((0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 6.958186003773551e-05, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)), 'structural_ensemble': {'ensemble_diversity': 0.0001391540368049784}, 'tertiary_structure': {'tertiary_structure_energy': -2.0}, 'validity': True}, {'mfe_struc