Skip to content

Commit

Permalink
DataLoaders again
Browse files Browse the repository at this point in the history
  • Loading branch information
daizutabi committed May 31, 2020
1 parent 739b19e commit 34d22fc
Show file tree
Hide file tree
Showing 26 changed files with 370 additions and 182 deletions.
25 changes: 25 additions & 0 deletions examples/nnabla.yml
@@ -0,0 +1,25 @@
library: nnabla
datasets:
data:
class: rectangle.data.Data
n_splits: 4
dataset:
fold: 0
model:
class: rectangle.nnabla.Model
hidden_sizes: [20, 30]
optimizer:
class: nnabla.solvers.Sgd
lr: 1e-3
results:
metrics:
monitor:
metric: val_loss
early_stopping:
patience: 10
trainer:
loss: mse
batch_size: 10
epochs: 10
shuffle: true
verbose: 2
19 changes: 19 additions & 0 deletions examples/rectangle/nnabla.py
@@ -0,0 +1,19 @@
import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF

import ivory.nnabla.model


class Model(ivory.nnabla.model.Model):
def __init__(self, hidden_sizes):
super().__init__()
self.hidden_sizes = hidden_sizes

def forward(self, x):
for k, hidden_size in enumerate(self.hidden_sizes):
with nn.parameter_scope(f"layer{k}"):
x = F.relu(PF.affine(x, hidden_size))
with nn.parameter_scope(f"layer{k+1}"):
x = PF.affine(x, 1)
return x
1 change: 1 addition & 0 deletions examples/tf.yml
Expand Up @@ -25,4 +25,5 @@ trainer:
loss: mse
batch_size: 10
epochs: 10
shuffle: true
verbose: 2
3 changes: 2 additions & 1 deletion examples/torch.yml
Expand Up @@ -24,7 +24,8 @@ monitor:
early_stopping:
patience: 10
trainer:
loss: torch.nn.functional.mse_loss
loss: mse
batch_size: 10
epochs: 10
shuffle: true
verbose: 2
3 changes: 0 additions & 3 deletions ivory/callbacks/early_stopping.py
Expand Up @@ -21,9 +21,6 @@ class EarlyStopping(State):
patience: int
wait: int = 0

# def __post_init__(self):
# self.wait = 0

def on_epoch_end(self, run: Run):
if run.monitor.is_best:
self.wait = 0
Expand Down
24 changes: 23 additions & 1 deletion ivory/callbacks/metrics.py
@@ -1,5 +1,7 @@
"""Metrics to record scores while training."""
from typing import Any, Dict
from typing import Any, Dict, List

import numpy as np

import ivory.core.collections
from ivory.core import instance
Expand Down Expand Up @@ -55,6 +57,26 @@ def metrics_dict(self, run: Run) -> Dict[str, Any]:
return metrics_dict


class BatchMetrics(Metrics):
def on_epoch_begin(self, run: Run):
self.epoch = run.trainer.epoch

def on_train_begin(self, run: Run):
self.losses: List[float] = []

def step(self, loss: float):
self.losses.append(loss)

def on_train_end(self, run: Run):
self["loss"] = np.mean(self.losses)

def on_val_begin(self, run: Run):
self.losses = []

def on_val_end(self, run: Run):
self["val_loss"] = np.mean(self.losses)


METRICS = {"mse": "sklearn.metrics.mean_squared_error"}


Expand Down
24 changes: 24 additions & 0 deletions ivory/callbacks/results.py
Expand Up @@ -74,6 +74,30 @@ def mean(self):
return results


class BatchResults(Results):
def reset(self):
super().reset()
self.indexes = []
self.outputs = []
self.targets = []

def step(self, index, output, target=None):
self.indexes.append(index)
self.outputs.append(output)
if target is not None:
self.targets.append(target)

def result_dict(self):
index = np.concatenate(self.indexes)
output = np.concatenate(self.outputs)
if self.targets:
target = np.concatenate(self.targets)
else:
target = None
super().step(index, output, target)
return super().result_dict()


def stack(x: List[np.ndarray]) -> np.ndarray:
if x[0].ndim == 1:
return np.hstack(x)
Expand Down
27 changes: 26 additions & 1 deletion ivory/core/data.py
Expand Up @@ -26,7 +26,7 @@
"""


from dataclasses import dataclass
from dataclasses import InitVar, dataclass
from typing import Callable, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -228,6 +228,7 @@ class Datasets(ivory.core.collections.Dict):
test (Dataset): Test dataset.
fold: Fold number.
"""

data: Data
dataset: Callable
fold: int
Expand All @@ -236,3 +237,27 @@ def __post_init__(self):
super().__init__()
for mode in ["train", "val", "test"]:
self[mode] = self.dataset(self.data, mode, self.fold)


class DataLoaders(ivory.core.collections.Dict):
"""DataLoaders class represents a collection of `DataLoader`.
Args:
datasets: `Datasets` instance.
batch_size: Batch_size
shuffle: If True, train dataset is shuffled.
Attributes:
train (Dataset): Train dataset.
val (Dataset): Validation dataset.
test (Dataset): Test dataset.
"""

def __init__(self, datasets: Datasets, batch_size: int, shuffle: bool):
super().__init__()
for mode in ["train", "val", "test"]:
self[mode] = self.get_dataloader(datasets[mode], batch_size, shuffle)
shuffle = False

def get_dataloader(self, dataset, batch_size, shuffle):
raise NotImplementedError
8 changes: 7 additions & 1 deletion ivory/core/default.py
Expand Up @@ -45,6 +45,12 @@ def get(name: str) -> Dict[str, Any]:
"trainer": "ivory.tensorflow.trainer.Trainer",
}

DEFAULT_CLASS["nnabla"] = {
"results": "ivory.callbacks.results.BatchResults",
"metrics": "ivory.nnabla.metrics.Metrics",
"trainer": "ivory.nnabla.trainer.Trainer",
}


DEFAULT_CLASS["sklearn"] = {
"estimator": "ivory.sklearn.estimator.Estimator",
Expand All @@ -53,7 +59,7 @@ def get(name: str) -> Dict[str, Any]:


def update_class(params: Dict[str, Any], library: str = "core"):
if 'library' in params:
if "library" in params:
library = params.pop("library")
for key, value in params.items():
if value is None:
Expand Down
9 changes: 3 additions & 6 deletions ivory/core/trainer.py
Expand Up @@ -63,9 +63,6 @@ def loop(self, run: Run):
if isinstance(pruned, TrialPruned):
raise pruned

def get_dataloader(self, run: Run, mode: str):
raise NotImplementedError

def tqdm(self, dataloader, mode):
if self.verbose == 1:
mode = "%-5s" % (mode[0].upper() + mode[1:])
Expand All @@ -74,22 +71,22 @@ def tqdm(self, dataloader, mode):

def train_loop(self, run: Run):
run.on_train_begin()
dataloader = self.tqdm(self.get_dataloader(run, "train"), "train")
dataloader = self.tqdm(run.dataloaders.train, "train")
for index, input, target in dataloader:
self.global_step += 1
self.train_step(run, index, input, target)
run.on_train_end()

def val_loop(self, run: Run):
run.on_val_begin()
dataloader = self.tqdm(self.get_dataloader(run, "val"), "val")
dataloader = self.tqdm(run.dataloaders.val, "val")
for index, input, target in dataloader:
self.val_step(run, index, input, target)
run.on_val_end()

def test_loop(self, run: Run):
run.on_test_begin()
dataloader = self.tqdm(self.get_dataloader(run, "test"), "test")
dataloader = self.tqdm(run.dataloaders.test, "test")
for index, input, *target in dataloader:
self.test_step(run, index, input, *target)
run.on_test_end()
Expand Down
2 changes: 0 additions & 2 deletions ivory/main.py
Expand Up @@ -8,8 +8,6 @@
import ivory
from ivory.core import parser

# from ivory.utils.range import Range

if "." not in sys.path:
sys.path.insert(0, ".")

Expand Down
Empty file added ivory/nnabla/__init__.py
Empty file.
40 changes: 40 additions & 0 deletions ivory/nnabla/data.py
@@ -0,0 +1,40 @@
from dataclasses import dataclass

from nnabla.utils.data_iterator import data_iterator_simple

import ivory.core.data
from ivory.core.data import Dataset


@dataclass
class DataLoader:
dataset: Dataset
batch_size: int
shuffle: bool = False
with_memory_cache: bool = False
with_file_cache: bool = False

def __post_init__(self):
self.iterator = data_iterator_simple(
self.load_func,
len(self.dataset),
self.batch_size,
shuffle=self.shuffle,
with_memory_cache=self.with_memory_cache,
with_file_cache=self.with_file_cache,
)

def __len__(self):
return len(self.dataset) // self.batch_size

def load_func(self, index):
return self.dataset[index]

def __iter__(self):
for _ in range(len(self)):
yield next(self.iterator)


class DataLoaders(ivory.core.data.DataLoaders):
def get_dataloader(self, dataset, batch_size, shuffle):
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
5 changes: 5 additions & 0 deletions ivory/nnabla/functions.py
@@ -0,0 +1,5 @@
import nnabla.functions as F


def mse(input, target):
return F.mean(F.squared_error(input, target))
7 changes: 7 additions & 0 deletions ivory/nnabla/metrics.py
@@ -0,0 +1,7 @@
import ivory.callbacks.metrics
from ivory.core.run import Run


class Metrics(ivory.callbacks.metrics.BatchMetrics):
def metrics_dict(self, run: Run):
return {}
68 changes: 68 additions & 0 deletions ivory/nnabla/model.py
@@ -0,0 +1,68 @@
import nnabla as nn

import ivory.core.collections


class Model:
NUM_MODELS = 0

def __init__(self):
self.training = True
self.scope = f"model{self.__class__.NUM_MODELS}"
self.__class__.NUM_MODELS += 1

def train(self, mode: bool = True):
self.training = mode

def eval(self):
self.train(False)

def build(self, loss, dataset, batch_size):
index, input, target = dataset[0]
input_shape = [batch_size] + list(input.shape)
target_shape = [batch_size] + list(target.shape)
self.input = ivory.core.collections.Dict()
self.output = ivory.core.collections.Dict()
self.target = ivory.core.collections.Dict()
self.loss = ivory.core.collections.Dict()
for mode in ["train", "test"]:
self.input[mode] = nn.Variable(input_shape)
self.target[mode] = nn.Variable(target_shape)
with nn.parameter_scope(self.scope):
self.training = mode == "train"
self.output[mode] = self.forward(self.input[mode])
if mode == "train":
self.output[mode].persistent = True
self.loss[mode] = loss(self.output[mode], self.target[mode])
self.training = True

def parameters(self):
with nn.parameter_scope(self.scope):
return nn.get_parameters()

def forward(self, input):
raise NotImplementedError

def __call__(self, input, target=None):
if self.training:
mode = "train"
else:
mode = "test"
self.input[mode].data.data = input
if target is not None:
self.target[mode].data.data = target
node = self.loss[mode]
else:
node = self.output[mode]
if mode == "train":
node.forward() # clear_no_need_grad=True)
else:
node.forward() # clear_buffer=True)
output = self.output[mode].data.data.copy()
if target is None:
return output
else:
return output, self.loss[mode].data.data.copy()

def backward(self):
self.loss["train"].backward() # clear_buffer=True)

0 comments on commit 34d22fc

Please sign in to comment.