Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API model #34

Open
Fenglly opened this issue Mar 4, 2024 · 3 comments
Open

API model #34

Fenglly opened this issue Mar 4, 2024 · 3 comments

Comments

@Fenglly
Copy link

Fenglly commented Mar 4, 2024

If I want to test the qwen model with the API, can I just use the GPTAPI class and replace the model URL with the qwen one?

@zehuichen123
Copy link
Collaborator

Sry, we do not support customized API models currently. You may initialize a new model from BaseAPIModel in lagent and write the code by yourself. Btw, we will release a template file for how to customize an API model class in lagent recently.

@zehuichen123
Copy link
Collaborator

Here is an unfinished reference code

import json
import os
import time
from concurrent.futures import ThreadPoolExecutor, wait
from logging import getLogger
from threading import Lock
from typing import Dict, List, Optional, Union
import requests

from .base_api import BaseAPIModel, APITemplateParser


class CustomAPI(BaseAPIModel):
    """Model wrapper around Custom API models.

    Args:
        model_url (str): The url of the requested API model.
        query_per_second (int): The maximum queries allowed per second
            between two consecutive calls of the API. Defaults to 1.
        retry (int): Number of retires if the API call fails. Defaults to 2.
        key (str or List[str]): key(s) for API model. In particular, when it
            is set to "ENV", If it's a list, the keys will be used in round-robin
            manner. Defaults to 'ENV'.
        meta_template (Dict, optional): The model's meta prompt
            template if needed, in case the requirement of injecting or
            wrapping of any meta instructions.
        gen_params: Default generation configuration which could be overridden
            on the fly of generation.
    """

    def __init__(self,
                 model_type: str,
                 model_url: str,
                 query_per_second: int = 1,
                 retry: int = 2,
                 key: Union[str, List[str]],
                 meta_template: Optional[Dict] = [
                     dict(role='system', api_role='system'),
                     dict(role='user', api_role='user'),
                     dict(role='assistant', api_role='assistant')
                 ],
                 **gen_params):
        self.url = model_url
        super().__init__(
            model_type=model_type,
            meta_template=meta_template,
            query_per_second=query_per_second,
            retry=retry,
            **gen_params)
        self.logger = getLogger(__name__)
        if key is None:
            self.keys = None
        elif isinstance(key, str):
            self.keys = [key]
        else:
            self.keys = key

    def _generate(self,
                  inputs: str or List,
                  max_out_len: int = None,
                  temperature: float = None) -> str:
        """Generate results given a list of inputs.

        Args:
            inputs (str or List): A string or PromptDict.
                The PromptDict should be organized in OpenCompass'
                API format.
            max_out_len (int): The maximum length of the output.
            temperature (float): What sampling temperature to use,
                between 0 and 2. Higher values like 0.8 will make the output
                more random, while lower values like 0.2 will make it more
                focused and deterministic.

        Returns:
            str: The generated string.
        """
        assert isinstance(inputs, (str))
        max_num_retries = 0
        while max_num_retries < self.retry:

            header = {
                'content-type': 'application/json',
            }
            self._session_id = (self._session_id + 1) % 1000000

            try:
                data = dict(
                    model=self.path,
                    session_id=self._session_id,
                    prompt=inputs,
                    sequence_start=True,
                    sequence_end=True,
                    max_tokens=max_out_len,
                )
                raw_response = requests.post(
                    self.url, headers=header, data=json.dumps(data))
            except requests.ConnectionError:
                print('Got connection error, retrying...')
                max_num_retries += 1
                continue
            try:
                response = raw_response.json()
            except requests.JSONDecodeError:
                print('JsonDecode error, got', str(raw_response.content))
                max_num_retries += 1
                continue
            try:
                if 'completion' in self.url:
                    return response['choices'][0]['text'].strip()
                else:
                    return response['text'].strip()
            except KeyError:
                max_num_retries += 1
                pass

        raise RuntimeError('Calling API model failed after retrying for '
                           f'{max_num_retries} times. Check the logs for '
                           'details.')

@Fenglly
Copy link
Author

Fenglly commented Mar 5, 2024

Thanks! I will try.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants