-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tensor] ColoTensor supports ZeRo #1015
Conversation
|
||
def __init__(self, chunk_manager: ChunkManager) -> None: | ||
super().__init__() | ||
self.chunk_manager = chunk_manager |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._chunk_manager
use _XXX as an internal var of Class.
colossalai/tensor/chunk.py
Outdated
self._update_tensors_state(TensorState.HOLD) | ||
|
||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None: | ||
assert tensor != TensorState.FREE, 'Can only set a chunk of tesors to FREE' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
The main concern is about suspending parameters.
The self.suspend_param will occur in module.parameters(). So your DDPV2 will append it to the chunk manager. |
Finally, we will use computation graph to derive the computation order of params and filter those unsed params. Chunk manager only manages used params |
colossalai/tensor/chunk.py
Outdated
return | ||
self.tensors_info[tensor].state = tensor_state | ||
|
||
def update_tensor(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name update is not informative enough.
I think you mean
def copy_tensor_to_chunk_slice()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
if not self.enable_distributed_storage: | ||
return | ||
chunk = self.tensor_chunk_map[tensor] | ||
if chunk not in self.accessed_chunks: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the access_chunks necessary?
It is only used for this line. You can know whether the chunk is accessed via its tensor states?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a rank store a chunk, its initial state is HOLD, but the rank should broadcast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cannot understand what you mean? You can say Chinese here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, dp size = 2 here, and rank0 stores chunk0. The initial state of chunk0 in rank0 is HOLD, which in rank1 is FREE. We want to access chunk0 now, even rank0 has chunk0, rank0 have to do broadcast()
. We can determine whether a rank has a chunk by state, but we cannot know whether the broadcast()
is done by state.
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad: | ||
p.grad = None | ||
else: | ||
p.grad = p.data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does it mean?
reusing grad fp16 with param fp16?
should it be p.data = p.grad?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set p.grad to correct ptr here. p.data saves grad, due to reuse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if a chunk is moved from gpu to cpu later. This line makes p.grad point to an old memory space.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, move device will update p.data at the same time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line makes p.grad -> p.data (addr1)
Afterwards, you move chunk
let p.data (addr2)
However, p.grad still points to addr1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so, just set p.grad again, or move this code snippets to optimizer.step()
after device moving is done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To check grad in unit test, I just set p.grad
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to build a dict {param: chunk slice} to index grad and its true memory space (may reuse with param.data} ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not necessary now I think. p.grad should always point to the chunk slice memory of p, as we always reuse now. If not reuse, I think it's necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, add it if necessary later.
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad: | ||
p.grad = None | ||
else: | ||
p.grad = p.data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, add it if necessary later.
Usage:
chunk_size=None
means chunk is not used.