Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct server-to-server communication during finetuning #560

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
254 changes: 254 additions & 0 deletions examples/workbench_call_rpc_directly.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"id": "21e78d30",
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"from typing import Sequence, Tuple, Iterable, List\n",
"from tqdm.auto import trange\n",
"\n",
"import torch\n",
"import hivemind\n",
"import petals\n",
"\n",
"from petals.server.handler import TransformerConnectionHandler, split_for_streaming\n",
"from petals.client import RemoteSequenceManager, ClientConfig\n",
"from petals.client.remote_forward_backward import DEFAULT_MAX_MSG_SIZE, iter_as_aiter, aiter_with_timeout, deserialize_tensor_stream\n",
"from petals.data_structures import ModuleUID, PeerID, CHAIN_DELIMITER, UID_DELIMITER\n",
"from petals.utils.packaging import pack_args_kwargs, unpack_args_kwargs\n",
"\n",
"from hivemind.compression import serialize_torch_tensor\n",
"from hivemind.utils import MSGPackSerializer, nested_flatten\n",
"from hivemind.proto import runtime_pb2\n",
"\n",
"_END_OF_STREAM_KEY = \"_EOS\"\n",
"\n",
"\n",
"async def pack_as_expert_requests(uid, flat_tensors, codecs, metadata):\n",
" # Asynchronous serialization\n",
" loop = asyncio.get_running_loop()\n",
" serialized_tensors = await asyncio.gather(\n",
" *(\n",
" loop.run_in_executor(None, serialize_torch_tensor, tensor, compression)\n",
" for tensor, compression in zip(flat_tensors, codecs)\n",
" )\n",
" )\n",
"\n",
" parts = [\n",
" tensor_part for tensor in serialized_tensors\n",
" for tensor_part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)\n",
" ]\n",
" if len(parts) > 1:\n",
" serialized_metadata = MSGPackSerializer.dumps(metadata)\n",
" serialized_metadata_last_piece = MSGPackSerializer.dumps(dict(metadata, **{_END_OF_STREAM_KEY: True}))\n",
" \n",
" return [\n",
" runtime_pb2.ExpertRequest(\n",
" uid=uid, tensors=[tensor_part], \n",
" metadata=serialized_metadata if i != len(parts) - 1 else serialized_metadata_last_piece)\n",
" for i, tensor_part in enumerate(parts)\n",
" ]\n",
" \n",
"async def run_remote_forward_backward(\n",
" sequence_manager: RemoteSequenceManager,\n",
" peer_id: PeerID,\n",
" span_uids: Sequence[ModuleUID],\n",
" *args: torch.Tensor,\n",
" **kwargs: torch.Tensor,\n",
") -> Tuple[torch.Tensor, ...]:\n",
" \"\"\"\n",
" Serializes input tensors and calls \"rpc_forward_backward\" on a remote server.\n",
" Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198\n",
" but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.\n",
" \"\"\"\n",
" merged_uid = CHAIN_DELIMITER.join(span_uids)\n",
" stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, peer_id)\n",
" flat_tensors, args_structure = pack_args_kwargs(*args, **kwargs)\n",
" metadata = sequence_manager.get_request_metadata(\"rpc_forward\", args_structure, uids=span_uids, *args, peer_id=peer_id, **kwargs) #TODO fix metadata api\n",
" #codecs = sequence_manager.get_compression_codecs(peer_id, \"rpc_forward\", span_uids, *args, **kwargs)\n",
" codecs = [runtime_pb2.CompressionType.NONE for _ in args] #TODO replace with proper compression\n",
" flat_tensors = tuple(tensor.cpu().detach().requires_grad_(tensor.requires_grad) for tensor in flat_tensors)\n",
" args_structure = metadata.setdefault(\"args_structure\", args_structure)\n",
" if codecs is None:\n",
" codecs = [runtime_pb2.CompressionType.NONE] * len(flat_tensors)\n",
" else:\n",
" codecs = list(nested_flatten(codecs))\n",
" assert len(codecs) == len(flat_tensors), f\"got {len(flat_tensors)} tensors but {len(codecs)} compression codecs\"\n",
"\n",
"\n",
" # call RPC on remote server\n",
" size = sum(t.element_size() * t.nelement() for t in flat_tensors)\n",
" # Hotfix: we use \"// 2\" since hivemind==1.1.5 serializes bfloat16 tensors in float32, so they take 2x more space - TODO remove in the next PR\n",
" \n",
" ### HERE BEGINS INLINED REQUEST SENDER \n",
" # used to look like this:\n",
" # output_tensors = await _run_forward_part(\n",
" # merged_uid, serialized_tensors, stub, sequence_manager.config, metadata=metadata\n",
" # )\n",
" config = sequence_manager.config\n",
" assert _END_OF_STREAM_KEY not in metadata\n",
" forward_requests = await pack_as_expert_requests(merged_uid, flat_tensors, codecs, metadata)\n",
" backward_codecs = [runtime_pb2.CompressionType.NONE] #TODO replace with proper compression\n",
" fake_grad_outputs = torch.randn_like(flat_tensors[0])\n",
" _, backward_args_structure = pack_args_kwargs(args[0], fake_grad_outputs, *args[1:], **kwargs)\n",
" backward_metadata = dict(metadata, args_structure=backward_args_structure)\n",
" \n",
" grad_requests = await pack_as_expert_requests(merged_uid, (fake_grad_outputs,), backward_codecs, backward_metadata)\n",
" \n",
" received_outputs = asyncio.Event()\n",
"\n",
" async def iterate_inputs():\n",
" for request in forward_requests:\n",
" yield request\n",
" print(\"WAITING FOR OUTPUTS\")\n",
" await received_outputs.wait()\n",
" print(\"RECEIVED OUTPUTS - SENDING GRADS\")\n",
" for request in grad_requests:\n",
" yield request\n",
" print(\"SENT GRADS\")\n",
"\n",
" async def _wrap_input_stream(stream):\n",
" async for expert_request in stream:\n",
" yield expert_request\n",
" if not expert_request.metadata:\n",
" continue #TODO write more generally\n",
" metadata = MSGPackSerializer.loads(expert_request.metadata)\n",
" print(metadata)\n",
" if metadata.get(_END_OF_STREAM_KEY):\n",
" break\n",
"\n",
" print(\"CALLING stub.rpc_forward_stream on serialized inputs\", iterate_inputs())\n",
" outputs_stream = await asyncio.wait_for(stub.rpc_forward_backward_stream(iterate_inputs()), config.connect_timeout)\n",
" outputs_stream = aiter_with_timeout(outputs_stream, config.request_timeout)\n",
" \n",
" output_hidden_states = await deserialize_tensor_stream(msg.tensors async for msg in _wrap_input_stream(outputs_stream))\n",
" received_outputs.set()\n",
"\n",
" grad_inputs = await deserialize_tensor_stream(msg.tensors async for msg in _wrap_input_stream(outputs_stream))\n",
" print(\"RECEIVED GRAD INPUTS\")\n",
" #TODOreturn output_hidden_states, grads\n",
"\n",
" ####\n",
" \n",
" # backward compatibility: ensure requires_grad; remove after https://github.com/learning-at-home/hivemind/pull/591\n",
" requires_grad = any(tensor.requires_grad for tensor in flat_tensors)\n",
" output_tensors = [tensor.requires_grad_(requires_grad) for tensor in output_hidden_states]\n",
" return output_tensors, grad_inputs\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1c47c89a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Mar 17 18:37:25.661 [\u001b[1m\u001b[34mINFO\u001b[0m] Make sure you follow the LLaMA's terms of use: https://bit.ly/llama2-license for LLaMA 2, https://bit.ly/llama-license for LLaMA 1\n",
"Mar 17 18:37:25.661 [\u001b[1m\u001b[34mINFO\u001b[0m] Using DHT prefix: TinyLLama-v0-hf\n",
"100%|██████████| 1/1 [00:00<00:00, 26.19it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CALLING stub.rpc_forward_stream on serialized inputs <async_generator object run_remote_forward_backward.<locals>.iterate_inputs at 0x75eb8d134d60>\n",
"WAITING FOR OUTPUTS\n",
"{'_EOS': True}\n",
"RECEIVED OUTPUTS - SENDING GRADS\n",
"SENT GRADS\n",
"RECEIVED GRAD INPUTS\n",
"outputs: tensor([[[-0.0835, 0.3027, 0.2217, ..., 1.1719 ...\n",
"It works!\n",
"shutting down\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"INITIAL_PEERS = ['/ip4/127.0.0.1/tcp/1337/p2p/QmRTdR9XmTHNXKiwtqRJ4i7tNofnmFrxkufBefguZUyXej']\n",
"peer_id_string = INITIAL_PEERS[0].split(\"/\")[-1]\n",
"model_name = \"Maykeye/TinyLLama-v0\"\n",
"\n",
"model_config = petals.DistributedLlamaConfig.from_pretrained(model_name)\n",
"block_uids = [\n",
" f\"{model_config.dht_prefix}{UID_DELIMITER}{i}\"\n",
" for i in range(model_config.num_hidden_layers)\n",
"]\n",
"\n",
"block_in_use = block_uids[0:2]\n",
"\n",
"try:\n",
" dht = hivemind.DHT(start=True, client_mode=True, initial_peers=INITIAL_PEERS)\n",
" sequence_manager = petals.RemoteSequenceManager(model_config, block_uids, dht=dht)\n",
" sequence_manager.rpc_info\n",
" p2p = await dht.replicate_p2p()\n",
" \n",
" dummy_inputs = [\n",
" torch.rand(1, 128, model_config.hidden_size, dtype=model_config.torch_dtype),\n",
" torch.empty(0, dtype=model_config.torch_dtype),\n",
" ]\n",
" peer_id = hivemind.PeerID.from_base58(peer_id_string)\n",
" for i in trange(1):\n",
" (outputs,), grads = await run_remote_forward_backward(sequence_manager, peer_id, block_in_use, *dummy_inputs)\n",
" print('outputs:', repr(outputs)[:50], '...')\n",
" print(\"It works!\")\n",
"\n",
"finally:\n",
" print(\"shutting down\")\n",
" await p2p.shutdown()\n",
" dht.shutdown() # it is okay to remove this clause, but you will be summoning a horde of daemons as you debug"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f72fac2c",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "5392ba6a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Binary file added server1.id
Binary file not shown.
76 changes: 76 additions & 0 deletions src/petals/server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,79 @@ async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) ->
result.update(block_info)

return runtime_pb2.ExpertInfo(serialized_info=MSGPackSerializer.dumps(result))

@staticmethod
async def _read_until_eos(stream):
while True:
expert_request = await anext(stream)
yield expert_request
metadata = MSGPackSerializer.loads(expert_request.metadata)
print(metadata)
if metadata.get("_EOS"):
break

async def rpc_forward_backward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
async with timeout(self.request_timeout):

# Parse requests and prepare backends
uid_str, flat_inputs, metadata = await self._gather_inputs(self._read_until_eos(requests), context)
requested_uids = self._check_uids(uid_str)
self._log_request("rpc_forward_stream", requested_uids, context)

requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
active_adapter = self._get_active_adapter(metadata)
points = metadata.get("points", 0)
args_structure = metadata.get("args_structure")
assert isinstance(
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"

print(f"{requested_backends=}, {active_adapter=}, {points=}, {args_structure=}")

hidden_states = await run_rpc_forward(
*flat_inputs,
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=args_structure,
)

for tensor in self._serialize_outputs(hidden_states, requested_backends, metadata):
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
print("EOS")
yield runtime_pb2.ExpertResponse(tensors=[part], metadata=MSGPackSerializer.dumps({"_EOS": True}))


####
new_uid_str, flat_extra_inputs, extra_metadata = await self._gather_inputs(self._read_until_eos(requests), context)
backward_args_structure = extra_metadata.get("args_structure")
assert len(flat_extra_inputs) == 1
assert new_uid_str == uid_str
print("I solemnly swear to think about how to use extra_metadata for pushing when it comes to this")
grad_outputs, = flat_extra_inputs

print("HERE!")

print("FLAT INPUTS", flat_inputs)
print("GRAD OUTPUTS", grad_outputs)
print(backward_args_structure)

grads = await run_rpc_backward(
flat_inputs[0],
grad_outputs,
*flat_inputs[1:],
requested_backends=requested_backends,
prioritizer=self._prioritizer,
active_adapter=active_adapter,
points=points,
args_structure=backward_args_structure,
)

# Split the serialized_grad_inputs for streaming and respond
for tensor in self._serialize_grads(grads, requested_backends, metadata):
for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE):
print("SENDING GRADS:", part)
yield runtime_pb2.ExpertResponse(tensors=[part])