-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add weight
kwarg to AlchemiscaleClient.action_tasks
method
#209
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @ianmkenney! In reviewing I realized we should pivot our approach from this a bit, using some existing functionality in Neo4jStore
instead of modifying its action_tasks
method.
Some additional points/questions:
- We need an explicit test for
Task
weight setting via theAlchemiscaleClient
intests/integration/interface/client/test_client.py
. - Let's change the default
weight
on theAlchemiscaleClient.action_tasks
method toNone
, and have this mean that no change is made toweight
s forTask
s already actioned, and newly-actionedTask
s retain the default of 0.5.
alchemiscale/interface/api.py
Outdated
taskhub_sk = n4js.get_taskhub(sk) | ||
actioned_sks = n4js.action_tasks(tasks, taskhub_sk) | ||
actioned_sks = n4js.action_tasks(tasks, taskhub_sk, weight=weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In reviewing the Neo4jStateStore
codebase, I realized that we already have a set_task_weights
method. Instead of modifying Neo4jStore.action_tasks
to also set the weight away from the default of 0.5, let's use the set_task_weights
method to accomplish this following our use of action_tasks
here.
alchemiscale/interface/api.py
Outdated
if not 0 <= weight <= 1: | ||
raise HTTPException( | ||
status_code=status.HTTPS_400_BAD_REQUEST, | ||
detail=f"weight must between 0.0 and 1.0 (inclusive), the " | ||
"provided weight was: {weight}", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we did in #212, can we put these bounds checks in Neo4jStore.set_task_weights
? Then we'll try..except
our call to that method and raise an HTTPException
in the API endpoint here.
alchemiscale/interface/api.py
Outdated
@@ -490,14 +490,22 @@ def action_tasks( | |||
network_scoped_key, | |||
*, | |||
tasks: List[ScopedKey] = Body(embed=True), | |||
weight: float = Body(embed=True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since Neo4jStore.set_task_weights
can also handle setting a different weight per Task
, can we accept a list of weight
s here in addition to accepting a float
? Could then make the typing: Union[List[float],float]
. We would then build a dict in this method of {task_sk
: weight
} key-value pairs from the two lists tasks
and weight
.
alchemiscale/storage/statestore.py
Outdated
CREATE (th)-[ar:ACTIONS {{weight: 1.0}}]->(task) | ||
CREATE (th)-[ar:ACTIONS {{weight: {weight}}}]->(task) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can revert changes to this method, since let's go ahead and use set_task_weights
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd assume we'd actually want this at 0.5 since it seems to be the new default. Or is there a reason to make it 1.0 here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, yes. Make the new default 0.5.
alchemiscale/interface/client.py
Outdated
@@ -612,7 +612,7 @@ def get_transformation_status( | |||
return status_counts | |||
|
|||
def action_tasks( | |||
self, tasks: List[ScopedKey], network: ScopedKey | |||
self, tasks: List[ScopedKey], network: ScopedKey, weight: float = 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this method also take weight
as a list of floats, making the typing signature Union[List[float],float]
. This will allow users to set a whole swath of Task
weights with a single call to this method, which would be more performant overall. I think the rest of this method can then remain unchanged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also set the default value of weight
to None
. What this means for users:
- using
action_tasks
without specifyingweight
will set the weight to 0.5 for newly-actionedTask
s;Task
s already actioned on the givenAlchemicalNetwork
will experience no change in their weights. - using
action_tasks
while specifyingweight
will set the weights of the actionedTask
s, even if they were already actioned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple small notes.
@@ -918,12 +918,12 @@ def test_get_set_weights(self, n4js: Neo4jStore, network_tyk2, scope_test): | |||
|
|||
# weights should all be the default 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix comment.
@@ -918,12 +918,12 @@ def test_get_set_weights(self, n4js: Neo4jStore, network_tyk2, scope_test): | |||
|
|||
# weights should all be the default 1.0 | |||
weights = n4js.get_task_weights(task_sks, taskhub_sk) | |||
assert all([w == 1.0 for w in weights]) | |||
assert all([w == 0.5 for w in weights]) | |||
|
|||
# set weights on the tasks to be all 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix comment.
* Test the use of including weights with the action_task method in the AlchemiscaleClient
…2-weight-action_tasks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking great @ianmkenney! I think we need some additional test coverage for the AlchemiscaleClient.action_tasks
usage, but otherwise looking very good!
Once you've added this additional coverage, we should be good to merge!
network_sk, | ||
weight, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test that calls action_tasks
with weight
None
and checks with n4j_preloaded.get_task_weights
that the weights are unchanged?
Can you also check that if you call action_tasks
on already-actioned Task
s with new weight
s that these get set appropriately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
task_sks = user_client.create_tasks(transformation_sk, count=3) | ||
|
||
if isinstance(weight, list): | ||
weight = weight * len(task_sks) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will create a list of all the same weight
s, which isn't bad, but I think it would be more thorough to test if we set a list of different weight
s that this works as expected. You can use n4js_preloaded.get_task_weights
to get at this information.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reworked how this is done
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #209 +/- ##
==========================================
+ Coverage 81.94% 82.03% +0.09%
==========================================
Files 22 22
Lines 2780 2795 +15
==========================================
+ Hits 2278 2293 +15
Misses 502 502 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @ianmkenney! Please merge if satisfied!
# task weights | ||
user_client.action_tasks(task_sks, network_sk, weight=None) | ||
|
||
task_weights = n4js.get_task_weights(task_sks, th_sk) | ||
assert task_weights == _weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful!
task_weights = n4js.get_task_weights(task_sks, th_sk) | ||
assert task_weights == _weight | ||
|
||
def test_action_tasks_update_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Thank you!
[ | ||
(None, False), | ||
(1.0, False), | ||
([0.25, 0.5, 0.75], False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much better!
In this PR, the kwarg
weight
has been added to theAlchemiScaleClient.action_tasks
method, where its value is added to the data payload. Validation of the weight (restricted to be between 0 and 1, see #208 ) is handled at the API layer. TheNeo4jStore
also now has aweight
kwarg that defaults to 0.5 (note that the previous weight was 1 and was hard coded).fixes #202