Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model metadata #96

Merged
merged 6 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions backend/extraction/parsing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Convert binary input to blobs and parse them using the appropriate parser."""
from __future__ import annotations

import io
from typing import BinaryIO, List

from fastapi import HTTPException
Expand All @@ -10,7 +9,6 @@
from langchain.document_loaders.parsers.txt import TextParser
from langchain_community.document_loaders import Blob
from langchain_core.documents import Document
from pdfminer.pdfpage import PDFPage

HANDLERS = {
"application/pdf": PDFMinerParser(),
Expand All @@ -28,7 +26,6 @@
SUPPORTED_MIMETYPES = sorted(HANDLERS.keys())

MAX_FILE_SIZE_MB = 10 # in MB
MAX_PAGES = 50 # for PDFs


def _guess_mimetype(file_bytes: bytes) -> str:
Expand All @@ -54,13 +51,6 @@ def _get_file_size_in_mb(data: BinaryIO) -> float:
return file_size_in_mb


def _get_pdf_page_count(file_bytes: bytes) -> int:
"""Get the number of pages in a PDF file."""
file_stream = io.BytesIO(file_bytes)
pages = PDFPage.get_pages(file_stream)
return sum(1 for _ in pages)


# PUBLIC API

MIMETYPE_BASED_PARSER = MimeTypeBasedParser(
Expand All @@ -83,17 +73,6 @@ def convert_binary_input_to_blob(data: BinaryIO) -> Blob:
mimetype = _guess_mimetype(file_data)
file_name = data.name

if mimetype == "application/pdf":
number_of_pages = _get_pdf_page_count(file_data)
if number_of_pages > MAX_PAGES:
raise HTTPException(
status_code=413,
detail=(
f"PDF has too many pages: {number_of_pages}, "
f"exceeding the maximum of {MAX_PAGES}."
),
)

return Blob.from_data(
data=file_data,
path=file_name,
Expand Down
42 changes: 26 additions & 16 deletions backend/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,32 @@ def get_supported_models():
"""Get models according to environment secrets."""
models = {}
if "OPENAI_API_KEY" in os.environ:
models["gpt-3.5-turbo"] = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
models["gpt-4-0125-preview"] = ChatOpenAI(
model="gpt-4-0125-preview", temperature=0
)
models["gpt-3.5-turbo"] = {
"chat_model": ChatOpenAI(model="gpt-3.5-turbo", temperature=0),
"description": "GPT-3.5 Turbo",
}
models["gpt-4-0125-preview"] = {
"chat_model": ChatOpenAI(model="gpt-4-0125-preview", temperature=0),
"description": "GPT-4 0125 Preview",
}
if "FIREWORKS_API_KEY" in os.environ:
models["fireworks"] = ChatFireworks(
model="accounts/fireworks/models/firefunction-v1",
temperature=0,
)
models["fireworks"] = {
"chat_model": ChatFireworks(
model="accounts/fireworks/models/firefunction-v1",
temperature=0,
),
"description": "Fireworks Firefunction-v1",
}
if "TOGETHER_API_KEY" in os.environ:
models["together-ai-mistral-8x7b-instruct-v0.1"] = ChatOpenAI(
base_url="https://api.together.xyz/v1",
api_key=os.environ["TOGETHER_API_KEY"],
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0,
)
models["together-ai-mistral-8x7b-instruct-v0.1"] = {
"chat_model": ChatOpenAI(
base_url="https://api.together.xyz/v1",
api_key=os.environ["TOGETHER_API_KEY"],
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
temperature=0,
),
"description": "Mixtral 8x7B Instruct v0.1 (Together AI)",
}

return models

Expand All @@ -47,7 +57,7 @@ def get_chunk_size(model_name: str) -> int:
def get_model(model_name: Optional[str] = None) -> BaseChatModel:
"""Get the model."""
if model_name is None:
return SUPPORTED_MODELS[DEFAULT_MODEL]
return SUPPORTED_MODELS[DEFAULT_MODEL]["chat_model"]
else:
supported_model_names = list(SUPPORTED_MODELS.keys())
if model_name not in supported_model_names:
Expand All @@ -56,4 +66,4 @@ def get_model(model_name: Optional[str] = None) -> BaseChatModel:
f"Supported models: {supported_model_names}"
)
else:
return SUPPORTED_MODELS[model_name]
return SUPPORTED_MODELS[model_name]["chat_model"]
20 changes: 13 additions & 7 deletions backend/tests/unit_tests/api/test_api_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ async def test_extract_from_file() -> None:
assert response.json() == {"data": ["This is a "]}


@patch(
"server.extraction_runnable.extraction_runnable",
new=RunnableLambda(mock_extraction_runnable),
)
@patch("server.extraction_runnable.TokenTextSplitter", mock_text_splitter)
async def test_extract_from_large_file() -> None:
user_id = str(uuid4())
headers = {"x-key": user_id}
Expand Down Expand Up @@ -167,22 +172,23 @@ async def test_extract_from_large_file() -> None:
)
assert response.status_code == 413

# Test page number constraint
# Test chunk count constraint
with tempfile.NamedTemporaryFile(mode="w+t", delete=True) as f:
f.write("This is a named temporary file.")
f.seek(0)
f.flush()
with patch(
"extraction.parsing._guess_mimetype", return_value="application/pdf"
):
with patch("extraction.parsing._get_pdf_page_count", return_value=100):
with patch("server.extraction_runnable.settings.MAX_CHUNKS", 1):
with patch.object(
CharacterTextSplitter, "split_text", return_value=["a", "b"]
):
response = await client.post(
"/extract",
data={
"extractor_id": extractor_id,
"mode": "entire_document",
},
files={"file": f.name},
files={"file": f},
headers=headers,
)
assert response.status_code == 413
assert response.status_code == 200
assert response.json() == {"data": ["a"]}
Loading