Skip to content
Permalink
Browse files

style: reformat code and optimize import

  • Loading branch information...
hanxiao committed Jul 19, 2019
1 parent 55d7792 commit d3347910b7048a8b41c7d3c11093fbf34ef9efe1
@@ -21,7 +21,6 @@
import numpy as np

from ..base import TrainableBase
from ..proto import gnes_pb2


class BaseEncoder(TrainableBase):
@@ -53,12 +53,13 @@ def fn_parser(self, layer: str) -> Callable:

if '(' not in layer and ')' not in layer:
# this is a shorthand syntax we need to add "(x)" at the end
layer = 'm.%s(x)'%layer
layer = 'm.%s(x)' % layer
else:
pass

def layer_fn(x, l, m, torch):
return eval(l)

return lambda x: layer_fn(x, layer, self.m, torch)

def forward(self, x):
@@ -93,4 +94,3 @@ def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
result_npy.append(encodes.data.cpu().numpy())

return np.array(result_npy, dtype=np.float32)

@@ -14,11 +14,13 @@
# limitations under the License.

from typing import List

import numpy as np
from gnes.helper import batch_iterator
from ..base import BaseImageEncoder
from PIL import Image

from ..base import BaseImageEncoder
from ...helper import batch_iterator


class CVAEEncoder(BaseImageEncoder):

@@ -56,7 +58,7 @@ def post_init(self):
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret = []
img = [(np.array(Image.fromarray(im).resize((120, 120)),
dtype=np.float32)/255) for im in img]
dtype=np.float32) / 255) for im in img]
for _im in batch_iterator(img, self.batch_size):
_mean, _var = self.sess.run((self.mean, self.var),
feed_dict={self.inputs: _im})
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
import numpy as np
import tensorflow as tf


class CVAE(tf.keras.Model):
@@ -23,53 +23,53 @@ def __init__(self, latent_dim):
self.latent_dim = latent_dim
self.inference_net = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(120, 120, 3)),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2),
padding='SAME',
activation='relu'),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2),
padding='SAME',
activation='relu'),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2),
padding='SAME',
activation='relu'),
tf.keras.layers.Flatten(),
# No activation
tf.keras.layers.Dense(latent_dim + latent_dim),
tf.keras.layers.InputLayer(input_shape=(120, 120, 3)),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2),
padding='SAME',
activation='relu'),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2),
padding='SAME',
activation='relu'),
tf.keras.layers.Conv2D(
filters=32, kernel_size=3, strides=(2, 2),
padding='SAME',
activation='relu'),
tf.keras.layers.Flatten(),
# No activation
tf.keras.layers.Dense(latent_dim + latent_dim),
]
)

self.generative_net = tf.keras.Sequential(
[
tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
tf.keras.layers.Dense(units=15*15*32,
activation=tf.nn.relu),
tf.keras.layers.Reshape(target_shape=(15, 15, 32)),
tf.keras.layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=(2, 2),
padding="SAME",
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=(2, 2),
padding="SAME",
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=(2, 2),
padding="SAME",
activation='relu'),
# No activation
tf.keras.layers.Conv2DTranspose(
filters=3, kernel_size=3, strides=(1, 1), padding="SAME"),
]
[
tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
tf.keras.layers.Dense(units=15 * 15 * 32,
activation=tf.nn.relu),
tf.keras.layers.Reshape(target_shape=(15, 15, 32)),
tf.keras.layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=(2, 2),
padding="SAME",
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=(2, 2),
padding="SAME",
activation='relu'),
tf.keras.layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=(2, 2),
padding="SAME",
activation='relu'),
# No activation
tf.keras.layers.Conv2DTranspose(
filters=3, kernel_size=3, strides=(1, 1), padding="SAME"),
]
)

def sample(self, eps=None):
@@ -14,10 +14,12 @@
# limitations under the License.

from typing import List

import numpy as np
from PIL import Image

from ..base import BaseImageEncoder
from ...helper import batching, batch_iterator
from PIL import Image


class TFInceptionEncoder(BaseImageEncoder):
@@ -63,7 +65,8 @@ def post_init(self):
def encode(self, img: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret = []
img = [(np.array(Image.fromarray(im).resize((self.inception_size_x,
self.inception_size_y)), dtype=np.float32) * 2 / 255. - 1.) for im in img]
self.inception_size_y)), dtype=np.float32) * 2 / 255. - 1.) for im
in img]
for _im in batch_iterator(img, self.batch_size):
_, end_points_ = self.sess.run((self.logits, self.end_points),
feed_dict={self.inputs: _im})
@@ -70,6 +70,5 @@ def encode(self, vecs: np.ndarray, *args, **kwargs) -> np.ndarray:
self._graph['ph_centroids']: self.centroids})
return tmp.astype(np.uint8)


def close(self):
self._sess.close()
@@ -66,4 +66,3 @@ def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
_pooled = pooling_np(_layer_data, self.pooling_strategy)
pooled_data.append(_pooled)
return np.array(pooled_data, dtype=np.float32)

@@ -109,7 +109,6 @@ def encode(self, text: List[str], *args, **kwargs) -> np.ndarray:
return output_tensor.numpy()



class GPT2Encoder(GPTEncoder):

def _get_token_ids(self, x):
@@ -20,9 +20,9 @@
from typing import List, Callable

import cv2
import imagehash
import numpy as np
from PIL import Image
import imagehash


def get_video_frames(buffer_data: bytes, image_format: str = "cv2",
@@ -73,7 +73,7 @@ def get_video_frames(buffer_data: bytes, image_format: str = "cv2",
def block_descriptor(image: "np.ndarray",
descriptor_fn: Callable,
num_blocks: int = 3) -> "np.ndarray":
h, w, _ = image.shape # find shape of image and channel
h, w, _ = image.shape # find shape of image and channel
block_h = int(np.ceil(h / num_blocks))
block_w = int(np.ceil(w / num_blocks))

@@ -91,7 +91,7 @@ def pyramid_descriptor(image: "np.ndarray",
max_level: int = 2) -> "np.ndarray":
descriptors = []
for level in range(max_level + 1):
num_blocks = 2**level
num_blocks = 2 ** level
descriptors.extend(block_descriptor(image, descriptor_fn, num_blocks))

return np.array(descriptors)
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

from ..base import BasePreprocessor
from ...proto import gnes_pb2
from typing import List


class BaseImagePreprocessor(BasePreprocessor):
@@ -35,4 +36,4 @@ def _get_all_chunks_weight(self, image_set: List['np.ndarray']) -> List[float]:
def _torch_transform(cls, image):
import torchvision.transforms as transforms
return transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
@@ -1,10 +1,12 @@
from .base import BaseImagePreprocessor
from ...proto import array2blob
from PIL import Image
import numpy as np
import io
import os

import numpy as np
from PIL import Image

from .base import BaseImagePreprocessor
from ...proto import array2blob


class SegmentPreprocessor(BaseImagePreprocessor):

@@ -81,7 +81,8 @@ def _get_all_sliding_window(self, img: 'np.ndarray') -> List['np.ndarray']:
writeable=False
)
expanded_input = expanded_input.reshape((-1, self.window_size, self.window_size, 3))
return [np.array(Image.fromarray(img).resize((self.target_img_size, self.target_img_size))) for img in expanded_input]
return [np.array(Image.fromarray(img).resize((self.target_img_size, self.target_img_size))) for img in
expanded_input]


class VanillaSlidingPreprocessor(BaseSlidingPreprocessor):
@@ -94,4 +95,3 @@ class WeightedSlidingPreprocessor(BaseSlidingPreprocessor):

def _get_all_chunks_weight(self, image_set: List['np.ndarray']) -> List[float]:
return FFmpegPreprocessor.pic_weight(image_set)

@@ -14,11 +14,12 @@
# limitations under the License.

from typing import List

import numpy as np

from .base import BaseVideoPreprocessor
from ...proto import gnes_pb2, array2blob
from ..helper import get_video_frames, phash_descriptor
from ...proto import gnes_pb2, array2blob


class FFmpegPreprocessor(BaseVideoPreprocessor):
@@ -16,9 +16,10 @@
# pylint: disable=low-comment-ratio

import numpy as np

from .base import BaseVideoPreprocessor
from ...proto import gnes_pb2, array2blob
from ..helper import get_video_frames, compute_descriptor, compare_descriptor
from ...proto import gnes_pb2, array2blob


class ShotDetectPreprocessor(BaseVideoPreprocessor):
@@ -66,7 +67,7 @@ def apply(self, doc: 'gnes_pb2.Document') -> None:
clt = KMeans(n_clusters=2)
clt.fit(dists)

#select which cluster includes shot frames
# select which cluster includes shot frames
big_center = np.argmax(clt.cluster_centers_)

shots = []
@@ -14,7 +14,7 @@
# limitations under the License.

# pylint: disable=low-comment-ratio
from typing import List, Optional, Generator
from typing import List, Generator

from ..base import TrainableBase
from ..proto import gnes_pb2, merge_routes
@@ -16,7 +16,7 @@
# pylint: disable=low-comment-ratio
from typing import List, Union

from .base import BaseService as BS, MessageHandler, BlockMessage
from .base import BaseService as BS, MessageHandler
from ..proto import gnes_pb2, array2blob, blob2array


@@ -135,8 +135,8 @@ def __init__(self, args):
self.logger = set_logger(self.__class__.__name__, args.verbose)
self.server = grpc.server(
futures.ThreadPoolExecutor(max_workers=args.max_concurrency),
options=[('grpc.max_send_message_length', args.max_send_size*1024*1024),
('grpc.max_receive_message_length', args.max_receive_size*1024*1024)])
options=[('grpc.max_send_message_length', args.max_send_size * 1024 * 1024),
('grpc.max_receive_message_length', args.max_receive_size * 1024 * 1024)])
self.logger.info('start a grpc server with %d workers' % args.max_concurrency)
gnes_pb2_grpc.add_GnesRPCServicer_to_server(GNESServicer(args), self.server)

@@ -17,7 +17,7 @@

import numpy as np

from .base import BaseService as BS, ComponentNotLoad, MessageHandler, ServiceError
from .base import BaseService as BS, MessageHandler, ServiceError
from ..proto import gnes_pb2, blob2array


@@ -15,7 +15,7 @@

# pylint: disable=low-comment-ratio

from .base import BaseService as BS, MessageHandler, ComponentNotLoad
from .base import BaseService as BS, MessageHandler
from ..proto import gnes_pb2


0 comments on commit d334791

Please sign in to comment.
You can’t perform that action at this time.