Skip to content
This repository has been archived by the owner on Feb 22, 2020. It is now read-only.

Commit

Permalink
refactor(base): make pipelineencoder more general and allow pipelinep…
Browse files Browse the repository at this point in the history
…reprocessor
  • Loading branch information
hanhxiao committed Aug 1, 2019
1 parent 8150cf1 commit 52f87c7
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 76 deletions.
68 changes: 67 additions & 1 deletion gnes/base/__init__.py
Expand Up @@ -22,7 +22,7 @@
import tempfile
import uuid
from functools import wraps
from typing import Dict, Any, Union, TextIO, TypeVar, Type
from typing import Dict, Any, Union, TextIO, TypeVar, Type, List, Callable

import ruamel.yaml.constructor

Expand Down Expand Up @@ -355,3 +355,69 @@ def _dump_instance_to_yaml(data):
if p:
r['gnes_config'] = p
return r

def _copy_from(self, x: 'TrainableBase') -> None:
pass


class CompositionalTrainableBase(TrainableBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._component = None # type: List[T]

@property
def component(self) -> Union[List[T], Dict[str, T]]:
return self._component

@property
def is_pipeline(self):
return isinstance(self.component, list)

@component.setter
def component(self, comps: Callable[[], Union[list, dict]]):
if not callable(comps):
raise TypeError('component must be a callable function that returns '
'a List[BaseEncoder]')
if not getattr(self, 'init_from_yaml', False):
self._component = comps()
else:
self.logger.info('component is omitted from construction, '
'as it is initialized from yaml config')

def close(self):
super().close()
# pipeline
if isinstance(self.component, list):
for be in self.component:
be.close()
# no typology
elif isinstance(self.component, dict):
for be in self.component.values():
be.close()
elif self.component is None:
pass
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

def _copy_from(self, x: T):
if isinstance(self.component, list):
for be1, be2 in zip(self.component, x.component):
be1._copy_from(be2)
elif isinstance(self.component, dict):
for k, v in self.component.items():
v._copy_from(x.component[k])
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

@classmethod
def to_yaml(cls, representer, data):
tmp = super()._dump_instance_to_yaml(data)
tmp['component'] = data.component
return representer.represent_mapping('!' + cls.__name__, tmp)

@classmethod
def from_yaml(cls, constructor, node):
obj, data, from_dump = super()._get_instance_from_yaml(constructor, node)
if not from_dump and 'component' in data:
obj.component = lambda: data['component']
return obj
1 change: 0 additions & 1 deletion gnes/encoder/__init__.py
Expand Up @@ -35,7 +35,6 @@
'BaseBinaryEncoder': 'base',
'BaseTextEncoder': 'base',
'BaseNumericEncoder': 'base',
'CompositionalEncoder': 'base',
'PipelineEncoder': 'base',
'HashEncoder': 'numeric.hash',
'BasePytorchEncoder': 'image.base',
Expand Down
69 changes: 3 additions & 66 deletions gnes/encoder/base.py
Expand Up @@ -16,11 +16,11 @@
# pylint: disable=low-comment-ratio


from typing import List, Any, Union, Dict, Callable
from typing import List, Any

import numpy as np

from ..base import TrainableBase
from ..base import TrainableBase, CompositionalTrainableBase


class BaseEncoder(TrainableBase):
Expand Down Expand Up @@ -58,70 +58,7 @@ def encode(self, data: np.ndarray, *args, **kwargs) -> bytes:
return data.tobytes()


class CompositionalEncoder(BaseEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._component = None # type: List['BaseEncoder']

@property
def component(self) -> Union[List['BaseEncoder'], Dict[str, 'BaseEncoder']]:
return self._component

@property
def is_pipeline(self):
return isinstance(self.component, list)

@component.setter
def component(self, comps: Callable[[], Union[list, dict]]):
if not callable(comps):
raise TypeError('component must be a callable function that returns '
'a List[BaseEncoder]')
if not getattr(self, 'init_from_yaml', False):
self._component = comps()
else:
self.logger.info('component is omitted from construction, '
'as it is initialized from yaml config')

def close(self):
super().close()
# pipeline
if isinstance(self.component, list):
for be in self.component:
be.close()
# no typology
elif isinstance(self.component, dict):
for be in self.component.values():
be.close()
elif self.component is None:
pass
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

def _copy_from(self, x: 'CompositionalEncoder'):
if isinstance(self.component, list):
for be1, be2 in zip(self.component, x.component):
be1._copy_from(be2)
elif isinstance(self.component, dict):
for k, v in self.component.items():
v._copy_from(x.component[k])
else:
raise TypeError('component must be dict or list, received %s' % type(self.component))

@classmethod
def to_yaml(cls, representer, data):
tmp = super()._dump_instance_to_yaml(data)
tmp['component'] = data.component
return representer.represent_mapping('!' + cls.__name__, tmp)

@classmethod
def from_yaml(cls, constructor, node):
obj, data, from_dump = super()._get_instance_from_yaml(constructor, node)
if not from_dump and 'component' in data:
obj.component = lambda: data['component']
return obj


class PipelineEncoder(CompositionalEncoder):
class PipelineEncoder(CompositionalTrainableBase):
def encode(self, data: Any, *args, **kwargs) -> Any:
if not self.component:
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions gnes/preprocessor/__init__.py
Expand Up @@ -20,6 +20,7 @@

_cls2file_map = {
'BasePreprocessor': 'base',
'PipelinePreprocessor': 'base',
'TextPreprocessor': 'text.simple',
'BaseImagePreprocessor': 'image.base',
'BaseTextPreprocessor': 'text.base',
Expand Down
18 changes: 17 additions & 1 deletion gnes/preprocessor/base.py
Expand Up @@ -21,7 +21,7 @@
import numpy as np
from PIL import Image

from ..base import TrainableBase
from ..base import TrainableBase, CompositionalTrainableBase
from ..proto import gnes_pb2, array2blob


Expand All @@ -38,6 +38,22 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
doc.doc_type = self.doc_type


class PipelinePreprocessor(CompositionalTrainableBase):
def apply(self, doc: 'gnes_pb2.Document') -> None:
if not self.component:
raise NotImplementedError
for be in self.component:
be.apply(doc)

def train(self, data, *args, **kwargs):
if not self.component:
raise NotImplementedError
for idx, be in enumerate(self.component):
be.train(data, *args, **kwargs)
if idx + 1 < len(self.component):
data = be.apply(data, *args, **kwargs)


class BaseUnaryPreprocessor(BasePreprocessor):

def __init__(self, doc_type: int, *args, **kwargs):
Expand Down
6 changes: 3 additions & 3 deletions gnes/preprocessor/video/ffmpeg.py
Expand Up @@ -25,7 +25,7 @@
class FFmpegPreprocessor(BaseVideoPreprocessor):

def __init__(self,
frame_size: str = "192*168",
frame_size: str = '192*168',
duplicate_rm: bool = True,
use_phash_weight: bool = False,
phash_thresh: int = 5,
Expand All @@ -48,8 +48,8 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
frames = get_video_frames(
doc.raw_bytes,
s=self.frame_size,
vsync=self._ffmpeg_kwargs.get("vsync", "vfr"),
vf=self._ffmpeg_kwargs.get("vf", "select=eq(pict_type\\,I)"))
vsync=self._ffmpeg_kwargs.get('vsync', 'vfr'),
vf=self._ffmpeg_kwargs.get('vf', 'select=eq(pict_type\\,I)'))

# remove dupliated key frames by phash value
if self.duplicate_rm:
Expand Down
8 changes: 4 additions & 4 deletions gnes/preprocessor/video/shotdetect.py
Expand Up @@ -26,9 +26,9 @@ class ShotDetectPreprocessor(BaseVideoPreprocessor):
store_args_kwargs = True

def __init__(self,
frame_size: str = "192*168",
descriptor: str = "block_hsv_histogram",
distance_metric: str = "bhattacharya",
frame_size: str = '192*168',
descriptor: str = 'block_hsv_histogram',
distance_metric: str = 'bhattacharya',
*args,
**kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -47,7 +47,7 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
frames = get_video_frames(
doc.raw_bytes,
s=self.frame_size,
vsync="vfr",
vsync='vfr',
vf='select=eq(pict_type\\,I)')

descriptors = []
Expand Down
45 changes: 45 additions & 0 deletions tests/test_pipelinepreprocess.py
@@ -0,0 +1,45 @@
import os
import unittest

from gnes.preprocessor.base import BasePreprocessor, PipelinePreprocessor
from gnes.proto import gnes_pb2


class P1(BasePreprocessor):
def apply(self, doc: 'gnes_pb2.Document'):
doc.doc_id += 1


class P2(BasePreprocessor):
def apply(self, doc: 'gnes_pb2.Document'):
doc.doc_id *= 3


class TestPartition(unittest.TestCase):
def setUp(self):
self.dirname = os.path.dirname(__file__)
self.p3_name = 'pipe-p12'
self.yml_dump_path = os.path.join(self.dirname, '%s.yml' % self.p3_name)
self.bin_dump_path = os.path.join(self.dirname, '%s.bin' % self.p3_name)

def tearDown(self):
if os.path.exists(self.yml_dump_path):
os.remove(self.yml_dump_path)
if os.path.exists(self.bin_dump_path):
os.remove(self.bin_dump_path)

def test_pipelinepreproces(self):
p3 = PipelinePreprocessor()
p3.component = lambda: [P1(), P2()]
d = gnes_pb2.Document()
d.doc_id = 1
p3.apply(d)
self.assertEqual(d.doc_id, 6)

p3.name = self.p3_name
p3.dump_yaml()
p3.dump()

p4 = BasePreprocessor.load_yaml(p3.yaml_full_path)
p4.apply(d)
self.assertEqual(d.doc_id, 21)

0 comments on commit 52f87c7

Please sign in to comment.