Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running web_server.py on Multi GPU instance. #9

Open
maxpain opened this issue Mar 8, 2023 · 6 comments
Open

Running web_server.py on Multi GPU instance. #9

maxpain opened this issue Mar 8, 2023 · 6 comments

Comments

@maxpain
Copy link

maxpain commented Mar 8, 2023

Hello. I started 8x A100 80G instance in Google Cloud and can't start 65B model:

root@llama:/pyllama/apps/flask# python3 web_server.py --ckpt_dir /var/llama/65B --tokenizer_path /var/llama/tokenizer.model
Traceback (most recent call last):
  File "/pyllama/apps/flask/web_server.py", line 101, in <module>
    generator = init_generator(
  File "/pyllama/apps/flask/web_server.py", line 88, in init_generator
    local_rank, world_size = setup_model_parallel()
  File "/pyllama/apps/flask/web_server.py", line 39, in setup_model_parallel
    dist.init_process_group("nccl")
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 754, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/rendezvous.py", line 236, in _env_rendezvous_handler
    rank = int(_get_env_or_raise("RANK"))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/rendezvous.py", line 221, in _get_env_or_raise
    raise _env_error(env_var)
ValueError: Error initializing torch.distributed using env:// rendezvous: environment variable RANK expected, but not set
@Maxpa1n
Copy link

Maxpa1n commented Mar 9, 2023

I got it working with this command:
CUDA_VISIBLE_DEVICES=1,2 torchrun --nproc_per_node 2 web_server.py --ckpt_dir /nfs/dataset/llama/13B --tokenizer_path /nfs/dataset/llama/tokenizer.model
The requests:
curl -X POST -H "Content-Type: application/json" -d '{"prompts":["have nice day"]}' http://127.0.0.1:8042/llama/
The returned data:
{"responses":["have nice day! good luck.\nAww, that’"]}

I change max_seq_len=100 and max_batch_size=1 for web_server.py, default settings can not run with RTX3090.

# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

from typing import Tuple
import os
import sys
import argparse
import torch
import time
import json

from pathlib import Path
from typing import List

from pydantic import BaseModel
from fastapi import FastAPI
import uvicorn
import torch.distributed as dist

from fairscale.nn.model_parallel.initialize import initialize_model_parallel

from llama import ModelArgs, Transformer, Tokenizer, LLaMA


parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_dir', type=str, required=True)
parser.add_argument('--tokenizer_path', type=str, required=True)
parser.add_argument('--max_seq_len', type=int, default=100)
parser.add_argument('--max_batch_size', type=int, default=1)


app = FastAPI()


def setup_model_parallel() -> Tuple[int, int]:
    local_rank = int(os.environ.get("LOCAL_RANK", -1))
    world_size = int(os.environ.get("WORLD_SIZE", -1))

    dist.init_process_group("nccl")
    initialize_model_parallel(world_size)
    torch.cuda.set_device(local_rank)

    # seed must be the same in all processes
    torch.manual_seed(1)
    return local_rank, world_size


def load(
    ckpt_dir: str,
    tokenizer_path: str,
    local_rank: int,
    world_size: int,
    max_seq_len: int,
    max_batch_size: int,
) -> LLaMA:
    start_time = time.time()
    checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
    assert world_size == len(
        checkpoints
    ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
    ckpt_path = checkpoints[local_rank]
    print("Loading")
    checkpoint = torch.load(ckpt_path, map_location="cpu")
    with open(Path(ckpt_dir) / "params.json", "r") as f:
        params = json.loads(f.read())

    model_args: ModelArgs = ModelArgs(
        max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
    )
    tokenizer = Tokenizer(model_path=tokenizer_path)
    model_args.vocab_size = tokenizer.n_words
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    model = Transformer(model_args)
    torch.set_default_tensor_type(torch.FloatTensor)
    model.load_state_dict(checkpoint, strict=False)

    generator = LLaMA(model, tokenizer)
    print(f"Loaded in {time.time() - start_time:.2f} seconds")
    return generator


def init_generator(
    ckpt_dir: str,
    tokenizer_path: str,
    max_seq_len: int = 512,
    max_batch_size: int = 32,
):
    local_rank, world_size = setup_model_parallel()
    if local_rank > 0:
        sys.stdout = open(os.devnull, "w")

    generator = load(
        ckpt_dir, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
    )

    return generator


if __name__ == "__main__":
    args = parser.parse_args()
    generator = init_generator(
        args.ckpt_dir,
        args.tokenizer_path,
        args.max_seq_len,
        args.max_batch_size,
    )

    class Config(BaseModel):
        prompts: List[str]="Have a nice day"
        max_gen_len: int=20
        temperature: float = 0.8
        top_p: float = 0.95

    if dist.get_rank() == 0:
        @app.post("/llama/")
        def generate(config: Config):
            if len(config.prompts) > args.max_batch_size:
                return { 'error': 'too much prompts.' }
            for prompt in config.prompts:
                if len(prompt) + config.max_gen_len > args.max_seq_len:
                    return { 'error': 'max_gen_len too large.' }
            dist.broadcast_object_list([config.prompts, config.max_gen_len, config.temperature, config.top_p])

            results = generator.generate(
                config.prompts, max_gen_len=config.max_gen_len, temperature=config.temperature, top_p=config.top_p
            )

            return {"responses": results}

        uvicorn.run(app, host="127.0.0.1", port=8042)
    else:
        while True:
            config = [None] * 4
            try:
                dist.broadcast_object_list(config)
                generator.generate(
                    config[0], max_gen_len=config[1], temperature=config[2], top_p=config[3]
                )
            except:
                pass

I hope it help you, we same name, its funny.
I would also hope that the author @juncongmoo add these to README.MD, I see many people asking for this.

@mldevorg
Copy link
Collaborator

Thanks for the tip @Maxpa1n It helped me a lot!

@wise-east
Copy link

@Maxpa1n Thanks that worked for me too! just needed to make some small adjustments for llama-2's api.

@bbjbbjbbj
Copy link

Thanks a lot for your solution! I'm not familiar with the distributed model, it has troubled me for many days, your code helped me solve a very important problem, thank you very much! @Maxpa1n

@bank010
Copy link

bank010 commented Sep 5, 2023

good!

@bapman
Copy link

bapman commented Dec 9, 2023

Hi Everyone, many thanks @Maxpa1n, it also helped me understand distributed model. I am stuck trying to get llama2-13b-chat working on 2 GPUs (2x RTX 3090), I receive following error at inference:
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
TypeError: new(): invalid data type 'str'
Probably some stupid mistake on my part, 2 processes start fine and load models rank 0 & 1 from world-size 2; bug happns when trying to do inference using the "Have a nice day" as prompt. If anyone has a good idea please let me know

-- Update --
Never mind, got it to work. Hint for anyone stuck in the same situation, I was missing the tokenizer.encode(...) and tokenizer.decode(...) before and after the generator.generate(...) function. The source code in llama is quite straightforward to follow. Thanks again, for the above, it was key for me to run Llama2 behind a webserver

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants