Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpoint for ColoTensor Model #1196

Merged
merged 6 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions colossalai/utils/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .module_checkpoint import save_checkpoint, load_checkpoint

__all__ = ['save_checkpoint', 'load_checkpoint']
73 changes: 73 additions & 0 deletions colossalai/utils/checkpoint/module_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
import torch.nn as nn
import torch.distributed as dist
import collections
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
from colossalai.utils.model.colo_init_context import colo_state_dict

def save_checkpoint(dire,
epoch: int,
feifeibear marked this conversation as resolved.
Show resolved Hide resolved
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
"""save_checkpoint
save a model, whose parameters are `ColoTensor`s.
Args:
dire (_type_): _description_
epoch (int): _description_
model (torch.nn.Module): _description_
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
model_state = {
'epoch': epoch,
'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)
}
if dist.get_rank() == 0:
feifeibear marked this conversation as resolved.
Show resolved Hide resolved
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
lr_scheduler_dict = lr_scheduler.state_dict()
lr_scheduler_dict['after_scheduler'] = lr_scheduler_dict['after_scheduler'].state_dict()
optim_state = {
'epoch': epoch,
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler_dict
}
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))




def load_checkpoint(dire,
epoch: int,
rank: int,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
*args,
**kwargs):
"""load_checkpoint
load a model, whose parameters are `ColoTensor`s.
Args:
dire (_type_): _description_
epoch (int): _description_
rank (int): _description_
model (torch.nn.Module): _description_
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
"""
model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch))
model_state['model'] = collections.OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state['model'].items()])
model.load_state_dict(model_state['model'])
optim_state = torch.load(dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, rank))
optimizer.load_state_dict(optim_state['optimizer'])
lr_scheduler_dict = optim_state['lr_scheduler']
after_scheduler_dict = lr_scheduler_dict['after_scheduler']
lr_scheduler_dict['after_scheduler'] = _CosineAnnealingLR(
optimizer,
after_scheduler_dict['T_max'],
after_scheduler_dict['eta_min'],
after_scheduler_dict['last_epoch']
)
lr_scheduler.load_state_dict(lr_scheduler_dict)
6 changes: 5 additions & 1 deletion colossalai/utils/model/colo_init_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,18 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
# build param to spec mapping
mapping1 = dict()
mapping2 = dict()
mapping3 = dict()
# gather all params
has_dist_parameter = False
with torch.no_grad():
for param in self.parameters():
if isinstance(param, ColoParameter) and param.has_compute_spec():
if isinstance(param, ColoParameter):
has_dist_parameter = True
mapping1[id(param)] = copy(param.dist_spec)
mapping2[id(param)] = copy(param.compute_spec)
mapping3[id(param)] = param.get_process_group()
param.set_dist_spec(distspec.replicate())
param.process_group = None

# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
Expand All @@ -64,6 +67,7 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
if param_id in mapping1:
dist_spec = mapping1[id(param)]
compute_spec = mapping2[id(param)]
param.process_group = mapping3[id(param)]
param.set_tensor_spec(dist_spec, compute_spec)
return ret

Expand Down
211 changes: 211 additions & 0 deletions tests/test_utils/test_colo_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from abc import ABC, abstractmethod
import os, sys, shutil
import torch
import torch.nn as nn
import pytest
import copy
import operator
import colossalai
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp
import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup, ColoTensor
from colossalai.core import global_context as gpc
from functools import partial
from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR


class DummyDataGenerator(ABC):

def __init__(self, length=10):
self.length = length

@abstractmethod
def generate(self):
pass

def __iter__(self):
self.step = 0
return self

def __next__(self):
if self.step < self.length:
self.step += 1
return self.generate()
else:
raise StopIteration

def __len__(self):
return self.length


class DummyDataLoader(DummyDataGenerator):
batch_size = 128
category = 16
feature_size = 256

def generate(self):
image_dict = {}
image_dict['pixel_values'] = torch.rand(
DummyDataLoader.batch_size, DummyDataLoader.feature_size, device=get_current_device()) * 2 - 1
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
dtype=torch.int64,
device=get_current_device())
return image_dict


class MLP(nn.Module):

def __init__(self, in_features, out_features, hidden_features=None):
super().__init__()
if hidden_features is None:
hidden_features = out_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.activation = nn.ReLU()

def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x


def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup):
spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n:
p.set_process_group(pg)
p.set_tensor_spec(*spec)


def check_param_equal(model, torch_model):
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1)


def remove(path):
""" param <path> could either be relative or absolute. """
if os.path.isfile(path) or os.path.islink(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
else:
raise ValueError("file {} is not a file or dir.".format(path))


def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
train_dataloader = DummyDataLoader(length=16)
with ColoInitContext(device=get_current_device()):
model = MLP(256, 16, 64)
model_reload = MLP(256, 16, 64)
model_ref = MLP(256, 16, 64)
model = model.cuda()
model_reload = model_reload.cuda()
model_ref = model_ref.cuda()
if use_ddp:
model = ColoDDP(model, pg)
model_reload = ColoDDP(model_reload, pg)
model_ref = ColoDDP(model_ref, pg)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
optimizer_reload = torch.optim.Adam(model_reload.parameters(),
lr=0.001,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=0)
optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)

lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=20, warmup_steps=5)
lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload, total_steps=20, warmup_steps=5)
lr_scheduler_ref = CosineAnnealingWarmupLR(optimizer=optimizer_ref, total_steps=20, warmup_steps=5)

init_spec_func(model, pg)
init_spec_func(model_ref, pg)

for epoch in range(0, 20):
if epoch <= test_epoch:
for i, image_dict in enumerate(train_dataloader):
if use_ddp:
model.zero_grad()
else:
optimizer.zero_grad()
logits = model(image_dict['pixel_values'])
loss = criterion(logits, image_dict['label'])
if use_ddp:
model.backward(loss)
else:
loss.backward()
optimizer.step()

if epoch == test_epoch:
for ref_p, p in zip(model_ref.parameters(), model.parameters()):
ref_p.data.copy_(p)
optimizer_ref = copy.deepcopy(optimizer)
lr_scheduler_ref = copy.deepcopy(lr_scheduler)

check_param_equal(model, model_ref)
save_checkpoint('./checkpoint', epoch, model, optimizer, lr_scheduler)
dist.barrier()
else:
if epoch == test_epoch + 1:
load_checkpoint('./checkpoint', test_epoch, dist.get_rank(), model_reload, optimizer_reload,
lr_scheduler_reload)
init_spec_func(model_reload, pg)
for i, image_dict in enumerate(train_dataloader):
if use_ddp:
model_ref.zero_grad()
model_reload.zero_grad()
else:
optimizer_ref.zero_grad()
optimizer_reload.zero_grad()
logits_ref = model_ref(image_dict['pixel_values'])
logits_reload = model_reload(image_dict['pixel_values'])
loss_ref = criterion(logits_ref, image_dict['label'])
loss_reload = criterion(logits_reload, image_dict['label'])
if use_ddp:
model_ref.backward(loss_ref)
model_reload.backward(loss_reload)
else:
loss_ref.backward()
loss_reload.backward()
optimizer_ref.step()
optimizer_reload.step()
lr_scheduler.step()

check_param_equal(model_ref, model_reload)


def run_dist(rank, world_size, port, use_ddp, test_epoch):
if use_ddp and world_size == 1:
return
tp_world_size = world_size // 2 if use_ddp else world_size
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, pg)


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@pytest.mark.parametrize('use_ddp', [True])
@pytest.mark.parametrize('test_epoch', [1, 2, 3])
@rerun_if_address_is_in_use()
def test_checkpoint(world_size, use_ddp, test_epoch):
if not os.path.isdir('./checkpoint'):
os.mkdir('./checkpoint')
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, test_epoch=test_epoch)
mp.spawn(run_func, nprocs=world_size)
remove('./checkpoint')


if __name__ == '__main__':
test_checkpoint(4, True, 1)