Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of Mac m1 #18

Merged
merged 2 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ The Code Llama and Code Llama - Python models are not fine-tuned to follow instr
See `example_completion.py` for some examples. To illustrate, see command below to run it with the `CodeLlama-7b` model (`nproc_per_node` needs to be set to the `MP` value):

```
torchrun --nproc_per_node 1 example_code_completion.py \
torchrun --nproc_per_node 1 example_completion.py \
--ckpt_dir CodeLlama-7b/ \
--tokenizer_path CodeLlama-7b/tokenizer.model \
--max_seq_len 128 --max_batch_size 4
Expand All @@ -66,7 +66,7 @@ Code Llama and Code Llama - Instruct 7B and 13B models are capable of filling in

See `example_infilling.py` for some examples. The `CodeLlama-7b` model can be run for infilling with the command below (`nproc_per_node` needs to be set to the `MP` value):
```
torchrun --nproc_per_node 1 example_text_infilling.py \
torchrun --nproc_per_node 1 example_infilling.py \
--ckpt_dir CodeLlama-7b/ \
--tokenizer_path CodeLlama-7b/tokenizer.model \
--max_seq_len 192 --max_batch_size 4
Expand Down
35 changes: 26 additions & 9 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

Role = Literal["system", "user", "assistant"]


Expand Down Expand Up @@ -65,14 +72,19 @@ def build(
model_parallel_size: Optional[int] = None,
) -> "Llama":
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if device == "cuda":
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if device == "cuda":
torch.cuda.set_device(local_rank)


# seed must be the same in all processes
torch.manual_seed(1)
Expand All @@ -98,12 +110,17 @@ def build(
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
# support for mac
if device == "cuda":
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
else:
torch.set_default_tensor_type(torch.cuda.HalfTensor)
torch.set_default_tensor_type(torch.HalfTensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
print(f"Loaded in {time.time() - start_time:.2f} seconds")

return Llama(model, tokenizer)
Expand Down Expand Up @@ -135,14 +152,14 @@ def generate(
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device=device)

prev_pos = 0
stop_reached = torch.tensor([False] * bsz, device="cuda")
stop_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
Expand Down
28 changes: 20 additions & 8 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
)
from torch import nn

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')


@dataclass
class ModelArgs:
Expand Down Expand Up @@ -48,6 +55,7 @@ def forward(self, x):

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

t = torch.arange(end, device=freqs.device, dtype=torch.float32) # type: ignore
freqs = torch.outer(t, freqs) # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
Expand All @@ -67,12 +75,17 @@ def apply_rotary_emb(
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if not torch.cuda.is_available():
xq = xq.to('cpu')
xk = xk.to('cpu')
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
if not torch.cuda.is_available():
freqs_cis = freqs_cis.to('cpu')
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
return xq_out.type_as(xq).to(device), xk_out.type_as(xk).to(device)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand All @@ -84,7 +97,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
) #.to(device)


class Attention(nn.Module):
Expand Down Expand Up @@ -133,15 +146,15 @@ def __init__(self, args: ModelArgs):
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)

def forward(
self,
Expand Down Expand Up @@ -252,7 +265,7 @@ def __init__(self, params: ModelArgs):
self.n_layers = params.n_layers

self.tok_embeddings = ParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
params.vocab_size, params.dim, init_method=lambda x: x,
)

self.layers = torch.nn.ModuleList()
Expand All @@ -274,18 +287,17 @@ def __init__(self, params: ModelArgs):
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
(1, 1, seqlen, seqlen), float("-inf"), device=torch.device('cpu')
)
mask = mask.to(torch.float32).triu(diagonal=start_pos+1).type_as(h)

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = layer(h, start_pos, freqs_cis, (mask.to(device) if mask is not None else mask))
h = self.norm(h)
output = self.output(h).float()
return output
2 changes: 1 addition & 1 deletion llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
return t

def decode(self, t: List[int]) -> str:
return self.sp_model.decode(t)
return self.sp_model.decode(list(filter(lambda tk: tk != -1, t)))

def encode_infilling(self, s: str) -> List[int]:
"""Encode a string without an implicit leading space."""
Expand Down