In [1]:
%reload_ext autoreload
%autoreload 2

from src.data.tokenize import (
    FoldSeekTokenizer,
    Bio2TokenTokenizer,
    FoldToken4Tokenizer,
)
from src.data.load import (
    MDCATHDataset,
    MisatoDataset,
    AtlasDataset,
)

In [3]:
mdcath_dataset = MDCATHDataset(data_dir="../tmp/data/mdcath/data", save_path="../tmp/data/tokenized/mdcath")
print(mdcath_dataset.trajectory_locations)
print(mdcath_dataset.used_trajectory_locations)

['../tmp/data/mdcath/data/mdcath_dataset_1a0aA00.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a0sP00.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a15A00.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a3oA00.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a2nA02.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a0rP01.h5', '../tmp/data/mdcath/data/mdcath_dataset_1ba5A00.h5', '../tmp/data/mdcath/data/mdcath_dataset_12asA00.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a48A01.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a05A00.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a02F00.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a6cA02.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a1zA00.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a66A00.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a6aB01.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a0hA01.h5', '../tmp/data/mdcath/data/mdcath_dataset_1avyB00.h5', '../tmp/data/mdcath/data/mdcath_dataset_1bdhA01.h5', '../tmp/data/mdcath/data/mdcath_dataset_1a3dA

In [None]:
mdcath_dataset.reset()
print(mdcath_dataset[10])
mdcath_dataset.use_trajectory_location(10)
try:
    mdcath_dataset[10]
except ValueError as e:
    print(e)
mdcath_dataset.reset()
print(mdcath_dataset[10])

In [None]:
mdcath_dataset.is_used_index(10), mdcath_dataset.is_used_index(9)

In [11]:
from concurrent.futures import ThreadPoolExecutor
import threading
from typing import Callable
import os
from src.data.load.trajectory_dataset import TrajectoryDataset
import time

def process_dataset_parallel(dataset: TrajectoryDataset, process_fn: Callable, max_workers: int = os.cpu_count() - 1):
    def _process_item(idx: int):
        try:
            item = dataset[idx]
            result = process_fn(item)
            dataset.use_trajectory_location(idx)
            return result
        except ValueError as e:
            print(e)

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        results = list(executor.map(_process_item, range(len(dataset))))

    return [r for r in results if r is not None]


def example_process_fn(item):
    time.sleep(1)
    print(item)
    time.sleep(1)
    return item


results = process_dataset_parallel(mdcath_dataset, example_process_fn)
print(results)

Trajectory ../tmp/data/mdcath/data/mdcath_dataset_1a0aA00.h5 has already been accessed and used.Trajectory ../tmp/data/mdcath/data/mdcath_dataset_1a0sP00.h5 has already been accessed and used.

Trajectory ../tmp/data/mdcath/data/mdcath_dataset_1a15A00.h5 has already been accessed and used.
Trajectory ../tmp/data/mdcath/data/mdcath_dataset_1a3oA00.h5 has already been accessed and used.
Trajectory ../tmp/data/mdcath/data/mdcath_dataset_1a2nA02.h5 has already been accessed and used.
Trajectory ../tmp/data/mdcath/data/mdcath_dataset_1a0rP01.h5 has already been accessed and used.
Trajectory ../tmp/data/mdcath/data/mdcath_dataset_1ba5A00.h5 has already been accessed and used.
Trajectory ../tmp/data/mdcath/data/mdcath_dataset_12asA00.h5 has already been accessed and used.
1a48A01
1a05A00
1a6cA02
1a02F00
1a6aB01
1a0hA01
1a1zA00
1a66A00
1avyB00
1bdhA01
16pkA02
1a3dA00
1a6sA00
153lA00
1a39A00
1a5cA00
['1a48A01', '1a05A00', '1a02F00', '1a6cA02', '1a1zA00', '1a66A00', '1a6aB01', '1a0hA01', '1avyB0

In [12]:
len(results)

16

In [None]:
len(mdcath_dataset.used_trajectory_locations)

In [9]:
mdcath_dataset.reset()