## Summary

**Notes:**

This notebook should be run on a machine with > 32G of memory.

---

## Imports

In [None]:
import os
from pathlib import Path

import crc32c
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm.notebook import tqdm

## Parameters

In [None]:
NOTEBOOK_NAME = "01_load_data"

In [None]:
NOTEBOOK_DIR = Path(NOTEBOOK_NAME).resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

In [None]:
if "DATAPKG_OUTPUT_DIR" in os.environ:
    DATAPKG_OUTPUT_DIR = Path(os.getenv("DATAPKG_OUTPUT_DIR")).resolve()
else:
    DATAPKG_OUTPUT_DIR = NOTEBOOK_DIR
DATAPKG_OUTPUT_DIR.mkdir(exist_ok=True)

DATAPKG_OUTPUT_DIR

In [None]:
if "DATAPKG_OUTPUT_DIR" in os.environ:
    OUTPUT_DIR = Path(os.getenv("DATAPKG_OUTPUT_DIR")).joinpath("elaspic-v2").resolve()
else:
    OUTPUT_DIR = NOTEBOOK_DIR.parent
OUTPUT_DIR.mkdir(exist_ok=True)

OUTPUT_DIR

## Datasets

In [None]:
resources = {
    # === Core ===
    "elaspic-training-set-core": DATAPKG_OUTPUT_DIR.joinpath(
        "elaspic-training-set", "02_export_data_core", "elaspic-training-set-core.parquet"
    ),
    "protherm-dagger-core": DATAPKG_OUTPUT_DIR.joinpath(
        "protein-folding-energy", "protherm_dagger", "mutation-by-sequence.parquet"
    ),
    "rocklin-2017-core": DATAPKG_OUTPUT_DIR.joinpath(
        "protein-folding-energy", "rocklin_2017", "mutation-ssm2.parquet"
    ),
    "dunham-2020-core": DATAPKG_OUTPUT_DIR.joinpath(
        "protein-folding-energy", "dunham_2020_tianyu", "monomers.parquet"
    ),
    "starr-2020-core": DATAPKG_OUTPUT_DIR.joinpath(
        "protein-folding-energy", "starr_2020_tianyu", "stability.parquet"
    ),
    "cagi5-frataxin-core": DATAPKG_OUTPUT_DIR.joinpath(
        "protein-folding-energy", "cagi5_frataxin", "1ekg-ddg.parquet"
    ),
    "huang-2020-core": DATAPKG_OUTPUT_DIR.joinpath(
        "protein-folding-energy", "huang_2020", "2jie-ddg.parquet"
    ),
    # === Interface ===
    "elaspic-training-set-interface": DATAPKG_OUTPUT_DIR.joinpath(
        "elaspic-training-set", "02_export_data_interface", "elaspic-training-set-interface.parquet"
    ),
    "skempi-v2-interface": DATAPKG_OUTPUT_DIR.joinpath(
        "protein-folding-energy", "skempi_v2", "skempi-v2.parquet"
    ),
    # "intact-mutations-interface": DATAPKG_OUTPUT_DIR.joinpath(
    #     "protein-folding-energy", "intact_mutations", "intact-mutations.parquet"
    # ),
    "dunham-2020-interface": DATAPKG_OUTPUT_DIR.joinpath(
        "protein-folding-energy", "dunham_2020_tianyu", "dimers.parquet"
    ),
    "starr-2020-interface": DATAPKG_OUTPUT_DIR.joinpath(
        "protein-folding-energy", "starr_2020_tianyu", "affinity.parquet"
    ),
}

In [None]:
row_group_sizes = {
    "dunham-2020-core": 1,
    "dunham-2020-interface": 1,
    "starr-2020-core": 1,
    "starr-2020-interface": 1,
    "huang-2020-core": 1,
}

In [None]:
for name, path in resources.items():
    assert Path(path).is_file(), path

## Load data

In [None]:
columns = [
    "unique_id",
    "dataset",
    "name",
    "protein_sequence",
    "ligand_sequence",
    "mutation",
    "effect",
    "effect_type",
    "protein_structure",
]

extra_columns = [
    "provean_score",
    "foldx_score",
    "elaspic_score",
]

In [None]:
def get_unique_id(dataset, effect_type, protein_sequence, ligand_sequence):
    if ligand_sequence is not None:
        key = f"{dataset}|{effect_type}|{protein_sequence}|{ligand_sequence}"
    else:
        key = f"{dataset}|{effect_type}|{protein_sequence}"
    return crc32c.crc32c(key.encode("utf-8"))

In [None]:
def get_unique_id_2(dataset, name, effect_type, protein_sequence, ligand_sequence):
    if ligand_sequence is not None:
        key = f"{dataset}|{name}|{effect_type}|{protein_sequence}|{ligand_sequence}"
    else:
        key = f"{dataset}|{name}|{effect_type}|{protein_sequence}"
    return crc32c.crc32c(key.encode("utf-8"))

In [None]:
output_dir = OUTPUT_DIR.joinpath(NOTEBOOK_NAME).resolve()
output_dir.mkdir(exist_ok=True)

output_dir

In [None]:
_seen = {
    "core": set(),
    "interface": set(),
}

for dataset_name, dataset_file in resources.items():
    print(dataset_name)

    coi = dataset_name.rsplit("-", 1)[-1]
    assert coi in ["core", "interface"]

    df = (
        pq.read_table(dataset_file)
        .to_pandas(integer_object_nulls=True)
        .rename(columns={"mutations": "mutation"})
    )
    print(f"Read {len(df)} rows.")

    # Remove unneeded data
    mask = df["mutation"].apply(len) >= 2
    print(f"Removing {(~mask).sum()} rows with fewer than two mutations.")
    df = df[mask]

    mask = df["effect"].apply(lambda x: len(set(x))) >= 2
    print(f"Removing {(~mask).sum()} rows with fewer than two unique effects.")
    df = df[mask]

    if "dataset" not in df:
        df["dataset"] = dataset_name
    if "ligand_sequence" not in df:
        df["ligand_sequence"] = None

    # Add a unique id

    df["unique_id"] = [
        get_unique_id(dataset, effect_type, protein_sequence, ligand_sequence)
        for dataset, effect_type, protein_sequence, ligand_sequence in df[
            ["dataset", "effect_type", "protein_sequence", "ligand_sequence"]
        ].values
    ]
    unique_ids = set(df["unique_id"].values)
    if len(unique_ids) != len(df):
        df["unique_id"] = [
            get_unique_id_2(dataset, name, effect_type, protein_sequence, ligand_sequence)
            for dataset, name, effect_type, protein_sequence, ligand_sequence in df[
                ["dataset", "name", "effect_type", "protein_sequence", "ligand_sequence"]
            ].values
        ]
        unique_ids = set(df["unique_id"].values)
    assert len(unique_ids) == len(df)
    assert not set(unique_ids) & _seen[coi]
    _seen[coi].update(unique_ids)

    columns_all = columns + [c for c in extra_columns if c in df]
    df_out = df[columns_all]

    # Write output
    output_file = output_dir.joinpath(f"{dataset_name}.parquet")
#     if output_file.is_file():
#         print(f"Refusing to overwrite existing file: {output_file}.\n")
#         continue
    pq.write_table(
        pa.Table.from_pandas(df_out, preserve_index=False),
        output_file,
        row_group_size=row_group_sizes.get(dataset_name, 100),
    )
    del df, df_out
    print()