In [45]:
import glob
import webdataset as wds
import os

data_dir = '../data/cpg_methylation/hyenadna-tiny-1k'

def print_samples_id(dataloader, n_samples=10):
    for idx, sample in enumerate(dataloader):
        print(sample['__key__'])
        
        if idx == n_samples-1:
            break

# Types of shuffling in WebDataset

Usually, a first shuffle is given to the entirety of the dataset. This allows to break the positional relationships between samples.

However, this step is not performed in the BEND benchmark. Instead, the sorted, by chromosome and sequence position, annotation .bed file is used for the embeddings. Hence sequences are embedded sequentially and saved into WebDataset shards while keeping their positional order.

### Creating shards

The `chunk_size` parameter in `embed.yml` defines the shard size.

Annotation data is split into chunks of the given `chunk size`, each chunk is embedded by processing samples sequentially and saved into a `shard` (a tar file).

Default `chunk_size` is 50,000


#### shuffling shards

Shards are nto shuffled explicitly (or maybe knowingly): while shards are names after `split_chunk-id`, they are read by `glob.glob(path_to_shards)`. Hence, they are NOT in ascending order by `chunk_id`, which is a degree of 'shuffling'.

Following this behaviour, the more the shards, the more the shuffling.

The proper way of shuffling at shard level is to set `shardShuffle= # shards` when creating the `WebDataset`.

In [46]:
split = 'test'

tars = glob.glob(f"{data_dir}/*.tar.gz")
data = [x for x in tars if os.path.split(x)[-1].startswith(split)]
data

['../data/cpg_methylation/hyenadna-tiny-1k/test_1.tar.gz',
 '../data/cpg_methylation/hyenadna-tiny-1k/test_2.tar.gz',
 '../data/cpg_methylation/hyenadna-tiny-1k/test_0.tar.gz']

In [None]:
# shardShuffle is set to None in the original code, which would lead to not be shuffled.
# However, this raises a warning asking to be set explicitly to False or a number.
# To avoid the warning and keep the original behavior, we set it to False.

dataset = wds.WebDataset(data, shardshuffle=False)

dataloader = wds.WebLoader(dataset, num_workers=0, batch_size=None)

print_samples_id(dataloader)

sample_50000
sample_50001
sample_50002
sample_50003
sample_50004
sample_50005
sample_50006
sample_50007
sample_50008
sample_50009


### In-memory shuffling
Each worker sequentially loads samples from one or more shards into batches (cannot load from 0 shards, will throw an error - hence make sure num workers<= num of shards).

By using `.shuffle(buffer_size)`, a buffer can be created so that samples are loaded into it first, then randomly distributed into batches.

Having a buffer size = to the number of samples into a shard allow to completely shuffle such shard.

In [50]:
dataset = wds.WebDataset(data, shardshuffle=False)


# Explanation of the initial argument: https://github.com/webdataset/webdataset/issues/62
# Basically allows to await streaming data until the given amount of samples are shuffled.
buffer_size = 200
dataset = dataset.shuffle(buffer_size, initial=buffer_size)

dataloader = wds.WebLoader(dataset, num_workers=0, batch_size=None)

print_samples_id(dataloader)

sample_50094
sample_50193
sample_50184
sample_50175
sample_50177
sample_50201
sample_50102
sample_50054
sample_50166
sample_50183


### Multiprocessing shuffling

When the dataloader's number of workers parameter is set to 0, one shard is loaded into memory at a time. Consequently, the shards are accessed in the order in which the `.tar` files where loaded by `glob.glob()`.

In [51]:
data

['../data/cpg_methylation/hyenadna-tiny-1k/test_1.tar.gz',
 '../data/cpg_methylation/hyenadna-tiny-1k/test_2.tar.gz',
 '../data/cpg_methylation/hyenadna-tiny-1k/test_0.tar.gz']

However, increasing the number of workers to values greaten than 0, enables multiprocessing, and the number of shards accessed simultaneously is equal to the number of workers.

This greatly increases randomisation, as samples of multiple shards are processed at once. If the number of shards = the number of workers, it allows to use samples from any __section__ of the dataset. 

However, each shard/section is accessed sequentially.

In [52]:
dataset = wds.WebDataset(data, shardshuffle=False)

dataloader = wds.WebLoader(dataset, num_workers=len(data), batch_size=None)

print_samples_id(dataloader)    

sample_50000
sample_100000
sample_0
sample_50001
sample_100001
sample_1
sample_50002
sample_100002
sample_2
sample_50003


Of course the `dataset.shuffle(buffer_size)` function can be used to allow the shuffling of the first N samples of each shard, where N=buffer_size.

In [53]:
dataset = wds.WebDataset(data, shardshuffle=False)

buffer_size = 200
dataset = dataset.shuffle(buffer_size, initial=buffer_size)

dataloader = wds.WebLoader(dataset, num_workers=len(data), batch_size=None)

print_samples_id(dataloader)
    

sample_50130
sample_100060
sample_17
sample_50152
sample_100139
sample_74
sample_50102
sample_100027
sample_156
sample_50028
