Skip to content

Commit

Permalink
Add LLaMA support
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga committed Mar 3, 2023
1 parent 2bff646 commit ea5c5eb
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 2 deletions.
96 changes: 96 additions & 0 deletions modules/LLaMA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.

import json
import os
import sys
import time
from pathlib import Path
from typing import Tuple

import fire
import torch
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from llama import LLaMA, ModelArgs, Tokenizer, Transformer

os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MP'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '2223'

def setup_model_parallel() -> Tuple[int, int]:
local_rank = int(os.environ.get("LOCAL_RANK", -1))
world_size = int(os.environ.get("WORLD_SIZE", -1))

torch.distributed.init_process_group("gloo")
initialize_model_parallel(world_size)
torch.cuda.set_device(local_rank)

# seed must be the same in all processes
torch.manual_seed(1)
return local_rank, world_size

def load(
ckpt_dir: str,
tokenizer_path: str,
local_rank: int,
world_size: int,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert world_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
print("Loading")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())

model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)

generator = LLaMA(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator


class LLaMAModel:
def __init__(self):
pass

@classmethod
def from_pretrained(self, path, max_seq_len=512, max_batch_size=32):

This comment has been minimized.

Copy link
@Sumanai

Sumanai Mar 3, 2023

max_batch_size is better reduced, maybe even to 1, because webui does not work in batch mode, and this setting higher than 1 just wastes memory.

This comment has been minimized.

Copy link
@oobabooga

oobabooga Mar 3, 2023

Author Owner

Done 5a79863

I have also increased max_seq_len to 2048 because it seems to be a hard limit on the size of the generated text. I wonder if there is a way to keep this unlimited.

tokenizer_path = path / "tokenizer.model"
path = os.path.abspath(path)
tokenizer_path = os.path.abspath(tokenizer_path)

local_rank, world_size = setup_model_parallel()
if local_rank > 0:
sys.stdout = open(os.devnull, "w")

generator = load(
path, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
)

result = self()
result.pipeline = generator
return result

def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95):

results = self.pipeline.generate(
[prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p
)

return results[0]
12 changes: 11 additions & 1 deletion modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ def load_model(model_name):
t0 = time.time()

shared.is_RWKV = model_name.lower().startswith('rwkv-')
shared.is_LLaMA = model_name.lower().startswith('llama-')

# Default settings
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV):
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV or shared.is_LLaMA):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
Expand Down Expand Up @@ -85,6 +86,15 @@ def load_model(model_name):

return model, None

# LLaMA model (not on HuggingFace)
elif shared.is_LLaMA:
import modules.LLaMA
from modules.LLaMA import LLaMAModel

model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))

return model, None

# Custom
else:
command = "AutoModelForCausalLM.from_pretrained"
Expand Down
2 changes: 2 additions & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
soft_prompt_tensor = None
soft_prompt = False
is_RWKV = False
is_LLaMA = False

# Chat variables
history = {'internal': [], 'visible': []}
Expand Down Expand Up @@ -42,6 +43,7 @@
'default': 'NovelAI-Sphinx Moth',
'pygmalion-*': 'Pygmalion',
'RWKV-*': 'Naive',
'llama-*': 'Naive',
'(rosey|chip|joi)_.*_instruct.*': 'Instruct Joi (Contrastive Search)'
},
'prompts': {
Expand Down
2 changes: 1 addition & 1 deletion modules/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
if not shared.args.cpu:
torch.cuda.empty_cache()

if shared.is_RWKV:
if shared.is_RWKV or shared.is_LLaMA:
if shared.args.no_stream:
reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p)
yield formatted_outputs(reply, shared.model_name)
Expand Down

2 comments on commit ea5c5eb

@BetaDoggo
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow that was fast

@USBhost
Copy link
Contributor

@USBhost USBhost commented on ea5c5eb Mar 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cheers!!

Please sign in to comment.