Skip to content

Commit

Permalink
[checkpoint] checkpoint for ColoTensor Model (#1196)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Jul 6, 2022
1 parent 291e22a commit f38006e
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 1 deletion.
3 changes: 3 additions & 0 deletions colossalai/utils/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .module_checkpoint import save_checkpoint, load_checkpoint

__all__ = ['save_checkpoint', 'load_checkpoint']
73 changes: 73 additions & 0 deletions colossalai/utils/checkpoint/module_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torch.nn as nn
import torch.distributed as dist
import collections
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
from colossalai.utils.model.colo_init_context import colo_state_dict

def save_checkpoint(dire,
epoch: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
"""save_checkpoint
save a model, whose parameters are `ColoTensor`s.
Args:
dire (_type_): _description_
epoch (int): _description_
model (torch.nn.Module): _description_
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
model_state = {
'epoch': epoch,
'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)
}
if dist.get_rank() == 0:
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
lr_scheduler_dict = lr_scheduler.state_dict()
lr_scheduler_dict['after_scheduler'] = lr_scheduler_dict['after_scheduler'].state_dict()
optim_state = {
'epoch': epoch,
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler_dict
}
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))




def load_checkpoint(dire,
epoch: int,
rank: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
"""load_checkpoint
load a model, whose parameters are `ColoTensor`s.
Args:
dire (_type_): _description_
epoch (int): _description_
rank (int): _description_
model (torch.nn.Module): _description_
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
model_state['model'] = collections.OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state['model'].items()])
model.load_state_dict(model_state['model'])
optim_state = torch.load(dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, rank))
optimizer.load_state_dict(optim_state['optimizer'])
lr_scheduler_dict = optim_state['lr_scheduler']
after_scheduler_dict = lr_scheduler_dict['after_scheduler']
lr_scheduler_dict['after_scheduler'] = _CosineAnnealingLR(
optimizer,
after_scheduler_dict['T_max'],
after_scheduler_dict['eta_min'],
after_scheduler_dict['last_epoch']
)
lr_scheduler.load_state_dict(lr_scheduler_dict)
6 changes: 5 additions & 1 deletion colossalai/utils/model/colo_init_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,18 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
# build param to spec mapping
mapping1 = dict()
mapping2 = dict()
mapping3 = dict()
# gather all params
has_dist_parameter = False
with torch.no_grad():
for param in self.parameters():
if isinstance(param, ColoParameter) and param.has_compute_spec():
if isinstance(param, ColoParameter):
has_dist_parameter = True
mapping1[id(param)] = copy(param.dist_spec)
mapping2[id(param)] = copy(param.compute_spec)
mapping3[id(param)] = param.get_process_group()
param.set_dist_spec(distspec.replicate())
param.process_group = None

# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
Expand All @@ -64,6 +67,7 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
if param_id in mapping1:
dist_spec = mapping1[id(param)]
compute_spec = mapping2[id(param)]
param.process_group = mapping3[id(param)]
param.set_tensor_spec(dist_spec, compute_spec)
return ret

Expand Down
211 changes: 211 additions & 0 deletions tests/test_utils/test_colo_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from abc import ABC, abstractmethod
import os, sys, shutil
import torch
import torch.nn as nn
import pytest
import copy
import operator
import colossalai
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup, ColoTensor
from colossalai.core import global_context as gpc
from functools import partial
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR


class DummyDataGenerator(ABC):

def __init__(self, length=10):
self.length = length

@abstractmethod
def generate(self):
pass

def __iter__(self):
self.step = 0
return self

def __next__(self):
if self.step < self.length:
self.step += 1
return self.generate()
else:
raise StopIteration

def __len__(self):
return self.length


class DummyDataLoader(DummyDataGenerator):
batch_size = 128
category = 16
feature_size = 256

def generate(self):
image_dict = {}
image_dict['pixel_values'] = torch.rand(
DummyDataLoader.batch_size, DummyDataLoader.feature_size, device=get_current_device()) * 2 - 1
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
dtype=torch.int64,
device=get_current_device())
return image_dict


class MLP(nn.Module):

def __init__(self, in_features, out_features, hidden_features=None):
super().__init__()
if hidden_features is None:
hidden_features = out_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.activation = nn.ReLU()

def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x


def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n:
p.set_process_group(pg)
p.set_tensor_spec(*spec)


def check_param_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1)


def remove(path):
""" param <path> could either be relative or absolute. """
if os.path.isfile(path) or os.path.islink(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
else:
raise ValueError("file {} is not a file or dir.".format(path))


def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
train_dataloader = DummyDataLoader(length=16)
with ColoInitContext(device=get_current_device()):
model = MLP(256, 16, 64)
model_reload = MLP(256, 16, 64)
model_ref = MLP(256, 16, 64)
model = model.cuda()
model_reload = model_reload.cuda()
model_ref = model_ref.cuda()
if use_ddp:
model = ColoDDP(model, pg)
model_reload = ColoDDP(model_reload, pg)
model_ref = ColoDDP(model_ref, pg)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
lr=0.001,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0)
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=20, warmup_steps=5)
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload, total_steps=20, warmup_steps=5)
lr_scheduler_ref = CosineAnnealingWarmupLR(optimizer=optimizer_ref, total_steps=20, warmup_steps=5)

init_spec_func(model, pg)
init_spec_func(model_ref, pg)

for epoch in range(0, 20):
if epoch <= test_epoch:
for i, image_dict in enumerate(train_dataloader):
if use_ddp:
model.zero_grad()
else:
optimizer.zero_grad()
logits = model(image_dict['pixel_values'])
loss = criterion(logits, image_dict['label'])
if use_ddp:
model.backward(loss)
else:
loss.backward()
optimizer.step()

if epoch == test_epoch:
for ref_p, p in zip(model_ref.parameters(), model.parameters()):
ref_p.data.copy_(p)
optimizer_ref = copy.deepcopy(optimizer)
lr_scheduler_ref = copy.deepcopy(lr_scheduler)

check_param_equal(model, model_ref)
save_checkpoint('./checkpoint', epoch, model, optimizer, lr_scheduler)
dist.barrier()
else:
if epoch == test_epoch + 1:
load_checkpoint('./checkpoint', test_epoch, dist.get_rank(), model_reload, optimizer_reload,
lr_scheduler_reload)
init_spec_func(model_reload, pg)
for i, image_dict in enumerate(train_dataloader):
if use_ddp:
model_ref.zero_grad()
model_reload.zero_grad()
else:
optimizer_ref.zero_grad()
optimizer_reload.zero_grad()
logits_ref = model_ref(image_dict['pixel_values'])
logits_reload = model_reload(image_dict['pixel_values'])
loss_ref = criterion(logits_ref, image_dict['label'])
loss_reload = criterion(logits_reload, image_dict['label'])
if use_ddp:
model_ref.backward(loss_ref)
model_reload.backward(loss_reload)
else:
loss_ref.backward()
loss_reload.backward()
optimizer_ref.step()
optimizer_reload.step()
lr_scheduler.step()

check_param_equal(model_ref, model_reload)


def run_dist(rank, world_size, port, use_ddp, test_epoch):
if use_ddp and world_size == 1:
return
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, pg)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@pytest.mark.parametrize('use_ddp', [True])
@pytest.mark.parametrize('test_epoch', [1, 2, 3])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size, use_ddp, test_epoch):
if not os.path.isdir('./checkpoint'):
os.mkdir('./checkpoint')
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, test_epoch=test_epoch)
mp.spawn(run_func, nprocs=world_size)
remove('./checkpoint')


if __name__ == '__main__':
test_checkpoint(4, True, 1)

0 comments on commit f38006e

Please sign in to comment.