Skip to content

Commit

Permalink
[checkpointio] support huggingface from_pretrained for all plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Sep 4, 2023
1 parent 508ca36 commit c44c62a
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 143 deletions.
2 changes: 2 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_optimizer_base_filenames,
get_shard_filename,
load_shard_state_dict,
save_config_file,
save_state_dict,
save_state_dict_shards,
)
Expand Down Expand Up @@ -111,6 +112,7 @@ def save_sharded_model(self,
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
Expand Down
2 changes: 2 additions & 0 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
Expand Down Expand Up @@ -185,6 +186,7 @@ def save_sharded_model(self,

index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
Expand Down
42 changes: 35 additions & 7 deletions tests/test_checkpoint_io/test_general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, parameterize
from tests.kit.model_zoo import model_zoo

# ========
# Note:
Expand Down Expand Up @@ -74,13 +75,6 @@ def test_sharded_model_checkpoint(use_safetensors: bool):
optimizer.step()

# create a temp file for checkpoint
if use_safetensors:
suffix = ".safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
else:
suffix = ".bin"
WEIGHTS_INDEX_NAME = "model.bin.index.json"

model_ckpt_dir = tempfile.TemporaryDirectory()
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()

Expand Down Expand Up @@ -208,3 +202,37 @@ def test_sharded_optimizer_multiple_param_groups():
# check for model and optimizer state dict recursively
check_state_dict_equal(model.state_dict(), new_model.state_dict())
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict())


@pytest.mark.parametrize('model_name', ['transformers_gpt', 'transformers_bert'])
@pytest.mark.parametrize('use_safetensors', [True, False])
def test_huggingface_from_pretrained(model_name: str, use_safetensors: bool):

(model_fn, data_gen_fn, output_transform_fn, loss_fn,
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))

model = model_fn().cuda()
optimizer = Adam(model.parameters(), lr=0.001)
criterion = loss_fn
data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}

# run fwd and bwd
output = model(**data)
loss = criterion(output)
loss.backward()
optimizer.step()

# create a temp file for checkpoint
model_ckpt_dir = tempfile.TemporaryDirectory()

# save the model
ckpt_io = GeneralCheckpointIO()
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors)

# load new model
new_model = model.__class__.from_pretrained(model_ckpt_dir.name)
new_model = new_model.cuda()

# check the loaded model
check_state_dict_equal(model.state_dict(), new_model.state_dict())
129 changes: 0 additions & 129 deletions tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,37 @@
@clear_cache_before_run()
@parameterize('shard', [True])
@parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32])
@parameterize('test_config', [{
'tp_size': 4,
'pp_size': 1,
'precision': 'fp32',
'from_pretrained': False
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp16',
'initial_scale': 1
'initial_scale': 1,
'from_pretrained': True
}, {
'tp_size': 2,
'pp_size': 1,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
'initial_scale': 1,
'from_pretrained': False
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
'initial_scale': 1,
'from_pretrained': True
}])
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
def exam_state_dict(shard: bool, model_name: str, test_config: dict, size_per_shard: int = 32):

from_pretrained = test_config.pop('from_pretrained')

(model_fn, data_gen_fn, output_transform_fn, loss_fn,
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))
Expand Down Expand Up @@ -101,12 +106,18 @@ def _preprocess_data(data):
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()

new_model = model_fn().cuda()
if from_pretrained:
# test compatibility with huggingface from_pretrained
new_model = model.unwrap().__class__.from_pretrained(model_ckpt_path)
else:
new_model = model_fn().cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

booster.load_model(new_model, model_ckpt_path)
if not from_pretrained:
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)

booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False)
dist.barrier()
Expand Down
83 changes: 83 additions & 0 deletions tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os

import pytest
import torch
import torch.distributed as dist
from utils import shared_tempdir

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo


@clear_cache_before_run()
@parameterize('model_name', ['transformers_gpt'])
@parameterize('plugin_type', ['ddp', 'zero', 'gemini'])
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
(model_fn, data_gen_fn, output_transform_fn, loss_fn,
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = loss_fn

if plugin_type == 'ddp':
plugin = TorchDDPPlugin()
elif plugin_type == 'zero':
plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32)
elif plugin_type == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32)
else:
raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.")

booster = Booster(plugin=plugin)

model = model_fn().cuda()
model_huggingface_cls = model.__class__
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
loss = criterion(output)

booster.backward(loss, optimizer)
optimizer.step()

with shared_tempdir() as tempdir:

model_ckpt_path = f"{tempdir}/model"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()

new_model = model_huggingface_cls.from_pretrained(model_ckpt_path)
new_model = new_model.cuda()
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

if plugin_type == 'gemini':
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
new_model.unwrap().state_dict(only_rank_0=False), False)
else:
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
dist.barrier()


def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_from_pretrained()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_huggingface_compatibility(world_size):
spawn(run_dist, world_size)

0 comments on commit c44c62a

Please sign in to comment.