Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
from fastapi import FastAPI
from fastapi import FastAPI, responses, status
import controller
from data import data_type
from typing import List, Dict, Tuple
Expand All @@ -25,7 +25,7 @@
@app.get("/classification/recommend/{data_type}")
def recommendations(
data_type: str,
) -> Tuple[List[Dict[str, str]], int]:
) -> responses.JSONResponse:
recommends = [
### English ###
{
Expand Down Expand Up @@ -92,39 +92,44 @@ def recommendations(
},
]

return recommends, 200
return responses.JSONResponse(status_code=status.HTTP_200_OK, content=recommends)


@app.post("/classification/encode")
def encode_classification(request: data_type.Request) -> Tuple[int, str]:
def encode_classification(request: data_type.Request) -> responses.PlainTextResponse:
# session logic for threads in side
return controller.start_encoding_thread(request, "classification"), ""
status_code = controller.start_encoding_thread(request, "classification")

return responses.PlainTextResponse(status_code=status_code)


@app.post("/extraction/encode")
def encode_extraction(request: data_type.Request) -> Tuple[int, str]:
def encode_extraction(request: data_type.Request) -> responses.PlainTextResponse:
# session logic for threads in side
return controller.start_encoding_thread(request, "extraction"), ""
status_code = controller.start_encoding_thread(request, "extraction")
return responses.PlainTextResponse(status_code=status_code)


@app.delete("/delete/{project_id}/{embedding_id}")
def delete_embedding(project_id: str, embedding_id: str) -> Tuple[int, str]:
def delete_embedding(project_id: str, embedding_id: str) -> responses.PlainTextResponse:
session_token = general.get_ctx_token()
return_value = controller.delete_embedding(project_id, embedding_id)
status_code = controller.delete_embedding(project_id, embedding_id)
general.remove_and_refresh_session(session_token)
return return_value, ""
return responses.PlainTextResponse(status_code=status_code)


@app.post("/upload_tensor_data/{project_id}/{embedding_id}")
def upload_tensor_data(project_id: str, embedding_id: str) -> Tuple[int, str]:
def upload_tensor_data(
project_id: str, embedding_id: str
) -> responses.PlainTextResponse:
session_token = general.get_ctx_token()
controller.upload_embedding_as_file(project_id, embedding_id)
request_util.post_embedding_to_neural_search(project_id, embedding_id)
general.remove_and_refresh_session(session_token)
return 200, ""
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)


@app.put("/config_changed")
def config_changed() -> int:
def config_changed() -> responses.PlainTextResponse:
config_handler.refresh_config()
return 200
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
17 changes: 9 additions & 8 deletions controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
notification,
organization,
)
from fastapi import status
import pickle
import torch
import traceback
Expand Down Expand Up @@ -85,7 +86,7 @@ def get_docbins(
def start_encoding_thread(request: data_type.Request, embedding_type: str) -> int:
doc_ock.post_embedding_creation(request.user_id, request.config_string)
daemon.run(prepare_run_encoding, request, embedding_type)
return 200
return status.HTTP_200_OK


def prepare_run_encoding(request: data_type.Request, embedding_type: str) -> int:
Expand Down Expand Up @@ -215,7 +216,7 @@ def run_encoding(
send_project_update(
request.project_id, f"notification_created:{request.user_id}", True
)
return 422
return status.HTTP_422_UNPROCESSABLE_ENTITY
except ValueError:
embedding.update_embedding_state_failed(
request.project_id,
Expand All @@ -239,7 +240,7 @@ def run_encoding(
send_project_update(
request.project_id, f"notification_created:{request.user_id}", True
)
return 422
return status.HTTP_422_UNPROCESSABLE_ENTITY

if not embedder:
embedding.update_embedding_state_failed(
Expand Down Expand Up @@ -288,7 +289,7 @@ def run_encoding(
f"embedding:{embedding_id}:state:{enums.EmbeddingState.FAILED.value}",
)
doc_ock.post_embedding_failed(request.user_id, request.config_string)
return 422
return status.HTTP_422_UNPROCESSABLE_ENTITY

try:
record_ids, attribute_values_raw = record.get_attribute_data(
Expand Down Expand Up @@ -410,7 +411,7 @@ def run_encoding(
)
print(traceback.format_exc(), flush=True)
doc_ock.post_embedding_failed(request.user_id, request.config_string)
return 500
return status.HTTP_500_INTERNAL_SERVER_ERROR

if embedding.get(request.project_id, embedding_id):
for warning_type, idx_list in embedder.get_warnings().items():
Expand Down Expand Up @@ -484,7 +485,7 @@ def run_encoding(
doc_ock.post_embedding_finished(request.user_id, request.config_string)
general.commit()
general.remove_and_refresh_session(session_token)
return 200
return status.HTTP_200_OK


def delete_embedding(project_id: str, embedding_id: str) -> int:
Expand All @@ -494,12 +495,12 @@ def delete_embedding(project_id: str, embedding_id: str) -> int:
object_name = f"embedding_tensors_{embedding_id}.csv.bz2"

org_id = organization.get_id_by_project_id(project_id)
s3.delete_object(org_id, project_id + "/" + object_name)
s3.delete_object(org_id, f"{project_id}/{object_name}")
request_util.delete_embedding_from_neural_search(embedding_id)
pickle_path = os.path.join("/inference", project_id, f"embedder-{embedding_id}.pkl")
if os.path.exists(pickle_path):
os.remove(pickle_path)
return 200
return status.HTTP_200_OK


@param_throttle(seconds=5)
Expand Down
12 changes: 8 additions & 4 deletions data/doc_ock.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ def _post_event(user_id: str, config_string: str, state: str) -> Any:
"State": state,
"Host": os.getenv("S3_ENDPOINT"),
}

response = requests.post(url, json=data)
if response.status_code == 200:
result, _ = response.json()
return result
else:

if response.status_code != 200:
raise Exception("Could not send data to Doc Ock")

if response.headers.get("content-type") == "application/json":
return response.json()
else:
return response.text
15 changes: 7 additions & 8 deletions util/config_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,21 @@ def __get_config() -> Dict[str, Any]:

def refresh_config():
response = requests.get(REQUEST_URL)
if response.status_code == 200:
global __config
__config = json.loads(json.loads(response.text))
daemon.run(invalidate_after, 3600) # one hour
else:
raise Exception(
if response.status_code != 200:
raise ValueError(
f"Config service cant be reached -- response.code{response.status_code}"
)
global __config
__config = response.json()
daemon.run(invalidate_after, 3600) # one hour


def get_config_value(
key: str, subkey: Optional[str] = None
) -> Union[str, Dict[str, str]]:
config = __get_config()
if key not in config:
raise Exception(f"Key {key} coudn't be found in config")
raise ValueError(f"Key {key} coudn't be found in config")
value = config[key]

if not subkey:
Expand All @@ -44,7 +43,7 @@ def get_config_value(
if isinstance(value, dict) and subkey in value:
return value[subkey]
else:
raise Exception(f"Subkey {subkey} coudn't be found in config[{key}]")
raise ValueError(f"Subkey {subkey} coudn't be found in config[{key}]")


def invalidate_after(sec: int) -> None:
Expand Down