Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Priority tasks #47

Merged
merged 40 commits into from
Sep 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1117867
priority in handlers and backend pools
GreenFatGuy Aug 20, 2022
170a57a
simple dirty dust points system
GreenFatGuy Aug 21, 2022
45caeef
fix renaming missprint
GreenFatGuy Aug 21, 2022
6c5be80
rework dusty client side, add client side tests
GreenFatGuy Aug 27, 2022
532f1cc
fix tests
GreenFatGuy Aug 27, 2022
3d0b9b8
we're forked
justheuristic Aug 29, 2022
a9f9133
Merge branch 'main' into priority-tasks
justheuristic Aug 29, 2022
db481f3
pass metadata
justheuristic Aug 29, 2022
aea7070
Merge branch 'main' into priority-tasks
GreenFatGuy Sep 6, 2022
0c6350d
intermediate changes
GreenFatGuy Sep 6, 2022
0253ea7
Merge branch 'main' into priority-tasks
GreenFatGuy Sep 6, 2022
74abb12
priortize task in handler before submit task
GreenFatGuy Sep 6, 2022
64a2a24
default to fifo
justheuristic Sep 6, 2022
09d5533
Merge branch 'priority-tasks' of github.com:bigscience-workshop/petal…
justheuristic Sep 6, 2022
b7ed72c
default to fifo
justheuristic Sep 6, 2022
5c67465
re-fix
justheuristic Sep 6, 2022
fb0aa13
re-fix
justheuristic Sep 6, 2022
a86145b
serialize points in inference session
justheuristic Sep 6, 2022
1c456f4
Merge branch 'main' into priority-tasks
justheuristic Sep 6, 2022
f295ec9
WIP
GreenFatGuy Sep 6, 2022
65ad7a1
refactor priority pool, copy-paste runtime from hivemind (but without…
justheuristic Sep 7, 2022
a668a67
explain test
justheuristic Sep 7, 2022
6a41bbf
create pools with proper max batch size
justheuristic Sep 7, 2022
5b53429
make points optional
justheuristic Sep 7, 2022
6a07bea
black-isort
justheuristic Sep 7, 2022
aa6badc
fix tests
justheuristic Sep 7, 2022
8b845fd
unused import
justheuristic Sep 7, 2022
a7395fe
delete DustyBlock, cosmetic changes
GreenFatGuy Sep 7, 2022
0d35dca
s/expert/block/g
justheuristic Sep 8, 2022
2906cec
s/expert/block/g
justheuristic Sep 8, 2022
97dd3c8
s/expert/block/g
justheuristic Sep 8, 2022
964dc32
switch to local code
justheuristic Sep 8, 2022
ef0a016
Merge branch 'main' into priority-tasks
justheuristic Sep 8, 2022
4ab44eb
WIP
justheuristic Sep 9, 2022
e4a9329
WIP
justheuristic Sep 10, 2022
49083c7
Merge branch 'priority-tasks' of github.com:bigscience-workshop/petal…
justheuristic Sep 10, 2022
6675d25
cover edge case
justheuristic Sep 10, 2022
b8c88e3
cover edge case
justheuristic Sep 10, 2022
42cfb96
cover edge case
justheuristic Sep 10, 2022
8b8d54a
typo kwargs
justheuristic Sep 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions cli/run_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,19 @@ def main():
parser.add_argument('--num_handlers', type=int, default=8, required=False,
help='server will use this many processes to handle incoming requests')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all expert operations')
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=16384,
help='The total number of tokens in the same batch will not exceed this value')
parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
help='Pre-form this many subsequent batches while GPU is processing the current one')
parser.add_argument('--sender_threads', type=int, default=1, required=False,
help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=16384,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
parser.add_argument('--device', type=str, default=None, required=False,
help='all experts will use this device in torch notation; default: cuda if available else cpu')
help='all blocks will use this device in torch notation; default: cuda if available else cpu')
parser.add_argument("--torch_dtype", type=str, default="auto",
help="Use this dtype to store block weights and do computations. "
"By default, respect the dtypes in the pre-trained state dict.")
Expand All @@ -58,7 +62,7 @@ def main():
'on the first run and uses these estimates for future runs. '
'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
parser.add_argument('--update_period', type=float, required=False, default=30,
help='Server will report experts to DHT once in this many seconds')
help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
Expand Down
1 change: 1 addition & 0 deletions src/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager
from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
3 changes: 2 additions & 1 deletion src/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,15 @@ def __init__(
outputs_aiter: AsyncIterator,
*,
max_length: int,
points: int = 0,
):
self.uid, self.rpc_info = uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length))
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
self.stepped = False
self.closed = False

Expand Down
156 changes: 156 additions & 0 deletions src/client/remote_forward_backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""
Utility functions that call RPC forward or backward on a single remote server
"""
import asyncio
from typing import Iterable, List, Sequence, Tuple

import torch
from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
from hivemind.p2p import StubBase
from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
from hivemind.utils.streaming import split_for_streaming

from src.data_structures import ModuleUID, RPCInfo


async def run_remote_forward(
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
) -> Tuple[torch.Tensor, ...]:
"""
Serializes input tensors and calls "rpc_forward" on a remote server.
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""

# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}

# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (inputs, kwargs)

# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)

if not nested_compare(forward_inputs, forward_schema_with_prompts):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")

forward_inputs = nested_flatten(forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)

# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
)
)

# call RPC on remote server
size = sum(t.element_size() * t.nelement() for t in inputs)
if size > MAX_UNARY_PAYLOAD_SIZE:
deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
else:
deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)

return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])


async def _forward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))

outputs = await stub.rpc_forward_stream(
amap_in_executor(
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
iter_as_aiter(split),
),
)

tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
return await deserialize_tensor_stream(tensors_stream)


async def _forward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
)
return [deserialize_torch_tensor(t) for t in outputs.tensors]


async def _backward_stream(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))

grad_inputs = await stub.rpc_backward_stream(
amap_in_executor(
lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
iter_as_aiter(split),
),
)
tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
return await deserialize_tensor_stream(tensors_stream)


async def run_remote_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
inputs: torch.Tensor,
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
**kwargs,
) -> Sequence[torch.Tensor]:
"""
Serializes grad outputs and calls "rpc_backward" on a remote server.
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""

grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))

# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
# TODO generalize this
prompts_schema = next(iter(args_schema))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))

# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
)
)

size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
if size > MAX_UNARY_PAYLOAD_SIZE:
deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
else:
deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)

return deserialized_grad_inputs


async def _backward_unary(
uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
) -> List[torch.Tensor]:
grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
)
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
4 changes: 3 additions & 1 deletion src/client/sequence_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from hivemind.proto import runtime_pb2
from hivemind.utils.logging import get_logger, use_hivemind_log_handler

from src.client.spending_policy import NoSpendingPolicy
from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
from src.dht_utils import get_remote_module_infos
from src.server.handler import TransformerConnectionHandler
Expand All @@ -24,6 +25,7 @@ class RemoteSequenceManager:
"""

def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
assert len(block_uids) > 0, "Sequences must contain at least one block"
self.dht, self.p2p = dht, p2p
self.block_uids: List[ModuleUID] = list(block_uids)
self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
Expand All @@ -39,7 +41,7 @@ def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retr
assert info is not None, f"Found no remote peers for block {uid}"
assert self.spans_by_priority and self.spans_containing_block

def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]:
def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
"""
Form a sequence of remote servers that collectively serve all consecutive layers

Expand Down
99 changes: 10 additions & 89 deletions src/client/sequential_autograd.py
Original file line number Diff line number Diff line change
@@ -1,102 +1,22 @@
"""
A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
"""
import asyncio
import logging
from typing import List, Optional, Sequence, Tuple

import torch
from hivemind import serialize_torch_tensor
from hivemind.moe.client.expert import expert_backward, expert_forward
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase
from hivemind.utils.nested import nested_compare, nested_flatten, nested_pack

from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy

MAX_TOKENS_IN_BATCH = 1024


async def run_expert_forward(
uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, **kwargs
) -> Tuple[torch.Tensor, ...]:
"""
Serializes input tensors and calls "expert_forward".
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""

# Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
# detach to avoid pickling the computation graph
assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}

# Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
forward_inputs = (inputs, kwargs)

# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
# TODO: rm this assert when support arbitrary number of input tensors
assert len(args_schema) == 1 and len(inputs) == 2
forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)

if not nested_compare(forward_inputs, forward_schema_with_prompts):
raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")

forward_inputs = nested_flatten(forward_inputs)
inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)

# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
)
)

deserialized_outputs = await expert_forward(uid, inputs, serialized_tensors, stub)
flat_outputs = tuple(deserialized_outputs)
return nested_pack(flat_outputs, structure=rpc_info["outputs_schema"])


async def run_expert_backward(
uid: ModuleUID,
stub: StubBase,
rpc_info: RPCInfo,
inputs: torch.Tensor,
grad_outputs: List[torch.Tensor],
*extra_tensors: torch.Tensor,
) -> Sequence[torch.Tensor]:
"""
Serializes grad outputs and calls "expert_backward".
Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
"""

grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))

# Modify forward_schema to support prompts
args_schema, kwargs_schema = rpc_info["forward_schema"]
assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
# TODO generalize this
prompts_schema = next(iter(args_schema))
backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))

# Asynchronous serialization
loop = asyncio.get_running_loop()
serialized_tensors = await asyncio.gather(
*(
loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
)
)

deserialized_grad_inputs = await expert_backward(uid, inputs_and_grad_outputs, serialized_tensors, stub)
return deserialized_grad_inputs


async def sequential_forward(
inputs: torch.Tensor,
prompts: torch.Tensor,
Expand All @@ -121,16 +41,17 @@ async def sequential_forward(
sequences = sequence_manager.make_sequence(start_index, end_index)
intermediate_inputs = []
done_sequences = []
outputs = inputs

while len(sequences) > 0:
while True:
span = sequences.pop(0)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
try:
span = sequences.pop(0)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
inputs_and_prompts = [inputs, prompts[span.start : span.end]]

(outputs,) = await run_expert_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
(outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)

assert isinstance(outputs, torch.Tensor)
assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
Expand Down Expand Up @@ -171,7 +92,7 @@ async def sequential_backward(
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
try:
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
grad_outputs, *span_grad_prompts = await run_expert_backward(
grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
)
grad_outputs = [grad_outputs]
Expand Down
14 changes: 14 additions & 0 deletions src/client/spending_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from abc import ABC, abstractmethod

from hivemind.proto.runtime_pb2 import ExpertRequest


class SpendingPolicyBase(ABC):
@abstractmethod
def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
pass


class NoSpendingPolicy(SpendingPolicyBase):
def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
return 0.0
Loading