In [2]:
%load_ext autoreload
%autoreload 2

import multiprocessing
import os
import pickle
import warnings
from dataclasses import InitVar, dataclass, field, fields
from functools import partial
from io import BytesIO

import h5py
import pyarrow as pa
from Bio import BiopythonWarning
from datasets import ClassLabel, Dataset, Image, Value, load_dataset
from datasets.utils.file_utils import xopen

from src.data import foldseek, mdcath_processing

warnings.filterwarnings("ignore", category=BiopythonWarning)

FILE_PATHS = {
    "3Di": "../tmp/data/3Di/",
    "mdCATH": "../tmp/data/mdCATH/",
    "trajectories": "../tmp/data/trajectories/",
    "cache": "../tmp/data/cache/",
}

for x in FILE_PATHS.values():
    os.makedirs(x, exist_ok=True)

In [13]:
# trajectory_url = ""

# with xopen(trajectory_url, "rb") as file:
#         bytes_ = BytesIO(file.read())

In [18]:
from src.data.mdcath_processing import extract_mdcath_information

tmp_config = {
    "temperatures": ["320", "348"],
    "replicas": ["0", "1"]
}

extract_mdcath_information(
    file_path=FILE_PATHS['trajectories']+"mdcath_dataset_1avyB00.h5",
    config=tmp_config)

1avyB00 320 0
1avyB00 320 1
1avyB00 348 0
1avyB00 348 1


In [2]:
dataset_mdcath = load_dataset("compsciencelab/mdCATH", split="train", streaming=True)
dataset_mdcath = dataset_mdcath.cast_column("image", Image(decode=False))

path_state_dict = "../tmp/data/3Di/_state.pkl"

if os.path.exists(path_state_dict):
    with open(path_state_dict, "rb") as f:
        state = pickle.load(f)
    print("State loaded,", state)
    dataset_mdcath = dataset_mdcath.skip(state)
else:
    state = 0
    print("No state found", state)

Resolving data files:   0%|          | 0/5400 [00:00<?, ?it/s]

No state found 0


In [3]:
iterations = 10
dataset_mdcath = dataset_mdcath.take(iterations)

print("CPU Count: ", multiprocessing.cpu_count())

pool = multiprocessing.Pool(processes=multiprocessing.cpu_count() - 1 or 1)
process_config = {
    "traj_temp": "320",
    "traj_sim": "0",
}
dataset_mdcath_mapped = pool.map(
    partial(mdcath_processing.download_process, config=process_config),
    dataset_mdcath,
)
pool.close()
pool.join()

state = state + iterations
with open(path_state_dict, "wb") as f:
    pickle.dump(state, f)
print(dataset_mdcath.state_dict())
print("State saved", state)

Elapsed time: 328.3863 seconds




Elapsed time: 0.9007 seconds
Elapsed time: 0.1924 seconds
Elapsed time: 397.9927 seconds




Elapsed time: 1.4628 seconds
Elapsed time: 0.3007 seconds
Elapsed time: 429.2657 seconds




Elapsed time: 1.3457 seconds
Elapsed time: 0.2384 seconds
Elapsed time: 576.5274 seconds




Elapsed time: 1.8924 seconds
Elapsed time: 0.2091 seconds
Elapsed time: 323.6645 seconds




Elapsed time: 1.0554 seconds
Elapsed time: 0.1486 seconds
Elapsed time: 786.6145 seconds




Elapsed time: 2.7971 seconds
Elapsed time: 0.2389 seconds
Elapsed time: 835.0794 seconds




Elapsed time: 3.1871 seconds
Elapsed time: 0.2277 seconds
Elapsed time: 1082.0669 seconds




Elapsed time: 4.7903 seconds
Elapsed time: 0.5071 seconds
Elapsed time: 1123.8518 seconds




Elapsed time: 5.1936 seconds
Elapsed time: 0.4152 seconds
Elapsed time: 1156.6098 seconds




Elapsed time: 5.7324 seconds
Elapsed time: 0.4725 seconds
{'num_taken': 10, 'ex_iterable': {'shard_idx': 9, 'shard_example_idx': 1}}
State saved 10


In [4]:
# from src.data import foldseek, mdcath_processing
# import h5py

# file_path = "../tmp/data/traj/mdcath_dataset_1avyB00.h5"
# process_config = {
#     "traj_temp": "320",
#     "traj_sim": "0",
# }

# data = mdcath_processing.extract_dataset_information(file_path, traj_temp=process_config["traj_temp"], traj_sim=process_config["traj_sim"])
# mdcath_processing.mdcath_process(data)