In [4]:
from torch.utils.data import Dataset
from streaming.base.format import reader_from_json
import numpy as np
from streaming.base.spanner import Spanner

class NoStreamingDataset(Dataset):
	"""
	A dataset class that can read data with raw mds-format (mosaic streaming-format without compression)
	from local. In comparison with `StreamingTextDataset` that also can read data with mds-format from local,
	this class is slimmer, more efficient, and does not contain redundant code required for streaming.
	"""

	def __init__(
		self,
		local: str,
		split: str,
		max_seq_len: int,
		tokenizer = None,
		pad_sequences: bool = True,
	) -> None:
		super().__init__()
		if split is not None:
			split_path = os.path.join(local, split)
		else:
			split_path = local
		index_file_path = os.path.join(split_path, "index.json")
		obj = json.load(open(index_file_path))
		self.shards = []
		for info in obj["shards"]:
			shard = reader_from_json(local, split, info)
			raw_filename = os.path.join(shard.dirname, shard.split, shard.raw_data.basename)
			assert os.path.isfile(raw_filename), f"Raw file {raw_filename} does not exist"
			shard.validate(True)
			self.shards.append(shard)
		samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64)
		self.len = samples_per_shard.sum()
		self.spanner = Spanner(samples_per_shard)
		self.max_seq_len = max_seq_len
		self.tokenizer = tokenizer
		self.pad_sequences = pad_sequences

	def _tokenize(self, text_sample):
		assert self.tokenizer is not None, "Tokenizer required if data is not pretokenized"
		# if self.tokenizer._pad_token is None:
		#     # Some tokenizers (e.g. GPT2 tokenizer) have no padding token which causes bugs
		#     raise RuntimeError("If tokenizing on-the-fly, tokenizer must have a pad_token_id")

		return self.tokenizer(
			text_sample["text"],
			truncation=True,
			padding="max_length" if self.pad_sequences else False,
			max_length=self.max_seq_len,
		)

	def __getitem__(self, index: int):
		shard_id, shard_sample_id = self.spanner[index]
		if index == 303114:
			print (shard_id, shard_sample_id)
		shard = self.shards[shard_id]
		sample = shard[shard_sample_id]
		if "input_ids" in sample:
			for k in list(sample.keys()):
				if isinstance(sample[k], np.ndarray):
					sample[k] = sample[k][: self.max_seq_len]
				else:
					del sample[k]
			if "attention_mask" not in sample:
				sample["attention_mask"] = np.ones_like(sample["input_ids"])
			return sample
		elif "text" in sample:
			s = self._tokenize(sample)
		else:
			RuntimeError("Data sample must contain a field with `input_ids` or `text`")

	def __len__(self):
		return self.len

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import os
import json
from tqdm import tqdm


data_dir = '../data/promoters_mds_full/'

dataset = NoStreamingDataset(data_dir, "train", 10, tokenizer = lambda x, *args, **kwargs: x)
print (len(dataset))

35758782


In [7]:
dataset.__getitem__(303114)

163 143


In [9]:
!md5sum /mnt/nfs_dna/minja/DNALM/promoter_pretrain/train/shard.00163.mds

a4b812714567a4eeec2e24819f0b13ba  /mnt/nfs_dna/minja/DNALM/promoter_pretrain/train/shard.00163.mds


In [8]:
print("Starting dataset integrity check...")

for i in tqdm(range(len(dataset)), desc="Reading samples"):
	try:
		sample = dataset.__getitem__(i)
	except Exception as e:
		print(f"\nCorrupt sample at index {i}: {e}\nSample: {sample}")
		break

print("Check finished.")


Starting dataset integrity check...


Reading samples:   1%|          | 303305/35758782 [05:03<9:40:15, 1018.37it/s]

163 143


Reading samples:   1%|          | 352232/35758782 [05:51<9:49:23, 1001.20it/s]


KeyboardInterrupt: 

In [11]:
import pandas as pd
original = pd.read_csv('../data/promoters_mds_full/train.md5', sep=' ', header=None, names=['md5','NA', 'fname']).drop(columns=['NA'])
H200 = pd.read_csv('../data/promoters_mds_full/train_H200.md5', sep=' ', header=None, names=['md5H200','NA', 'fname']).drop(columns=['NA'])

assert len(original) == len(H200)
data = pd.merge(original, H200, on='fname', validate='one_to_one')

In [13]:
data['diff'] = data['md5'] != data['md5H200']
data[data['diff']]


Unnamed: 0,md5,fname,md5H200,diff
2762,4174201032735aaf4615e5f9b8bf54ba,train/shard.02761.mds,2f5b413880ab4568b17061f02bf456b2,True
2789,3de611cf19b5e10286f8a44d17124437,train/shard.02788.mds,b1d85d48c8a08ae5901e27db59b3b8a8,True
4099,f5de18521c9196d0e6b906d8f5f866cf,train/shard.04098.mds,36f96ec443afc43798085beb7fd7362c,True
4692,dc1def6d9dd37fba48690eda455408e8,train/shard.04691.mds,2ff513289ba58a88b6d10ecc392c4eaf,True
6078,cdc083ec6b2c4576107d9fb3b0f5385f,train/shard.06077.mds,6e5f93659864df7a033d7f50eddaf237,True
6489,40247af76e8a6d3b7bbf5f0cd3439c78,train/shard.06488.mds,53d66485f7709062afc6a074830084ca,True
12737,d36ea01a4a32043435c92f0c5819f484,train/shard.12736.mds,7a5fac7a7e8183b43ce371286fc890d7,True
15142,e7527ccac80772c8646af3c4e05d6681,train/shard.15141.mds,018b8eb6400de307453839437a7b5a42,True


In [14]:
import os
import subprocess

# Create tmp directory
os.makedirs('tmp', exist_ok=True)

# Get list of files that need to be redownloaded (where diff is True)
files_to_check = data[data['diff']]['fname'].tolist()

print("Checking files...")
for fname in files_to_check:
    shard_name = fname.split('/')[-1]
    s3_path = f"s3://genalm/data/pretraining/promoters/train/{shard_name}"
    local_path = f"tmp/{shard_name}"
    
    # Download file
    print(f"\nDownloading {shard_name}...")
    subprocess.run([
        "aws", "s3", "cp", 
        s3_path, local_path,
        "--endpoint-url", "https://s3.cloud.ru",
        "--profile", "airi"
    ])
    
    # Check MD5
    md5_proc = subprocess.run(["md5sum", local_path], capture_output=True, text=True)
    downloaded_md5 = md5_proc.stdout.split()[0]
    expected_md5 = data[data['fname'] == fname]['md5'].iloc[0]
    
    print(f"MD5 check for {shard_name}:")
    print(f"Expected:   {expected_md5}")
    print(f"Downloaded: {downloaded_md5}")
    print(f"Match: {downloaded_md5 == expected_md5}")


Checking files...

Downloading shard.02761.mds...
download: s3://genalm/data/pretraining/promoters/train/shard.02761.mds to tmp/shard.02761.mds
MD5 check for shard.02761.mds:
Expected:   4174201032735aaf4615e5f9b8bf54ba
Downloaded: 4174201032735aaf4615e5f9b8bf54ba
Match: True

Downloading shard.02788.mds...
download: s3://genalm/data/pretraining/promoters/train/shard.02788.mds to tmp/shard.02788.mds
MD5 check for shard.02788.mds:
Expected:   3de611cf19b5e10286f8a44d17124437
Downloaded: 3de611cf19b5e10286f8a44d17124437
Match: True

Downloading shard.04098.mds...
download: s3://genalm/data/pretraining/promoters/train/shard.04098.mds to tmp/shard.04098.mds
MD5 check for shard.04098.mds:
Expected:   f5de18521c9196d0e6b906d8f5f866cf
Downloaded: f5de18521c9196d0e6b906d8f5f866cf
Match: True

Downloading shard.04691.mds...
download: s3://genalm/data/pretraining/promoters/train/shard.04691.mds to tmp/shard.04691.mds
MD5 check for shard.04691.mds:
Expected:   dc1def6d9dd37fba48690eda455408e8
Dow

In [17]:
" ".join(files_to_check)

'train/shard.02761.mds train/shard.02788.mds train/shard.04098.mds train/shard.04691.mds train/shard.06077.mds train/shard.06488.mds train/shard.12736.mds train/shard.15141.mds'

In [20]:
files = " ".join(files_to_check)
print(f'for f in {files}; do shard_name=$(basename $f); echo "Downloading $shard_name..."; aws s3 cp "s3://genalm/data/pretraining/promoters/$shard_name" "./$f" --endpoint-url "https://s3.cloud.ru" --profile "airi"; done')


for f in train/shard.02761.mds train/shard.02788.mds train/shard.04098.mds train/shard.04691.mds train/shard.06077.mds train/shard.06488.mds train/shard.12736.mds train/shard.15141.mds; do shard_name=$(basename $f); echo "Downloading $shard_name..."; aws s3 cp "s3://genalm/data/pretraining/promoters/$shard_name" "./$f" --endpoint-url "https://s3.cloud.ru" --profile "airi"; done
