In [None]:
#|default_exp utils.splits

# Splits

> Helper function for splitting data.

In [None]:
#|export
from fastcore.all import *
import random as _random
from collections import defaultdict
from polvo.utils.misc import random_local_seed

  from tqdm.autonotebook import tqdm


In [None]:
#|export
def random(items, probs, seed=None):
    # Calculate split indexes
    n = len(items)
    p = [int(round(prob * n)) for prob in probs]  # convert percentage to absolute and round
    p[p.index(max(p))] += n - sum(p)  # adjusts the largest split to ensure the total sum matches the length of items
    # Shuffle items with the given seed
    with random_local_seed(seed):
        shuffled = items[:]
        _random.shuffle(shuffled)
    # Create cumulative split indexes and perform the split
    split_points = [sum(p[:i+1]) for i in range(len(p))]
    splits = [shuffled[start:end] for start, end in zip([0] + split_points[:-1], split_points)]
    return splits

In [None]:
random(list(range(10)), [0.8, 0.1, 0.1])

[[6, 9, 3, 5, 7, 1, 4, 2], [8], [0]]

In [None]:
#|export
def from_fn(items, fn):
    "`fn` should return the index for each subset"
    splits = defaultdict(list)
    for item in items: splits[fn(item)].append(item)
    return list(zip(*sorted(splits.items())))[1]

In [None]:
from_fn(list(range(10)), lambda x: 0 if x>4 else 1)

([5, 6, 7, 8, 9], [0, 1, 2, 3, 4])

In [None]:
#|hide
from nbdev import nbdev_export
nbdev_export()