In [None]:
# default_exp data.mixed

# Mixed data

> DataLoader than can take data from multiple dataloaders with different types of data

In [None]:
#export
from tsai.imports import *

In [None]:
# export
# This implementation of a mixed dataloader is based on a great implementation created by Zach Mueller in this fastai thread:
# https://forums.fast.ai/t/combining-tabular-images-in-fastai2-and-should-work-with-almost-any-other-type/73197

from packaging import version
from fastai.data.load import _FakeLoader
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter, _DatasetKind
_loaders = (_MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter)


class MixedDataLoader():
    def __init__(self, *loaders, path='.', shuffle=False, device=None, bs=None):
        "Accepts any number of `DataLoader` and a device"
        self.path = path
        device = ifnone(device, default_device())
        self.device = device
        self.c = None
        self.d = None
        self.bs = ifnone(bs, min([dl.bs for dl in loaders]))
        for i, dl in enumerate(loaders):  # ensure all dls have the same bs
            if hasattr(dl, 'vars'):
                self.vars = dl.vars
            if hasattr(dl, 'len'):
                self.len = dl.len
            if hasattr(dl, 'split_idxs'):
                self.split_idxs = dl.split_idxs
            dl.bs = self.bs
            dl.shuffle_fn = self.shuffle_fn
            if self.c is None and hasattr(dl, "c"):
                self.c = dl.c
            if self.d is None and hasattr(dl, "d"):
                self.d = dl.d
            if i == 0:
                self.dataset = dl.dataset
            dl.to(device=device)
        self.shuffle = shuffle
        if not self.shuffle:
            self.rng = np.arange(len(self.dataset)).tolist()
        self.loaders = loaders
        self.count = 0
        self.fake_l = _FakeLoader(self, False, 0, 0, 0) if version.parse(
            fastai.__version__) >= version.parse("2.1") else _FakeLoader(self, False, 0, 0)
        if sum([len(dl.dataset) for dl in loaders]) > 0:
            self._get_idxs()  # Do not apply on an empty dataset

    def new(self, *args, **kwargs):
        loaders = [dl.new(*args, **kwargs) for dl in self.loaders]
        return type(self)(*loaders, path=self.path, device=self.device)

#     def __len__(self): return len(self.loaders[0])
    def __len__(self): return self.loaders[0].__len__()

    def _get_vals(self, x):
        "Checks for duplicates in batches"
        idxs, new_x = [], []
        for i, o in enumerate(x):
            x[i] = o.cpu().numpy().flatten()
        for idx, o in enumerate(x):
            if not self._arrayisin(o, new_x):
                idxs.append(idx)
                new_x.append(o)
        return idxs

    def _get_idxs(self):
        "Get `x` and `y` indices for batches of data"
        self.n_inps = [dl.n_inp for dl in self.loaders]
        self.x_idxs = self._split_idxs(self.n_inps)

        # Identify duplicate targets
        dl_dict = dict(zip(range(0, len(self.loaders)), self.n_inps))
        outs = L([])
        for key, n_inp in dl_dict.items():
            b = next(iter(self.loaders[key]))
            outs += L(b[n_inp:])
        self.y_idxs = self._get_vals(outs)

    def __iter__(self):
        z = zip(*[_loaders[i.fake_l.num_workers == 0](i.fake_l) for i in self.loaders])
        for b in z:
            inps = []
            outs = []
            if self.device is not None:
                b = to_device(b, self.device)
            for batch, dl in zip(b, self.loaders):
                if hasattr(dl, 'idxs'): self.idxs = dl.idxs
                if hasattr(dl, 'input_idxs'): self.input_idxs = dl.input_idxs
                batch = dl.after_batch(batch)
                inps += batch[:dl.n_inp]
                outs += batch[dl.n_inp:]
            inps = tuple([tuple(L(inps)[idx]) if isinstance(idx, list) else inps[idx]
                          for idx in self.x_idxs]) if len(self.x_idxs) > 1 else tuple(L(outs)[self.x_idxs][0])
            outs = tuple(L(outs)[self.y_idxs]) if len(self.y_idxs) > 1 else L(outs)[self.y_idxs][0]
            yield inps, outs

    def one_batch(self):
        "Grab one batch of data"
        with self.fake_l.no_multiproc():
            res = first(self)
        if hasattr(self, 'it'):
            delattr(self, 'it')
        return res

    def shuffle_fn(self, idxs):
        "Generate the same idxs for all dls in each batch when shuffled"
        if self.count == 0:
            self.shuffled_idxs = np.random.permutation(idxs)
        # sort each batch
        for i in range(len(self.shuffled_idxs)//self.bs + 1):
            self.shuffled_idxs[i*self.bs:(i+1)*self.bs] = np.sort(self.shuffled_idxs[i*self.bs:(i+1)*self.bs])
        self.count += 1
        if self.count == len(self.loaders):
            self.count = 0
        return self.shuffled_idxs

    def show_batch(self):
        "Show a batch of data"
        for dl in self.loaders:
            dl.show_batch()

    def to(self, device): self.device = device

    def _arrayisin(self, arr, arr_list):
        "Checks if `arr` is in `arr_list`"
        for a in arr_list:
            if np.array_equal(arr, a):
                return True
        return False

    def _split_idxs(self, a):
        a_cum = np.array(a).cumsum().tolist()
        b = np.arange(sum(a)).tolist()
        start = 0
        b_ = []
        for i, idx in enumerate(range(len(a))):
            end = a_cum[i]
            b_.append(b[start:end] if end - start > 1 else b[start])
            start = end
        return b_
    

class MixedDataLoaders(DataLoaders):
    pass

In [None]:
# export

def get_mixed_dls(*dls, device=None, shuffle_train=None, shuffle_valid=None, **kwargs):
    _mixed_train_dls = []
    _mixed_valid_dls = []
    for dl in dls:
        _mixed_train_dls.append(dl.train)
        _mixed_valid_dls.append(dl.valid)
        if shuffle_train is None: shuffle_train = dl.train.shuffle
        if shuffle_valid is None: shuffle_valid = dl.valid.shuffle
        if device is None: device = dl.train.device
    mixed_train_dl = MixedDataLoader(*_mixed_train_dls, shuffle=shuffle_train, **kwargs)
    mixed_valid_dl = MixedDataLoader(*_mixed_valid_dls, shuffle=shuffle_valid, **kwargs)
    mixed_dls = MixedDataLoaders(mixed_train_dl, mixed_valid_dl, device=device)
    return mixed_dls

In [None]:
from tsai.data.tabular import *

path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
# df['salary'] = np.random.rand(len(df)) # uncomment to simulate a cont dependent variable
target = 'salary'
splits = RandomSplitter()(range_of(df))

cat_names = ['workclass', 'education', 'marital-status']
cont_names = ['age', 'fnlwgt']
dls1 = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, y_names=target, splits=splits, bs=512)
dls1.show_batch()

cat_names = None #['occupation', 'relationship', 'race']
cont_names = ['education-num']
dls2 = get_tabular_dls(df, cat_names=cat_names, cont_names=cont_names, y_names=target, splits=splits, bs=128)
dls2.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,?,Some-college,Married-civ-spouse,62.999999,149697.998687,<50k
1,Private,5th-6th,Separated,36.0,177616.000313,<50k
2,Local-gov,Some-college,Separated,30.0,178383.000379,<50k
3,Self-emp-not-inc,Some-college,Married-civ-spouse,27.0,411950.01119,<50k
4,Private,Bachelors,Married-civ-spouse,37.0,192938.999932,>=50k
5,Private,Masters,Divorced,54.0,161691.000074,>=50k
6,Private,HS-grad,Married-civ-spouse,36.0,95336.001179,>=50k
7,State-gov,HS-grad,Married-civ-spouse,46.0,273770.997233,<50k
8,Self-emp-not-inc,HS-grad,Married-civ-spouse,68.000001,197015.000115,<50k
9,Self-emp-not-inc,Some-college,Married-civ-spouse,28.0,149323.999684,<50k


Unnamed: 0,education-num_na,education-num,salary
0,False,10.0,<50k
1,False,9.0,<50k
2,False,9.0,>=50k
3,False,10.0,<50k
4,False,13.0,>=50k
5,False,13.0,<50k
6,False,10.0,<50k
7,False,11.0,<50k
8,False,12.0,>=50k
9,False,10.0,<50k


In [None]:
dls = get_mixed_dls(dls1, dls2, bs=8)
first(dls.train)
first(dls.valid)
torch.save(dls,'export/mixed_dls.pth')
del dls
dls = torch.load('export/mixed_dls.pth')
dls.train.show_batch()

Unnamed: 0,workclass,education,marital-status,age,fnlwgt,salary
0,Self-emp-not-inc,HS-grad,Never-married,19.0,137577.998451,<50k
1,Private,HS-grad,Never-married,23.000001,199884.000276,<50k
2,Self-emp-not-inc,Prof-school,Married-civ-spouse,52.999999,33303.99988,>=50k
3,Private,10th,Never-married,28.0,204516.000215,<50k
4,Private,10th,Never-married,28.0,412148.999546,<50k
5,Self-emp-not-inc,Bachelors,Married-civ-spouse,57.999999,310013.997374,<50k
6,State-gov,HS-grad,Never-married,35.0,237873.000453,<50k
7,Private,Bachelors,Married-civ-spouse,37.0,178948.000389,>=50k


Unnamed: 0,education-num_na,education-num,salary
0,False,9.0,<50k
1,False,9.0,<50k
2,False,15.0,>=50k
3,False,6.0,<50k
4,False,6.0,<50k
5,False,13.0,<50k
6,False,9.0,<50k
7,False,13.0,>=50k


In [None]:
xb, yb = first(dls.train)
xb

((tensor([[ 7, 12,  5],
          [ 5, 12,  5],
          [ 7, 15,  3],
          [ 5,  1,  5],
          [ 5,  1,  5],
          [ 7, 10,  3],
          [ 8, 12,  5],
          [ 5, 10,  3]]),
  tensor([[-1.4394, -0.4971],
          [-1.1456,  0.0957],
          [ 1.0581, -1.4893],
          [-0.7783,  0.1397],
          [-0.7783,  2.1153],
          [ 1.4254,  1.1435],
          [-0.2641,  0.4571],
          [-0.1172, -0.1035]])),
 (tensor([[1],
          [1],
          [1],
          [1],
          [1],
          [1],
          [1],
          [1]]),
  tensor([[-0.4232],
          [-0.4232],
          [ 1.9228],
          [-1.5961],
          [-1.5961],
          [ 1.1408],
          [-0.4232],
          [ 1.1408]])))

In [None]:
xs, ys = first(dls.train)
xs[0][0].shape, xs[0][1].shape, xs[1][0].shape, xs[1][1].shape

(torch.Size([8, 3]),
 torch.Size([8, 2]),
 torch.Size([8, 1]),
 torch.Size([8, 1]))

In [None]:
from tsai.data.validation import TimeSplitter
from tsai.data.core import TSRegression, get_ts_dls
X = np.repeat(np.repeat(np.arange(8)[:, None, None], 2, 1), 5, 2).astype(float)
X = np.concatenate([X, X])
y = np.concatenate([np.arange(len(X)//2)]*2)
alphabet = np.array(list(string.ascii_lowercase))
# y = alphabet[y]
splits = TimeSplitter(.5, show_plot=False)(range_of(X))
tfms = [None, TSRegression()]
dls1 = get_ts_dls(X, y, splits=splits, tfms=tfms)
dls1.one_batch()

(TSTensor(samples:8, vars:2, len:5, device=cpu),
 tensor([0., 1., 2., 3., 4., 5., 6., 7.]))

In [None]:
data = np.concatenate([np.repeat(np.arange(8)[:, None], 3, 1)*np.array([1, 10, 100])]*2)
df = pd.DataFrame(data, columns=['cat1', 'cat2', 'cont'])
df['cont'] = df['cont'].astype(float)
df['target'] = y
cat_names = ['cat1', 'cat2']
cont_names = ['cont']
target = 'target'
dls2 = get_tabular_dls(df, procs=[Categorify, FillMissing, #Normalize
                                 ], cat_names=cat_names, cont_names=cont_names, y_names=target, splits=splits, bs=8)
dls2.one_batch()

(tensor([[7, 7],
         [1, 1],
         [8, 8],
         [6, 6],
         [4, 4],
         [2, 2],
         [5, 5],
         [3, 3]]),
 tensor([[600.],
         [  0.],
         [700.],
         [500.],
         [300.],
         [100.],
         [400.],
         [200.]]),
 tensor([[6],
         [0],
         [7],
         [5],
         [3],
         [1],
         [4],
         [2]], dtype=torch.int8))

In [None]:
z = zip(_loaders[dls1.train.fake_l.num_workers == 0](dls1.train.fake_l))
for b in z: 
    print(b)
    break

((TSTensor(samples:8, vars:2, len:5, device=cpu), tensor([0., 1., 2., 3., 4., 5., 6., 7.])),)


In [None]:
bs = 8
dls = get_mixed_dls(dls1, dls2, bs=bs)
dl = dls.train
xb, yb = dl.one_batch()
test_eq(len(xb), 2)
test_eq(len(xb[0]), bs)
test_eq(len(xb[1]), 2)
test_eq(len(xb[1][0]), bs)
test_eq(len(xb[1][1]), bs)
test_eq(xb[0].data[:, 0, 0].long(), xb[1][0][:, 0] - 1) # categorical data and ts are in synch
test_eq(xb[0].data[:, 0, 0], (xb[1][1]/100).flatten()) # continuous data and ts are in synch
test_eq(tensor(dl.input_idxs), yb.long().cpu())
dl = dls.valid
xb, yb = dl.one_batch()
test_eq(tensor(y[dl.input_idxs]), yb.long().cpu())

In [None]:
#hide
from tsai.imports import create_scripts
from tsai.export import get_nb_name
nb_name = get_nb_name()
create_scripts(nb_name);