Skip to content
Permalink
Browse files

refactor(base): make pipelineencoder more general and allow pipelinep…

…reprocessor
  • Loading branch information...
hanxiao committed Aug 1, 2019
1 parent 8150cf1 commit 52f87c7fa2d54b25a6b075cf549ce960ed63b59d
@@ -22,7 +22,7 @@
import tempfile import tempfile
import uuid import uuid
from functools import wraps from functools import wraps
from typing import Dict, Any, Union, TextIO, TypeVar, Type from typing import Dict, Any, Union, TextIO, TypeVar, Type, List, Callable


import ruamel.yaml.constructor import ruamel.yaml.constructor


@@ -355,3 +355,69 @@ def _dump_instance_to_yaml(data):
if p: if p:
r['gnes_config'] = p r['gnes_config'] = p
return r return r

def _copy_from(self, x: 'TrainableBase') -> None:
pass


class CompositionalTrainableBase(TrainableBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._component = None # type: List[T]

@property
def component(self) -> Union[List[T], Dict[str, T]]:
return self._component

@property
def is_pipeline(self):
return isinstance(self.component, list)

@component.setter
def component(self, comps: Callable[[], Union[list, dict]]):
if not callable(comps):
raise TypeError('component must be a callable function that returns '
'a List[BaseEncoder]')
if not getattr(self, 'init_from_yaml', False):
self._component = comps()
else:
self.logger.info('component is omitted from construction, '
'as it is initialized from yaml config')

def close(self):
super().close()
# pipeline
if isinstance(self.component, list):
for be in self.component:
be.close()
# no typology
elif isinstance(self.component, dict):
for be in self.component.values():
be.close()
elif self.component is None:
pass
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

def _copy_from(self, x: T):
if isinstance(self.component, list):
for be1, be2 in zip(self.component, x.component):
be1._copy_from(be2)
elif isinstance(self.component, dict):
for k, v in self.component.items():
v._copy_from(x.component[k])
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

@classmethod
def to_yaml(cls, representer, data):
tmp = super()._dump_instance_to_yaml(data)
tmp['component'] = data.component
return representer.represent_mapping('!' + cls.__name__, tmp)

@classmethod
def from_yaml(cls, constructor, node):
obj, data, from_dump = super()._get_instance_from_yaml(constructor, node)
if not from_dump and 'component' in data:
obj.component = lambda: data['component']
return obj
@@ -35,7 +35,6 @@
'BaseBinaryEncoder': 'base', 'BaseBinaryEncoder': 'base',
'BaseTextEncoder': 'base', 'BaseTextEncoder': 'base',
'BaseNumericEncoder': 'base', 'BaseNumericEncoder': 'base',
'CompositionalEncoder': 'base',
'PipelineEncoder': 'base', 'PipelineEncoder': 'base',
'HashEncoder': 'numeric.hash', 'HashEncoder': 'numeric.hash',
'BasePytorchEncoder': 'image.base', 'BasePytorchEncoder': 'image.base',
@@ -16,11 +16,11 @@
# pylint: disable=low-comment-ratio # pylint: disable=low-comment-ratio




from typing import List, Any, Union, Dict, Callable from typing import List, Any


import numpy as np import numpy as np


from ..base import TrainableBase from ..base import TrainableBase, CompositionalTrainableBase




class BaseEncoder(TrainableBase): class BaseEncoder(TrainableBase):
@@ -58,70 +58,7 @@ def encode(self, data: np.ndarray, *args, **kwargs) -> bytes:
return data.tobytes() return data.tobytes()




class CompositionalEncoder(BaseEncoder): class PipelineEncoder(CompositionalTrainableBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._component = None # type: List['BaseEncoder']

@property
def component(self) -> Union[List['BaseEncoder'], Dict[str, 'BaseEncoder']]:
return self._component

@property
def is_pipeline(self):
return isinstance(self.component, list)

@component.setter
def component(self, comps: Callable[[], Union[list, dict]]):
if not callable(comps):
raise TypeError('component must be a callable function that returns '
'a List[BaseEncoder]')
if not getattr(self, 'init_from_yaml', False):
self._component = comps()
else:
self.logger.info('component is omitted from construction, '
'as it is initialized from yaml config')

def close(self):
super().close()
# pipeline
if isinstance(self.component, list):
for be in self.component:
be.close()
# no typology
elif isinstance(self.component, dict):
for be in self.component.values():
be.close()
elif self.component is None:
pass
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

def _copy_from(self, x: 'CompositionalEncoder'):
if isinstance(self.component, list):
for be1, be2 in zip(self.component, x.component):
be1._copy_from(be2)
elif isinstance(self.component, dict):
for k, v in self.component.items():
v._copy_from(x.component[k])
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

@classmethod
def to_yaml(cls, representer, data):
tmp = super()._dump_instance_to_yaml(data)
tmp['component'] = data.component
return representer.represent_mapping('!' + cls.__name__, tmp)

@classmethod
def from_yaml(cls, constructor, node):
obj, data, from_dump = super()._get_instance_from_yaml(constructor, node)
if not from_dump and 'component' in data:
obj.component = lambda: data['component']
return obj


class PipelineEncoder(CompositionalEncoder):
def encode(self, data: Any, *args, **kwargs) -> Any: def encode(self, data: Any, *args, **kwargs) -> Any:
if not self.component: if not self.component:
raise NotImplementedError raise NotImplementedError
@@ -20,6 +20,7 @@


_cls2file_map = { _cls2file_map = {
'BasePreprocessor': 'base', 'BasePreprocessor': 'base',
'PipelinePreprocessor': 'base',
'TextPreprocessor': 'text.simple', 'TextPreprocessor': 'text.simple',
'BaseImagePreprocessor': 'image.base', 'BaseImagePreprocessor': 'image.base',
'BaseTextPreprocessor': 'text.base', 'BaseTextPreprocessor': 'text.base',
@@ -21,7 +21,7 @@
import numpy as np import numpy as np
from PIL import Image from PIL import Image


from ..base import TrainableBase from ..base import TrainableBase, CompositionalTrainableBase
from ..proto import gnes_pb2, array2blob from ..proto import gnes_pb2, array2blob




@@ -38,6 +38,22 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
doc.doc_type = self.doc_type doc.doc_type = self.doc_type




class PipelinePreprocessor(CompositionalTrainableBase):
def apply(self, doc: 'gnes_pb2.Document') -> None:
if not self.component:
raise NotImplementedError
for be in self.component:
be.apply(doc)

def train(self, data, *args, **kwargs):
if not self.component:
raise NotImplementedError
for idx, be in enumerate(self.component):
be.train(data, *args, **kwargs)
if idx + 1 < len(self.component):
data = be.apply(data, *args, **kwargs)


class BaseUnaryPreprocessor(BasePreprocessor): class BaseUnaryPreprocessor(BasePreprocessor):


def __init__(self, doc_type: int, *args, **kwargs): def __init__(self, doc_type: int, *args, **kwargs):
@@ -25,7 +25,7 @@
class FFmpegPreprocessor(BaseVideoPreprocessor): class FFmpegPreprocessor(BaseVideoPreprocessor):


def __init__(self, def __init__(self,
frame_size: str = "192*168", frame_size: str = '192*168',
duplicate_rm: bool = True, duplicate_rm: bool = True,
use_phash_weight: bool = False, use_phash_weight: bool = False,
phash_thresh: int = 5, phash_thresh: int = 5,
@@ -48,8 +48,8 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
frames = get_video_frames( frames = get_video_frames(
doc.raw_bytes, doc.raw_bytes,
s=self.frame_size, s=self.frame_size,
vsync=self._ffmpeg_kwargs.get("vsync", "vfr"), vsync=self._ffmpeg_kwargs.get('vsync', 'vfr'),
vf=self._ffmpeg_kwargs.get("vf", "select=eq(pict_type\\,I)")) vf=self._ffmpeg_kwargs.get('vf', 'select=eq(pict_type\\,I)'))


# remove dupliated key frames by phash value # remove dupliated key frames by phash value
if self.duplicate_rm: if self.duplicate_rm:
@@ -26,9 +26,9 @@ class ShotDetectPreprocessor(BaseVideoPreprocessor):
store_args_kwargs = True store_args_kwargs = True


def __init__(self, def __init__(self,
frame_size: str = "192*168", frame_size: str = '192*168',
descriptor: str = "block_hsv_histogram", descriptor: str = 'block_hsv_histogram',
distance_metric: str = "bhattacharya", distance_metric: str = 'bhattacharya',
*args, *args,
**kwargs): **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -47,7 +47,7 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
frames = get_video_frames( frames = get_video_frames(
doc.raw_bytes, doc.raw_bytes,
s=self.frame_size, s=self.frame_size,
vsync="vfr", vsync='vfr',
vf='select=eq(pict_type\\,I)') vf='select=eq(pict_type\\,I)')


descriptors = [] descriptors = []
@@ -0,0 +1,45 @@
import os
import unittest

from gnes.preprocessor.base import BasePreprocessor, PipelinePreprocessor
from gnes.proto import gnes_pb2


class P1(BasePreprocessor):
def apply(self, doc: 'gnes_pb2.Document'):
doc.doc_id += 1


class P2(BasePreprocessor):
def apply(self, doc: 'gnes_pb2.Document'):
doc.doc_id *= 3


class TestPartition(unittest.TestCase):
def setUp(self):
self.dirname = os.path.dirname(__file__)
self.p3_name = 'pipe-p12'
self.yml_dump_path = os.path.join(self.dirname, '%s.yml' % self.p3_name)
self.bin_dump_path = os.path.join(self.dirname, '%s.bin' % self.p3_name)

def tearDown(self):
if os.path.exists(self.yml_dump_path):
os.remove(self.yml_dump_path)
if os.path.exists(self.bin_dump_path):
os.remove(self.bin_dump_path)

def test_pipelinepreproces(self):
p3 = PipelinePreprocessor()
p3.component = lambda: [P1(), P2()]
d = gnes_pb2.Document()
d.doc_id = 1
p3.apply(d)
self.assertEqual(d.doc_id, 6)

p3.name = self.p3_name
p3.dump_yaml()
p3.dump()

p4 = BasePreprocessor.load_yaml(p3.yaml_full_path)
p4.apply(d)
self.assertEqual(d.doc_id, 21)

0 comments on commit 52f87c7

Please sign in to comment.
You can’t perform that action at this time.