Skip to content
Permalink
Browse files

fix(proto): fix version check in recv message

  • Loading branch information...
hanxiao committed Sep 16, 2019
1 parent 458bf91 commit 8828535ca10516101ff04c0ba29f3df2bb07e90f
Showing with 37 additions and 32 deletions.
  1. +34 −29 gnes/client/cli.py
  2. +3 −3 gnes/proto/__init__.py
@@ -17,8 +17,7 @@
import sys
import time
import zipfile
from math import ceil
from typing import List
from typing import List, Generator

from termcolor import colored

@@ -29,25 +28,25 @@
class CLIClient(GrpcClient):
def __init__(self, args):
super().__init__(args)
getattr(self, self.args.mode)(self.read_all())
getattr(self, self.args.mode)()
self.close()

def train(self, all_bytes: List[bytes]):
with ProgressBar(all_bytes, self.args.batch_size, task_name=self.args.mode) as p_bar:
for _ in self._stub.StreamCall(RequestGenerator.train(all_bytes,
def train(self):
with ProgressBar(task_name=self.args.mode) as p_bar:
for _ in self._stub.StreamCall(RequestGenerator.train(self.bytes_generator,
doc_id_start=self.args.start_doc_id,
batch_size=self.args.batch_size)):
p_bar.update()

def index(self, all_bytes: List[bytes]):
with ProgressBar(all_bytes, self.args.batch_size, task_name=self.args.mode) as p_bar:
for _ in self._stub.StreamCall(RequestGenerator.index(all_bytes,
def index(self):
with ProgressBar(task_name=self.args.mode) as p_bar:
for _ in self._stub.StreamCall(RequestGenerator.index(self.bytes_generator,
doc_id_start=self.args.start_doc_id,
batch_size=self.args.batch_size)):
p_bar.update()

def query(self, all_bytes: List[bytes]):
for idx, q in enumerate(all_bytes):
def query(self):
for idx, q in enumerate(self.bytes_generator):
for req in RequestGenerator.query(q, request_id_start=idx, top_k=self.args.top_k):
resp = self._stub.Call(req)
self.query_callback(req, resp)
@@ -77,45 +76,51 @@ def read_all(self) -> List[bytes]:

return all_bytes

@property
def bytes_generator(self) -> Generator[bytes]:
if self.args.txt_file:
all_bytes = (v.encode() for v in self.args.txt_file)
elif self.args.image_zip_file:
zipfile_ = zipfile.ZipFile(self.args.image_zip_file)
all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist())
elif self.args.video_zip_file:
zipfile_ = zipfile.ZipFile(self.args.video_zip_file)
all_bytes = (zipfile_.open(v).read() for v in zipfile_.namelist())
else:
raise AttributeError('--txt_file, --image_zip_file, --video_zip_file one must be given')

return all_bytes


class ProgressBar:
def __init__(self, all_bytes: List[bytes], batch_size: int, bar_len: int = 20, task_name: str = ''):
self.all_bytes_len = [len(v) for v in all_bytes]
self.batch_size = batch_size
self.total_batch = ceil(len(self.all_bytes_len) / self.batch_size)
def __init__(self, bar_len: int = 20, task_name: str = ''):
self.bar_len = bar_len
self.task_name = task_name

def update(self):
if self.num_batch > self.total_batch - 1:
return
sys.stdout.write('\r')
elapsed = time.perf_counter() - self.start_time
elapsed_str = colored('elapsed', 'yellow')
speed_str = colored('speed', 'yellow')
estleft_str = colored('left', 'yellow')
self.num_batch += 1
percent = self.num_batch / self.total_batch
num_bytes = sum(self.all_bytes_len[((self.num_batch - 1) * self.batch_size):(self.num_batch * self.batch_size)])
self.num_bars += 1
if self.num_bars > self.bar_len:
self.num_bars -= self.bar_len
sys.stdout.write('\n')
sys.stdout.write(
'{:>10} [{:<{}}] {:3.0f}% {:>8}: {:3.1f}s {:>8}: {:3.1f} bytes/s {:3.1f} batch/s {:>8}: {:3.1f}s'.format(
'{:>10} [{:<{}}] {:3.0f}% {:>8}: {:3.1f}s {:>8}: {:3.1f} batch/s'.format(
colored(self.task_name, 'cyan'),
colored('=' * int(self.bar_len * percent), 'green'),
colored('=' * self.num_bars, 'green'),
self.bar_len + 9,
percent * 100,
elapsed_str,
elapsed,
speed_str,
num_bytes / elapsed,
self.num_batch / elapsed,
estleft_str,
(self.total_batch - self.num_batch) / ((self.num_batch + 0.0001) / elapsed)
self.num_bars / elapsed,
))
sys.stdout.flush()

def __enter__(self):
self.start_time = time.perf_counter()
self.num_batch = -1
self.num_bars = -1
sys.stdout.write('\n')
self.update()
return self
@@ -15,7 +15,7 @@

import ctypes
import random
from typing import List
from typing import List, Iterator
from typing import Optional

import numpy as np
@@ -30,7 +30,7 @@

class RequestGenerator:
@staticmethod
def index(data: List[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT,
def index(data: Iterator[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT,
doc_id_start: int = 0, request_id_start: int = 0,
random_doc_id: bool = False,
*args, **kwargs):
@@ -49,7 +49,7 @@ def index(data: List[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Docum
request_id_start += 1

@staticmethod
def train(data: List[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT,
def train(data: Iterator[bytes], batch_size: int = 0, doc_type: int = gnes_pb2.Document.TEXT,
doc_id_start: int = 0, request_id_start: int = 0,
random_doc_id: bool = False,
*args, **kwargs):

0 comments on commit 8828535

Please sign in to comment.
You can’t perform that action at this time.