Skip to content

Conversation

@pradeepfn
Copy link
Contributor

@pradeepfn pradeepfn commented Aug 11, 2025

1\ Single learner and generator. No resharding.
2\ Next step is to introduce DTensor and resharding to this example.
3\ TODO: currently learner/generator run synchronous to each other.

Output of the example run as follows. You can see the weights getting updated between steps.

(newts) [pradeepfdo@devvm2487.eag0 ~/torchstore (actor_example)]$ python example/torchstore_rl.py
[0] [learner] weights: OrderedDict({'weight': tensor([[-0.2220, -0.1694, -0.2595, -0.3997],
[0] [ 0.0110, -0.1343, 0.0462, 0.1690],
[0] [ 0.0455, 0.4084, -0.2703, -0.4669],
[0] [-0.2401, -0.2885, 0.4115, 0.1186]])})
step: 0, loss: -87.97368621826172, rewards: -3.6650943756103516
[0] [generator 0] original weights: OrderedDict({'weight': tensor([[ 0.0074, 0.1013, -0.2601, 0.1455],
[0] [-0.4886, -0.3607, -0.2561, 0.1958],
[0] [-0.0184, 0.1435, 0.2396, 0.1962],
[0] [-0.3341, -0.3102, 0.4408, 0.3402]], device='cuda:0')})
[0] [generator 0] new weights: OrderedDict({'weight': tensor([[-0.2220, -0.1694, -0.2595, -0.3997],
[0] [ 0.0110, -0.1343, 0.0462, 0.1690],
[0] [ 0.0455, 0.4084, -0.2703, -0.4669],
[0] [-0.2401, -0.2885, 0.4115, 0.1186]], device='cuda:0')})
[0] [learner] weights: OrderedDict({'weight': tensor([[-0.2210, -0.1699, -0.2605, -0.3999],
[0] [ 0.0101, -0.1341, 0.0471, 0.1688],
[0] [ 0.0462, 0.4092, -0.2712, -0.4673],
[0] [-0.2400, -0.2876, 0.4114, 0.1196]])})
step: 1, loss: 47.551456451416016, rewards: 2.125025987625122
[0] [generator 0] original weights: OrderedDict({'weight': tensor([[-0.2220, -0.1694, -0.2595, -0.3997],
[0] [ 0.0110, -0.1343, 0.0462, 0.1690],
[0] [ 0.0455, 0.4084, -0.2703, -0.4669],
[0] [-0.2401, -0.2885, 0.4115, 0.1186]], device='cuda:0')})
[0] [generator 0] new weights: OrderedDict({'weight': tensor([[-0.2210, -0.1699, -0.2605, -0.3999],
[0] [ 0.0101, -0.1341, 0.0471, 0.1688],
[0] [ 0.0462, 0.4092, -0.2712, -0.4673],
[0] [-0.2400, -0.2876, 0.4114, 0.1196]], device='cuda:0')})
[0] [learner] weights: OrderedDict({'weight': tensor([[-0.2211, -0.1701, -0.2615, -0.4000],
[0] [ 0.0096, -0.1341, 0.0478, 0.1686],
[0] [ 0.0461, 0.4100, -0.2721, -0.4673],
[0] [-0.2394, -0.2878, 0.4120, 0.1200]])})
step: 2, loss: -18.45710563659668, rewards: -0.824840784072876
[0] [generator 0] original weights: OrderedDict({'weight': tensor([[-0.2210, -0.1699, -0.2605, -0.3999],
[0] [ 0.0101, -0.1341, 0.0471, 0.1688],
[0] [ 0.0462, 0.4092, -0.2712, -0.4673],
[0] [-0.2400, -0.2876, 0.4114, 0.1196]], device='cuda:0')})
[0] [generator 0] new weights: OrderedDict({'weight': tensor([[-0.2211, -0.1701, -0.2615, -0.4000],
[0] [ 0.0096, -0.1341, 0.0478, 0.1686],
[0] [ 0.0461, 0.4100, -0.2721, -0.4673],
[0] [-0.2394, -0.2878, 0.4120, 0.1200]], device='cuda:0')})
done

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 11, 2025
@pradeepfn pradeepfn requested a review from LucasLLC August 11, 2025 21:00
Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an awesome starting point for the rl example, and I think one that we absolutely need to have in the repro, thanks so much for getting started on this!

Can we also add some assertions so we can gaurantee data is actually getting moved around? Thank you so much!

learner = await learner_mesh.spawn("learner", Learner, store)
generators = await gen_mesh.spawn("generator", Generator, store)

generation_stream = generators.generate.stream(torch.randn(4, 4, device="cuda"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! I haven't actually used stream before. Mind explaining the usage here?

generation_stream = generators.generate.stream(torch.randn(4, 4, device="cuda"))
for step in range(3):
generations = [gen.get() for gen in generation_stream]
loss, rewards = await learner.step.call_one(generations)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, one possible improvement here is that in single-controller land we don't want actual data to be moved into the controller (since it's slow, and may not scale). Can we orchestrate this such that "generate" pushes the data directly to the trainer? Another idea would be to use torchstore here as well

Copy link
Contributor Author

@pradeepfn pradeepfn Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. We can use the torch.store itself for that interaction. But lets try that out as a followup (?).

generations = [gen.get() for gen in generation_stream]
loss, rewards = await learner.step.call_one(generations)
print(f"step: {step}, loss: {loss}, rewards: {rewards}")
generation_stream = generators.generate.stream(torch.randn(4, 4, device="cuda"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this one should be .call?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if generation_stream is gauranteed to finish before the next call?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@allenwang28 I do not know fully understand the guarantees of the stream call. Can you help. Especially, is possible to weights-update has data race with reward-generation (stream). ? Thanks!.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate is the same thing as call, but it returns a generator rather than a ValueMesh. The main idea here was supposed to be to start the generations while the learner steps to simulate off-by-1, but I realize that it's not actually doing what we want lol

@pradeepfn
Copy link
Contributor Author

Can we also add some assertions so we can gaurantee data is actually getting moved around? Thank you so much!

This is bit tricky to do, as there is no source of truth at any particular actors.
I rely on print statements to verify the weight updates. See my updated PR summary.

@pradeepfn pradeepfn requested a review from allenwang28 August 11, 2025 23:17
@pradeepfn
Copy link
Contributor Author

PTAL. @LucasLLC

Here is the output of the run. We can see the weights being updated.

(newts) [pradeepfdo@devvm2487.eag0 ~/torchstore (actor_example)]$ python example/torchstore_rl.py
[0] [learner] weights: OrderedDict({'weight': tensor([[ 0.4177, -0.4198, -0.4904, 0.4202],
[0] [ 0.3519, 0.4973, 0.0048, 0.1372],
[0] [-0.3005, -0.0588, -0.2192, -0.0132],
[0] [ 0.1890, 0.1302, 0.0713, -0.3678]])})
[0] [generator 0] original weights: OrderedDict({'weight': tensor([[-0.1823, 0.2535, -0.3074, 0.4840],
[0] [ 0.2224, -0.3184, -0.3903, 0.4777],
[0] [-0.2466, -0.0403, 0.3493, 0.2106],
[0] [ 0.4137, -0.1508, -0.4138, -0.0231]], device='cuda:0')})
[0] [generator 0] new weights: OrderedDict({'weight': tensor([[ 0.4177, -0.4198, -0.4904, 0.4202],
[0] [ 0.3519, 0.4973, 0.0048, 0.1372],
[0] [-0.3005, -0.0588, -0.2192, -0.0132],
[0] [ 0.1890, 0.1302, 0.0713, -0.3678]], device='cuda:0')})
[0] [learner] weights: OrderedDict({'weight': tensor([[ 0.4182, -0.4208, -0.4914, 0.4208],
[0] [ 0.3529, 0.4982, 0.0050, 0.1382],
[0] [-0.3015, -0.0596, -0.2193, -0.0126],
[0] [ 0.1898, 0.1296, 0.0722, -0.3687]])})
[0] [generator 0] original weights: OrderedDict({'weight': tensor([[ 0.4177, -0.4198, -0.4904, 0.4202],
[0] [ 0.3519, 0.4973, 0.0048, 0.1372],
[0] [-0.3005, -0.0588, -0.2192, -0.0132],
[0] [ 0.1890, 0.1302, 0.0713, -0.3678]], device='cuda:0')})
[0] [generator 0] new weights: OrderedDict({'weight': tensor([[ 0.4182, -0.4208, -0.4914, 0.4208],
[0] [ 0.3529, 0.4982, 0.0050, 0.1382],
[0] [-0.3015, -0.0596, -0.2193, -0.0126],
[0] [ 0.1898, 0.1296, 0.0722, -0.3687]], device='cuda:0')})
[0] [learner] weights: OrderedDict({'weight': tensor([[ 0.4182, -0.4204, -0.4921, 0.4210],
[0] [ 0.3537, 0.4982, 0.0056, 0.1384],
[0] [-0.3021, -0.0602, -0.2195, -0.0118],
[0] [ 0.1906, 0.1296, 0.0727, -0.3691]])})
[0] [generator 0] original weights: OrderedDict({'weight': tensor([[ 0.4182, -0.4208, -0.4914, 0.4208],
[0] [ 0.3529, 0.4982, 0.0050, 0.1382],
[0] [-0.3015, -0.0596, -0.2193, -0.0126],
[0] [ 0.1898, 0.1296, 0.0722, -0.3687]], device='cuda:0')})
[0] [generator 0] new weights: OrderedDict({'weight': tensor([[ 0.4182, -0.4204, -0.4921, 0.4210],
[0] [ 0.3537, 0.4982, 0.0056, 0.1384],
[0] [-0.3021, -0.0602, -0.2195, -0.0118],
[0] [ 0.1906, 0.1296, 0.0727, -0.3691]], device='cuda:0')})
done

@pradeepfn pradeepfn requested a review from LucasLLC August 18, 2025 15:27
Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! Tysm Pradeep, this will be a super helpful example to base off of

)
# Fetch weights from torch.store
await get_state_dict(
self.store, key="toy_app", user_state_dict=self.model.state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this PR, but just wondering out loud of it makes sense to use different versions for each update in this example. Might be a good topic to discuss, circle back to

self.store, key="toy_app", user_state_dict=self.model.state_dict()
)
print(
"[generator {}] new weights: {}".format(self.index, self.model.state_dict())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we could turn this into a test imo the PR becomes even more powerful, but again not blocking :)

learner = await learner_mesh.spawn("learner", Learner, store)
generators = await gen_mesh.spawn("generator", Generator, store)

logits, reward = await generators.generate.call_one(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm! I see we have a follow up to make this example distributed, and just in case it's a new semantic I usually do something like:

for _, val in generators.generate.call_one():
    logits, reward = val
    ...

@pradeepfn pradeepfn merged commit 7e51b5e into main Aug 18, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants