In [1]:
%reload_ext autoreload
%autoreload 2

import h5py
from plms import auto_tokenizer
from datasets import Dataset, load_from_disk


In [2]:
tokenizer_prostT5 = auto_tokenizer("Rostlab/ProstT5")
tokenizer_protT5 = auto_tokenizer("Rostlab/prot_t5_xl_uniref50")
data_path = "/home/lfi/mnt/dev/prot-md-pssm-legacy/tmp/data/pssm/pssm_data.h5"

In [None]:
def gen():
    with h5py.File(data_path, "r") as f:
        for protein_id in f:
            for trajectory_pssm in f[protein_id]:
                # ProstT5
                sequence_tokenized_prostT5 = tokenizer_prostT5.encode(
                    text=f[protein_id].attrs["sequence"], padding=False, truncation=False
                )
                input_ids_prostT5 = sequence_tokenized_prostT5["input_ids"]
                attention_mask_prostT5 = sequence_tokenized_prostT5["attention_mask"]

                # ProtT5
                sequence_tokenized_protT5 = tokenizer_protT5.encode(
                    text=f[protein_id].attrs["sequence"], padding=False, truncation=False
                )
                input_ids_protT5 = sequence_tokenized_protT5["input_ids"]
                attention_mask_protT5 = sequence_tokenized_protT5["attention_mask"]

                # Metadata
                temperature = trajectory_pssm.split("_")[0]
                replica = trajectory_pssm.split("_")[1]
                yield {
                    "name": f[protein_id].attrs["name"],
                    "temperature": temperature,
                    "replica": replica,
                    "sequence": f[protein_id].attrs["sequence"],
                    "input_ids_prostT5": input_ids_prostT5,
                    "attention_mask_prostT5": attention_mask_prostT5,
                    "input_ids_protT5": input_ids_protT5,
                    "attention_mask_protT5": attention_mask_protT5,
                    "pssm_features": f[protein_id][trajectory_pssm][:],
                }


ds = Dataset.from_generator(gen)

In [None]:
ds.save_to_disk("../tmp/data/pssm/mdcath_pssm_full")
ds = load_from_disk("../tmp/data/pssm/mdcath_pssm_full")

In [5]:
# Faster way to save filtered datasets

from concurrent.futures import ProcessPoolExecutor


identifiers_temperature = ["320", "348", "379", "413", "450"]
identifiers_replica = ["0", "1", "2", "3", "4"]


def save_filtered(identifier):
    ds_tmp = ds.filter(lambda x: x["temperature"] == identifier)
    path = f"../tmp/data/pssm/mdcath_pssm_{identifier}"
    ds_tmp.save_to_disk(path)
    return path


with ProcessPoolExecutor() as executor:
    results = list(executor.map(save_filtered, identifiers_temperature))


def save_filtered(identifier):
    ds_tmp = ds.filter(lambda x: x["replica"] == identifier)
    path = f"../tmp/data/pssm/mdcath_pssm_{identifier}"
    ds_tmp.save_to_disk(path)
    return path


with ProcessPoolExecutor() as executor:
    results = list(executor.map(save_filtered, identifiers_replica))

In [None]:
# Slower way to save filtered datasets
# file_name = "_prostT5"
# for temp in ["320", "348", "379", "413", "450"]:
#     ds_tmp = ds.filter(lambda x: x["temperature"] == temp)
#     ds_tmp.save_to_disk(f"../tmp/data/pssm/pssm_dataset{file_name}_{temp}")

In [15]:
ds_reloaded = load_from_disk("../tmp/data/pssm/mdcath_pssm_450")

In [None]:
import pandas as pd

display(pd.DataFrame(ds_reloaded["temperature"]).value_counts())
display(pd.DataFrame(ds_reloaded["replica"]).value_counts())