# Combined entries Loader Strategies

This notebook contains several strategies for combining multiple datasets into a single data loader using GluonTS. The goal is for each batch of the data loader to **contain a balanced mix of samples each datasets**.

## Initialization

Define a function to generate dummy datasets.

In [37]:
from gluonts.dataset.common import ListDataset
import pandas as pd
import numpy as np


def generate_dummy_dataset(
    start_date="01-01-2019",
    freq="1h",
    num_entries=10,
    num_timesteps=100,
    source=None,
):
    """
    Utility function to generate a dataset of random time series entries.

    Args:
        start_date (str or pd.Timestamp): The start timestamp for each entry
            (e.g., "2023-01-01").
        freq (str): The frequency of the time series (e.g., "1H", "1D").
        num_entries (int): The number of time series to generate.
        num_timesteps (int): The number of time steps in each time series.
        source (str): A label identifying which dataset the time series
                belongs to.

    Returns:
        ListDataset[Dict]: A list of dictionaries, where each dictionary is a
            time series entry with "start", "target", and "source" keys.
    """
    start = pd.Period(start_date, freq=freq)
    dataset = np.random.normal(size=(num_entries, num_timesteps))

    return ListDataset(
        [
            {
                "start": start,
                "target": target,
                "source": source,  # Custom field to track dataset
                "series_number": index,
            }
            for index, target in enumerate(dataset)
        ],
        freq=freq,
    )

Initialize some dummy datasets for us to use.

In [38]:
dataset_one = generate_dummy_dataset(
    source="Dataset 1",
    num_entries=10,
    num_timesteps=100,
)
dataset_two = generate_dummy_dataset(
    source="Dataset 2",
    num_entries=50,
    num_timesteps=500,
)
dataset_three = generate_dummy_dataset(
    source="Dataset 3",
    num_entries=100,
    num_timesteps=1000,
)
dataset_four = generate_dummy_dataset(
    source="Dataset 4",
    num_entries=1000,
    num_timesteps=1000,
)


context_length = 48
prediction_length = 24

Take a look at the first entry of each dataset.

In [39]:
datasets = [dataset_one, dataset_two, dataset_three, dataset_four]

for dataset in datasets:
    first_entry = next(iter(dataset))
    print(f"start: {first_entry['start']}")
    print(f"target: {first_entry['target'][:10]}")
    print(f"source: {first_entry['source']}")
    print(f"series number: {first_entry['series_number']}")
    print("-" * 80)

start: 2019-01-01 00:00
target: [-0.6787465  -0.95481426  0.80087197  0.60148644  1.9362926   0.6924021
 -0.63677317 -0.38945454  1.1554981  -2.0923688 ]
source: Dataset 1
series number: 0
--------------------------------------------------------------------------------
start: 2019-01-01 00:00
target: [ 0.33069488  0.2391058   1.4121292   1.9540036  -0.24719605  0.2795431
  0.82791764 -1.6132694  -0.91395915 -1.8362839 ]
source: Dataset 2
series number: 0
--------------------------------------------------------------------------------
start: 2019-01-01 00:00
target: [ 0.02935562 -2.1665876  -0.4610103  -0.15737787 -1.3025767  -0.7422203
  1.6018269   0.26374477 -0.93608266 -0.3171792 ]
source: Dataset 3
series number: 0
--------------------------------------------------------------------------------
start: 2019-01-01 00:00
target: [ 2.0260003   0.8996754  -0.062847   -1.1292576   1.2507395   0.89989054
  0.9305872  -0.89675844  1.0092256   1.8340392 ]
source: Dataset 4
series number: 0


## Strategies

Here are some stratigies I came up with for combining multiple datasets into a single data loader.

### Strategy 1

Store all of the datasets in a `DatasetCollection` object with `interleave` set to True.

In [40]:
from gluonts.dataset import DatasetCollection
from gluonts.itertools import Cyclic

dataset_collection = DatasetCollection(
    datasets=[Cyclic(dataset) for dataset in datasets],
    interleave=True,
)

Define transformations for the data loader.

In [41]:
from gluonts.dataset.field_names import FieldName
from gluonts.transform import (
    AddObservedValuesIndicator,
    InstanceSplitter,
    ExpectedNumInstanceSampler,
)

mask_unobserved = AddObservedValuesIndicator(
    target_field=FieldName.TARGET,
    output_field=FieldName.OBSERVED_VALUES,
)

instance_sampler = ExpectedNumInstanceSampler(
    num_instances=1,
    min_future=prediction_length,
)

time_series_fields = [FieldName.OBSERVED_VALUES]

splitter = InstanceSplitter(
    target_field=FieldName.TARGET,
    is_pad_field=FieldName.IS_PAD,
    start_field=FieldName.START,
    forecast_start_field=FieldName.FORECAST_START,
    instance_sampler=instance_sampler,
    past_length=context_length,
    future_length=prediction_length,
    time_series_fields=time_series_fields,
)

Initialize a `TrainDataLoader` with `shuffle_buffer_length` set to a non-zero integer.

In [42]:
from gluonts.dataset.loader import TrainDataLoader
from gluonts.torch.batchify import batchify

batch_size = 32
num_batches_per_epoch = 1000

data_loader = TrainDataLoader(
    dataset=dataset_collection,
    batch_size=batch_size,
    stack_fn=batchify,
    transform=mask_unobserved + splitter,
    num_batches_per_epoch=num_batches_per_epoch,
    # shuffle_buffer_length=1024,
)

Count how many samples belong to each dataset in each batch.
- Each batch is a dict of key-value pairs:
  - Each key is a different feature e.g. "start" or "past_target"
  - Each value is a list of `batch_size` elements 

In [43]:
from collections import Counter

counter = Counter()

for i, batch in enumerate(data_loader):
    print(f"Batch {i + 1}")

    series_numbers = batch["series_number"]
    sources = batch["source"]

    # See which time series where used for this batch
    print(
        {
            series_number: source
            for series_number, source in zip(series_numbers, sources)
        }
    )

    # Update counter with number of entries from each source in the batch
    counter.update(sources)
    print(f"Running counter: {dict(sorted(counter.items()))}")

    # Count the total number of entries in the batch
    total = sum(counter.values())

    # Convert each count to percentages
    percentages = {k: f"{str((v / total) * 100)}%" for k, v in counter.items()}
    print(f"Running percentages: {dict(sorted(percentages.items()))}")
    print("-" * 80)

Batch 1
{0: 'Dataset 4', 1: 'Dataset 4', 2: 'Dataset 4', 3: 'Dataset 3', 4: 'Dataset 4', 5: 'Dataset 4', 6: 'Dataset 4'}
Running counter: {'Dataset 1': 1, 'Dataset 2': 7, 'Dataset 3': 11, 'Dataset 4': 13}
Running percentages: {'Dataset 1': '3.125%', 'Dataset 2': '21.875%', 'Dataset 3': '34.375%', 'Dataset 4': '40.625%'}
--------------------------------------------------------------------------------
Batch 2
{7: 'Dataset 4', 8: 'Dataset 4', 9: 'Dataset 3', 10: 'Dataset 4', 1: 'Dataset 1', 11: 'Dataset 4', 12: 'Dataset 4', 13: 'Dataset 4', 14: 'Dataset 4', 15: 'Dataset 2'}
Running counter: {'Dataset 1': 2, 'Dataset 2': 13, 'Dataset 3': 24, 'Dataset 4': 25}
Running percentages: {'Dataset 1': '3.125%', 'Dataset 2': '20.3125%', 'Dataset 3': '37.5%', 'Dataset 4': '39.0625%'}
--------------------------------------------------------------------------------
Batch 3
{15: 'Dataset 4', 16: 'Dataset 4', 17: 'Dataset 4', 19: 'Dataset 3', 20: 'Dataset 4', 21: 'Dataset 3'}
Running counter: {'Dataset 1

Thoughts:
- There's an uneven number of samples from each dataset for all batches
  - The larger datasets dominate in terms of number of entries for all batches

### Strategy 2

In [44]:
def symmetric_extreme_swap(lst):
    """
    Swap the smallest element with the largest, the second smallest with the
    second largest, and so on.

    Args:
        lst (list): A list of numbers.

    Returns:
        list: A new list with the specified elements swapped.
    """
    lst_copy = lst.copy()
    sorted_indices = sorted(range(len(lst)), key=lambda i: lst[i])
    n = len(lst)

    for i in range(n // 2):
        small_idx = sorted_indices[i]
        large_idx = sorted_indices[-(i + 1)]
        lst_copy[small_idx], lst_copy[large_idx] = (
            lst_copy[large_idx],
            lst_copy[small_idx],
        )

    return lst_copy

Use `RandomYield` to increase the probability of sampling smaller datasets.

In [48]:
from gluonts.itertools import RandomYield, Cyclic

dataset_lengths = [len(list(dataset)) for dataset in datasets]
inverse_lengths = np.array([1.0 / length for length in dataset_lengths])
probabilities = (inverse_lengths / inverse_lengths.sum()).tolist()

print(probabilities)

random_yield = RandomYield(
    iterables=[Cyclic(dataset) for dataset in datasets],
)

[0.7633587786259542, 0.15267175572519084, 0.07633587786259542, 0.007633587786259542]


Create a dataloader.

In [49]:
data_loader = TrainDataLoader(
    dataset=random_yield,
    batch_size=batch_size,
    stack_fn=batchify,
    transform=mask_unobserved + splitter,
    num_batches_per_epoch=num_batches_per_epoch,
    shuffle_buffer_length=1024,
)

Count how many samples belong to each dataset in each batch.
- Each batch is a dict of key-value pairs:
  - Each key is a different feature e.g. "start" or "past_target"
  - Each value is a list of `batch_size` elements 

In [50]:
from collections import Counter

counter = Counter()

for i, batch in enumerate(data_loader):
    print(f"Batch {i + 1}")

    series_numbers = batch["series_number"]
    sources = batch["source"]

    # See which time series where used for this batch
    print(
        {
            series_number: source
            for series_number, source in zip(series_numbers, sources)
        }
    )

    # Update counter with number of entries from each source in the batch
    counter.update(sources)
    print(f"Running counter: {dict(sorted(counter.items()))}")

    # Count the total number of entries in the batch
    total = sum(counter.values())

    # Convert each count to percentages
    percentages = {k: f"{str((v / total) * 100)}%" for k, v in counter.items()}
    print(f"Running percentages: {dict(sorted(percentages.items()))}")
    print("-" * 80)

Batch 1
{179: 'Dataset 4', 41: 'Dataset 2', 137: 'Dataset 4', 83: 'Dataset 3', 11: 'Dataset 2', 87: 'Dataset 3', 63: 'Dataset 3', 15: 'Dataset 2', 8: 'Dataset 1', 69: 'Dataset 4', 85: 'Dataset 3', 7: 'Dataset 3'}
Running counter: {'Dataset 1': 1, 'Dataset 2': 7, 'Dataset 3': 14, 'Dataset 4': 10}
Running percentages: {'Dataset 1': '3.125%', 'Dataset 2': '21.875%', 'Dataset 3': '43.75%', 'Dataset 4': '31.25%'}
--------------------------------------------------------------------------------
Batch 2
{41: 'Dataset 2', 24: 'Dataset 2', 224: 'Dataset 4', 73: 'Dataset 3', 153: 'Dataset 4', 8: 'Dataset 4', 1: 'Dataset 1', 52: 'Dataset 3', 194: 'Dataset 4', 45: 'Dataset 3', 43: 'Dataset 3', 59: 'Dataset 3', 159: 'Dataset 4', 241: 'Dataset 4', 86: 'Dataset 3'}
Running counter: {'Dataset 1': 2, 'Dataset 2': 10, 'Dataset 3': 27, 'Dataset 4': 25}
Running percentages: {'Dataset 1': '3.125%', 'Dataset 2': '15.625%', 'Dataset 3': '42.1875%', 'Dataset 4': '39.0625%'}
------------------------------------