Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove data, metric and common to neural_compressor #244

Merged
merged 39 commits into from Dec 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
f55987f
refine class name
changwangss Dec 8, 2022
c69b9d1
add common folder
changwangss Dec 8, 2022
d2389da
add data folder
changwangss Dec 8, 2022
66ddd2a
rename DATASETS
changwangss Dec 8, 2022
61982c0
append dataset and transform feature
changwangss Dec 8, 2022
86f8188
fix pylint and pydocstyle
changwangss Dec 8, 2022
bc016ca
fix docstring
changwangss Dec 8, 2022
ad5d78e
replace UT DATASETS
changwangss Dec 9, 2022
9db4958
remove common
changwangss Dec 9, 2022
d01f5c0
rebase master
changwangss Dec 9, 2022
9d2e61b
fix path
changwangss Dec 9, 2022
a98c7e7
rebase model/model.py
changwangss Dec 9, 2022
f6dabae
add no cover
changwangss Dec 9, 2022
c27df84
remove import for tensorflow_model.py
changwangss Dec 9, 2022
5179326
fix MODELS
changwangss Dec 9, 2022
6e06124
fix ut issue about experimental data and data
changwangss Dec 9, 2022
1970aab
Merge branch 'master' into wangchang/api
changwangss Dec 9, 2022
c3ce078
refine class name
changwangss Dec 8, 2022
cb64192
add common folder
changwangss Dec 8, 2022
b282506
add data folder
changwangss Dec 8, 2022
0ff55f8
rename DATASETS
changwangss Dec 8, 2022
7c7e7d9
append dataset and transform feature
changwangss Dec 8, 2022
13b90f5
fix pylint and pydocstyle
changwangss Dec 8, 2022
ad154d3
fix docstring
changwangss Dec 8, 2022
4dcc749
replace UT DATASETS
changwangss Dec 9, 2022
02b41d5
remove common
changwangss Dec 9, 2022
f9f497c
rebase master
changwangss Dec 9, 2022
f64ce3a
fix path
changwangss Dec 9, 2022
df3203f
rebase model/model.py
changwangss Dec 9, 2022
78e01ae
add no cover
changwangss Dec 9, 2022
7b60dc2
remove import for tensorflow_model.py
changwangss Dec 9, 2022
9dbc7be
fix MODELS
changwangss Dec 9, 2022
bbfc9e4
fix ut issue about experimental data and data
changwangss Dec 9, 2022
dabb103
rebase master
changwangss Dec 10, 2022
a9d67bb
fix ut
changwangss Dec 10, 2022
5219a12
fix conflict
changwangss Dec 11, 2022
6273157
Merge branch 'master' into wangchang/api
changwangss Dec 11, 2022
c33cc2a
fix name model to tensorflow_model
changwangss Dec 11, 2022
d8890bc
fix import name
changwangss Dec 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -173,7 +173,7 @@ After prepare step is done, we add tune and benchmark code to generate quantized
#### Benchmark
```python
from neural_compressor.experimental import Benchmark, common
from neural_compressor.model.model import get_model_type
from neural_compressor.model.tensorflow_model import get_model_type
evaluator = Benchmark(FLAGS.config)
dataset = Dataset(eval_file, FLAGS.eval_batch_size)
evaluator.b_dataloader = common.DataLoader(\
Expand Down
Expand Up @@ -1109,7 +1109,7 @@ def result(self):
evaluator.metric = Accuracy()


from neural_compressor.model.model import get_model_type
from neural_compressor.model.tensorflow_model import get_model_type
model_type = get_model_type(FLAGS.input_model)
if model_type == 'frozen_pb':
evaluator.model = FLAGS.input_model
Expand Down
2 changes: 1 addition & 1 deletion neural_compressor/adaptor/mxnet_utils/util.py
Expand Up @@ -24,7 +24,7 @@
from enum import Enum
from tempfile import TemporaryDirectory
from neural_compressor.utils.utility import LazyImport
from neural_compressor.model.model import MXNetModel as NCModel
from neural_compressor.model.mxnet_model import MXNetModel as NCModel
changwangss marked this conversation as resolved.
Show resolved Hide resolved

mx = LazyImport("mxnet")

Expand Down
7 changes: 4 additions & 3 deletions neural_compressor/adaptor/tensorflow.py
Expand Up @@ -136,7 +136,7 @@ def train(self, model, dataloader, optimizer_tuple,
criterion_tuple, hooks, postprocess, **kwargs):
# check model is savedmodel or not
import tensorflow as tf
from neural_compressor.model.model import get_model_type
from neural_compressor.model.tensorflow_model import get_model_type
tf.random.set_seed(1)
self.model_type = get_model_type(model._model)
optimizer = optimizer_tuple[0](**optimizer_tuple[1])
Expand Down Expand Up @@ -1204,7 +1204,7 @@ def inspect_tensor(self, model, dataloader=None, op_list=[], iteration_list=[],
]
}
"""
from neural_compressor.model.model import TensorflowBaseModel
from neural_compressor.model.tensorflow_model import TensorflowBaseModel
from neural_compressor.utils.utility import load_data_from_pkl, dump_data_to_local
from neural_compressor.adaptor.tf_utils.graph_util import GraphAnalyzer
from .tf_utils.util import int8_node_name_reverse
Expand Down Expand Up @@ -1586,7 +1586,8 @@ def _get_mse_order(self, fp32_model, tune_cfg, replace_cfgs, ops_lst, dataloader

def _partial_dataset_of(self, dataloader, confidence_batches):
from neural_compressor.experimental.data.datasets.dummy_dataset import DummyDataset
if isinstance(dataloader.dataset, DummyDataset):
from neural_compressor.data.datasets.dummy_dataset import DummyDataset as DummyDataset_v2_x
if isinstance(dataloader.dataset, DummyDataset) or isinstance(dataloader.dataset, DummyDataset_v2_x):
assert(isinstance(confidence_batches, int))
ds = copy.deepcopy(dataloader.dataset)
ds.dataset = ds.dataset[:confidence_batches]
Expand Down
15 changes: 9 additions & 6 deletions neural_compressor/data/__init__.py
Expand Up @@ -14,26 +14,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Built-in dataloaders, datasets, transforms, filters for multiple framework backends."""


from .dataloaders import DataLoader
import neural_compressor.data.datasets
import neural_compressor.data.transforms
from ..experimental.data.datasets import DATASETS, Dataset, IterableDataset, dataset_registry
from ..experimental.data.transforms import TRANSFORMS, BaseTransform, transform_registry
from ..experimental.data.dataloaders import DATALOADERS
from ..experimental.data.filters import FILTERS, Filter, filter_registry
from .datasets import Datasets, Dataset, IterableDataset, dataset_registry
from .dataloaders import DATALOADERS, DataLoader
from .transforms import TRANSFORMS, BaseTransform, transform_registry, Postprocess

from .filters import FILTERS, Filter, filter_registry

__all__ = [
"DataLoader",
"DATALOADERS",
"DATASETS",
"Datasets",
"Dataset",
"IterableDataset",
"dataset_registry",
"TRANSFORMS",
"BaseTransform",
"transform_registry",
"Postprocess",
"FILTERS",
"Filter",
"filter_registry",]
6 changes: 4 additions & 2 deletions neural_compressor/data/dataloaders/__init__.py
Expand Up @@ -14,9 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from .dataloader import DataLoader
from .dataloader import DataLoader, DATALOADERS

__all__ = [
"DataLoader",
]
"DATALOADERS"
]
119 changes: 119 additions & 0 deletions neural_compressor/data/dataloaders/base_dataloader.py
@@ -0,0 +1,119 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BaseDataloder of all dataloaders."""

from abc import abstractmethod


class BaseDataLoader: # pragma: no cover
"""Base class for all DataLoaders.

_generate_dataloader is needed to create a dataloader object
from the general params like batch_size and sampler. The dynamic batching is just to
generate a new dataloader by setting batch_size and last_batch.

"""

def __init__(self, dataset, batch_size=1, last_batch='rollover', collate_fn=None,
sampler=None, batch_sampler=None, num_workers=0, pin_memory=False,
shuffle=False, distributed=False):
"""Initialize BaseDataLoader.

Args:
dataset (object): dataset from which to load the data
batch_size (int, optional): number of samples per batch. Defaults to 1.
last_batch (str, optional): whether to drop the last batch if it is incomplete.
Support ['rollover', 'discard'], rollover means False, discard means True.
Defaults to 'rollover'.
collate_fn (callable, optional): merge data with outer dimension batch size. Defaults to None.
sampler (Sampler, optional): Sampler object to sample data. Defaults to None.
batch_sampler (BatchSampler, optional): BatchSampler object to generate batch of indices. Defaults to None.
num_workers (int, optional): number of subprocesses to use for data loading. Defaults to 0.
pin_memory (bool, optional): whether to copy data into pinned memory before returning. Defaults to False.
shuffle (bool, optional): whether to shuffle data. Defaults to False.
distributed (bool, optional): whether the dataloader is distributed. Defaults to False.
"""
self.dataset = dataset
self.collate_fn = collate_fn
self.sampler = sampler
self.batch_sampler = batch_sampler
self.num_workers = num_workers
self.pin_memory = pin_memory
self._batch_size = batch_size
self.shuffle = shuffle
self.distributed = distributed
self.last_batch = last_batch
self.drop_last = False if last_batch == 'rollover' else True

self.dataloader = self._generate_dataloader(
self.dataset,
batch_size=batch_size,
last_batch=last_batch,
collate_fn=collate_fn,
sampler=sampler,
batch_sampler=batch_sampler,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=shuffle,
distributed=distributed)

def batch(self, batch_size, last_batch=None):
"""Set batch size for dataloader.

Args:
batch_size (int): number of samples per batch.
last_batch (str, optional): whether to drop the last batch if it is incomplete.
Support ['rollover', 'discard'], rollover means False, discard means True.
Defaults to None.
"""
self._batch_size = batch_size
if last_batch is not None:
self.last_batch = last_batch
self.dataloader = self._generate_dataloader(
self.dataset,
batch_size,
self.last_batch,
self.collate_fn,
self.sampler,
self.batch_sampler,
self.num_workers,
self.pin_memory,
self.shuffle,
self.distributed)

@property
def batch_size(self):
"""Get dataloader's batch_size.

Returns:
int: batch_size
"""
return self._batch_size

def __iter__(self):
"""Yield data in iterative order.

Returns:
iterator: iterator for dataloder
"""
return iter(self.dataloader)

@abstractmethod
def _generate_dataloader(self, dataset, batch_size, last_batch, collate_fn, sampler,
batch_sampler, num_workers, pin_memory, shuffle, distributed):
raise NotImplementedError
3 changes: 2 additions & 1 deletion neural_compressor/data/dataloaders/dataloader.py
Expand Up @@ -15,13 +15,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Built-in dataloaders for multiple framework backends."""

from neural_compressor.experimental.data.dataloaders import DATALOADERS

# THIS API IS TO BE DEPRECATED!
class DataLoader(object):
"""Entrance of all configured DataLoaders. Will dispatch the DataLoaders to framework
specific one. Users will be not aware of the dispatching, and the Interface is unified.

"""

def __new__(cls, framework, dataset, batch_size=1, collate_fn=None,
Expand Down
143 changes: 143 additions & 0 deletions neural_compressor/data/dataloaders/default_dataloader.py
@@ -0,0 +1,143 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Default dataloader for multiple framework backends."""

import collections
import numpy as np
from math import ceil, floor
from abc import abstractmethod
from .sampler import IterableSampler, SequentialSampler, BatchSampler
from .fetcher import FETCHERS
from .base_dataloader import BaseDataLoader

def default_collate(batch): # pragma: no cover
"""Merge data with outer dimension batch size."""
elem = batch[0]
if isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, collections.abc.Sequence):
batch = zip(*batch)
return [default_collate(samples) for samples in batch]
elif isinstance(elem, np.ndarray):
try:
return np.stack(batch)
except:
return batch
else:
return batch

class DefaultDataLoader(BaseDataLoader): # pragma: no cover
"""DefaultDataLoader for multiple framework backends."""

def __init__(self, dataset, batch_size=1, last_batch='rollover', collate_fn=None,
sampler=None, batch_sampler=None, num_workers=0, pin_memory=False,
shuffle=False, distributed=False):
"""Initialize DefaultDataLoader.

Args:
dataset (object): dataset from which to load the data
batch_size (int, optional): number of samples per batch. Defaults to 1.
last_batch (str, optional): whether to drop the last batch if it is incomplete.
Support ['rollover', 'discard'], rollover means False, discard means True.
Defaults to 'rollover'.
collate_fn (callable, optional): merge data with outer dimension batch size. Defaults to None.
sampler (Sampler, optional): Sampler object to sample data. Defaults to None.
batch_sampler (BatchSampler, optional): BatchSampler object to generate batch of indices. Defaults to None.
num_workers (int, optional): number of subprocesses to use for data loading. Defaults to 0.
pin_memory (bool, optional): whether to copy data into pinned memory before returning. Defaults to False.
shuffle (bool, optional): whether to shuffle data. Defaults to False.
distributed (bool, optional): whether the dataloader is distributed. Defaults to False.
"""
self.dataset = dataset
self.last_batch = last_batch
self.sampler = sampler
self.batch_sampler = batch_sampler
self.num_workers = num_workers
self.pin_memory = pin_memory
self.collate_fn = collate_fn
self._batch_size = batch_size
self.shuffle = shuffle
self.distributed = distributed
self.drop_last = False if last_batch == 'rollover' else True
if self.collate_fn == None:
self.collate_fn = default_collate

def batch(self, batch_size, last_batch='rollover'):
"""Set batch_size and last_batch."""
self._batch_size = batch_size
self.last_batch = last_batch

@property
def dataloader(self):
"""Return dataloader."""
return self

def __iter__(self):
"""Yield data in iterative order."""
return self._generate_dataloader(
self.dataset,
batch_size=self.batch_size,
last_batch=self.last_batch,
collate_fn=self.collate_fn,
sampler=self.sampler,
batch_sampler=self.batch_sampler,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
shuffle=self.shuffle,
distributed=self.distributed)

def __len__(self):
"""Get dataset length."""
try:
dataset_len = self.dataset.__len__()
except (AttributeError, TypeError):
dataset_len = 0
for _ in self.dataset:
dataset_len += 1
except Exception:
raise ValueError(f"{self.dataset} is invalid, {self.dataset}" \
" does not support calculating the length of its dataloader")
if self.drop_last == False:
dataloader_len = ceil(dataset_len / self.batch_size)
else:
dataloader_len = floor(dataset_len / self.batch_size)
return dataloader_len

def _generate_dataloader(self, dataset, batch_size, last_batch, collate_fn, sampler,
batch_sampler, num_workers, pin_memory, shuffle, distributed):

sampler = self._generate_sampler(dataset, distributed)
self.batch_sampler = BatchSampler(sampler, batch_size, self.drop_last)
self.fetcher = FETCHERS[self.dataset_type](dataset, collate_fn, self.drop_last, distributed)

for batched_indices in self.batch_sampler:
try:
data = self.fetcher(batched_indices)
yield data
except StopIteration:
return

def _generate_sampler(self, dataset, distributed):
if hasattr(dataset, "__getitem__"):
self.dataset_type = 'index'
return SequentialSampler(dataset, distributed)
elif hasattr(dataset, "__iter__"):
self.dataset_type = 'iter'
return IterableSampler(dataset)
else:
raise ValueError("dataset type only support (index, iter)")