In [13]:
import os
import argparse
import torch
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data import get_worker_info
import torch.distributed as dist
import transformers
from transformers import default_data_collator
from transformers import AutoTokenizer
from omegaconf import DictConfig, OmegaConf as om
from tqdm import tqdm
import numpy as np
import random
random.seed(111)
from collections.abc import Mapping
from typing import Optional, Tuple, Union, Any, Dict, List
import logging		


# Add ModernBERT to path
# TODO: add morerngena distr target path argparse
import sys
sys.path.append(os.path.abspath("../"))
sys.path.append(os.path.abspath("."))

from src.text_data import NoStreamingGenomeDataset

class _GenomeDatasetForMasking(NoStreamingGenomeDataset):
	def __getitem__(self, index: int):
		shard_id, shard_sample_id = self.spanner[index]
		shard = self.shards[shard_id]
		sample = shard[shard_sample_id]
		return sample['file_id'], sample['line_id'], shard_id


data_dir= "/mnt/nfs_dna/shadskiy/promoters/pretrena/mds_v2/"
tokenizer = AutoTokenizer.from_pretrained("AIRI-Institute/gena-lm-bert-base-t2t")
# Create dataset
ds = _GenomeDatasetForMasking(
	local=data_dir,
	split="train",
	max_seq_len=1024,
	tokenizer = tokenizer,
)




In [None]:
genomes = set()
last_shard=30
for i in tqdm(range(0,len(ds),1)):
	_ = ds[i]
	genome = _[0]
	shard_id = _[2]
	if shard_id>last_shard:
		print (i,shard_id)
		break
	else:
		if i % 5000 == 0:
			print (i,shard_id)
	# if genome.find("GCF_000001405")!=-1:
	# 	print (i)
	# 	print (genome)
	# 	break
	# else:
	# 	genomes.add(genome)
print (len(genomes))

In [28]:
print (ds.spanner.shard_bounds[0:5])

[    0  2586  5093  7654 10256]


In [27]:
len(ds.shards[0])

2586

In [37]:
import h5py
with h5py.File("../runs/test/mlm_efficiency/train/shard_0.hdf5", "r") as f:
	data = f['0'][:]

In [40]:
data.sum(), len(data)

(11434, 11434)

In [41]:
data[:10]

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True])

In [1]:
import numpy as np
bool(np.nan)

True

In [1]:
import h5py
import numpy as np

with h5py.File("../runs/test/test.h5","w") as f:
	for _ in range(100):
		f.create_dataset(str(_), 
				data = np.zeros(shape=(3000,), dtype=bool), 
				dtype=bool
				)

In [6]:
# test time
import datetime
import random
import h5py
import numpy as np

# ids = [2,3,7]
ids = np.arange(50)
ids = np.random.permutation(ids).tolist()
print (ids)

_ = datetime.datetime.now()
random.seed(23)
for repeat in range(100):
	with h5py.File("../runs/test/test.h5","a") as f:
		sample = random.randint(0,99)
		data = f[str(sample)][:]
		
		data[ids] = True
		f[str(sample)][:] = data  # Modify existing dataset in-place
timedelta = datetime.datetime.now() - _
print ("Variant1 ", timedelta)

_ = datetime.datetime.now()
random.seed(23)
for repeat in range(10):
	with h5py.File("../runs/test/test.h5","a") as f:
		sample = random.randint(0,99)
		for i in ids:
			f[str(sample)][i] = True
timedelta = datetime.datetime.now() - _
print ("Variant2 ", timedelta)

[26, 31, 11, 6, 8, 16, 15, 21, 47, 17, 13, 42, 20, 43, 41, 4, 34, 14, 39, 45, 25, 3, 29, 9, 33, 32, 37, 35, 46, 19, 18, 1, 27, 10, 44, 22, 23, 28, 24, 5, 0, 40, 2, 30, 12, 49, 36, 38, 7, 48]
Variant1  0:00:00.293402
Variant2  0:00:00.142648
