Skip to content
This repository has been archived by the owner on Jan 18, 2023. It is now read-only.

Commit

Permalink
Merge pull request #42 from milvus-io/split-files
Browse files Browse the repository at this point in the history
Refine utils file structure
  • Loading branch information
czhen-zilliz committed Oct 13, 2021
2 parents 5beb1c2 + 837d675 commit f632925
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 361 deletions.
73 changes: 73 additions & 0 deletions milvus_cli/Fs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from Types import ParameterException
import os


def readCsvFile(path='', withCol=True):
if not path or not path[-4:] == '.csv':
raise ParameterException('Path is empty or target file is not .csv')
fileSize = os.stat(path).st_size
if fileSize >= 512000000:
raise ParameterException(
'File is too large! Only allow csv files less than 512MB.')
from csv import reader
from json import JSONDecodeError
import click
try:
result = {'columns': [], 'data': []}
with click.open_file(path, 'r') as csv_file:
click.echo(f'Opening csv file({fileSize} bytes)...')
csv_reader = reader(csv_file, delimiter=',')
# For progressbar, transform it to list.
rows = list(csv_reader)
line_count = 0
with click.progressbar(rows, label='Reading csv rows...', show_percent=True) as bar:
# for row in csv_reader:
for row in bar:
if withCol and line_count == 0:
result['columns'] = row
line_count += 1
else:
formatRowForData(row, result['data'])
line_count += 1
click.echo(f'''Column names are {result['columns']}''')
click.echo(f'Processed {line_count} lines.')
except FileNotFoundError as fe:
raise ParameterException(f'FileNotFoundError {str(fe)}')
except UnicodeDecodeError as ue:
raise ParameterException(f'UnicodeDecodeError {str(ue)}')
except JSONDecodeError as je:
raise ParameterException(f'JSONDecodeError {str(je)}')
else:
return result


# For readCsvFile formatting data.
def formatRowForData(row=[], data=[]):
from json import loads
# init data with empty list
if not data:
for _in in range(len(row)):
data.append([])
for idx, val in enumerate(row):
formattedVal = loads(val)
data[idx].append(formattedVal)


def writeCsvFile(path, rows, headers=[]):
if not path:
raise ParameterException(f'Path should not be empty')
from csv import writer
import click
try:
with click.open_file(path, 'w+') as csv_file:
csv_writer = writer(csv_file, delimiter=',')
if headers:
csv_writer.writerow(headers)
line_count = 0
with click.progressbar(rows, label='Writing csv rows...', show_percent=True) as bar:
for row in bar:
csv_writer.writerow(row)
line_count += 1
click.echo(f'Processed {line_count} lines.')
except Exception as e:
raise ParameterException(f'Export csv file error! {str(e)}')
118 changes: 118 additions & 0 deletions milvus_cli/Types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from functools import reduce


class ParameterException(Exception):
"Custom Exception for parameters checking."

def __init__(self, msg):
self.msg = msg

def __str__(self):
return str(self.msg)


class ConnectException(Exception):
"Custom Exception for milvus connection."

def __init__(self, msg):
self.msg = msg

def __str__(self):
return str(self.msg)


FiledDataTypes = [
"BOOL",
"INT8",
"INT16",
"INT32",
"INT64",
"FLOAT",
"DOUBLE",
"STRING",
"BINARY_VECTOR",
"FLOAT_VECTOR"
]

IndexTypes = [
"FLAT",
"IVF_FLAT",
"IVF_SQ8",
"IVF_PQ",
"RNSG",
"HNSW",
# "NSG",
"ANNOY",
# "RHNSW_FLAT",
# "RHNSW_PQ",
# "RHNSW_SQ",
# "BIN_FLAT",
# "BIN_IVF_FLAT"
]

IndexParams = [
"nlist",
"m",
"nbits",
"M",
"efConstruction",
"n_trees",
"PQM",
]

IndexTypesMap = {
"FLAT": {
"index_building_parameters": [],
"search_parameters": ["metric_type"],
},
"IVF_FLAT": {
"index_building_parameters": ["nlist"],
"search_parameters": ["nprobe"],
},
"IVF_SQ8": {
"index_building_parameters": ["nlist"],
"search_parameters": ["nprobe"],
},
"IVF_PQ": {
"index_building_parameters": ["nlist", "m", "nbits"],
"search_parameters": ["nprobe"],
},
"RNSG": {
"index_building_parameters": ["out_degree", "candidate_pool_size", "search_length", "knng"],
"search_parameters": ["search_length"],
},
"HNSW": {
"index_building_parameters": ["M", "efConstruction"],
"search_parameters": ["ef"],
},
"ANNOY": {
"index_building_parameters": ["n_trees"],
"search_parameters": ["search_k"],
},
}

DupSearchParams = reduce(
lambda x, y: x+IndexTypesMap[y]['search_parameters'], IndexTypesMap.keys(), [])
SearchParams = list(dict.fromkeys(DupSearchParams))

MetricTypes = [
"L2",
"IP",
"HAMMING",
"TANIMOTO"
]

DataTypeByNum = {
0: 'NONE',
1: 'BOOL',
2: 'INT8',
3: 'INT16',
4: 'INT32',
5: 'INT64',
10: 'FLOAT',
11: 'DOUBLE',
20: 'STRING',
100: 'BINARY_VECTOR',
101: 'FLOAT_VECTOR',
999: 'UNKNOWN',
}
169 changes: 169 additions & 0 deletions milvus_cli/Validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from Types import ParameterException
from Types import FiledDataTypes, IndexTypes, IndexTypesMap, SearchParams, MetricTypes
from Fs import readCsvFile


def validateParamsByCustomFunc(customFunc, errMsg, *params):
try:
customFunc(*params)
except Exception as e:
raise ParameterException(f"{errMsg}")


def validateCollectionParameter(collectionName, primaryField, fields):
if not collectionName:
raise ParameterException('Missing collection name.')
if not primaryField:
raise ParameterException('Missing primary field.')
if not fields:
raise ParameterException('Missing fields.')
fieldNames = []
for field in fields:
fieldList = field.split(':')
if not (len(fieldList) == 3):
raise ParameterException(
'Field should contain three paremeters and concat by ":".')
[fieldName, fieldType, fieldData] = fieldList
fieldNames.append(fieldName)
if fieldType not in FiledDataTypes:
raise ParameterException(
'Invalid field data type, should be one of {}'.format(str(FiledDataTypes)))
if fieldType in ['BINARY_VECTOR', 'FLOAT_VECTOR']:
try:
int(fieldData)
except ValueError as e:
raise ParameterException("""Vector's dim should be int.""")
# Dedup field name.
newNames = list(set(fieldNames))
if not (len(newNames) == len(fieldNames)):
raise ParameterException('Field names are duplicated.')
if primaryField not in fieldNames:
raise ParameterException(
"""Primary field name doesn't exist in input fields.""")


def validateIndexParameter(indexType, metricType, params):
if indexType not in IndexTypes:
raise ParameterException(
'Invalid index type, should be one of {}'.format(str(IndexTypes)))
if metricType not in MetricTypes:
raise ParameterException(
'Invalid index metric type, should be one of {}'.format(str(MetricTypes)))
# if not params:
# raise ParameterException('Missing params')
paramNames = []
buildingParameters = IndexTypesMap[indexType]['index_building_parameters']
for param in params:
paramList = param.split(':')
if not (len(paramList) == 2):
raise ParameterException(
'Params should contain two paremeters and concat by ":".')
[paramName, paramValue] = paramList
paramNames.append(paramName)
if paramName not in buildingParameters:
raise ParameterException(
'Invalid index param, should be one of {}'.format(str(buildingParameters)))
try:
int(paramValue)
except ValueError as e:
raise ParameterException("""Index param's value should be int.""")
# Dedup field name.
newNames = list(set(paramNames))
if not (len(newNames) == len(paramNames)):
raise ParameterException('Index params are duplicated.')


def validateSearchParams(data, annsField, metricType, params, limit, expr, partitionNames, timeout, roundDecimal, hasIndex=True):
import json
result = {}
# Validate data
try:
if '.csv' in data:
csvData = readCsvFile(data, withCol=False)
result['data'] = csvData['data'][0]
else:
result['data'] = json.loads(
data.replace('\'', '').replace('\"', ''))
except Exception as e:
raise ParameterException(
'Format(list[list[float]]) "Data" error! {}'.format(str(e)))
# Validate annsField
if not annsField:
raise ParameterException('annsField is empty!')
result['anns_field'] = annsField
if hasIndex:
# Validate metricType
if metricType not in MetricTypes:
raise ParameterException(
'Invalid index metric type, should be one of {}'.format(str(MetricTypes)))
# Validate params
paramDict = {}
if type(params) == str:
paramsList = params.replace(' ', '').split(',')
else:
paramsList = params
for param in paramsList:
if not param:
continue
paramList = param.split(':')
if not (len(paramList) == 2):
raise ParameterException(
'Params should contain two paremeters and concat by ":".')
[paramName, paramValue] = paramList
if paramName not in SearchParams:
raise ParameterException(
'Invalid search parameter, should be one of {}'.format(str(SearchParams)))
try:
paramDict[paramName] = int(paramValue)
except ValueError as e:
raise ParameterException(
"""Search parameter's value should be int.""")
result['param'] = {"metric_type": metricType}
if paramDict.keys():
result['param']['params'] = paramDict
else:
result['param'] = {}
# Validate limit
try:
result['limit'] = int(limit)
except Exception as e:
raise ParameterException(
'Format(int) "limit" error! {}'.format(str(e)))
# Validate expr
result['expr'] = expr
# Validate partitionNames
if partitionNames:
try:
result['partition_names'] = partitionNames.replace(
' ', '').split(',')
except Exception as e:
raise ParameterException(
'Format(list[str]) "partitionNames" error! {}'.format(str(e)))
# Validate timeout
if timeout:
result['timeout'] = float(timeout)
if roundDecimal:
result['round_decimal'] = int(roundDecimal)
return result


def validateQueryParams(expr, partitionNames, outputFields, timeout):
result = {}
if not expr:
raise ParameterException('expr is empty!')
if ' in ' not in expr:
raise ParameterException(
'expr only accepts "<field_name> in [<min>,<max>]"!')
result['expr'] = expr
if not outputFields:
result['output_fields'] = None
else:
nameList = outputFields.replace(' ', '').split(',')
result['output_fields'] = nameList
if not partitionNames:
result['partition_names'] = None
else:
nameList = partitionNames.replace(' ', '').split(',')
result['partition_names'] = nameList
result['timeout'] = float(timeout) if timeout else None
return result
16 changes: 9 additions & 7 deletions milvus_cli/scripts/milvus_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
currentdir = os.path.dirname(os.path.realpath(__file__))
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)
from utils import PyOrm, Completer
from utils import getPackageVersion, readCsvFile
from utils import validateParamsByCustomFunc, validateCollectionParameter, validateIndexParameter, validateSearchParams, validateQueryParams
from utils import ParameterException, ConnectException
from utils import MetricTypes, IndexParams, SearchParams, IndexTypesMap, IndexTypes
from utils import PyOrm, Completer, getPackageVersion
from Fs import readCsvFile
from Validation import validateParamsByCustomFunc, validateCollectionParameter, validateIndexParameter, validateSearchParams, validateQueryParams
from Types import ParameterException, ConnectException
from Types import MetricTypes, IndexTypesMap, IndexTypes


pass_context = click.make_pass_decorator(PyOrm, ensure=True)

Expand Down Expand Up @@ -506,7 +507,7 @@ def search(obj):
The names of partitions to search(split by "," if multiple) ['_default'] []: _default
timeout []:
Example-3(collection has no index):
Collection name (car, car2): car
Expand Down Expand Up @@ -548,7 +549,8 @@ def search(obj):
else:
metricType = ''
params = []
roundDecimal = click.prompt('The specified number of decimal places of returned distance', default=-1, type=int)
roundDecimal = click.prompt(
'The specified number of decimal places of returned distance', default=-1, type=int)
limit = click.prompt(
'The max number of returned record, also known as topk', default=None, type=int)
expr = click.prompt(
Expand Down

0 comments on commit f632925

Please sign in to comment.