依赖安装

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

!pip install protobuf==3.20.0 transformers==4.27.1 icetk cpm_kernels torch
!pip install fastapi pydantic uvicorn sse_starlette pyngrok nest-asyncio

环境配置

In [None]:
import torch
if torch.cuda.is_available() == False:
    print("请在右上方 Colab 运行时类型中，选择 GPU 类型的运行时。")

chatglm_models = [
    "THUDM/chatglm-6b",       # 原始模型
    "THUDM/chatglm-6b-int8",  # int8 量化
    "THUDM/chatglm-6b-int4",  # int4 量化
]

CHATGLM_MODEL = "THUDM/chatglm-6b-int4"

# GPU/CPU
RUNNING_DEVICE = "GPU"

# API_TOKEN
TOKEN = "token1"

模型启动

In [None]:
from transformers import AutoModel, AutoTokenizer

def init_chatglm(model_name: str, running_device: str):
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)

    if running_device == "GPU":
        model = model.half().cuda()
    else:
        model = model.float()
    model.eval()
    return tokenizer, model

tokenizer, model = init_chatglm(CHATGLM_MODEL, RUNNING_DEVICE)

模型测试

In [None]:
response, history = model.chat(tokenizer, "你好", history=[])
print(response)
print(history)
response, history = model.chat(tokenizer, "很高兴认识你", history=history)
print(response)
print(history)

下载webui

In [None]:
!wget https://github.com/ninehills/chatgpt-web/releases/download/1.0/dist.tgz -O dist.tgz
!tar zxvf dist.tgz
!mv ./dist/index.html ./dist/assets/

启动服务

In [None]:
import torch
from fastapi import FastAPI, Request, status, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from pydantic import BaseModel
from sse_starlette.sse import EventSourceResponse
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import json
from typing import List, Optional


# 参考 https://github.com/josStorer/selfhostedAI/blob/master/main.py

def torch_gc():
    if torch.cuda.is_available():
        with torch.cuda.device(0):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()


app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

app.mount("/assets", StaticFiles(directory="dist/assets"), name="assets")


class Message(BaseModel):
    role: str
    content: str


class Body(BaseModel):
    messages: List[Message]
    model: str
    stream: Optional[bool] = False
    max_tokens: Optional[int] = 256
    temperature: Optional[float] = 0.95
    top_p: Optional[float] = 0.7



@app.get("/")
def read_root():
    return RedirectResponse("/assets/index.html")


@app.get("/v1/models")
def get_models():
    return {"data": [
      {
        "created": 1677610602,
        "id": "gpt-3.5-turbo",
        "object": "model",
        "owned_by": "openai",
        "permission": [
          {
            "created": 1680818747,
            "id": "modelperm-fTUZTbzFp7uLLTeMSo9ks6oT",
            "object": "model_permission",
            "allow_create_engine": False,
            "allow_sampling": True,
            "allow_logprobs": True,
            "allow_search_indices": False,
            "allow_view": True,
            "allow_fine_tuning": False,
            "organization": "*",
            "group": None,
            "is_blocking": False
          }
        ],
        "root": "gpt-3.5-turbo",
        "parent": None,
      },
    ],
    "object": "list"
  }

def generate_response(content: str):
    return {
        "id": "chatcmpl-77PZm95TtxE0oYLRx3cxa6HtIDI7s",
        "object": "chat.completion",
        "created": 1682000966,
        "model": "gpt-3.5-turbo-0301",
        "usage": {
            "prompt_tokens": 10,
            "completion_tokens": 10,
            "total_tokens": 20,
        },
        "choices": [{
            "message": {"role": "assistant", "content": content}, "finish_reason": "stop", "index": 0}
        ]
    }

def generate_stream_response_start():
    return {"id":"chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB","object":"chat.completion.chunk","created":1682004627,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":None}]}

def generate_stream_response(content: str):
    return {
        "id":"chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB",
        "object":"chat.completion.chunk",
        "created":1682004627,
        "model":"gpt-3.5-turbo-0301",
        "choices":[{"delta":{"content":content},"index":0,"finish_reason":None}
    ]}

def generate_stream_response_stop():
    return {"id":"chatcmpl-77QWpn5cxFi9sVMw56DZReDiGKmcB","object":"chat.completion.chunk","created":1682004627,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{},"index":0,"finish_reason":"stop"}]}

@app.post("/v1/chat/completions")
async def completions(body: Body, request: Request):
    # Cancel token
    #if request.headers.get("Authorization").split(" ")[1] != TOKEN:
    #    raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token is wrong!")
    
    torch_gc()

    question = body.messages[-1]
    if question.role == 'user':
        question = question.content
    else:
        raise HTTPException(status.HTTP_400_BAD_REQUEST, "No Question Found")

    history = []
    user_question = ''
    for message in body.messages:
        if message.role == 'system':
            history.append((message.content, "OK"))
        if message.role == 'user':
            user_question = message.content
        elif message.role == 'assistant':
            assistant_answer = message.content
            history.append((user_question, assistant_answer))

    print(f"question = {question}, history = {history}")

    
    if body.stream:
        async def eval_chatglm():
            sends = 0
            first = True
            for response, _ in model.stream_chat(
                tokenizer, question, history,
                temperature=body.temperature,
                top_p=body.top_p,
                max_length=max(2048, body.max_tokens)):
                if await request.is_disconnected():
                    return
                ret = response[sends:]
                sends = len(response)
                if first:
                    first = False
                    yield json.dumps(generate_stream_response_start(), ensure_ascii=False)
                yield json.dumps(generate_stream_response(ret), ensure_ascii=False)
            yield json.dumps(generate_stream_response_stop(), ensure_ascii=False)
            yield "[DONE]"
        return EventSourceResponse(eval_chatglm(), ping=10000)
    else:
        response, _ = model.chat(
            tokenizer, question, history,
            temperature=body.temperature,
            top_p=body.top_p,
            max_length=max(2048, body.max_tokens))
        print(f"response: {response}")
        return JSONResponse(content=generate_response(response))

下载代理

In [None]:
!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared
!chmod a+x cloudflared

启动ui

In [None]:
# 在 Notebook 中运行所需
import nest_asyncio
nest_asyncio.apply()

#from pyngrok import ngrok, conf

#ngrok.set_auth_token(os.environ["ngrok_token"])
#http_tunnel = ngrok.connect(8000)
#print(http_tunnel.public_url)

import subprocess
print("start cloudflared runnel")
f = open("stdout", "w")
p = subprocess.Popen(['./cloudflared', '--url', 'http://localhost:8000'], bufsize=0, stdout=f, stderr=subprocess.STDOUT)

import time

time.sleep(3)

!grep -F trycloudflare stdout

print("start app")
import uvicorn
uvicorn.run(app, port=8000)