# Create Final Datasets for Training

## Setup

In [None]:
import os
import shutil
import ast

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

## Paths and Directories

In [None]:
# Directory containing the datasets.
all_datasets_directory = "./datasets"

# Datasets to load for creating the final datasets.
datasets_to_load = ["rcsb_cif_na", "rf2na_distillation_cis_bp", "rf2na_distillation_transfac"]

# Path to protein chain clustering output.
protein_chain_clustering_path = "./clustering/all_protein_sequences/all_protein_sequences_clusters.csv"

# Path to nucleic acid clustering output.
nucleic_acid_chain_clustering_path = "./clustering/all_nucleic_acid_sequences/all_nucleic_acid_sequences_clusters.csv"

# Path to protein family labeling output.
protein_family_labeling_path = "./protein_family_labeling/all_protein_family_labels.csv"

# Output directory for the aggregate of all datasets.
all_datasets_output_directory = os.path.join(all_datasets_directory, "all_datasets")
all_datasets_output_path = os.path.join(all_datasets_output_directory, "all_datasets.csv")

# Columns with list typing.
list_columns = [
    "ppm_paths", 
    "protein_chain_cluster_ids", 
    "protein_chain_cluster_ids_chain_types", 
    "nucleic_acid_chain_cluster_ids", 
    "nucleic_acid_chain_cluster_ids_chain_types", 
    "pfam_ids", 
    "pfam_descriptions", 
    "interpro_ids", 
    "interpro_descriptions"
]

# Output directory for the design dataset.
design_dataset_output_directory = os.path.join(all_datasets_directory, "design_dataset_v2")
specificity_dataset_output_directory = os.path.join(all_datasets_directory, "specificity_dataset_v2")

# Helper Functions

In [None]:
def compute_chain_cluster_degrees(data_df, chain_cluster_ids_column_name):
    # Compute the degree of the chain clusters.
    chain_cluster_id_to_degree = {}
    for chain_cluster_ids in data_df[chain_cluster_ids_column_name]:
        for chain_cluster_id in chain_cluster_ids:
            chain_cluster_id_to_degree[chain_cluster_id] = chain_cluster_id_to_degree.get(chain_cluster_id, 0) + 1
    
    # Add a degree column to the data frame.
    chain_cluster_degrees_column_name = chain_cluster_ids_column_name.replace("ids", "degrees")
    data_df[chain_cluster_degrees_column_name] = \
        data_df[chain_cluster_ids_column_name].apply(lambda chain_cluster_ids: [chain_cluster_id_to_degree[chain_cluster_id] for chain_cluster_id in chain_cluster_ids])

    return chain_cluster_id_to_degree

In [None]:
def read_text_file(path):
    with open(path, mode = "rt") as f:
        return f.read()

def write_text_file(path, content):
    with open(path, mode = "wt") as f:
        f.write(content)

def read_seed(path):
    return int(read_text_file(path))

In [None]:
def split_train_valid_test_clusters(chain_cluster_to_degree,
                                    valid_fraction,
                                    test_fraction,
                                    max_valid_test_cluster_degree,
                                    extra_test_cluster_ids,
                                    seed):
    # Create the rng.
    rng = np.random.default_rng(seed)

    all_cluster_ids = list(chain_cluster_to_degree.keys())
    # Choose a random sub-sample of chain clusters for validation/test.
    if max_valid_test_cluster_degree is None:
        cluster_ids_for_valid_test = all_cluster_ids
    else:
        cluster_ids_for_valid_test = [cluster_id for cluster_id in all_cluster_ids if chain_cluster_to_degree[cluster_id] <= max_valid_test_cluster_degree]
    
    # Remove the extra test cluster ids from the valid/test cluster ids.
    if extra_test_cluster_ids is not None:
        cluster_ids_for_valid_test = list(set(cluster_ids_for_valid_test) - set(extra_test_cluster_ids))

    # Ensure there are enough clusters for the valid and test sets with the
    # appropriate degree.
    assert((len(cluster_ids_for_valid_test) / len(all_cluster_ids)) >= (test_fraction + valid_fraction))

    # Choose the test and valid cluster ids.
    test_cluster_ids = rng.choice(list(cluster_ids_for_valid_test), 
                                  size = int(test_fraction * len(all_cluster_ids)), 
                                  replace = False)
    valid_cluster_ids = rng.choice(list(set(cluster_ids_for_valid_test) - set(test_cluster_ids)), 
                                   size = int(valid_fraction * len(all_cluster_ids)), 
                                   replace = False)

    # Add the extra test cluster ids to the test set.
    if extra_test_cluster_ids is not None:
        test_cluster_ids = list(set(test_cluster_ids).union(set(extra_test_cluster_ids)))
    
    # The train clusters are the rest of the clusters.
    train_cluster_ids = list(set(all_cluster_ids) - set(valid_cluster_ids) - set(test_cluster_ids))

    # Check that the train, validation, and test clusters are disjoint.
    assert(len(set(train_cluster_ids).intersection(set(valid_cluster_ids))) == 0)
    assert(len(set(train_cluster_ids).intersection(set(test_cluster_ids))) == 0)
    assert(len(set(valid_cluster_ids).intersection(set(test_cluster_ids))) == 0)

    # Assert that nothing is left out from all.
    assert(len(set(train_cluster_ids).union(set(valid_cluster_ids)).union(set(test_cluster_ids))) == len(all_cluster_ids))

    print(f"clusters: {len(all_cluster_ids)}")
    print(f"train clusters: {len(train_cluster_ids)}")
    print(f"validation clusters: {len(valid_cluster_ids)}")
    print(f"test clusters: {len(test_cluster_ids)}")
    print()

    return train_cluster_ids, valid_cluster_ids, test_cluster_ids

def split_train_valid_test_entries(data_df, 
                                   chain_cluster_ids_column_name,
                                   train_cluster_ids,
                                   valid_cluster_ids,
                                   test_cluster_ids):
    # Assert all entry ids in the dataset are unique.
    assert(len(set(data_df["id"])) == len(data_df))

    # Get the dataset entries for each set.
    test_entry_ids = []
    valid_entry_ids = []
    train_entry_ids = []
    test_cluster_ids_set = set(test_cluster_ids)
    valid_cluster_ids_set = set(valid_cluster_ids)
    for id, chain_cluster_ids in zip(data_df["id"], data_df[chain_cluster_ids_column_name]):
        if len(set(chain_cluster_ids).intersection(test_cluster_ids_set)) > 0:
            test_entry_ids.append(id)
        elif len(set(chain_cluster_ids).intersection(valid_cluster_ids_set)) > 0:
            valid_entry_ids.append(id)
        else:
            train_entry_ids.append(id)
    
    # Check that the train, validation, and test entries are disjoint.
    assert(len(set(train_entry_ids).intersection(set(valid_entry_ids))) == 0)
    assert(len(set(train_entry_ids).intersection(set(test_entry_ids))) == 0)
    assert(len(set(valid_entry_ids).intersection(set(test_entry_ids))) == 0)

    # Assert that nothing is left out from all.
    assert(len(set(train_entry_ids).union(set(valid_entry_ids)).union(set(test_entry_ids))) == len(data_df))

    # Split the data into train, validation, and test sets.
    train_data_df = data_df[data_df["id"].isin(train_entry_ids)]
    valid_data_df = data_df[data_df["id"].isin(valid_entry_ids)]
    test_data_df = data_df[data_df["id"].isin(test_entry_ids)]

    # Assert that the train and valid dataframes don't contain any chains from
    # the test set, and that the train dataframe doesn't contain any chains from
    # the validation set.
    for chain_cluster_ids in train_data_df[chain_cluster_ids_column_name]:
        assert(len(set(chain_cluster_ids).intersection(set(test_cluster_ids))) == 0)
        assert(len(set(chain_cluster_ids).intersection(set(valid_cluster_ids))) == 0)
    for chain_cluster_ids in valid_data_df[chain_cluster_ids_column_name]:
        assert(len(set(chain_cluster_ids).intersection(set(test_cluster_ids))) == 0)

    print(f"entries: {len(data_df)}")
    print(f"train entries: {len(train_data_df)}")
    print(f"validation entries: {len(valid_data_df)}")
    print(f"test entries: {len(test_data_df)}")
    print()

    return train_data_df, valid_data_df, test_data_df

def evaluate_train_valid_test_split(split_dfs): 
    for split_name, split_df in split_dfs.items():
        # Count number of DNA/RNA/hybrid entries with and without protein in 
        # dataset.
        dna_entries = 0
        rna_entries = 0
        dna_rna_entries = 0
        dna_entries_with_protein = 0
        rna_entries_with_protein = 0
        dna_rna_entries_with_protein = 0
        # Count number of entries with PPMs, from crystal and distillation.
        entries_with_ppm = 0
        entries_with_ppm_from_crystal = 0
        entries_with_pmm_from_distill = 0
        # Record the pfam id and pfam description entry counts, only counting
        # entries with ppms.
        pfam_id_with_ppm_counts = {}
        pfam_id_with_ppm_counts_from_crystal = {}
        pfam_id_with_ppm_counts_from_distill = {}
        pfam_description_with_ppm_counts = {}
        pfam_description_with_ppm_counts_from_crystal = {}
        pfam_description_with_ppm_counts_from_distill = {}
        for (protein_chain_cluster_ids, 
             nucleic_acid_chain_cluster_ids_chain_types,
             ppm_paths,
             dataset_name,
             pfam_ids,
             pfam_descriptions) in zip(
            split_df["protein_chain_cluster_ids"], 
            split_df["nucleic_acid_chain_cluster_ids_chain_types"],
            split_df["ppm_paths"],
            split_df["dataset_name"],
            split_df["pfam_ids"],
            split_df["pfam_descriptions"]
        ):
            # Count the number of DNA/RNA/hybrid entries with and without 
            # protein in the complex.
            protein_in_complex = (len(protein_chain_cluster_ids) > 0)
            dna_in_complex = (nucleic_acid_chain_cluster_ids_chain_types.count("polydeoxyribonucleotide") > 0)
            rna_in_complex = (nucleic_acid_chain_cluster_ids_chain_types.count("polyribonucleotide") > 0)
            dna_rna_hybrid_in_complex = (nucleic_acid_chain_cluster_ids_chain_types.count("polydeoxyribonucleotide/polyribonucleotide hybrid") > 0)
            if dna_in_complex:
                if protein_in_complex:
                    dna_entries_with_protein += 1
                else:
                    dna_entries += 1
            if rna_in_complex:
                if protein_in_complex:
                    rna_entries_with_protein += 1
                else:
                    rna_entries += 1
            if dna_rna_hybrid_in_complex:
                if protein_in_complex:
                    dna_rna_entries_with_protein += 1
                else:
                    dna_rna_entries += 1
            
            # Count the number of entries with PPMs, as well as record pfam
            # id and description counts for entries with PPMs.
            if len(ppm_paths) > 0:
                entries_with_ppm += 1
                if dataset_name == "rcsb_cif_na":
                    entries_with_ppm_from_crystal += 1
                elif dataset_name == "rf2na_distillation_cis_bp" or dataset_name == "rf2na_distillation_transfac":
                    entries_with_pmm_from_distill += 1

                # Only record each pfam id and description once per entry,
                # because there may not be a one-to-one correspondence between
                # chains and PPMs.
                for pfam_id in set(pfam_ids):
                    pfam_id_with_ppm_counts[pfam_id] = pfam_id_with_ppm_counts.get(pfam_id, 0) + 1
                    if dataset_name == "rcsb_cif_na":
                        pfam_id_with_ppm_counts_from_crystal[pfam_id] = pfam_id_with_ppm_counts_from_crystal.get(pfam_id, 0) + 1
                    elif dataset_name == "rf2na_distillation_cis_bp" or dataset_name == "rf2na_distillation_transfac":
                        pfam_id_with_ppm_counts_from_distill[pfam_id] = pfam_id_with_ppm_counts_from_distill.get(pfam_id, 0) + 1

                for pfam_description in set(pfam_descriptions):
                    pfam_description_with_ppm_counts[pfam_description] = pfam_description_with_ppm_counts.get(pfam_description, 0) + 1
                    if dataset_name == "rcsb_cif_na":
                        pfam_description_with_ppm_counts_from_crystal[pfam_description] = pfam_description_with_ppm_counts_from_crystal.get(pfam_description, 0) + 1
                    elif dataset_name == "rf2na_distillation_cis_bp" or dataset_name == "rf2na_distillation_transfac":
                        pfam_description_with_ppm_counts_from_distill[pfam_description] = pfam_description_with_ppm_counts_from_distill.get(pfam_description, 0) + 1

        print(f"{split_name}")
        print(f"\tDNA entries: {dna_entries}")
        print(f"\tRNA entries: {rna_entries}")
        print(f"\tDNA/RNA hybrid entries: {dna_rna_entries}")
        print(f"\tDNA entries with protein in the complex: {dna_entries_with_protein}")
        print(f"\tRNA entries with protein in the complex: {rna_entries_with_protein}")
        print(f"\tDNA/RNA hybrid entries with protein in the complex: {dna_rna_entries_with_protein}")
        print(f"\tEntries with PPMs: {entries_with_ppm}")
        print(f"\tEntries with PPMs from crystal: {entries_with_ppm_from_crystal}")
        print(f"\tEntries with PPMs from distillation: {entries_with_pmm_from_distill}")
        print(f"\tPFAM id counts for entries with PPMs: {pfam_id_with_ppm_counts}")
        print(f"\tPFAM id counts for entries with PPMs from crystal: {pfam_id_with_ppm_counts_from_crystal}")
        print(f"\tPFAM id counts for entries with PPMs from distillation: {pfam_id_with_ppm_counts_from_distill}")
        print(f"\tPFAM description counts for entries with PPMs: {pfam_description_with_ppm_counts}")
        print(f"\tPFAM description counts for entries with PPMs from crystal: {pfam_description_with_ppm_counts_from_crystal}")
        print(f"\tPFAM description counts for entries with PPMs from distillation: {pfam_description_with_ppm_counts_from_distill}")

def save_train_valid_test_split(output_directory, 
                                chain_cluster_ids_column_name, 
                                data_df, 
                                split_dfs,
                                split_cluster_ids,
                                seed):
    # Replace the output directory.
    if os.path.exists(output_directory):
        shutil.rmtree(output_directory)
    os.makedirs(output_directory)

    # Paths for the output dataframes.
    all_df_path = os.path.join(output_directory, "all.csv")
    train_df_path = os.path.join(output_directory, "train.csv")
    valid_df_path = os.path.join(output_directory, "valid.csv")
    test_df_path = os.path.join(output_directory, "test.csv")
    
    # Save the dataframes.
    data_df.to_csv(all_df_path, index = False)
    split_dfs["train"].to_csv(train_df_path, index = False)
    split_dfs["valid"].to_csv(valid_df_path, index = False)
    split_dfs["test"].to_csv(test_df_path, index = False)

    # Paths for the output cluster ids.
    train_cluster_ids_path = os.path.join(output_directory, f"train_{chain_cluster_ids_column_name}.txt")
    valid_cluster_ids_path = os.path.join(output_directory, f"valid_{chain_cluster_ids_column_name}.txt")
    test_cluster_ids_path = os.path.join(output_directory, f"test_{chain_cluster_ids_column_name}.txt")

    # Save the cluster ids.
    write_text_file(train_cluster_ids_path, "\n".join(map(str, split_cluster_ids["train"])))
    write_text_file(valid_cluster_ids_path, "\n".join(map(str, split_cluster_ids["valid"])))
    write_text_file(test_cluster_ids_path, "\n".join(map(str, split_cluster_ids["test"])))

    # Path for the seed.
    seed_path = os.path.join(output_directory, "seed.txt")

    # Save the seed.
    write_text_file(seed_path, str(seed))


def train_valid_test_split(data_df, 
                           chain_cluster_to_degree, 
                           chain_cluster_ids_column_name,
                           output_directory,
                           valid_fraction = None, 
                           test_fraction = None,
                           max_valid_test_cluster_degree = None,
                           extra_test_cluster_ids = None,
                           seed = None):
    """
    Split the data_df (grouped by entries) into train, validation, and test
    sets, based on chain clusters. Note, train_fraction is 
    1 - valid_fraction - test_fraction.

    Arguments:
        data_df (pd.DataFrame):
            DataFrame with the data.
        chain_cluster_to_degree ((int -> int) dict):
            Dictionary with the degree of each chain cluster.
        chain_cluster_ids_column_name (str):
            Name of the column with the chain cluster ids.
        output_directory (str):
            Directory to save the split data.
        valid_fraction (optional, float):
            Fraction of the data to use for validation.
        test_fraction (optional, float):
            Fraction of the data to use for testing.
        max_valid_test_cluster_degree (optional, int):
            Maximum degree of a cluster to be included in the validation and 
            test sets.
        extra_test_cluster_ids (optional, list):
            List of extra cluster ids to include in the test set.
        seed (optional, int):
            Seed for the random number generator.

    Side Effects:
        Saves the train, validation, and test dataframes to the output directory.
        Saves the train, validation, and test chain cluster ids to the output directory.
    """
    if seed == None:
        seed_rng = np.random.default_rng()
        seed = seed_rng.integers(0, 2 ** 32 - 1)

    # Split the chain clusters into train, validation, and test sets.
    train_cluster_ids, valid_cluster_ids, test_cluster_ids = \
        split_train_valid_test_clusters(chain_cluster_to_degree = chain_cluster_to_degree,
                                        valid_fraction = valid_fraction,
                                        test_fraction = test_fraction,
                                        max_valid_test_cluster_degree = max_valid_test_cluster_degree,
                                        extra_test_cluster_ids = extra_test_cluster_ids,
                                        seed = seed)

    # Split the entries in the dataset into train, validation, and test sets.
    train_data_df, valid_data_df, test_data_df = \
        split_train_valid_test_entries(data_df,
                                       chain_cluster_ids_column_name,
                                       train_cluster_ids,
                                       valid_cluster_ids,
                                       test_cluster_ids)

    # Evaluate how many DNA/RNA/protein-DNA/protein-RNA chains are in each set.
    # Note, this only counts chains whose cluster id matches the cluster id
    # split for the set. Since we cluster and split by cluster ids, but sample
    # by complex, some chains from validation may end up in test, and some 
    # chains from training may end up in validation/test. 
    # NOTE: When evaluating the models, we should only look at chains that are 
    # in the correct set. The test/valid chains won't end up in training.
    split_cluster_ids = {
        "train": train_cluster_ids,
        "valid": valid_cluster_ids,
        "test": test_cluster_ids
    }
    split_dfs = {
        "train": train_data_df,
        "valid": valid_data_df,
        "test": test_data_df
    }
    evaluate_train_valid_test_split(split_dfs)

    # Save the train, validation, and test information.
    save_train_valid_test_split(output_directory,
                                chain_cluster_ids_column_name,
                                data_df,
                                split_dfs,
                                split_cluster_ids,
                                seed)

# Aggregate the Datasets

In [None]:
if os.path.exists(all_datasets_output_directory):
    shutil.rmtree(all_datasets_output_directory)
os.makedirs(all_datasets_output_directory)

## Combine the Datasets

In [None]:
# Combine the dataframes from the datasets.
all_dfs = []
for dataset_name in datasets_to_load:
    dataset_directory = os.path.join(all_datasets_directory, dataset_name)
    
    preprocessing_output_path = os.path.join(dataset_directory, "preprocessing_output.csv")
    preprocessing_output_df = pd.read_csv(preprocessing_output_path)

    all_dfs.append(preprocessing_output_df)

all_datasets_df = pd.concat(all_dfs, ignore_index = True)

In [None]:
all_datasets_df

## Create Dictionaries for Auxillary Chain Information

In [None]:
# Load the family labels and chain clusters.
all_protein_chain_clusters_df = pd.read_csv(protein_chain_clustering_path)
all_nucleic_acid_chain_clusters_df = pd.read_csv(nucleic_acid_chain_clustering_path)
all_protein_family_labels_df = pd.read_csv(protein_family_labeling_path)

# Create a dictionary mapping protein chain sequence to protein cluster id.
protein_sequence_to_chain_cluster_id = dict(zip(all_protein_chain_clusters_df["sequence"], all_protein_chain_clusters_df["protein_chain_cluster_id"]))

# Create a dictionary mapping nucleic acid chain sequence to nucleic acid cluster id.
nucleic_acid_sequence_to_chain_cluster_id = dict(zip(all_nucleic_acid_chain_clusters_df["sequence"], all_nucleic_acid_chain_clusters_df["nucleic_acid_chain_cluster_id"]))

# Create dictionaries mapping protein sequence to pfam id, pfam description, interpro id, and interpro description.
all_protein_family_labels_df_grouped = all_protein_family_labels_df.groupby("sequence").agg({"signature_accession": list,
                                                                                             "signature_description": list,
                                                                                             "interpro_accession": list,
                                                                                             "interpro_description": list})
protein_sequence_to_pfam_ids = dict(zip(all_protein_family_labels_df_grouped.index, all_protein_family_labels_df_grouped["signature_accession"]))
protein_sequence_to_pfam_descriptions = dict(zip(all_protein_family_labels_df_grouped.index, all_protein_family_labels_df_grouped["signature_description"]))
protein_sequence_to_interpro_ids = dict(zip(all_protein_family_labels_df_grouped.index, all_protein_family_labels_df_grouped["interpro_accession"]))
protein_sequence_to_interpro_descriptions = dict(zip(all_protein_family_labels_df_grouped.index, all_protein_family_labels_df_grouped["interpro_description"]))

# Create dictionaries mapping pfam id to pfam description and interpro id to interpro description.
pfam_id_to_pfam_description = dict()
interpro_id_to_interpro_description = dict()
for (pfam_id, 
     pfam_description, 
     interpro_id,
     interpro_description) in zip(
    all_protein_family_labels_df["signature_accession"],
    all_protein_family_labels_df["signature_description"],
    all_protein_family_labels_df["interpro_accession"],
    all_protein_family_labels_df["interpro_description"]
):
    # pfam_id to pfam_description mapping.
    if pfam_id not in pfam_id_to_pfam_description:
        pfam_id_to_pfam_description[pfam_id] = pfam_description
    else:
        assert(pfam_id_to_pfam_description[pfam_id] == pfam_description)
    
    # interpro_id to interpro_description mapping.
    if interpro_id not in interpro_id_to_interpro_description:
        interpro_id_to_interpro_description[interpro_id] = interpro_description
    else:
        assert(interpro_id_to_interpro_description[interpro_id] == interpro_description)

# Create a dictionary mapping protein chain cluster id to pfam ids, using
# sequence as the intermediate.
protein_chain_cluster_id_to_pfam_ids = dict()
protein_chain_cluster_id_to_interpro_ids = dict()
for sequence, protein_chain_cluster_id in protein_sequence_to_chain_cluster_id.items():
    pfam_ids = protein_sequence_to_pfam_ids.get(sequence, [])
    if protein_chain_cluster_id not in protein_chain_cluster_id_to_pfam_ids:
        protein_chain_cluster_id_to_pfam_ids[protein_chain_cluster_id] = pfam_ids
    else:
        protein_chain_cluster_id_to_pfam_ids[protein_chain_cluster_id].extend(pfam_ids)

    interpro_ids = protein_sequence_to_interpro_ids.get(sequence, [])
    if protein_chain_cluster_id not in protein_chain_cluster_id_to_interpro_ids:
        protein_chain_cluster_id_to_interpro_ids[protein_chain_cluster_id] = interpro_ids
    else:
        protein_chain_cluster_id_to_interpro_ids[protein_chain_cluster_id].extend(interpro_ids)

In [None]:
# Save some protein label/protein cluster dictionaries for use in specificity
# dataset creation.
pfam_id_to_pfam_description_path = os.path.join(all_datasets_output_directory, "pfam_id_to_pfam_description.npy")
interpro_id_to_interpro_description_path = os.path.join(all_datasets_output_directory, "interpro_id_to_interpro_description.npy")
protein_chain_cluster_id_to_pfam_ids_path = os.path.join(all_datasets_output_directory, "protein_chain_cluster_id_to_pfam_ids.npy")
protein_chain_cluster_id_to_interpro_ids_path = os.path.join(all_datasets_output_directory, "protein_chain_cluster_id_to_interpro_ids.npy")

np.save(pfam_id_to_pfam_description_path, pfam_id_to_pfam_description)
np.save(interpro_id_to_interpro_description_path, interpro_id_to_interpro_description)
np.save(protein_chain_cluster_id_to_pfam_ids_path, protein_chain_cluster_id_to_pfam_ids)
np.save(protein_chain_cluster_id_to_interpro_ids_path, protein_chain_cluster_id_to_interpro_ids)

## Load the Auxillary Data

In [None]:
def label_sequence_with_auxillary_data(sequences_path, sequence_to_auxillary_data_dict, chain_types_to_consider, save_chain_types = False):
    # Load the sequences.
    sequences_df = pd.read_csv(sequences_path)    

    # Compute the auxillary data to the sequences.
    chain_types = []
    per_sequence_auxillary_data = []
    for chain_type, sequence in zip(sequences_df["chain_type"], sequences_df["sequence"]):
        if chain_type in chain_types_to_consider and sequence in sequence_to_auxillary_data_dict:
            auxillary_data = sequence_to_auxillary_data_dict[sequence]

            # Save the auxillary data, flattening any lists.
            if type(auxillary_data) == list:
                per_sequence_auxillary_data.extend(auxillary_data)
                # Save the chain type.
                if save_chain_types:
                    chain_types.extend([chain_type] * len(auxillary_data))
            else:
                per_sequence_auxillary_data.append(auxillary_data)
                # Save the chain type.
                if save_chain_types:
                    chain_types.append(chain_type)
    
    return per_sequence_auxillary_data, chain_types

In [None]:
# Chain types, for use in the loading of auxillary data.
protein_chain_types = [
    "polypeptide(L)"
]
nucleic_acid_chain_types = [
    "polydeoxyribonucleotide/polyribonucleotide hybrid", 
    "polydeoxyribonucleotide", 
    "polyribonucleotide"
]

# Information of the format auxillary_data_metadata[auxillary_data_name] = (sequence_to_auxillary_data_dict, chain_types_to_consider, save_chain_types)
auxillary_data_metadata = {
    "protein_chain_cluster_ids": (protein_sequence_to_chain_cluster_id, protein_chain_types, True),
    "nucleic_acid_chain_cluster_ids": (nucleic_acid_sequence_to_chain_cluster_id, nucleic_acid_chain_types, True),
    "pfam_ids": (protein_sequence_to_pfam_ids, protein_chain_types, False),
    "pfam_descriptions": (protein_sequence_to_pfam_descriptions, protein_chain_types, False),
    "interpro_ids": (protein_sequence_to_interpro_ids, protein_chain_types, False),
    "interpro_descriptions": (protein_sequence_to_interpro_descriptions, protein_chain_types, False)
}

# For each entry, label the sequences with auxillary data.
auxillary_data_dict = {}
for sequences_path in all_datasets_df["sequences_path"]:
    # Load the sequences.
    sequences_df = pd.read_csv(sequences_path)   

    for auxillary_data_name in auxillary_data_metadata:
        sequence_to_auxillary_data_dict, chain_types_to_consider, save_chain_types = auxillary_data_metadata[auxillary_data_name]
        
        per_sequence_auxillary_data = []
        per_sequence_chain_types = []
        for chain_type, sequence in zip(sequences_df["chain_type"], sequences_df["sequence"]):
            if chain_type in chain_types_to_consider and sequence in sequence_to_auxillary_data_dict:
                auxillary_data = sequence_to_auxillary_data_dict[sequence]

                # Save the auxillary data, flattening any lists.
                if type(auxillary_data) == list:
                    per_sequence_auxillary_data.extend(auxillary_data)
                    per_sequence_chain_types.extend([chain_type] * len(auxillary_data))
                else:
                    per_sequence_auxillary_data.append(auxillary_data)
                    per_sequence_chain_types.append(chain_type)
        
        # Make the overall lists if they do not exist.
        if auxillary_data_name not in auxillary_data_dict:
            auxillary_data_dict[auxillary_data_name] = []
            if save_chain_types:
                auxillary_data_dict[auxillary_data_name + "_chain_types"] = []

        # Record the auxillary data for the entry.
        auxillary_data_dict[auxillary_data_name].append(per_sequence_auxillary_data)
        if save_chain_types:
            auxillary_data_dict[auxillary_data_name + "_chain_types"].append(per_sequence_chain_types)

all_datasets_df["protein_chain_cluster_ids"] = auxillary_data_dict["protein_chain_cluster_ids"]
all_datasets_df["protein_chain_cluster_ids_chain_types"] = auxillary_data_dict["protein_chain_cluster_ids_chain_types"]
all_datasets_df["nucleic_acid_chain_cluster_ids"] = auxillary_data_dict["nucleic_acid_chain_cluster_ids"]
all_datasets_df["nucleic_acid_chain_cluster_ids_chain_types"] = auxillary_data_dict["nucleic_acid_chain_cluster_ids_chain_types"]
all_datasets_df["pfam_ids"] = auxillary_data_dict["pfam_ids"]
all_datasets_df["pfam_descriptions"] = auxillary_data_dict["pfam_descriptions"]
all_datasets_df["interpro_ids"] = auxillary_data_dict["interpro_ids"]
all_datasets_df["interpro_descriptions"] = auxillary_data_dict["interpro_descriptions"]

In [None]:
all_datasets_df

## Remove Any Entries with No Protein or Nucleic Acid Chain IDs

This will only occur if all chains in the complex were too short for clustering; also, these short sequences are not included for determining sampling probability.

In [None]:
all_datasets_df = all_datasets_df[((all_datasets_df.protein_chain_cluster_ids.map(len) > 0) | (all_datasets_df.nucleic_acid_chain_cluster_ids.map(len) > 0))]
all_datasets_df

## Save the Dataset

In [None]:
# Save the aggregate dataset.
all_datasets_df.to_csv(all_datasets_output_path, index = False)

# Design Dataset

In [None]:
all_datasets_df = pd.read_csv(all_datasets_output_path, 
                              converters = {column: ast.literal_eval for column in list_columns})

## Limit to RCSB CIF NA

In [None]:
design_dataset_df = all_datasets_df[all_datasets_df["dataset_name"] == "rcsb_cif_na"].copy()

In [None]:
design_dataset_df

## Compute the Chain Cluster Degrees

In [None]:
# Compute the degrees of the protein and nucleic acid chain clusters.
design_protein_chain_cluster_to_degree = compute_chain_cluster_degrees(design_dataset_df, "protein_chain_cluster_ids")
design_nucleic_acid_chain_cluster_to_degree = compute_chain_cluster_degrees(design_dataset_df, "nucleic_acid_chain_cluster_ids")

In [None]:
design_dataset_df

## Compute Sampling Probability

In [None]:
# Compute the mean 1 / (1 + degree) across protein and nucleic acid chain clusters.
design_dataset_df["sampling_probability"] = design_dataset_df.apply(lambda row: np.mean(1 / (1 + np.array(row["protein_chain_cluster_degrees"] + row["nucleic_acid_chain_cluster_degrees"]))), axis = 1)

In [None]:
design_dataset_df

## Split into All/Train/Valid/Test

In [None]:
# Investigate the distribution of the degrees of the protein chain clusters.
print(len(design_protein_chain_cluster_to_degree))
plt.hist(design_protein_chain_cluster_to_degree.values(), bins = 1000)
plt.show()

In [None]:
# Investigate the distribution of the degrees of the nucleic acid chain clusters.
print(len(design_nucleic_acid_chain_cluster_to_degree))
plt.hist(design_nucleic_acid_chain_cluster_to_degree.values(), bins = 1000)
plt.show()

In [None]:
# Remove some psuedoknots as an extra test set.
pseudoknot_pdb_ids = ["7kd1", "3q3z", "4plx", "2m8k", "4oqu", "7kga", "1drz", "7qr4", "2miy", "4znp"]

pseudoknot_nucleic_acid_chain_cluster_ids = set()
for pseudoknot_pdb_id in pseudoknot_pdb_ids:
    pseudoknot_nucleic_acid_chain_cluster_ids.update(design_dataset_df[design_dataset_df["id"] == pseudoknot_pdb_id]["nucleic_acid_chain_cluster_ids"].values[0])

print(pseudoknot_nucleic_acid_chain_cluster_ids)

In [None]:
train_valid_test_split(data_df = design_dataset_df,
                       chain_cluster_to_degree = design_nucleic_acid_chain_cluster_to_degree,
                       chain_cluster_ids_column_name = "nucleic_acid_chain_cluster_ids",
                       output_directory = design_dataset_output_directory,
                       valid_fraction = 0.1,
                       test_fraction = 0.1,
                       max_valid_test_cluster_degree = 25,
                       extra_test_cluster_ids = pseudoknot_nucleic_acid_chain_cluster_ids)

# Specificity Dataset

In [None]:
all_datasets_df = pd.read_csv(all_datasets_output_path, 
                              converters = {column: ast.literal_eval for column in list_columns})

## Include All Datasets

In [None]:
specificity_dataset_df = all_datasets_df.copy()

In [None]:
specificity_dataset_df

## Compute the Chain Cluster Degrees

In [None]:
# Compute the degrees of the protein and nucleic acid chain clusters.
specificity_protein_chain_cluster_to_degree = compute_chain_cluster_degrees(specificity_dataset_df, "protein_chain_cluster_ids")
specificity_nucleic_acid_chain_cluster_to_degree = compute_chain_cluster_degrees(specificity_dataset_df, "nucleic_acid_chain_cluster_ids")

In [None]:
specificity_dataset_df

## Compute Sampling Probability

In [None]:
# Compute the mean 1 / (1 + degree) across protein and nucleic acid chain clusters.
specificity_dataset_df["sampling_probability"] = specificity_dataset_df.apply(lambda row: np.mean(1 / (1 + np.array(row["protein_chain_cluster_degrees"] + row["nucleic_acid_chain_cluster_degrees"]))), axis = 1)

In [None]:
specificity_dataset_df

## Split into All/Train/Valid/Test

In [None]:
# Investigate the distribution of the degrees of the protein chain clusters.
print(len(specificity_protein_chain_cluster_to_degree))
plt.hist(specificity_protein_chain_cluster_to_degree.values(), bins = 1000)
plt.show()

In [None]:
# Investigate the distribution of the degrees of the nucleic acid chain clusters.
print(len(specificity_nucleic_acid_chain_cluster_to_degree))
plt.hist(specificity_nucleic_acid_chain_cluster_to_degree.values(), bins = 1000)
plt.show()

In [None]:
# Create a map of the pfam labels to the number of entries that they occur in
# with ppms, with ppms from crystal, and with ppms from distillation.
pfam_description_to_num_entries_with_ppm = {}
pfam_description_to_num_entries_with_ppm_from_crystal = {}
pfam_description_to_num_entries_with_ppm_from_distillation = {}
for pfam_descriptions, ppm_paths, dataset_name in zip(
    specificity_dataset_df["pfam_descriptions"], 
    specificity_dataset_df["ppm_paths"],
    specificity_dataset_df["dataset_name"]
):
    if len(ppm_paths) > 0:
        for pfam_description in set(pfam_descriptions):
            pfam_description_to_num_entries_with_ppm[pfam_description] = pfam_description_to_num_entries_with_ppm.get(pfam_description, 0) + 1

            if dataset_name == "rcsb_cif_na":
                pfam_description_to_num_entries_with_ppm_from_crystal[pfam_description] = pfam_description_to_num_entries_with_ppm_from_crystal.get(pfam_description, 0) + 1
            elif dataset_name == "rf2na_distillation_cis_bp" or dataset_name == "rf2na_distillation_transfac":
                pfam_description_to_num_entries_with_ppm_from_distillation[pfam_description] = pfam_description_to_num_entries_with_ppm_from_distillation.get(pfam_description, 0) + 1

In [None]:
# Plot histograms for the number of entries with PPMs, with PPMs from crystal,
# and with PPMs from distillation.
plt.hist(pfam_description_to_num_entries_with_ppm.values(), bins = 1000)
plt.show()
print(sorted(pfam_description_to_num_entries_with_ppm.items(), key = lambda x: x[1], reverse = True))

In [None]:
plt.hist(pfam_description_to_num_entries_with_ppm_from_crystal.values(), bins = 1000)
plt.show()
print(sorted(pfam_description_to_num_entries_with_ppm_from_crystal.items(), key = lambda x: x[1], reverse = True))

In [None]:
plt.hist(pfam_description_to_num_entries_with_ppm_from_distillation.values(), bins = 1000)
plt.show()
print(sorted(pfam_description_to_num_entries_with_ppm_from_distillation.items(), key = lambda x: x[1], reverse = True))

In [None]:
train_valid_test_split(data_df = specificity_dataset_df,
                       chain_cluster_to_degree = specificity_protein_chain_cluster_to_degree,
                       chain_cluster_ids_column_name = "protein_chain_cluster_ids",
                       output_directory = specificity_dataset_output_directory,
                       valid_fraction = 0.1,
                       test_fraction = 0.1,
                       max_valid_test_cluster_degree = 25)