-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[checkpoint] checkpoint for ColoTensor Model (#1196)
- Loading branch information
1 parent
291e22a
commit f38006e
Showing
4 changed files
with
292 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .module_checkpoint import save_checkpoint, load_checkpoint | ||
|
||
__all__ = ['save_checkpoint', 'load_checkpoint'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.distributed as dist | ||
import collections | ||
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR | ||
from colossalai.utils.model.colo_init_context import colo_state_dict | ||
|
||
def save_checkpoint(dire, | ||
epoch: int, | ||
model: torch.nn.Module, | ||
optimizer: torch.optim.Optimizer = None, | ||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, | ||
*args, | ||
**kwargs): | ||
"""save_checkpoint | ||
save a model, whose parameters are `ColoTensor`s. | ||
Args: | ||
dire (_type_): _description_ | ||
epoch (int): _description_ | ||
model (torch.nn.Module): _description_ | ||
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. | ||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. | ||
""" | ||
model_state = { | ||
'epoch': epoch, | ||
'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict) | ||
} | ||
if dist.get_rank() == 0: | ||
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch)) | ||
lr_scheduler_dict = lr_scheduler.state_dict() | ||
lr_scheduler_dict['after_scheduler'] = lr_scheduler_dict['after_scheduler'].state_dict() | ||
optim_state = { | ||
'epoch': epoch, | ||
'optimizer': optimizer.state_dict(), | ||
'lr_scheduler': lr_scheduler_dict | ||
} | ||
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank())) | ||
|
||
|
||
|
||
|
||
def load_checkpoint(dire, | ||
epoch: int, | ||
rank: int, | ||
model: torch.nn.Module, | ||
optimizer: torch.optim.Optimizer = None, | ||
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, | ||
*args, | ||
**kwargs): | ||
"""load_checkpoint | ||
load a model, whose parameters are `ColoTensor`s. | ||
Args: | ||
dire (_type_): _description_ | ||
epoch (int): _description_ | ||
rank (int): _description_ | ||
model (torch.nn.Module): _description_ | ||
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. | ||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. | ||
""" | ||
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch)) | ||
model_state['model'] = collections.OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state['model'].items()]) | ||
model.load_state_dict(model_state['model']) | ||
optim_state = torch.load(dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, rank)) | ||
optimizer.load_state_dict(optim_state['optimizer']) | ||
lr_scheduler_dict = optim_state['lr_scheduler'] | ||
after_scheduler_dict = lr_scheduler_dict['after_scheduler'] | ||
lr_scheduler_dict['after_scheduler'] = _CosineAnnealingLR( | ||
optimizer, | ||
after_scheduler_dict['T_max'], | ||
after_scheduler_dict['eta_min'], | ||
after_scheduler_dict['last_epoch'] | ||
) | ||
lr_scheduler.load_state_dict(lr_scheduler_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
from abc import ABC, abstractmethod | ||
import os, sys, shutil | ||
import torch | ||
import torch.nn as nn | ||
import pytest | ||
import copy | ||
import operator | ||
import colossalai | ||
from colossalai.context.parallel_mode import ParallelMode | ||
import torch.multiprocessing as mp | ||
import torch.distributed as dist | ||
from colossalai.testing import rerun_if_address_is_in_use | ||
from colossalai.utils.cuda import get_current_device | ||
from colossalai.utils import free_port | ||
from colossalai.utils.model.colo_init_context import ColoInitContext | ||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup, ColoTensor | ||
from colossalai.core import global_context as gpc | ||
from functools import partial | ||
from colossalai.nn.parallel.data_parallel import ColoDDP | ||
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint | ||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR | ||
|
||
|
||
class DummyDataGenerator(ABC): | ||
|
||
def __init__(self, length=10): | ||
self.length = length | ||
|
||
@abstractmethod | ||
def generate(self): | ||
pass | ||
|
||
def __iter__(self): | ||
self.step = 0 | ||
return self | ||
|
||
def __next__(self): | ||
if self.step < self.length: | ||
self.step += 1 | ||
return self.generate() | ||
else: | ||
raise StopIteration | ||
|
||
def __len__(self): | ||
return self.length | ||
|
||
|
||
class DummyDataLoader(DummyDataGenerator): | ||
batch_size = 128 | ||
category = 16 | ||
feature_size = 256 | ||
|
||
def generate(self): | ||
image_dict = {} | ||
image_dict['pixel_values'] = torch.rand( | ||
DummyDataLoader.batch_size, DummyDataLoader.feature_size, device=get_current_device()) * 2 - 1 | ||
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), | ||
dtype=torch.int64, | ||
device=get_current_device()) | ||
return image_dict | ||
|
||
|
||
class MLP(nn.Module): | ||
|
||
def __init__(self, in_features, out_features, hidden_features=None): | ||
super().__init__() | ||
if hidden_features is None: | ||
hidden_features = out_features | ||
self.fc1 = nn.Linear(in_features, hidden_features) | ||
self.fc2 = nn.Linear(hidden_features, out_features) | ||
self.activation = nn.ReLU() | ||
|
||
def forward(self, x): | ||
x = self.fc1(x) | ||
x = self.activation(x) | ||
x = self.fc2(x) | ||
return x | ||
|
||
|
||
def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): | ||
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) | ||
with DistSpecManager.no_grad(): | ||
for n, p in model.named_parameters(): | ||
if 'weight' in n: | ||
p.set_process_group(pg) | ||
p.set_tensor_spec(*spec) | ||
|
||
|
||
def check_param_equal(model, torch_model): | ||
for p, torch_p in zip(model.parameters(), torch_model.parameters()): | ||
assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1) | ||
|
||
|
||
def remove(path): | ||
""" param <path> could either be relative or absolute. """ | ||
if os.path.isfile(path) or os.path.islink(path): | ||
os.remove(path) | ||
elif os.path.isdir(path): | ||
shutil.rmtree(path) | ||
else: | ||
raise ValueError("file {} is not a file or dir.".format(path)) | ||
|
||
|
||
def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg): | ||
train_dataloader = DummyDataLoader(length=16) | ||
with ColoInitContext(device=get_current_device()): | ||
model = MLP(256, 16, 64) | ||
model_reload = MLP(256, 16, 64) | ||
model_ref = MLP(256, 16, 64) | ||
model = model.cuda() | ||
model_reload = model_reload.cuda() | ||
model_ref = model_ref.cuda() | ||
if use_ddp: | ||
model = ColoDDP(model, pg) | ||
model_reload = ColoDDP(model_reload, pg) | ||
model_ref = ColoDDP(model_ref, pg) | ||
|
||
criterion = torch.nn.CrossEntropyLoss() | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) | ||
optimizer_reload = torch.optim.Adam(model_reload.parameters(), | ||
lr=0.001, | ||
betas=(0.9, 0.999), | ||
eps=1e-08, | ||
weight_decay=0) | ||
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) | ||
|
||
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=20, warmup_steps=5) | ||
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload, total_steps=20, warmup_steps=5) | ||
lr_scheduler_ref = CosineAnnealingWarmupLR(optimizer=optimizer_ref, total_steps=20, warmup_steps=5) | ||
|
||
init_spec_func(model, pg) | ||
init_spec_func(model_ref, pg) | ||
|
||
for epoch in range(0, 20): | ||
if epoch <= test_epoch: | ||
for i, image_dict in enumerate(train_dataloader): | ||
if use_ddp: | ||
model.zero_grad() | ||
else: | ||
optimizer.zero_grad() | ||
logits = model(image_dict['pixel_values']) | ||
loss = criterion(logits, image_dict['label']) | ||
if use_ddp: | ||
model.backward(loss) | ||
else: | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if epoch == test_epoch: | ||
for ref_p, p in zip(model_ref.parameters(), model.parameters()): | ||
ref_p.data.copy_(p) | ||
optimizer_ref = copy.deepcopy(optimizer) | ||
lr_scheduler_ref = copy.deepcopy(lr_scheduler) | ||
|
||
check_param_equal(model, model_ref) | ||
save_checkpoint('./checkpoint', epoch, model, optimizer, lr_scheduler) | ||
dist.barrier() | ||
else: | ||
if epoch == test_epoch + 1: | ||
load_checkpoint('./checkpoint', test_epoch, dist.get_rank(), model_reload, optimizer_reload, | ||
lr_scheduler_reload) | ||
init_spec_func(model_reload, pg) | ||
for i, image_dict in enumerate(train_dataloader): | ||
if use_ddp: | ||
model_ref.zero_grad() | ||
model_reload.zero_grad() | ||
else: | ||
optimizer_ref.zero_grad() | ||
optimizer_reload.zero_grad() | ||
logits_ref = model_ref(image_dict['pixel_values']) | ||
logits_reload = model_reload(image_dict['pixel_values']) | ||
loss_ref = criterion(logits_ref, image_dict['label']) | ||
loss_reload = criterion(logits_reload, image_dict['label']) | ||
if use_ddp: | ||
model_ref.backward(loss_ref) | ||
model_reload.backward(loss_reload) | ||
else: | ||
loss_ref.backward() | ||
loss_reload.backward() | ||
optimizer_ref.step() | ||
optimizer_reload.step() | ||
lr_scheduler.step() | ||
|
||
check_param_equal(model_ref, model_reload) | ||
|
||
|
||
def run_dist(rank, world_size, port, use_ddp, test_epoch): | ||
if use_ddp and world_size == 1: | ||
return | ||
tp_world_size = world_size // 2 if use_ddp else world_size | ||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) | ||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') | ||
pg = ProcessGroup(tp_degree=world_size) | ||
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, pg) | ||
|
||
|
||
@pytest.mark.dist | ||
@pytest.mark.parametrize('world_size', [4]) | ||
@pytest.mark.parametrize('use_ddp', [True]) | ||
@pytest.mark.parametrize('test_epoch', [1, 2, 3]) | ||
@rerun_if_address_is_in_use() | ||
def test_checkpoint(world_size, use_ddp, test_epoch): | ||
if not os.path.isdir('./checkpoint'): | ||
os.mkdir('./checkpoint') | ||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, test_epoch=test_epoch) | ||
mp.spawn(run_func, nprocs=world_size) | ||
remove('./checkpoint') | ||
|
||
|
||
if __name__ == '__main__': | ||
test_checkpoint(4, True, 1) |