In [1]:
from typing import Any, Dict, Generator, List, Optional
import requests
from langchain.llms.base import LLM

class TextGenerationWebUI(LLM):
    """Wrapper for text-generation-webui API."""

    max_new_tokens: int = 200
    """maximum number of tokens to generate."""

    temperature: float = 0.7
    """Primary factor to control randomness of outputs. 0 = deterministic (only the most likely token is used). Higher value = more randomness."""

    top_p: float = 1
    """If not set to 1, select tokens with probabilities adding up to less than this number. Higher value = higher range of possible random results."""

    top_k: int = 40
    """Similar to top_p, but select instead only the top_k most likely tokens. Higher value = higher range of possible random results."""

    typical_p: float = 1.0
    """If not set to 1, select only tokens that are at least this much more likely to appear than random tokens, given the prior text."""

    repetition_penalty: float = 1.2
    """Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition."""
    
    encoder_repetition_penalty: float = 1.0
    """Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge."""

    no_repeat_ngram_size: int = 0
    """If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases."""

    min_length: int = 0
    """Minimum generation length in tokens."""

    do_sample: bool = True

    seed: int = -1
    """seed: -1 for random"""

    penalty_alpha: float = 0.0
    """Contrastive search"""

    num_beams: int = 1
    """Beam search (uses a lot of VRAM)"""

    length_penalty: float = 1.0
    """Beam search length penalty"""
        
    early_stopping: bool = False
        
    truncation_length: int = 2048
    """The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048."""

    stop: List[str] = None
    """Custom stopping strings"""

    add_bos_token: bool = True
    """Add the bos_token to the beginning of prompts. Disabling this can make the replies more creative."""
        
    ban_eos_token: bool = False
    """Ban the eos_token. Forces the model to never end the generation prematurely."""
    
    skip_special_tokens: bool = True
    """Skip special tokens. Some specific models need this unset."""
    
    api_host: str = "localhost"
    api_port: int = 5000
    api_streaming_port: int = 5005
    use_https: bool = False

#     @root_validator()
#     def validate_environment(cls, values: Dict) -> Dict:
#         """Validate"""
#         return values

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters."""
        return {
            "max_new_tokens": self.max_new_tokens,
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "repetition_penalty": self.repetition_penalty,
            "encoder_repetition_penalty": self.encoder_repetition_penalty,
            "typical_p": self.typical_p,
            "min_length": self.min_length,
            "no_repeat_ngram_size": self.no_repeat_ngram_size,
            "num_beams": self.num_beams,
            "length_penalty": self.length_penalty,
            "penalty_alpha": self.penalty_alpha,
            "early_stopping": self.early_stopping,
            "seed": self.seed,
            "add_bos_token": self.add_bos_token,
            "ban_eos_token": self.ban_eos_token,
            "truncation_length": self.truncation_length,
            "skip_special_tokens": self.skip_special_tokens,
            "do_sample": self.do_sample,
        }

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Get the identifying parameters."""
        return self._default_params
#         return {**{"model_path": self.model_path}, **self._default_params}

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "text-generation-webui"

    def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:

        params = self._default_params
        params["stop"] = self.stop or stop or []
        params["stopping_strings"] = params["stop"]

        return params
    
    def get_base_url(self):
        proto = "https" if self.use_https else "http"
        return f"{proto}://{self.api_host}:{self.api_port}"
    
    def get_model_name(self):
        URI = f'{self.get_base_url()}/api/v1/model'
        response = requests.get(URI)
        return response.json()['result']

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
    ) -> str:
        """Call the text-generation-webui API and return the output.

        Args:
            prompt: The prompt to use for generation.
            stop: A list of strings to stop generation when encountered.

        Returns:
            The generated text.

        """
        request = self._get_parameters(stop) 
        generations = []
        
        request['prompt'] = prompt

        URI = f'{self.get_base_url()}/api/v1/generate'
        response = requests.post(URI, json=request)

        if response.status_code == 200:
            result = response.json()['results'][0]['text']
            return result
        
    def get_token_count(self, prompt):
        URI = f'{self.get_base_url()}/api/v1/token-count'
        request = {}
        request['prompt'] = prompt
        response = requests.post(URI, json=request)
        if response.status_code == 200:
            result = response.json()['results'][0]['tokens']
            return result


In [4]:
llm = TextGenerationWebUI(temperature=0.1)

llm.get_model_name()

'TheBloke_wizardLM-7B-GPTQ'

In [5]:
prompt = """A chat between a human and an assistant.

### Human: write a youtube video title about Large Language Model.

TITLE:
### Assistant:"""

result = llm(prompt, stop=["\n### Human:", "\n### Assistant:"])
result

' "Unlocking the Power of Large Language Models with AI"'

In [6]:
llm.get_token_count(prompt)

41