Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensor] ColoTensor supports ZeRo #1015

Merged
merged 27 commits into from
May 31, 2022
Merged

[tensor] ColoTensor supports ZeRo #1015

merged 27 commits into from
May 31, 2022

Conversation

ver217
Copy link
Member

@ver217 ver217 commented May 24, 2022

Usage:

chunk_size = 38 * 1024**2 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
model = ColoDDPV2(model, chunk_manager)

chunk_size=None means chunk is not used.


def __init__(self, chunk_manager: ChunkManager) -> None:
super().__init__()
self.chunk_manager = chunk_manager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._chunk_manager
use _XXX as an internal var of Class.

@ver217 ver217 marked this pull request as ready for review May 27, 2022 09:38
self._update_tensors_state(TensorState.HOLD)

def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
assert tensor != TensorState.FREE, 'Can only set a chunk of tesors to FREE'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: tensor

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@feifeibear
Copy link
Contributor

The main concern is about suspending parameters.
For example

class Net(_)
  def __init__
       self.fc1 = Linear()
       self.suspend_param = torch.Paramer(..)

The self.suspend_param will occur in module.parameters(). So your DDPV2 will append it to the chunk manager.
Managing the state of the param will be a disaster for your design?

@ver217
Copy link
Member Author

ver217 commented May 30, 2022

The main concern is about suspending parameters. For example

class Net(_)
  def __init__
       self.fc1 = Linear()
       self.suspend_param = torch.Paramer(..)

The self.suspend_param will occur in module.parameters(). So your DDPV2 will append it to the chunk manager. Managing the state of the param will be a disaster for your design?

Finally, we will use computation graph to derive the computation order of params and filter those unsed params. Chunk manager only manages used params

return
self.tensors_info[tensor].state = tensor_state

def update_tensor(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name update is not informative enough.
I think you mean
def copy_tensor_to_chunk_slice()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if not self.enable_distributed_storage:
return
chunk = self.tensor_chunk_map[tensor]
if chunk not in self.accessed_chunks:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the access_chunks necessary?
It is only used for this line. You can know whether the chunk is accessed via its tensor states?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a rank store a chunk, its initial state is HOLD, but the rank should broadcast.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot understand what you mean? You can say Chinese here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, dp size = 2 here, and rank0 stores chunk0. The initial state of chunk0 in rank0 is HOLD, which in rank1 is FREE. We want to access chunk0 now, even rank0 has chunk0, rank0 have to do broadcast(). We can determine whether a rank has a chunk by state, but we cannot know whether the broadcast() is done by state.

colossalai/tensor/chunk.py Show resolved Hide resolved
colossalai/tensor/chunk.py Show resolved Hide resolved
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
p.grad = None
else:
p.grad = p.data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it mean?
reusing grad fp16 with param fp16?
should it be p.data = p.grad?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set p.grad to correct ptr here. p.data saves grad, due to reuse.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a chunk is moved from gpu to cpu later. This line makes p.grad point to an old memory space.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, move device will update p.data at the same time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line makes p.grad -> p.data (addr1)
Afterwards, you move chunk
let p.data (addr2)
However, p.grad still points to addr1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, just set p.grad again, or move this code snippets to optimizer.step() after device moving is done.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To check grad in unit test, I just set p.grad here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to build a dict {param: chunk slice} to index grad and its true memory space (may reuse with param.data} ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary now I think. p.grad should always point to the chunk slice memory of p, as we always reuse now. If not reuse, I think it's necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, add it if necessary later.

colossalai/nn/parallel.py Show resolved Hide resolved
colossalai/nn/parallel.py Outdated Show resolved Hide resolved
if self.chunk_manager.is_chunk_free(p) or not p.requires_grad:
p.grad = None
else:
p.grad = p.data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, add it if necessary later.

@ver217 ver217 merged commit 9492a56 into main May 31, 2022
@ver217 ver217 deleted the feature/colo-tensor-zero branch May 31, 2022 07:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants