Skip to content
Permalink
Browse files

refactor(base): make pipelineencoder more general and allow pipelinep…

…reprocessor
  • Loading branch information...
hanxiao committed Aug 1, 2019
1 parent 8150cf1 commit 52f87c7fa2d54b25a6b075cf549ce960ed63b59d
@@ -22,7 +22,7 @@
import tempfile
import uuid
from functools import wraps
from typing import Dict, Any, Union, TextIO, TypeVar, Type
from typing import Dict, Any, Union, TextIO, TypeVar, Type, List, Callable

import ruamel.yaml.constructor

@@ -355,3 +355,69 @@ def _dump_instance_to_yaml(data):
if p:
r['gnes_config'] = p
return r

def _copy_from(self, x: 'TrainableBase') -> None:
pass


class CompositionalTrainableBase(TrainableBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._component = None # type: List[T]

@property
def component(self) -> Union[List[T], Dict[str, T]]:
return self._component

@property
def is_pipeline(self):
return isinstance(self.component, list)

@component.setter
def component(self, comps: Callable[[], Union[list, dict]]):
if not callable(comps):
raise TypeError('component must be a callable function that returns '
'a List[BaseEncoder]')
if not getattr(self, 'init_from_yaml', False):
self._component = comps()
else:
self.logger.info('component is omitted from construction, '
'as it is initialized from yaml config')

def close(self):
super().close()
# pipeline
if isinstance(self.component, list):
for be in self.component:
be.close()
# no typology
elif isinstance(self.component, dict):
for be in self.component.values():
be.close()
elif self.component is None:
pass
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

def _copy_from(self, x: T):
if isinstance(self.component, list):
for be1, be2 in zip(self.component, x.component):
be1._copy_from(be2)
elif isinstance(self.component, dict):
for k, v in self.component.items():
v._copy_from(x.component[k])
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

@classmethod
def to_yaml(cls, representer, data):
tmp = super()._dump_instance_to_yaml(data)
tmp['component'] = data.component
return representer.represent_mapping('!' + cls.__name__, tmp)

@classmethod
def from_yaml(cls, constructor, node):
obj, data, from_dump = super()._get_instance_from_yaml(constructor, node)
if not from_dump and 'component' in data:
obj.component = lambda: data['component']
return obj
@@ -35,7 +35,6 @@
'BaseBinaryEncoder': 'base',
'BaseTextEncoder': 'base',
'BaseNumericEncoder': 'base',
'CompositionalEncoder': 'base',
'PipelineEncoder': 'base',
'HashEncoder': 'numeric.hash',
'BasePytorchEncoder': 'image.base',
@@ -16,11 +16,11 @@
# pylint: disable=low-comment-ratio


from typing import List, Any, Union, Dict, Callable
from typing import List, Any

import numpy as np

from ..base import TrainableBase
from ..base import TrainableBase, CompositionalTrainableBase


class BaseEncoder(TrainableBase):
@@ -58,70 +58,7 @@ def encode(self, data: np.ndarray, *args, **kwargs) -> bytes:
return data.tobytes()


class CompositionalEncoder(BaseEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._component = None # type: List['BaseEncoder']

@property
def component(self) -> Union[List['BaseEncoder'], Dict[str, 'BaseEncoder']]:
return self._component

@property
def is_pipeline(self):
return isinstance(self.component, list)

@component.setter
def component(self, comps: Callable[[], Union[list, dict]]):
if not callable(comps):
raise TypeError('component must be a callable function that returns '
'a List[BaseEncoder]')
if not getattr(self, 'init_from_yaml', False):
self._component = comps()
else:
self.logger.info('component is omitted from construction, '
'as it is initialized from yaml config')

def close(self):
super().close()
# pipeline
if isinstance(self.component, list):
for be in self.component:
be.close()
# no typology
elif isinstance(self.component, dict):
for be in self.component.values():
be.close()
elif self.component is None:
pass
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

def _copy_from(self, x: 'CompositionalEncoder'):
if isinstance(self.component, list):
for be1, be2 in zip(self.component, x.component):
be1._copy_from(be2)
elif isinstance(self.component, dict):
for k, v in self.component.items():
v._copy_from(x.component[k])
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

@classmethod
def to_yaml(cls, representer, data):
tmp = super()._dump_instance_to_yaml(data)
tmp['component'] = data.component
return representer.represent_mapping('!' + cls.__name__, tmp)

@classmethod
def from_yaml(cls, constructor, node):
obj, data, from_dump = super()._get_instance_from_yaml(constructor, node)
if not from_dump and 'component' in data:
obj.component = lambda: data['component']
return obj


class PipelineEncoder(CompositionalEncoder):
class PipelineEncoder(CompositionalTrainableBase):
def encode(self, data: Any, *args, **kwargs) -> Any:
if not self.component:
raise NotImplementedError
@@ -20,6 +20,7 @@

_cls2file_map = {
'BasePreprocessor': 'base',
'PipelinePreprocessor': 'base',
'TextPreprocessor': 'text.simple',
'BaseImagePreprocessor': 'image.base',
'BaseTextPreprocessor': 'text.base',
@@ -21,7 +21,7 @@
import numpy as np
from PIL import Image

from ..base import TrainableBase
from ..base import TrainableBase, CompositionalTrainableBase
from ..proto import gnes_pb2, array2blob


@@ -38,6 +38,22 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
doc.doc_type = self.doc_type


class PipelinePreprocessor(CompositionalTrainableBase):
def apply(self, doc: 'gnes_pb2.Document') -> None:
if not self.component:
raise NotImplementedError
for be in self.component:
be.apply(doc)

def train(self, data, *args, **kwargs):
if not self.component:
raise NotImplementedError
for idx, be in enumerate(self.component):
be.train(data, *args, **kwargs)
if idx + 1 < len(self.component):
data = be.apply(data, *args, **kwargs)


class BaseUnaryPreprocessor(BasePreprocessor):

def __init__(self, doc_type: int, *args, **kwargs):
@@ -25,7 +25,7 @@
class FFmpegPreprocessor(BaseVideoPreprocessor):

def __init__(self,
frame_size: str = "192*168",
frame_size: str = '192*168',
duplicate_rm: bool = True,
use_phash_weight: bool = False,
phash_thresh: int = 5,
@@ -48,8 +48,8 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
frames = get_video_frames(
doc.raw_bytes,
s=self.frame_size,
vsync=self._ffmpeg_kwargs.get("vsync", "vfr"),
vf=self._ffmpeg_kwargs.get("vf", "select=eq(pict_type\\,I)"))
vsync=self._ffmpeg_kwargs.get('vsync', 'vfr'),
vf=self._ffmpeg_kwargs.get('vf', 'select=eq(pict_type\\,I)'))

# remove dupliated key frames by phash value
if self.duplicate_rm:
@@ -26,9 +26,9 @@ class ShotDetectPreprocessor(BaseVideoPreprocessor):
store_args_kwargs = True

def __init__(self,
frame_size: str = "192*168",
descriptor: str = "block_hsv_histogram",
distance_metric: str = "bhattacharya",
frame_size: str = '192*168',
descriptor: str = 'block_hsv_histogram',
distance_metric: str = 'bhattacharya',
*args,
**kwargs):
super().__init__(*args, **kwargs)
@@ -47,7 +47,7 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
frames = get_video_frames(
doc.raw_bytes,
s=self.frame_size,
vsync="vfr",
vsync='vfr',
vf='select=eq(pict_type\\,I)')

descriptors = []
@@ -0,0 +1,45 @@
import os
import unittest

from gnes.preprocessor.base import BasePreprocessor, PipelinePreprocessor
from gnes.proto import gnes_pb2


class P1(BasePreprocessor):
def apply(self, doc: 'gnes_pb2.Document'):
doc.doc_id += 1


class P2(BasePreprocessor):
def apply(self, doc: 'gnes_pb2.Document'):
doc.doc_id *= 3


class TestPartition(unittest.TestCase):
def setUp(self):
self.dirname = os.path.dirname(__file__)
self.p3_name = 'pipe-p12'
self.yml_dump_path = os.path.join(self.dirname, '%s.yml' % self.p3_name)
self.bin_dump_path = os.path.join(self.dirname, '%s.bin' % self.p3_name)

def tearDown(self):
if os.path.exists(self.yml_dump_path):
os.remove(self.yml_dump_path)
if os.path.exists(self.bin_dump_path):
os.remove(self.bin_dump_path)

def test_pipelinepreproces(self):
p3 = PipelinePreprocessor()
p3.component = lambda: [P1(), P2()]
d = gnes_pb2.Document()
d.doc_id = 1
p3.apply(d)
self.assertEqual(d.doc_id, 6)

p3.name = self.p3_name
p3.dump_yaml()
p3.dump()

p4 = BasePreprocessor.load_yaml(p3.yaml_full_path)
p4.apply(d)
self.assertEqual(d.doc_id, 21)

0 comments on commit 52f87c7

Please sign in to comment.
You can’t perform that action at this time.