In [9]:
"""Wrapper around custom APIs."""
import asyncio
from functools import partial
from typing import Any, Dict, List, Mapping, Optional

import requests
from pydantic import Extra, root_validator

from langchain.callbacks.manager import (
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens

VALID_TASKS = ("text2text-generation", "text-generation")

class HuggingFaceEndpoint(LLM):
    """Wrapper around local HuggingFace Inference Endpoints.
    Only supports `text-generation` and `text2text-generation` for now.
    Example:
        .. code-block:: python
            from langchain.llms import LocalHuggingFaceEndpoint
            endpoint_url = "https://api/completion-endpoint/"
            config_endpoint_url = "https://api/config-endpoint/"
            llm = LocalHuggingFaceEndpoint(
                endpoint_url=endpoint_url,
                config_endpoint_url=config_endpoint_url,
                headers = {"Content-Type": "application/json"}
            )
    """

    endpoint_url: str
    """Endpoint URL to use for completion."""
    token: str
    """Endpoint URL to use to GET the model config."""
    task: Optional[str] = None
    """Task to call the model with. Should be a task that returns `generated_text`."""
    model_kwargs: Optional[dict] = None
    """Key word arguments to pass to the model."""
    headers: Optional[dict] = None
    """Endpoint specific headers."""

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid

    # @root_validator()
    # def validate_environment(cls, values: Dict) -> Dict:
    #     """Validate the task and endpoint urls."""
    #     task = values.get("task", "")
    #     if task not in VALID_TASKS:
    #         raise ValueError(
    #             f"Got invalid task {task}, "
    #             f"currently only {VALID_TASKS} are supported"
    #         )
    #     try:
    #         config_endpoint_url = values.get("config_endpoint_url", "")
    #         headers = values.get("headers")
    #         response = requests.get(config_endpoint_url, headers=headers)
    #         response.raise_for_status()
    #     except Exception as e:
    #         raise ValueError(
    #             f"Could not connect to '{config_endpoint_url}' with error {e}"
    #         )
    #     try:
    #         local_task = response.json()["task"]
    #     except Exception as e:
    #         raise ValueError(f"Could not parse response with error {e}")
    #     if task != local_task:
    #         raise ValueError(
    #             f"The llm task '{task}' differs from the local task '{local_task}'."
    #         )
    #     return values

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        _model_kwargs = self.model_kwargs or {}
        return {
            **{"endpoint_url": self.endpoint_url},
            **{"task": self.task},
            **{"model_kwargs": _model_kwargs},
        }

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "huggingface_endpoint"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
    ) -> str:
        """Call out to local Huggingface Inference endpoint.
        Args:
            prompt: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.
        Returns:
            The string generated by the model.
        Example:
            .. code-block:: python
                response = llm("Tell me a joke.")
        """
        _model_kwargs = self.model_kwargs or {}

        # payload samples
        parameter_payload = {"inputs": prompt, "model_kwargs": _model_kwargs}

        # send request
        try:
            headers = {"Authorization": f"Bearer {self.token}"}
            payload = {"inputs": prompt, "parameters": {"max_new_tokens": 250}}
            response = requests.request("POST", self.endpoint_url, headers=headers, json=payload)
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            raise ValueError(f"Error raised by inference endpoint: {e}")
        response_dict = response.json()
        text = response_dict[0]["generated_text"]
        if "error" in response_dict:
            raise ValueError(f"Error raised by inference API: {response_dict['error']}")
        if stop is not None:
            text = enforce_stop_tokens(text, stop)
        return text

    async def _acall(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
    ) -> str:
        """Call out to custom inference endpoint."""
        func = partial(self._call, prompt, stop)
        return await asyncio.get_event_loop().run_in_executor(None, func)

In [10]:

import os
llm = HuggingFaceEndpoint(endpoint_url=os.environ.get("HUGGINGFACEHUB_ENDPOINT"), token=os.environ.get('HUGGINGFACEHUB_API_TOKEN'))

In [11]:
llm("Tell me a joke.")

'\nWhy did the tomato turn red? Because it saw the salad dressing!'