Skip to content

Commit

Permalink
[checkpointio] support huggingface from_pretrained for all plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Sep 4, 2023
1 parent 0a94fcd commit 31f4d4b
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 129 deletions.
2 changes: 2 additions & 0 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_optimizer_base_filenames,
get_shard_filename,
load_shard_state_dict,
save_config_file,
save_state_dict,
save_state_dict_shards,
)
Expand Down Expand Up @@ -111,6 +112,7 @@ def save_sharded_model(self,
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
Expand Down
2 changes: 2 additions & 0 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
Expand Down Expand Up @@ -185,6 +186,7 @@ def save_sharded_model(self,

index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
Expand Down
129 changes: 0 additions & 129 deletions tests/test_checkpoint_io/test_hybrid_huggingface_compatibility.py

This file was deleted.

83 changes: 83 additions & 0 deletions tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os

import pytest
import torch
import torch.distributed as dist
from utils import shared_tempdir

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo


@clear_cache_before_run()
@parameterize('model_name', ['transformers_gpt'])
@parameterize('plugin_type', ['ddp', 'zero', 'gemini'])
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
(model_fn, data_gen_fn, output_transform_fn, loss_fn,
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = loss_fn

if plugin_type == 'ddp':
plugin = TorchDDPPlugin()
elif plugin_type == 'zero':
plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32)
elif plugin_type == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32)
else:
raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.")

booster = Booster(plugin=plugin)

model = model_fn().cuda()
model_huggingface_cls = model.__class__
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)

data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
loss = criterion(output)

booster.backward(loss, optimizer)
optimizer.step()

with shared_tempdir() as tempdir:

model_ckpt_path = f"{tempdir}/model"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()

new_model = model_huggingface_cls.from_pretrained(model_ckpt_path)
new_model = new_model.cuda()
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)

if plugin_type == 'gemini':
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
new_model.unwrap().state_dict(only_rank_0=False), False)
else:
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
dist.barrier()


def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_from_pretrained()


@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_huggingface_compatibility(world_size):
spawn(run_dist, world_size)

0 comments on commit 31f4d4b

Please sign in to comment.