Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model metadata #96

Merged
merged 6 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions backend/extraction/parsing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Convert binary input to blobs and parse them using the appropriate parser."""
from __future__ import annotations

import io
from typing import BinaryIO, List

from fastapi import HTTPException
Expand All @@ -10,7 +9,6 @@
from langchain.document_loaders.parsers.txt import TextParser
from langchain_community.document_loaders import Blob
from langchain_core.documents import Document
from pdfminer.pdfpage import PDFPage

HANDLERS = {
"application/pdf": PDFMinerParser(),
Expand All @@ -28,7 +26,7 @@
SUPPORTED_MIMETYPES = sorted(HANDLERS.keys())

MAX_FILE_SIZE_MB = 10 # in MB
MAX_PAGES = 50 # for PDFs
MAX_CHUNK_COUNT = 50


def _guess_mimetype(file_bytes: bytes) -> str:
Expand All @@ -54,13 +52,6 @@ def _get_file_size_in_mb(data: BinaryIO) -> float:
return file_size_in_mb


def _get_pdf_page_count(file_bytes: bytes) -> int:
"""Get the number of pages in a PDF file."""
file_stream = io.BytesIO(file_bytes)
pages = PDFPage.get_pages(file_stream)
return sum(1 for _ in pages)


# PUBLIC API

MIMETYPE_BASED_PARSER = MimeTypeBasedParser(
Expand All @@ -83,17 +74,6 @@ def convert_binary_input_to_blob(data: BinaryIO) -> Blob:
mimetype = _guess_mimetype(file_data)
file_name = data.name

if mimetype == "application/pdf":
number_of_pages = _get_pdf_page_count(file_data)
if number_of_pages > MAX_PAGES:
raise HTTPException(
status_code=413,
detail=(
f"PDF has too many pages: {number_of_pages}, "
f"exceeding the maximum of {MAX_PAGES}."
),
)

return Blob.from_data(
data=file_data,
path=file_name,
Expand Down
3 changes: 2 additions & 1 deletion backend/server/api/configurables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fastapi import APIRouter
from typing_extensions import TypedDict

from extraction.parsing import MAX_FILE_SIZE_MB, SUPPORTED_MIMETYPES
from extraction.parsing import MAX_CHUNK_COUNT, MAX_FILE_SIZE_MB, SUPPORTED_MIMETYPES
from server.models import SUPPORTED_MODELS

router = APIRouter(
Expand All @@ -29,4 +29,5 @@ def get() -> ConfigurationResponse:
"available_models": sorted(SUPPORTED_MODELS),
"accepted_mimetypes": SUPPORTED_MIMETYPES,
"max_file_size_mb": MAX_FILE_SIZE_MB,
"max_chunk_count": MAX_CHUNK_COUNT,
}
6 changes: 6 additions & 0 deletions backend/server/extraction_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing_extensions import TypedDict

from db.models import Example, Extractor
from extraction.parsing import MAX_CHUNK_COUNT
from extraction.utils import update_json_schema
from server.models import DEFAULT_MODEL, get_chunk_size, get_model
from server.validators import validate_json_schema
Expand Down Expand Up @@ -191,6 +192,11 @@ async def extract_entire_document(
model_name=DEFAULT_MODEL,
)
texts = text_splitter.split_text(content)
if len(texts) > MAX_CHUNK_COUNT:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eyurtsev what's your opinion on

  1. raising error as we do here;
  2. truncating (e.g., proceed with first N chunks) and propagating information back to the user. if we did that, would we need to add metadata to the extraction response? lmk what you think (can do this in a separate PR too).

raise HTTPException(
status_code=413,
detail=f"Text exceeds the maximum limit of {MAX_CHUNK_COUNT} chunks.",
)
extraction_requests = [
ExtractRequest(
text=text,
Expand Down
42 changes: 26 additions & 16 deletions backend/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,32 @@ def get_supported_models():
"""Get models according to environment secrets."""
models = {}
if "OPENAI_API_KEY" in os.environ:
models["gpt-3.5-turbo"] = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
models["gpt-4-0125-preview"] = ChatOpenAI(
model="gpt-4-0125-preview", temperature=0
)
models["gpt-3.5-turbo"] = {
"chat_model": ChatOpenAI(model="gpt-3.5-turbo", temperature=0),
"description": "GPT-3.5 Turbo",
}
models["gpt-4-0125-preview"] = {
"chat_model": ChatOpenAI(model="gpt-4-0125-preview", temperature=0),
"description": "GPT-4 0125 Preview",
}
if "FIREWORKS_API_KEY" in os.environ:
models["fireworks"] = ChatFireworks(
model="accounts/fireworks/models/firefunction-v1",
temperature=0,
)
models["fireworks"] = {
"chat_model": ChatFireworks(
model="accounts/fireworks/models/firefunction-v1",
temperature=0,
),
"description": "Fireworks Firefunction-v1",
}
if "TOGETHER_API_KEY" in os.environ:
models["together-ai-mistral-8x7b-instruct-v0.1"] = ChatOpenAI(
base_url="https://api.together.xyz/v1",
api_key=os.environ["TOGETHER_API_KEY"],
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0,
)
models["together-ai-mistral-8x7b-instruct-v0.1"] = {
"chat_model": ChatOpenAI(
base_url="https://api.together.xyz/v1",
api_key=os.environ["TOGETHER_API_KEY"],
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0,
),
"description": "Mixtral 8x7B Instruct v0.1 (Together AI)",
}

return models

Expand All @@ -47,7 +57,7 @@ def get_chunk_size(model_name: str) -> int:
def get_model(model_name: Optional[str] = None) -> BaseChatModel:
"""Get the model."""
if model_name is None:
return SUPPORTED_MODELS[DEFAULT_MODEL]
return SUPPORTED_MODELS[DEFAULT_MODEL]["chat_model"]
else:
supported_model_names = list(SUPPORTED_MODELS.keys())
if model_name not in supported_model_names:
Expand All @@ -56,4 +66,4 @@ def get_model(model_name: Optional[str] = None) -> BaseChatModel:
f"Supported models: {supported_model_names}"
)
else:
return SUPPORTED_MODELS[model_name]
return SUPPORTED_MODELS[model_name]["chat_model"]
30 changes: 16 additions & 14 deletions backend/tests/unit_tests/api/test_api_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ async def test_extract_from_file() -> None:
assert response.json() == {"data": ["This is a "]}


@patch(
"server.extraction_runnable.extraction_runnable",
new=RunnableLambda(mock_extraction_runnable),
)
@patch("server.extraction_runnable.TokenTextSplitter", mock_text_splitter)
async def test_extract_from_large_file() -> None:
user_id = str(uuid4())
headers = {"x-key": user_id}
Expand Down Expand Up @@ -167,22 +172,19 @@ async def test_extract_from_large_file() -> None:
)
assert response.status_code == 413

# Test page number constraint
# Test chunk count constraint
with tempfile.NamedTemporaryFile(mode="w+t", delete=True) as f:
f.write("This is a named temporary file.")
f.seek(0)
f.flush()
with patch(
"extraction.parsing._guess_mimetype", return_value="application/pdf"
):
with patch("extraction.parsing._get_pdf_page_count", return_value=100):
response = await client.post(
"/extract",
data={
"extractor_id": extractor_id,
"mode": "entire_document",
},
files={"file": f.name},
headers=headers,
)
with patch("server.extraction_runnable.MAX_CHUNK_COUNT", 0):
response = await client.post(
"/extract",
data={
"extractor_id": extractor_id,
"mode": "entire_document",
},
files={"file": f},
headers=headers,
)
assert response.status_code == 413
Loading