-
Notifications
You must be signed in to change notification settings - Fork 5
Foundations for Multi-node support #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torchstore/__init__.py
Outdated
|
|
||
| from torchstore.store import MultiProcessStore | ||
| async def create_store(num_hosts=1) -> "LocalClient": | ||
| """Initializes the global store, and returns a local client. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting use of a comma.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting use of a comma.
very async
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is incredible! I know this is still WIP but couldn't wait to take a look. Left some comments mainly re readability. This is a really good starting point to build upon to support tensor parallelism etc.
torchstore/controller.py
Outdated
| raise NotImplementedError() | ||
|
|
||
| @classmethod | ||
| def get_client_id(cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably shouldn't be called get_client_id(). It's very misleading.
If I understand correctly, this actually specifies the id of the VOLUME the current worker should connect to.
| def get_client_id(cls): | |
| get_remote_volume_id(cls): | |
| """Returns the id of the StorageVolume the current worker should connect to.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's actually a great point!
The way I was thinking about this was each client has an ID, and the policy maps that ID to the ID of a storage volume in the 'select' method. This only happens on put. I added some comments, curious on your thoughts if you still think the refactor would make sense?
| object_type: ObjectType | ||
| tensor_slices: Set[Optional[TensorSlice]] = field(default_factory=set) | ||
|
|
||
| def update(self, other_storage_info: "StorageInfo"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not apparent update here has the same meaning of set.update(), maybe
| def update(self, other_storage_info: "StorageInfo"): | |
| def merge(self, other_storage_info: "StorageInfo"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, really? In my mind this is actually the same as set/dict.update. Would you mind clarifying?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, really? In my mind this is actually the same as set/dict.update. Would you mind clarifying?
Sorry, yes this is exactly the same as set/dict.update. What I meant is that it's not obvious why StorageInfo has a dict-like interface.
For example, withou reading the implemation it is very possible to read
storage_info.update(another_storage_info)as meaning replace the content of storage_info with the content of another_storage_info,
whereas if you write
storage_info.merge(another_storage_info)then it is perfectly clear what is going on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean @casteryh and I think both options are reasonable.
IIUC tho the "update" nomenclature is consistent with the method "update" on a standard Python dictionary, so I'm less worried than a torchstore user would be confused.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also as an aside, both of you should stop working 😂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean @casteryh and I think both options are reasonable.
IIUC tho the "update" nomenclature is consistent with the method "update" on a standard Python dictionary, so I'm less worried than a torchstore user would be confused.
That's fair - I think I was secretly reviewing this like C++ code - if this is standard python nomenclature used beyond python's standard library then it's totally up to you guys.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just want this thing done @joecummings 😆
| self.file_store_name = file_store_name | ||
|
|
||
| #torchstore will fail without this (see LocalRankStrategy) | ||
| os.environ["LOCAL_RANK"] = str(self.rank) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should assert that LOCAL_RANK is set here too with an error message
| DEFAULT_TORCHSTORE_NAME: str = "TorchStoreController" | ||
|
|
||
| # cache for local clients | ||
| _local_clent_map: Dict[str, LocalClient] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One suggestion here is to use contextvars to capture _local_client_map
As a global, it might leak state across concurrent tasks which had been a problem for Monarch before. contextvars would provide safe, per-task isolation
|
LGTM, approval |
Implements the bare bones of multi-node support.
Major updates:
* StorageVolume: Actual in-memory storage
* Controller: singleton handles global coordination
* Strategy: defines how storage is created
* LocalClient: user interface for querying from store
Good places to start review
Large things missing:
-- fetch and re-create only what we need. This actually makes things so slow we may want to wait for a fix before merging. (ptr)
-- pass in a non_local proc mesh to storage volume spawn and test on mast.
Optimization ideas:
-- coalesce multiple fetches for tensor slices in a storage volume, and instead return all in a single chunk.
Design in pictures!