Large diffs are not rendered by default.

@@ -1,12 +1,10 @@
# flake8: noqa

from openai.api_resources.abstract.api_resource import APIResource
from openai.api_resources.abstract.singleton_api_resource import SingletonAPIResource
from openai.api_resources.abstract.createable_api_resource import CreateableAPIResource
from openai.api_resources.abstract.updateable_api_resource import UpdateableAPIResource
from openai.api_resources.abstract.deletable_api_resource import DeletableAPIResource
from openai.api_resources.abstract.listable_api_resource import ListableAPIResource
from openai.api_resources.abstract.custom_method import custom_method
from openai.api_resources.abstract.nested_resource_class_methods import (
nested_resource_class_methods,
)
from openai.api_resources.abstract.updateable_api_resource import UpdateableAPIResource
@@ -14,16 +14,16 @@ def retrieve(cls, id, api_key=None, request_id=None, **params):
return instance

def refresh(self, request_id=None):
headers = util.populate_headers(request_id=request_id)
self.refresh_from(self.request("get", self.instance_url(), headers=headers))
self.refresh_from(
self.request("get", self.instance_url(), request_id=request_id)
)
return self

@classmethod
def class_url(cls):
if cls == APIResource:
raise NotImplementedError(
"APIResource is an abstract class. You should perform "
"actions on its subclasses (e.g. Charge, Customer)"
"APIResource is an abstract class. You should perform actions on its subclasses."
)
# Namespaces are separated in object names with periods (.) and in URLs
# with forward slashes (/), so replace the former with the latter.
@@ -54,7 +54,6 @@ def _static_request(
url_,
api_key=None,
api_base=None,
idempotency_key=None,
request_id=None,
api_version=None,
organization=None,
@@ -66,8 +65,9 @@ def _static_request(
organization=organization,
api_base=api_base,
)
headers = util.populate_headers(idempotency_key, request_id)
response, _, api_key = requestor.request(method_, url_, params, headers)
response, _, api_key = requestor.request(
method_, url_, params, request_id=request_id
)
return util.convert_to_openai_object(
response, api_key, api_version, organization
)
@@ -1,7 +1,5 @@
from __future__ import absolute_import, division, print_function

from openai.api_resources.abstract.api_resource import APIResource
from openai import api_requestor, util
from openai.api_resources.abstract.api_resource import APIResource


class CreateableAPIResource(APIResource):
@@ -12,7 +10,6 @@ def create(
cls,
api_key=None,
api_base=None,
idempotency_key=None,
request_id=None,
api_version=None,
organization=None,
@@ -25,8 +22,9 @@ def create(
organization=organization,
)
url = cls.class_url()
headers = util.populate_headers(idempotency_key, request_id)
response, _, api_key = requestor.request("post", url, params, headers)
response, _, api_key = requestor.request(
"post", url, params, request_id=request_id
)

return util.convert_to_openai_object(
response,

This file was deleted.

@@ -1,16 +1,12 @@
from urllib.parse import quote_plus

from openai import util
from openai.api_resources.abstract.api_resource import APIResource


class DeletableAPIResource(APIResource):
@classmethod
def _cls_delete(cls, sid, **params):
def delete(cls, sid, **params):
if isinstance(cls, APIResource):
raise ValueError(".delete may only be called as a class method now.")
url = "%s/%s" % (cls.class_url(), quote_plus(sid))
return cls._static_request("delete", url, **params)

@util.class_method_variant("_cls_delete")
def delete(self, **params):
self.refresh_from(self.request("delete", self.instance_url(), params))
return self
@@ -4,6 +4,7 @@

from openai import api_requestor, error, util
from openai.api_resources.abstract.api_resource import APIResource
from openai.openai_response import OpenAIResponse

MAX_TIMEOUT = 20

@@ -31,7 +32,6 @@ def create(
cls,
api_key=None,
api_base=None,
idempotency_key=None,
request_id=None,
api_version=None,
organization=None,
@@ -62,12 +62,12 @@ def create(
organization=organization,
)
url = cls.class_url(engine)
headers = util.populate_headers(idempotency_key, request_id)
response, _, api_key = requestor.request(
"post", url, params, headers, stream=stream
"post", url, params, stream=stream, request_id=request_id
)

if stream:
assert not isinstance(response, OpenAIResponse) # must be an iterator
return (
util.convert_to_openai_object(
line,
@@ -99,9 +99,7 @@ def instance_url(self):

if not isinstance(id, str):
raise error.InvalidRequestError(
"Could not determine which URL to request: %s instance "
"has invalid ID: %r, %s. ID should be of type `str` (or"
" `unicode`)" % (type(self).__name__, id, type(id)),
f"Could not determine which URL to request: {type(self).__name__} instance has invalid ID: {id}, {type(id)}. ID should be of type str.",
"id",
)

@@ -1,5 +1,3 @@
from __future__ import absolute_import, division, print_function

from openai import api_requestor, util
from openai.api_resources.abstract.api_resource import APIResource

@@ -19,15 +17,16 @@ def list(
api_base=None,
**params,
):
headers = util.populate_headers(request_id=request_id)
requestor = api_requestor.APIRequestor(
api_key,
api_base=api_base or cls.api_base(),
api_version=api_version,
organization=organization,
)
url = cls.class_url()
response, _, api_key = requestor.request("get", url, params, headers)
response, _, api_key = requestor.request(
"get", url, params, request_id=request_id
)
openai_object = util.convert_to_openai_object(
response, api_key, api_version, organization
)
@@ -28,7 +28,6 @@ def nested_resource_request(
method,
url,
api_key=None,
idempotency_key=None,
request_id=None,
api_version=None,
organization=None,
@@ -37,8 +36,9 @@ def nested_resource_request(
requestor = api_requestor.APIRequestor(
api_key, api_version=api_version, organization=organization
)
headers = util.populate_headers(idempotency_key, request_id)
response, _, api_key = requestor.request(method, url, params, headers)
response, _, api_key = requestor.request(
method, url, params, request_id=request_id
)
return util.convert_to_openai_object(
response, api_key, api_version, organization
)

This file was deleted.

@@ -1,6 +1,5 @@
from urllib.parse import quote_plus

from openai import util
from openai.api_resources.abstract.api_resource import APIResource


@@ -9,15 +8,3 @@ class UpdateableAPIResource(APIResource):
def modify(cls, sid, **params):
url = "%s/%s" % (cls.class_url(), quote_plus(sid))
return cls._static_request("post", url, **params)

def save(self, idempotency_key=None, request_id=None):
updated_params = self.serialize(None)
headers = util.populate_headers(idempotency_key, request_id)

if updated_params:
self.refresh_from(
self.request("post", self.instance_url(), updated_params, headers)
)
else:
util.logger.debug("Trying to save already saved object %r", self)
return self
@@ -3,7 +3,7 @@
from openai import util
from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource
from openai.api_resources.abstract.engine_api_resource import EngineAPIResource
from openai.error import TryAgain, InvalidRequestError
from openai.error import InvalidRequestError, TryAgain


class Completion(EngineAPIResource, ListableAPIResource, DeletableAPIResource):
@@ -1,10 +1,7 @@
import time

from openai import util
from openai.api_resources.abstract import (
ListableAPIResource,
UpdateableAPIResource,
)
from openai.api_resources.abstract import ListableAPIResource, UpdateableAPIResource
from openai.error import TryAgain


@@ -1,23 +1,22 @@
from __future__ import absolute_import, division, print_function
from typing import Optional

from openai.util import merge_dicts
from openai.openai_object import OpenAIObject
from openai.util import merge_dicts


class ErrorObject(OpenAIObject):
def refresh_from(
self,
values,
api_key=None,
partial=False,
api_version=None,
organization=None,
last_response=None,
response_ms: Optional[int] = None,
):
# Unlike most other API resources, the API will omit attributes in
# error objects when they have a null value. We manually set default
# values here to facilitate generic error handling.
values = merge_dicts({"message": None, "type": None}, values)
return super(ErrorObject, self).refresh_from(
values, api_key, partial, api_version, organization, last_response
values, api_key, api_version, organization, response_ms
)
@@ -1,34 +1,41 @@
from __future__ import absolute_import, division, print_function

import json
import os
from typing import cast

import openai
from openai import api_requestor, util
from openai.api_resources.abstract import (
DeletableAPIResource,
ListableAPIResource,
)
from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource


class File(ListableAPIResource, DeletableAPIResource):
OBJECT_NAME = "file"

@classmethod
def create(
cls, api_key=None, api_base=None, api_version=None, organization=None, **params
cls,
file,
purpose,
model=None,
api_key=None,
api_base=None,
api_version=None,
organization=None,
):
if purpose != "search" and model is not None:
raise ValueError("'model' is only meaningful if 'purpose' is 'search'")
requestor = api_requestor.APIRequestor(
api_key,
api_base=api_base or openai.file_api_base or openai.api_base,
api_base=api_base or openai.api_base,
api_version=api_version,
organization=organization,
)
url = cls.class_url()
supplied_headers = {"Content-Type": "multipart/form-data"}
response, _, api_key = requestor.request(
"post", url, params=params, headers=supplied_headers
)
# Set the filename on 'purpose' and 'model' to None so they are
# interpreted as form data.
files = [("file", file), ("purpose", (None, purpose))]
if model is not None:
files.append(("model", (None, model)))
response, _, api_key = requestor.request("post", url, files=files)
return util.convert_to_openai_object(
response, api_key, api_version, organization
)
@@ -39,17 +46,21 @@ def download(
):
requestor = api_requestor.APIRequestor(
api_key,
api_base=api_base or openai.file_api_base or openai.api_base,
api_base=api_base or openai.api_base,
api_version=api_version,
organization=organization,
)
url = f"{cls.class_url()}/{id}/content"
rbody, rcode, rheaders, _, _ = requestor.request_raw("get", url)
if not 200 <= rcode < 300:
result = requestor.request_raw("get", url)
if not 200 <= result.status_code < 300:
raise requestor.handle_error_response(
rbody, rcode, json.loads(rbody), rheaders, stream_error=False
result.content,
result.status_code,
json.loads(cast(bytes, result.content)),
result.headers,
stream_error=False,
)
return rbody
return result.content

@classmethod
def find_matching_files(
@@ -71,7 +82,7 @@ def find_matching_files(
)
all_files = cls.list(
api_key=api_key,
api_base=api_base or openai.file_api_base or openai.api_base,
api_base=api_base or openai.api_base,
api_version=api_version,
organization=organization,
).get("data", [])
@@ -1,11 +1,12 @@
from urllib.parse import quote_plus

from openai import api_requestor, util
from openai.api_resources.abstract import (
ListableAPIResource,
CreateableAPIResource,
ListableAPIResource,
nested_resource_class_methods,
)
from openai import api_requestor, util
from openai.openai_response import OpenAIResponse


@nested_resource_class_methods("event", operations=["list"])
@@ -18,8 +19,7 @@ def cancel(cls, id, api_key=None, request_id=None, **params):
extn = quote_plus(id)
url = "%s/%s/cancel" % (base, extn)
instance = cls(id, api_key, **params)
headers = util.populate_headers(request_id=request_id)
return instance.request("post", url, headers=headers)
return instance.request("post", url, request_id=request_id)

@classmethod
def stream_events(
@@ -42,11 +42,11 @@ def stream_events(
organization=organization,
)
url = "%s/%s/events?stream=true" % (base, extn)
headers = util.populate_headers(request_id=request_id)
response, _, api_key = requestor.request(
"get", url, params, headers=headers, stream=True
"get", url, params, stream=True, request_id=request_id
)

assert not isinstance(response, OpenAIResponse) # must be an iterator
return (
util.convert_to_openai_object(
line,
@@ -1,7 +1,4 @@
from openai.api_resources.abstract import (
ListableAPIResource,
DeletableAPIResource,
)
from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource


class Model(ListableAPIResource, DeletableAPIResource):
@@ -5,6 +5,7 @@
import warnings

import openai
from openai.upload_progress import BufferReader
from openai.validators import (
apply_necessary_remediation,
apply_optional_remediation,
@@ -60,9 +61,7 @@ def get(cls, args):

@classmethod
def update(cls, args):
engine = openai.Engine(id=args.id)
engine.replicas = args.replicas
engine.save()
engine = openai.Engine.modify(args.id, replicas=args.replicas)
display(engine)

@classmethod
@@ -181,14 +180,12 @@ def create(cls, args):
class Model:
@classmethod
def get(cls, args):
resp = openai.Model.retrieve(
id=args.id,
)
resp = openai.Model.retrieve(id=args.id)
print(resp)

@classmethod
def delete(cls, args):
model = openai.Model(id=args.id).delete()
model = openai.Model.delete(args.id)
print(model)

@classmethod
@@ -200,10 +197,10 @@ def list(cls, args):
class File:
@classmethod
def create(cls, args):
with open(args.file, "rb") as file_reader:
buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
resp = openai.File.create(
file=open(args.file),
purpose=args.purpose,
model=args.model,
file=buffer_reader, purpose=args.purpose, model=args.model
)
print(resp)

@@ -214,7 +211,7 @@ def get(cls, args):

@classmethod
def delete(cls, args):
file = openai.File(id=args.id).delete()
file = openai.File.delete(args.id)
print(file)

@classmethod
@@ -1,5 +1,3 @@
from __future__ import absolute_import, division, print_function

import openai


@@ -66,7 +64,7 @@ def construct_error_object(self):
return None

return openai.api_resources.error_object.ErrorObject.construct_from(
self.json_body["error"], openai.api_key
self.json_body["error"], key=None
)


@@ -95,10 +93,6 @@ def __init__(
self.should_retry = should_retry


class IdempotencyError(OpenAIError):
pass


class InvalidRequestError(OpenAIError):
def __init__(
self,
@@ -138,6 +132,10 @@ class RateLimitError(OpenAIError):
pass


class ServiceUnavailableError(OpenAIError):
pass


class SignatureVerificationError(OpenAIError):
def __init__(self, message, sig_header, http_body=None):
super(SignatureVerificationError, self).__init__(message, http_body)

This file was deleted.

This file was deleted.

This file was deleted.

@@ -1,5 +1,3 @@
from __future__ import absolute_import, division, print_function

from openai import api_resources
from openai.api_resources.experimental.completion_config import CompletionConfig

@@ -1,66 +1,32 @@
from __future__ import absolute_import, division, print_function

import datetime
import json
from copy import deepcopy
from typing import Optional

import openai
from openai import api_requestor, util


def _compute_diff(current, previous):
if isinstance(current, dict):
previous = previous or {}
diff = current.copy()
for key in set(previous.keys()) - set(diff.keys()):
diff[key] = ""
return diff
return current if current is not None else ""


def _serialize_list(array, previous):
array = array or []
previous = previous or []
params = {}

for i, v in enumerate(array):
previous_item = previous[i] if len(previous) > i else None
if hasattr(v, "serialize"):
params[str(i)] = v.serialize(previous_item)
else:
params[str(i)] = _compute_diff(v, previous_item)

return params
from openai.openai_response import OpenAIResponse


class OpenAIObject(dict):
api_base_override = None

class ReprJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime.datetime):
return api_requestor._encode_datetime(obj)
return super(OpenAIObject.ReprJSONEncoder, self).default(obj)

def __init__(
self,
id=None,
api_key=None,
api_version=None,
organization=None,
last_response=None,
response_ms: Optional[int] = None,
api_base=None,
engine=None,
**params,
):
super(OpenAIObject, self).__init__()

self._unsaved_values = set()
self._transient_values = set()
self._last_response = last_response
if response_ms is not None and not isinstance(response_ms, int):
raise TypeError(f"response_ms is a {type(response_ms).__name__}.")
self._response_ms = response_ms

self._retrieve_params = params
self._previous = None

object.__setattr__(self, "api_key", api_key)
object.__setattr__(self, "api_version", api_version)
@@ -72,14 +38,8 @@ def __init__(
self["id"] = id

@property
def last_response(self):
return self._last_response

def update(self, update_dict):
for k in update_dict:
self._unsaved_values.add(k)

return super(OpenAIObject, self).update(update_dict)
def response_ms(self) -> Optional[int]:
return self._response_ms

def __setattr__(self, k, v):
if k[0] == "_" or k in self.__dict__:
@@ -91,7 +51,6 @@ def __setattr__(self, k, v):
def __getattr__(self, k):
if k[0] == "_":
raise AttributeError(k)

try:
return self[k]
except KeyError as err:
@@ -110,37 +69,10 @@ def __setitem__(self, k, v):
"We interpret empty strings as None in requests."
"You may set %s.%s = None to delete the property" % (k, str(self), k)
)

# Allows for unpickling in Python 3.x
if not hasattr(self, "_unsaved_values"):
self._unsaved_values = set()

self._unsaved_values.add(k)

super(OpenAIObject, self).__setitem__(k, v)

def __getitem__(self, k):
try:
return super(OpenAIObject, self).__getitem__(k)
except KeyError as err:
if k in self._transient_values:
raise KeyError(
"%r. HINT: The %r attribute was set in the past. "
"It was then wiped when refreshing the object with "
"the result returned by OpenAI's API, probably as a "
"result of a save(). The attributes currently "
"available on this object are: %s"
% (k, k, ", ".join(list(self.keys())))
)
else:
raise err

def __delitem__(self, k):
super(OpenAIObject, self).__delitem__(k)

# Allows for unpickling in Python 3.x
if hasattr(self, "_unsaved_values") and k in self._unsaved_values:
self._unsaved_values.remove(k)
raise NotImplementedError("del is not supported")

# Custom unpickling method that uses `update` to update the dictionary
# without calling __setitem__, which would fail if any value is an empty
@@ -172,52 +104,40 @@ def construct_from(
api_version=None,
organization=None,
engine=None,
last_response=None,
response_ms: Optional[int] = None,
):
instance = cls(
values.get("id"),
api_key=key,
api_version=api_version,
organization=organization,
engine=engine,
last_response=last_response,
response_ms=response_ms,
)
instance.refresh_from(
values,
api_key=key,
api_version=api_version,
organization=organization,
last_response=last_response,
response_ms=response_ms,
)
return instance

def refresh_from(
self,
values,
api_key=None,
partial=False,
api_version=None,
organization=None,
last_response=None,
response_ms: Optional[int] = None,
):
self.api_key = api_key or getattr(values, "api_key", None)
self.api_version = api_version or getattr(values, "api_version", None)
self.organization = organization or getattr(values, "organization", None)
self._last_response = last_response or getattr(values, "_last_response", None)

# Wipe old state before setting new. This is useful for e.g.
# updating a customer, where there is no persistent card
# parameter. Mark those values which don't persist as transient
if partial:
self._unsaved_values = self._unsaved_values - set(values)
else:
removed = set(self.keys()) - set(values)
self._transient_values = self._transient_values | removed
self._unsaved_values = set()
self.clear()

self._transient_values = self._transient_values - set(values)
self._response_ms = response_ms or getattr(values, "_response_ms", None)

# Wipe old state before setting new.
self.clear()
for k, v in values.items():
super(OpenAIObject, self).__setitem__(
k, util.convert_to_openai_object(v, api_key, api_version, organization)
@@ -230,7 +150,14 @@ def api_base(cls):
return None

def request(
self, method, url, params=None, headers=None, stream=False, plain_old_data=False
self,
method,
url,
params=None,
headers=None,
stream=False,
plain_old_data=False,
request_id: Optional[str] = None,
):
if params is None:
params = self._retrieve_params
@@ -241,10 +168,11 @@ def request(
organization=self.organization,
)
response, stream, api_key = requestor.request(
method, url, params, headers, stream=stream
method, url, params, stream=stream, headers=headers, request_id=request_id
)

if stream:
assert not isinstance(response, OpenAIResponse) # must be an iterator
return (
util.convert_to_openai_object(
line,
@@ -284,7 +212,7 @@ def __repr__(self):

def __str__(self):
obj = self.to_dict_recursive()
return json.dumps(obj, sort_keys=True, indent=2, cls=self.ReprJSONEncoder)
return json.dumps(obj, sort_keys=True, indent=2)

def to_dict(self):
return dict(self)
@@ -305,27 +233,6 @@ def to_dict_recursive(self):
def openai_id(self):
return self.id

def serialize(self, previous):
params = {}
unsaved_keys = self._unsaved_values or set()
previous = previous or self._previous or {}

for k, v in self.items():
if k == "id" or (isinstance(k, str) and k.startswith("_")):
continue
elif isinstance(v, openai.api_resources.abstract.APIResource):
continue
elif hasattr(v, "serialize"):
child = v.serialize(previous.get(k, None))
if child != {}:
params[k] = child
elif k in unsaved_keys:
params[k] = _compute_diff(v, previous.get(k, None))
elif k == "additional_owners" and v is not None:
params[k] = _serialize_list(v, previous.get(k, None))

return params

# This class overrides __setitem__ to throw exceptions on inputs that it
# doesn't like. This can cause problems when we try to copy an object
# wholesale because some data that's returned from the API may not be valid
@@ -1,25 +1,20 @@
from __future__ import absolute_import, division, print_function
from typing import Optional

import json

class OpenAIResponse:
def __init__(self, data, headers):
self._headers = headers
self.data = data

class OpenAIResponse(object):
def __init__(self, body, code, headers):
self.body = body
self.code = code
self.headers = headers
self.data = json.loads(body)
@property
def request_id(self) -> Optional[str]:
return self._headers.get("request-id")

@property
def idempotency_key(self):
try:
return self.headers["idempotency-key"]
except KeyError:
return None
def organization(self) -> Optional[str]:
return self._headers.get("OpenAI-Organization")

@property
def request_id(self):
try:
return self.headers["request-id"]
except KeyError:
return None
def response_ms(self) -> Optional[int]:
h = self._headers.get("Openai-Processing-Ms")
return None if h is None else int(h)

This file was deleted.

@@ -0,0 +1,27 @@
import json

import requests
from pytest_mock import MockerFixture

from openai import Model


def test_requestor_sets_request_id(mocker: MockerFixture) -> None:
# Fake out 'requests' and confirm that the X-Request-Id header is set.

got_headers = {}

def fake_request(self, *args, **kwargs):
nonlocal got_headers
got_headers = kwargs["headers"]
r = requests.Response()
r.status_code = 200
r.headers["content-type"] = "application/json"
r._content = json.dumps({}).encode("utf-8")
return r

mocker.patch("requests.sessions.Session.request", fake_request)
fake_request_id = "1234"
Model.retrieve("xxx", request_id=fake_request_id) # arbitrary API resource
got_request_id = got_headers.get("X-Request-Id")
assert got_request_id == fake_request_id
@@ -1,7 +1,8 @@
import openai
import io
import json

import openai


# FILE TESTS
def test_file_upload():
@@ -12,15 +13,18 @@ def test_file_upload():
assert result.purpose == "search"
assert "id" in result

result = openai.File.retrieve(id=result.id)
assert result.status == "uploaded"


# COMPLETION TESTS
def test_completions():
result = openai.Completion.create(prompt="This was a test", n=5, engine="davinci")
result = openai.Completion.create(prompt="This was a test", n=5, engine="ada")
assert len(result.choices) == 5


def test_completions_multiple_prompts():
result = openai.Completion.create(
prompt=["This was a test", "This was another test"], n=5, engine="davinci"
prompt=["This was a test", "This was another test"], n=5, engine="ada"
)
assert len(result.choices) == 10
@@ -0,0 +1,39 @@
import json
import subprocess
import time
from tempfile import NamedTemporaryFile

STILL_PROCESSING = "File is still processing. Check back later."


def test_file_cli() -> None:
contents = json.dumps({"prompt": "1 + 3 =", "completion": "4"}) + "\n"
with NamedTemporaryFile(suffix=".jsonl", mode="wb") as train_file:
train_file.write(contents.encode("utf-8"))
train_file.flush()
create_output = subprocess.check_output(
["openai", "api", "files.create", "-f", train_file.name, "-p", "fine-tune"]
)
file_obj = json.loads(create_output)
assert file_obj["bytes"] == len(contents)
file_id: str = file_obj["id"]
assert file_id.startswith("file-")
start_time = time.time()
while True:
delete_result = subprocess.run(
["openai", "api", "files.delete", "-i", file_id],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8",
)
if delete_result.returncode == 0:
break
elif STILL_PROCESSING in delete_result.stderr:
time.sleep(0.5)
if start_time + 60 < time.time():
raise RuntimeError("timed out waiting for file to become available")
continue
else:
raise RuntimeError(
f"delete failed: stdout={delete_result.stdout} stderr={delete_result.stderr}"
)
@@ -0,0 +1,30 @@
from tempfile import NamedTemporaryFile

import pytest

import openai
from openai import util


@pytest.fixture(scope="function")
def api_key_file():
saved_path = openai.api_key_path
try:
with NamedTemporaryFile(prefix="openai-api-key", mode="wt") as tmp:
openai.api_key_path = tmp.name
yield tmp
finally:
openai.api_key_path = saved_path


def test_openai_api_key_path(api_key_file) -> None:
print("sk-foo", file=api_key_file)
api_key_file.flush()
assert util.default_api_key() == "sk-foo"


def test_openai_api_key_path_with_malformed_key(api_key_file) -> None:
print("malformed-api-key", file=api_key_file)
api_key_file.flush()
with pytest.raises(ValueError, match="Malformed API key"):
util.default_api_key()
@@ -1,36 +1,23 @@
import functools
import hmac
import io
import logging
import os
import re
import sys
from urllib.parse import parse_qsl
from typing import Optional

import openai


OPENAI_LOG = os.environ.get("OPENAI_LOG")

logger = logging.getLogger("openai")

__all__ = [
"io",
"parse_qsl",
"log_info",
"log_debug",
"log_warn",
"dashboard_link",
"logfmt",
]


def is_appengine_dev():
return "APPENGINE_RUNTIME" in os.environ and "Dev" in os.environ.get(
"SERVER_SOFTWARE", ""
)


def _console_log_level():
if openai.log in ["debug", "info"]:
return openai.log
@@ -60,21 +47,6 @@ def log_warn(message, **params):
logger.warn(msg)


def _test_or_live_environment():
if openai.api_key is None:
return
match = re.match(r"sk_(live|test)_", openai.api_key)
if match is None:
return
return match.groups()[0]


def dashboard_link(request_id):
return "https://dashboard.openai.com/{env}/logs/{reqid}".format(
env=_test_or_live_environment() or "test", reqid=request_id
)


def logfmt(props):
def fmt(key, val):
# Handle case where val is a bytes or bytesarray
@@ -93,10 +65,6 @@ def fmt(key, val):
return " ".join([fmt(key, val) for key, val in sorted(props.items())])


def secure_compare(val1, val2):
return hmac.compare_digest(val1, val2)


def get_object_classes():
# This is here to avoid a circular dependency
from openai.object_classes import OBJECT_CLASSES
@@ -112,18 +80,13 @@ def convert_to_openai_object(
engine=None,
plain_old_data=False,
):
# If we get a OpenAIResponse, we'll want to return a
# OpenAIObject with the last_response field filled out with
# the raw API response information
openai_response = None
# If we get a OpenAIResponse, we'll want to return a OpenAIObject.

response_ms: Optional[int] = None
if isinstance(resp, openai.openai_response.OpenAIResponse):
# TODO: move this elsewhere
openai_response = resp
resp = openai_response.data
organization = (
openai_response.headers.get("OpenAI-Organization") or organization
)
organization = resp.organization
response_ms = resp.response_ms
resp = resp.data

if plain_old_data:
return resp
@@ -151,7 +114,7 @@ def convert_to_openai_object(
api_key,
api_version=api_version,
organization=organization,
last_response=openai_response,
response_ms=response_ms,
engine=engine,
)
else:
@@ -178,47 +141,22 @@ def convert_to_dict(obj):
return obj


def populate_headers(idempotency_key=None, request_id=None):
headers = {}
if idempotency_key is not None:
headers["Idempotency-Key"] = idempotency_key
if request_id is not None:
headers["X-Request-Id"] = request_id
if openai.debug:
headers["OpenAI-Debug"] = "true"

return headers


def merge_dicts(x, y):
z = x.copy()
z.update(y)
return z


class class_method_variant(object):
def __init__(self, class_method_name):
self.class_method_name = class_method_name

def __call__(self, method):
self.method = method
return self

def __get__(self, obj, objtype=None):
@functools.wraps(self.method)
def _wrapper(*args, **kwargs):
if obj is not None:
# Method was called as an instance method, e.g.
# instance.method(...)
return self.method(obj, *args, **kwargs)
elif len(args) > 0 and isinstance(args[0], objtype):
# Method was called as a class method with the instance as the
# first argument, e.g. Class.method(instance, ...) which in
# Python is the same thing as calling an instance method
return self.method(args[0], *args[1:], **kwargs)
else:
# Method was called as a class method, e.g. Class.method(...)
class_method = getattr(objtype, self.class_method_name)
return class_method(*args, **kwargs)

return _wrapper
def default_api_key() -> str:
if openai.api_key_path:
with open(openai.api_key_path, "rt") as k:
api_key = k.read().strip()
if not api_key.startswith("sk-"):
raise ValueError(f"Malformed API key in {openai.api_key_path}.")
return api_key
elif openai.api_key is not None:
return openai.api_key
else:
raise openai.error.AuthenticationError(
"No API key provided. You can set your API key in code using 'openai.api_key = <API-KEY>', or you can set the environment variable OPENAI_API_KEY=<API-KEY>). If your API key is stored in a file, you can point the openai module at it with 'openai.api_key_path = <PATH>'. You can generate API keys in the OpenAI web interface. See https://onboard.openai.com for details, or email support@openai.com if you have any questions."
)
@@ -1,9 +1,9 @@
import os
import sys
import pandas as pd
import numpy as np
from typing import Any, Callable, NamedTuple, Optional

from typing import NamedTuple, Optional, Callable, Any
import numpy as np
import pandas as pd


class Remediation(NamedTuple):
@@ -1 +1 @@
VERSION = "0.10.5"
VERSION = "0.11.0"
@@ -20,8 +20,8 @@
"pandas-stubs>=1.1.0.11", # Needed for type hints for mypy
"openpyxl>=3.0.7", # Needed for CLI fine-tuning data preparation tool xlsx format
],
extras_require={"dev": ["black==20.8b1", "pytest==6.*"]},
python_requires=">=3.6",
extras_require={"dev": ["black~=21.6b0", "pytest==6.*"]},
python_requires=">=3.7.1",
scripts=["bin/openai"],
packages=find_packages(exclude=["tests", "tests.*"]),
package_data={