In [None]:
import ast
import pandas as pd
import numpy as np

In [None]:
design_dfs = {
    "train": pd.read_csv("/home/akubaney/projects/na_mpnn/data/datasets/design_dataset_v2/train.csv"),
    "valid": pd.read_csv("/home/akubaney/projects/na_mpnn/data/datasets/design_dataset_v2/valid.csv"),
    "test": pd.read_csv("/home/akubaney/projects/na_mpnn/data/datasets/design_dataset_v2/test.csv"),
}
specificity_dfs = {
    "train": pd.read_csv("/home/akubaney/projects/na_mpnn/data/datasets/specificity_dataset_v2/train.csv"),
    "valid": pd.read_csv("/home/akubaney/projects/na_mpnn/data/datasets/specificity_dataset_v2/valid.csv"),
    "test": pd.read_csv("/home/akubaney/projects/na_mpnn/data/datasets/specificity_dataset_v2/test.csv"),
}
eval_dfs = {
    "design_valid": pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_valid_plot.csv"),
    "design_test": pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_test_plot.csv"),
    "design_rna_monomer_test": pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_rna_monomer_test_plot.csv"),
    "design_pseudoknot_test": pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_pseudoknot_test_plot.csv"),
    "specificity_valid": pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/specificity_valid_plot.csv"),
    "specificity_test": pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/specificity_test_plot.csv"),
}

In [None]:
def get_train_valid_test_entry_statistics(
    split_dfs, 
    polymer_type_stats = True,
    ppm_stats = True
): 
    for split_name, split_df in split_dfs.items():
        # Count number of DNA/RNA/hybrid entries with and without protein in 
        # dataset.
        dna_entries = 0
        rna_entries = 0
        dna_rna_hybrid_entries = 0
        dna_protein_entries = 0
        rna_protein_entries = 0
        dna_rna_hybrid_protein_entries = 0
        multi_nucleic_acid_types_entries = 0
        multi_nucleic_acid_types_protein_entries = 0
        # Count number of entries with PPMs, from crystal and distillation.
        entries_with_ppm = 0
        entries_with_ppm_from_crystal = 0
        entries_with_ppm_from_distill_cis_bp = 0
        entries_with_ppm_from_distill_transfac = 0
        entries_without_ppm = 0
        entries_without_ppm_from_crystal = 0
        entries_without_ppm_from_distill_cis_bp = 0
        entries_without_ppm_from_distill_transfac = 0
        for (protein_chain_cluster_ids, 
             nucleic_acid_chain_cluster_ids_chain_types,
             ppm_paths,
             dataset_name) in zip(
            split_df["protein_chain_cluster_ids"], 
            split_df["nucleic_acid_chain_cluster_ids_chain_types"],
            split_df["ppm_paths"],
            split_df["dataset_name"]
        ):
            # Convert the chain cluster ids and chain types to lists.
            protein_chain_cluster_ids = ast.literal_eval(
                protein_chain_cluster_ids
            ) 
            nucleic_acid_chain_cluster_ids_chain_types = ast.literal_eval(
                nucleic_acid_chain_cluster_ids_chain_types
            )
            ppm_paths = ast.literal_eval(ppm_paths)

            # Count the number of DNA/RNA/hybrid entries with and without 
            # protein in the complex.
            protein_in_complex = (len(protein_chain_cluster_ids) > 0)
            dna_in_complex = (nucleic_acid_chain_cluster_ids_chain_types.count("polydeoxyribonucleotide") > 0)
            rna_in_complex = (nucleic_acid_chain_cluster_ids_chain_types.count("polyribonucleotide") > 0)
            dna_rna_hybrid_in_complex = (nucleic_acid_chain_cluster_ids_chain_types.count("polydeoxyribonucleotide/polyribonucleotide hybrid") > 0)
            if dna_in_complex and not rna_in_complex and not dna_rna_hybrid_in_complex:
                if protein_in_complex:
                    dna_protein_entries += 1
                else:
                    dna_entries += 1
            elif rna_in_complex and not dna_in_complex and not dna_rna_hybrid_in_complex:
                if protein_in_complex:
                    rna_protein_entries += 1
                else:
                    rna_entries += 1
            elif dna_rna_hybrid_in_complex and not dna_in_complex and not rna_in_complex:
                if protein_in_complex:
                    dna_rna_hybrid_protein_entries += 1
                else:
                    dna_rna_hybrid_entries += 1
            else:
                if protein_in_complex:
                    multi_nucleic_acid_types_protein_entries += 1
                else:
                    multi_nucleic_acid_types_entries += 1
            
            # Count the number of entries with PPMs, as well as record pfam
            # id and description counts for entries with PPMs.
            if len(ppm_paths) > 0:
                entries_with_ppm += 1
                if dataset_name == "rcsb_cif_na":
                    entries_with_ppm_from_crystal += 1
                elif dataset_name == "rf2na_distillation_cis_bp":
                    entries_with_ppm_from_distill_cis_bp += 1
                elif dataset_name == "rf2na_distillation_transfac":
                    entries_with_ppm_from_distill_transfac += 1
            else:
                entries_without_ppm += 1
                if dataset_name == "rcsb_cif_na":
                    entries_without_ppm_from_crystal += 1
                elif dataset_name == "rf2na_distillation_cis_bp":
                    entries_without_ppm_from_distill_cis_bp += 1
                elif dataset_name == "rf2na_distillation_transfac":
                    entries_without_ppm_from_distill_transfac += 1

        print(f"{split_name}")
        if polymer_type_stats:    
            print(f"\tDNA entries: {dna_entries}")
            print(f"\tRNA entries: {rna_entries}")
            print(f"\tDNA/RNA hybrid entries: {dna_rna_hybrid_entries}")
            print(f"\tMulti-nucleic acid types entries: {multi_nucleic_acid_types_entries}")
            print(f"\tDNA-protein entries: {dna_protein_entries}")
            print(f"\tRNA-protein entries: {rna_protein_entries}")
            print(f"\tDNA/RNA hybrid-protein entries: {dna_rna_hybrid_protein_entries}")
            print(f"\tMulti-nucleic acid types-protein entries: {multi_nucleic_acid_types_protein_entries}")
        if ppm_stats:
            print(f"\tEntries with PPMs: {entries_with_ppm}")
            print(f"\t\tFrom crystal: {entries_with_ppm_from_crystal}")
            print(f"\t\tFrom distillation (CisBP): {entries_with_ppm_from_distill_cis_bp}")
            print(f"\t\tFrom distillation (TRANSFAC): {entries_with_ppm_from_distill_transfac}")
            print(f"\tEntries without PPMs: {entries_without_ppm}")
            print(f"\t\tFrom crystal: {entries_without_ppm_from_crystal}")
            print(f"\t\tFrom distillation (CisBP): {entries_without_ppm_from_distill_cis_bp}")
            print(f"\t\tFrom distillation (TRANSFAC): {entries_without_ppm_from_distill_transfac}")

In [None]:
get_train_valid_test_entry_statistics(design_dfs, polymer_type_stats=True, ppm_stats=False)

In [None]:
get_train_valid_test_entry_statistics(specificity_dfs, polymer_type_stats=False, ppm_stats=True)

In [None]:
def get_evaluation_statistics(
    eval_df_name,
    eval_df
):
    print(f"{eval_df_name}")

    num_inputs = np.unique(eval_df["structure_path"]).shape[0]

    num_models = np.unique(eval_df["Model"]).shape[0]
    
    if "design" in eval_df_name:
        num_runs = 10
        polymer_type_stats = True
        ppm_stats = False
    elif "specificity" in eval_df_name:
        num_runs = 1
        polymer_type_stats = False
        ppm_stats = True
    
    assert(len(eval_df) == num_inputs * num_models * num_runs)

    print(f"\tNumber of input structures: {num_inputs}")
    print(f"\tNumber of models: {num_models}")
    print(f"\tNumber of runs: {num_runs}")
    print()

    if polymer_type_stats:
        num_dna = len(eval_df[eval_df["Group"] == "DNA"]) // (num_models * num_runs)
        num_rna = len(eval_df[eval_df["Group"] == "RNA"]) // (num_models * num_runs)
        num_dna_protein = len(eval_df[eval_df["Group"] == "DNA (protein context)"]) // (num_models * num_runs)
        num_rna_protein = len(eval_df[eval_df["Group"] == "RNA (protein context)"]) // (num_models * num_runs)

        print(f"\tNumber of DNA inputs: {num_dna}")
        print(f"\tNumber of RNA inputs: {num_rna}")
        print(f"\tNumber of DNA-protein inputs: {num_dna_protein}")
        print(f"\tNumber of RNA-protein inputs: {num_rna_protein}")
    
    if ppm_stats:
        num_distillation = len(eval_df[eval_df["Group"] == "Distillation"]) // (num_models * num_runs)
        num_distillation_cis_bp = len(eval_df[(eval_df["Group"] == "Distillation") & (eval_df["dataset_name"] == "rf2na_distillation_cis_bp")]) // (num_models * num_runs)
        num_distillation_transfac = len(eval_df[(eval_df["Group"] == "Distillation") & (eval_df["dataset_name"] == "rf2na_distillation_transfac")]) // (num_models * num_runs)
        num_crystal = len(eval_df[eval_df["Group"] == "Crystal"]) // (num_models * num_runs)

        print(f"\tNumber of inputs with PPMs from distillation: {num_distillation}")
        print(f"\t\tFrom distillation (CisBP): {num_distillation_cis_bp}")
        print(f"\t\tFrom distillation (TRANSFAC): {num_distillation_transfac}")
        print(f"\tNumber of inputs with PPMs from crystal: {num_crystal}")

In [None]:
for eval_df_name, eval_df in eval_dfs.items():
    get_evaluation_statistics(
        eval_df_name,
        eval_df
    )

In [None]:
eval_dfs["specificity_test"]