loader/core.py

In [1]:
import math
import random
import numpy as np
import pandas as pd
from contextlib import contextmanager
from copy import deepcopy

import torch.utils.data

In [2]:
# ignore
import sys; sys.path.append("..")

In [3]:
# replace(torchtable, ..custom_types)
from torchtable import *

In [4]:
# replace(torchtable, .)
from torchtable.utils import *

In [5]:
class RandomShuffler(object):
    """
    Use random functions while keeping track of the random state to make it
    reproducible and deterministic. Borrowed from torchtext.
    """

    def __init__(self, random_state=None):
        self._random_state = random_state
        if self._random_state is None:
            self._random_state = random.getstate()

    @contextmanager
    def use_internal_state(self):
        """Use a specific RNG state."""
        old_state = random.getstate()
        random.setstate(self._random_state)
        yield
        self._random_state = random.getstate()
        random.setstate(old_state)

    @property
    def random_state(self):
        return deepcopy(self._random_state)

    @random_state.setter
    def random_state(self, s):
        self._random_state = s

    def __call__(self, data):
        """Shuffle and return a new list."""
        with self.use_internal_state():
            return random.sample(data, len(data))

In [6]:
ProcessedBatch = Tuple[Dict[ColumnName, OneorMore[torch.tensor]], Dict[ColumnName, OneorMore[torch.tensor]]]

In [7]:
class DefaultLoader(torch.utils.data.DataLoader):
    """
    Defines an iterator that loads batches of data from a Dataset.
    Heavily based on the Iterator from torchtext.

    Args:
        dataset: The Dataset object to load examples from.
        batch_size: Batch size.
        repeat: Whether to repeat the iterator for multiple epochs.
        shuffle: Whether to shuffle examples between epochs.
        device (str or `torch.device`): A string or instance of `torch.device`
            specifying which device the Variables are going to be created on.
            If None, the tensors will be created on cpu.
    """

    def __init__(self, dataset: torch.utils.data.Dataset, batch_size: int,
                 device: Optional[torch.device]=None, repeat: bool=False,
                 shuffle: Optional[bool]=None):
        self.batch_size, self.dataset = batch_size, dataset
        self.iterations = 0
        self.repeat = repeat
        self.shuffle = with_default(shuffle, self.dataset.train)

        if isinstance(device, int):
            warnings.warn("The `device` argument should be set by using `torch.device`" +
                           " or passing a string as an argument. This behavior will be" +
                           " deprecated soon and currently defaults to cpu.")
            device = None
        self.device = device
        if self.shuffle:
            # TODO: Clean interface
            self.index_generator = RandomShuffler()
        else:
            self.index_generator = lambda x: x

        # For state loading/saving only
        self._iterations_this_epoch = 0
        self._random_state_this_epoch = None
        self._restored_from_state = False
    
    @classmethod
    def from_dataset(cls, dataset: torch.utils.data.Dataset, batch_size: int,
                 device: torch.device=None, repeat: bool=False, shuffle: Optional[bool]=None):
        return cls(dataset, batch_size, device=device, repeat=repeat, shuffle=shuffle)
    
    @classmethod
    def from_datasets(cls, train_ds: torch.utils.data.Dataset, batch_size: OneorMore[int],
                      val_ds: Optional[torch.utils.data.Dataset]=None, test_ds: Optional[torch.utils.data.Dataset]=None,
                      device: OneorMore[torch.device]=None, repeat: OneorMore[bool]=False,
                      shuffle: Optional[OneorMore[Optional[bool]]]=None) -> Iterable['DefaultLoader']:
        n_ds = 1
        if val_ds is not None: n_ds += 1
        if test_ds is not None: n_ds += 1
            
        args = (expand(batch_size, n_ds), )
        kwargs = {
            "device": expand(device, n_ds),
            "repeat": expand(repeat, n_ds),
            "shuffle": expand(shuffle, n_ds),
        }
        
        i = 0
        yield cls.from_dataset(train_ds, *([a[i] for a in args]), **({k: v[i] for k, v in kwargs.items()}))
        i += 1
        if val_ds is not None:
            yield cls.from_dataset(val_ds, *([a[i] for a in args]), **({k: v[i] for k, v in kwargs.items()}))
            i += 1
        if test_ds is not None:
            yield cls.from_dataset(test_ds, *([a[i] for a in args]), **({k: v[i] for k, v in kwargs.items()}))

    def _process_batch(self, data: Dict[ColumnName, OneorMore[ArrayLike]]) -> ProcessedBatch:
        """
        Converts examples in a dataset to model inputs by using the fields to transform
        the inputs to tensors. Override in subclass to add custom behavior.
        """
        in_data = {}
        tgt_data = {}
        for k, batch in data.items():
            fld = self.dataset.fields[k]
            if isinstance(fld, (tuple, list)):
                for f, v in zip(fld, batch):
                    data_dict = tgt_data if f.is_target else in_data
                    if k not in data_dict: data_dict[k] = []
                    data_dict[k].append(f.transform_batch(v, device=self.device, train=self.dataset.train))
            else:
                tsr = fld.transform_batch(batch, device=self.device, train=self.dataset.train)
                # add to data dicts
                if fld.is_target: tgt_data[k] = tsr
                else: in_data[k] = tsr
        return in_data, tgt_data
            
    def _batches(self) -> Iterable[ProcessedBatch]:
        """
        Iterates through the dataset while generating batches of input and target variables.
        Assumes dataset can be indexed using a list.
        """
        indices = []
        for i in self.index_generator(range(len(self.dataset))):
            indices.append(i)
            if len(indices) == self.batch_size:
                yield self._process_batch(self.dataset[indices])
                indices = []
        if len(indices) > 0:
            yield self._process_batch(self.dataset[indices])    

    def init_epoch(self):
        """Set up the batch generator for a new epoch."""
        if self.shuffle:
            if self._restored_from_state:
                self.index_generator.random_state = self._random_state_this_epoch
            else:
                self._random_state_this_epoch = self.index_generator.random_state
        
        if self._restored_from_state:
            self._restored_from_state = False
        else:
            self._iterations_this_epoch = 0

        if not self.repeat: self.iterations = 0
    
    @property
    def epoch(self):
        return math.floor(self.iterations / len(self))

    def __len__(self):
        return math.ceil(len(self.dataset) / self.batch_size)

    def __iter__(self) -> Iterable[Dict[str, torch.tensor]]:
        while True:
            self.init_epoch()
            for idx, minibatch in enumerate(self._batches()):
                # fast-forward if loaded from state
                if self._iterations_this_epoch > idx:
                    continue
                self.iterations += 1
                self._iterations_this_epoch += 1
                yield minibatch
            if not self.repeat:
                break

    def state_dict(self) -> Dict[str, Any]:
        return {
            "iterations": self.iterations,
            "iterations_this_epoch": self._iterations_this_epoch,
            "random_state_this_epoch": self._random_state_this_epoch,
        }

    def load_state_dict(self, state_dict: Dict[str, Any]):
        self.iterations = state_dict["iterations"]
        self._iterations_this_epoch = state_dict["iterations_this_epoch"]
        self._random_state_this_epoch = state_dict["random_state_this_epoch"]
        self._restored_from_state = True

# Tests

test_loader.py

In [8]:
import pytest
import itertools

In [9]:
# uncomment
# from torchtable import *
# from torchtable.field import *
# from torchtable.dataset import *
# from torchtable.loader import *

In [10]:
# ignore
from torchtable.field import *
from torchtable.dataset import *

In [11]:
def flatten(x):
    for v in x:
        if isinstance(v, (tuple, list)):
            yield from v
        else:
            yield v

In [12]:
# test_from_dataset
df = pd.DataFrame({"a": [1, 2, 3, 4, 5],
                   "b": [-0.4, -2.1, 3.3, 4.4, 5.5]})
ds = TabularDataset.from_df(df, fields={
    "a": CategoricalField(max_features=100),
    "b": [NumericField(normalization="Gaussian"), IdentityField()],
})
dl = DefaultLoader

In [13]:
# test_from_datasets
df1 = pd.DataFrame({"a": [1, 2, 3, 4, 5],
                   "b": [-0.4, -2.1, 3.3, 4.4, 5.5]})
df2 = pd.DataFrame({"a": [1, 2, 3], "b": [-1., -2, -3.]})
df3 = pd.DataFrame({"a": [3, 2], "b": [-1., -2]})
train, val, test = TabularDataset.from_dfs(df1, val_df=df2, test_df=df3, fields={
    "a": CategoricalField(),
    "b": [NumericField(normalization="Gaussian"), CategoricalField(handle_unk=True)],
})
# all present
train_dl, val_dl, test_dl = DefaultLoader.from_datasets(train, 3, val_ds=val, test_ds=test)
# val only
train_dl, val_dl = DefaultLoader.from_datasets(train, 3, val_ds=val, test_ds=None)
# test only
train_dl, test_dl = DefaultLoader.from_datasets(train, 3, val_ds=None, test_ds=test)

In [14]:
# test_from_datasets_multiple_args
df1 = pd.DataFrame({"a": [3, 4, 5, 1, 2],
                   "b": [1.3, -2.1, 2.3, 5.4, 5.6]})
df2 = pd.DataFrame({"a": [1, 2, 3, 4, 5], "b": [-1., -2, -3., -4., -5.]})
df3 = pd.DataFrame({"a": [3, 2], "b": [-1., -2]})
train, val, test = TabularDataset.from_dfs(df1, val_df=df2, test_df=df3, fields={
    "a": CategoricalField(),
    "b": [NumericField(normalization="Gaussian"), CategoricalField(handle_unk=True)],
})
train_dl, val_dl, test_dl = DefaultLoader.from_datasets(train, (5, 3, 2), val_ds=val, test_ds=test,
                                                        device=(None, None, None), repeat=(True, True, True),
                                                        shuffle=(True, True, True))
x, y = next(iter(train_dl))
for v in flatten(itertools.chain(x.values(), y.values())): assert v.size()[0] == 5
x, y = next(iter(val_dl))
for v in flatten(itertools.chain(x.values(), y.values())): assert v.size()[0] == 3
x, y = next(iter(test_dl))
for v in flatten(itertools.chain(x.values(), y.values())): assert v.size()[0] == 2
    
train_dl, val_dl = DefaultLoader.from_datasets(train, (3, 4), val_ds=val, test_ds=None)
x, y = next(iter(train_dl))
for v in flatten(itertools.chain(x.values(), y.values())): assert v.size()[0] == 3
x, y = next(iter(val_dl))
for v in flatten(itertools.chain(x.values(), y.values())): assert v.size()[0] == 4

In [15]:
# test_real_data
"""Smoke test for real dataset"""
df = pd.read_csv("./tests/resources/sample.csv")
ds = TabularDataset.from_df(df, fields={
    "category_1": None,
    "category_3": None,
    "merchant_id": None,
    "subsector_id": CategoricalField(min_freq=3),
    "merchant_category_id": CategoricalField(min_freq=3),
    "city_id": None,
    "month_lag": NumericField(normalization="RankGaussian"),
    "card_id": None,
    "installments": NumericField(normalization=None),
    "state_id": CategoricalField(),
    "category_2": NumericField(normalization=None),
    "authorized_flag": CategoricalField(min_freq=3, handle_unk=True),
    "purchase_date": datetime_fields(),
    "purchase_amount": NumericField(normalization=None, fill_missing=None, is_target=True),
}, train=True)

bs = 32
x, y = next(iter(DefaultLoader.from_dataset(ds, bs)))
for v in flatten(itertools.chain(x.values(), y.values())):
    assert v.size()[0] == bs