In [82]:
import pandas as pd
import numpy as np
from schema import Schema


def arr2_in_arr1(arr1, arr2):
    """ Find index of arr2 in arr1 if arr2 is subset of arr1.

    Args:
        arr1 (np.array):
        arr2 (np.array):

    Returns:

    """

    sort_idx = arr1.argsort()
    ret = sort_idx[np.searchsorted(arr1, arr2, sorter=sort_idx)]
    return ret

In [83]:

class Fold(dict):
    def __init__(self, *args, **kw):
        """
        fold = { <dataset name>: array([indices]), <dataset name>: array([indices]), ... }
        """
        super().__init__(*args, **kw)

        self.dataset_names = list(self.keys())  # todo: ordered?
        self.dtype = self[self.dataset_names[0]].dtype.type
        self._index = None

        # checks
        assert self.dtype in [np.str_, np.int64], f"Dataset indices must be int or str"
        assert 'train' in self.dataset_names, f"At least one dataset must be named 'train'"
        for k, v in self.items():
            assert v.dtype.type == self.dtype, f"Dataset indices must be of same dtype"
            assert isinstance(k, str), f"Dataset names must be string"
            assert isinstance(v, np.ndarray), f"Dataset indices must be arrays"

    def __str__(self):
        ret = ""
        for d in self.dataset_names:
            ret = f"{ret}Dataset: {d:<10} Num points: {len(self[d])}\n"
        return ret

    @property
    def index(self):
        # cached property - calculated only when called first time as it can be expensive for large folds
        if self._index is None:
            ret = np.array([], dtype=self.dtype)
            for d in self.dataset_names:
                ret = np.append(ret, np.array(self[d]))
            self._index = np.unique(ret)
        return self._index


In [138]:
class Split(list):
    def __init__(self, *args, **kw):
        """
        split = [ Fold(), Fold(), ... ]
        """
        super().__init__(*args, **kw)
        self.dataset_names = self[0].dataset_names
        self.dtype = self[0].dtype
        self._index = None

        # checks
        for v in self:
            assert isinstance(v, Fold)
            assert list(v.keys()) == self.dataset_names, f"All folds must have same dataset names"
            assert v.dtype == self.dtype, f"All folds must be of same dtype"
            
    @property
    def index(self):
        # cached property - calculated only when called first time as it can be expensive for large folds
        if self._index is None:
            ret = np.array([], dtype=self.dtype)
            for name in range(len(self)):
                for dataset in self.dataset_names:
                    ret = np.append(ret, np.array(self[name][dataset], dtype=self.dtype))
            self._index = np.unique(ret)
        return self._index
    
    def __str__(self):
        ret = ""
        for name in range(len(self)):
            for d in self.dataset_names:
                ret = f"{ret}Fold: {name:<10} Dataset: {d:<10} Num points: {len(self[name][d])}\n"

        ret = f"{ret}\n{len(self)} folds, {len(self.dataset_names)} datasets, {len(self.index)} points"
        return ret
    
    def reset_index(self):
        index = self.index
        new_split = []
        for fold_name, fold in enumerate(self):
            new_split.append(Fold(
                {dataset_name: arr2_in_arr1(arr1=index, arr2=indices) for dataset_name, indices in fold.items()})
                            )

        return Split(new_split)
    def iter(self):
        for ix in range(len(self)):
            # yields <fold no>, <watchlist>
            # where <watchlist> = [(<dataset name 1> , <dataset idx 1>), (<dataset name 2> , <dataset idx 2>), .. ]
            yield ix, [(d, self[ix][d]) for d in self.dataset_names]


In [141]:
s = Split([Fold({'train': np.array(['1','2'])}),
           Fold({'train': np.array(['2','3'])})])
s.index
print(s)
# s.index
# print(s)
s.reset_index()
for k, v in s.iter():
    k, v

array(['1', '2', '3'], dtype='<U1')

Fold: 0          Dataset: train      Num points: 2
Fold: 1          Dataset: train      Num points: 2

2 folds, 1 datasets, 3 points


[{'train': array([0, 1])}, {'train': array([1, 2])}]

(0, [('train', array(['1', '2'], dtype='<U1'))])

(1, [('train', array(['2', '3'], dtype='<U1'))])

In [40]:
x[1].dtype.type

numpy.str_