Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions openai/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Tuple,
Union,
overload,
Any,
)
from urllib.parse import urlencode, urlsplit, urlunsplit

Expand All @@ -29,7 +30,7 @@
import openai
from openai import error, util, version
from openai.openai_response import OpenAIResponse
from openai.util import ApiType
from openai.util import ApiType, to_key_val_list

TIMEOUT_SECS = 600
MAX_CONNECTION_RETRIES = 2
Expand Down Expand Up @@ -569,12 +570,8 @@ async def arequest_raw(
)

if files:
# TODO: Use `aiohttp.MultipartWriter` to create the multipart form data here.
# For now we use the private `requests` method that is known to have worked so far.
data, content_type = requests.models.RequestEncodingMixin._encode_files( # type: ignore
files, data
)
headers["Content-Type"] = content_type
data = _aiohttp_encode_formdata(files, data)

request_kwargs = {
"method": method,
"url": abs_url,
Expand Down Expand Up @@ -685,6 +682,38 @@ def _interpret_response_line(
return resp


def _aiohttp_encode_formdata(files: Any, data: Any) -> aiohttp.FormData:
if not files:
raise ValueError("Files must be provided.")
elif isinstance(data, (str, bytes, bytearray)):
raise ValueError("Data must not be a string.")

form_data = aiohttp.FormData(charset="utf-8")
fields = to_key_val_list(data or {})
files = to_key_val_list(files or {})

for name, file_metadata in files:
content_type: Optional[str] = None
if len(file_metadata) == 2:
file_name, file_content = file_metadata
elif len(file_metadata) == 3:
file_name, file_content, content_type = file_metadata
else:
raise ValueError(f"The file named {name} has an invalid payload.")

form_data.add_field(
name,
file_content,
filename=file_name,
content_type=content_type,
)

for field_name, value in fields:
form_data.add_field(field_name, value)

return form_data


@asynccontextmanager
async def aiohttp_session() -> AsyncIterator[aiohttp.ClientSession]:
user_set_session = openai.aiosession.get()
Expand Down
16 changes: 15 additions & 1 deletion openai/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os
import re
import sys
from collections.abc import Mapping
from enum import Enum
from typing import Optional
from typing import Optional, Any, List, Tuple

import openai

Expand Down Expand Up @@ -173,6 +174,19 @@ def merge_dicts(x, y):
return z


def to_key_val_list(value: Any) -> Optional[List[Tuple[str, Any]]]:
if value is None:
return None

if isinstance(value, (str, bytes, bool, int)):
raise ValueError("cannot encode objects that are not 2-tuples")

if isinstance(value, Mapping):
value = value.items()

return list(value)


def default_api_key() -> str:
if openai.api_key_path:
with open(openai.api_key_path, "rt") as k:
Expand Down