-
Notifications
You must be signed in to change notification settings - Fork 5
torch store utility example with monarch app (toy). #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an awesome starting point for the rl example, and I think one that we absolutely need to have in the repro, thanks so much for getting started on this!
Can we also add some assertions so we can gaurantee data is actually getting moved around? Thank you so much!
example/torchstore_rl.py
Outdated
| learner = await learner_mesh.spawn("learner", Learner, store) | ||
| generators = await gen_mesh.spawn("generator", Generator, store) | ||
|
|
||
| generation_stream = generators.generate.stream(torch.randn(4, 4, device="cuda")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! I haven't actually used stream before. Mind explaining the usage here?
example/torchstore_rl.py
Outdated
| generation_stream = generators.generate.stream(torch.randn(4, 4, device="cuda")) | ||
| for step in range(3): | ||
| generations = [gen.get() for gen in generation_stream] | ||
| loss, rewards = await learner.step.call_one(generations) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, one possible improvement here is that in single-controller land we don't want actual data to be moved into the controller (since it's slow, and may not scale). Can we orchestrate this such that "generate" pushes the data directly to the trainer? Another idea would be to use torchstore here as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. We can use the torch.store itself for that interaction. But lets try that out as a followup (?).
example/torchstore_rl.py
Outdated
| generations = [gen.get() for gen in generation_stream] | ||
| loss, rewards = await learner.step.call_one(generations) | ||
| print(f"step: {step}, loss: {loss}, rewards: {rewards}") | ||
| generation_stream = generators.generate.stream(torch.randn(4, 4, device="cuda")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this one should be .call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know if generation_stream is gauranteed to finish before the next call?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@allenwang28 I do not know fully understand the guarantees of the stream call. Can you help. Especially, is possible to weights-update has data race with reward-generation (stream). ? Thanks!.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate is the same thing as call, but it returns a generator rather than a ValueMesh. The main idea here was supposed to be to start the generations while the learner steps to simulate off-by-1, but I realize that it's not actually doing what we want lol
This is bit tricky to do, as there is no source of truth at any particular actors. |
…ds to be investigated more
|
PTAL. @LucasLLC Here is the output of the run. We can see the weights being updated. (newts) [pradeepfdo@devvm2487.eag0 ~/torchstore (actor_example)]$ python example/torchstore_rl.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Tysm Pradeep, this will be a super helpful example to base off of
| ) | ||
| # Fetch weights from torch.store | ||
| await get_state_dict( | ||
| self.store, key="toy_app", user_state_dict=self.model.state_dict() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR, but just wondering out loud of it makes sense to use different versions for each update in this example. Might be a good topic to discuss, circle back to
| self.store, key="toy_app", user_state_dict=self.model.state_dict() | ||
| ) | ||
| print( | ||
| "[generator {}] new weights: {}".format(self.index, self.model.state_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we could turn this into a test imo the PR becomes even more powerful, but again not blocking :)
| learner = await learner_mesh.spawn("learner", Learner, store) | ||
| generators = await gen_mesh.spawn("generator", Generator, store) | ||
|
|
||
| logits, reward = await generators.generate.call_one( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm! I see we have a follow up to make this example distributed, and just in case it's a new semantic I usually do something like:
for _, val in generators.generate.call_one():
logits, reward = val
...
1\ Single learner and generator. No resharding.
2\ Next step is to introduce DTensor and resharding to this example.
3\ TODO: currently learner/generator run synchronous to each other.
Output of the example run as follows. You can see the weights getting updated between steps.
(newts) [pradeepfdo@devvm2487.eag0 ~/torchstore (actor_example)]$ python example/torchstore_rl.py
[0] [learner] weights: OrderedDict({'weight': tensor([[-0.2220, -0.1694, -0.2595, -0.3997],
[0] [ 0.0110, -0.1343, 0.0462, 0.1690],
[0] [ 0.0455, 0.4084, -0.2703, -0.4669],
[0] [-0.2401, -0.2885, 0.4115, 0.1186]])})
step: 0, loss: -87.97368621826172, rewards: -3.6650943756103516
[0] [generator 0] original weights: OrderedDict({'weight': tensor([[ 0.0074, 0.1013, -0.2601, 0.1455],
[0] [-0.4886, -0.3607, -0.2561, 0.1958],
[0] [-0.0184, 0.1435, 0.2396, 0.1962],
[0] [-0.3341, -0.3102, 0.4408, 0.3402]], device='cuda:0')})
[0] [generator 0] new weights: OrderedDict({'weight': tensor([[-0.2220, -0.1694, -0.2595, -0.3997],
[0] [ 0.0110, -0.1343, 0.0462, 0.1690],
[0] [ 0.0455, 0.4084, -0.2703, -0.4669],
[0] [-0.2401, -0.2885, 0.4115, 0.1186]], device='cuda:0')})
[0] [learner] weights: OrderedDict({'weight': tensor([[-0.2210, -0.1699, -0.2605, -0.3999],
[0] [ 0.0101, -0.1341, 0.0471, 0.1688],
[0] [ 0.0462, 0.4092, -0.2712, -0.4673],
[0] [-0.2400, -0.2876, 0.4114, 0.1196]])})
step: 1, loss: 47.551456451416016, rewards: 2.125025987625122
[0] [generator 0] original weights: OrderedDict({'weight': tensor([[-0.2220, -0.1694, -0.2595, -0.3997],
[0] [ 0.0110, -0.1343, 0.0462, 0.1690],
[0] [ 0.0455, 0.4084, -0.2703, -0.4669],
[0] [-0.2401, -0.2885, 0.4115, 0.1186]], device='cuda:0')})
[0] [generator 0] new weights: OrderedDict({'weight': tensor([[-0.2210, -0.1699, -0.2605, -0.3999],
[0] [ 0.0101, -0.1341, 0.0471, 0.1688],
[0] [ 0.0462, 0.4092, -0.2712, -0.4673],
[0] [-0.2400, -0.2876, 0.4114, 0.1196]], device='cuda:0')})
[0] [learner] weights: OrderedDict({'weight': tensor([[-0.2211, -0.1701, -0.2615, -0.4000],
[0] [ 0.0096, -0.1341, 0.0478, 0.1686],
[0] [ 0.0461, 0.4100, -0.2721, -0.4673],
[0] [-0.2394, -0.2878, 0.4120, 0.1200]])})
step: 2, loss: -18.45710563659668, rewards: -0.824840784072876
[0] [generator 0] original weights: OrderedDict({'weight': tensor([[-0.2210, -0.1699, -0.2605, -0.3999],
[0] [ 0.0101, -0.1341, 0.0471, 0.1688],
[0] [ 0.0462, 0.4092, -0.2712, -0.4673],
[0] [-0.2400, -0.2876, 0.4114, 0.1196]], device='cuda:0')})
[0] [generator 0] new weights: OrderedDict({'weight': tensor([[-0.2211, -0.1701, -0.2615, -0.4000],
[0] [ 0.0096, -0.1341, 0.0478, 0.1686],
[0] [ 0.0461, 0.4100, -0.2721, -0.4673],
[0] [-0.2394, -0.2878, 0.4120, 0.1200]], device='cuda:0')})
done