From 1fbd8726e4e747bec56fc983e7448c93082ff04d Mon Sep 17 00:00:00 2001 From: LKJacky <108643365+LKJacky@users.noreply.github.com> Date: Wed, 26 Apr 2023 15:55:14 +0800 Subject: [PATCH] refine sparse gpt (#526) * save cpu memory * update * update * update * update * refine * update * update --------- Co-authored-by: Your Name --- .../pruning/sparse_gpt/mutator.py | 4 +- .../implementations/pruning/sparse_gpt/ops.py | 51 ++++- .../ResNet/sparse_gpt/resnet18_sparse_gpt.py | 6 +- .../language_models/Llama/datautils.py | 46 +++- .../Llama/llama_sparse_gpt_fsdp.py | 198 ++++++++++++++++++ .../language_models/Llama/utils.py | 91 ++++++++ .../OPT/opt_sparse_gpt_fsdp.py | 182 +++++++--------- .../language_models/OPT/utils.py | 76 +++++++ 8 files changed, 535 insertions(+), 119 deletions(-) create mode 100644 projects/mmrazor_large/examples/model_examples/language_models/Llama/llama_sparse_gpt_fsdp.py diff --git a/mmrazor/implementations/pruning/sparse_gpt/mutator.py b/mmrazor/implementations/pruning/sparse_gpt/mutator.py index 3ef53c666..c4406b68b 100644 --- a/mmrazor/implementations/pruning/sparse_gpt/mutator.py +++ b/mmrazor/implementations/pruning/sparse_gpt/mutator.py @@ -48,9 +48,9 @@ def end_init_hessian(self): for module in self.sparse_ops: module.end_init_hessian() - def keep_hessian_in_float(self): + def init_hessian(self, device=None): for op in self.sparse_ops: - op.keep_hessian_in_float() + op.init_hessian(device=device) # prune def prune(self, diff --git a/mmrazor/implementations/pruning/sparse_gpt/ops.py b/mmrazor/implementations/pruning/sparse_gpt/ops.py index 8f438e76f..9cba48aaf 100644 --- a/mmrazor/implementations/pruning/sparse_gpt/ops.py +++ b/mmrazor/implementations/pruning/sparse_gpt/ops.py @@ -20,9 +20,7 @@ def _sparse_gpt_mix_in_init(self): self.rows = self.weight_matrix.shape[0] self.columns = self.weight_matrix.shape[1] - _hessian = torch.zeros([self.columns, self.columns]) - self.register_buffer('_hessian', _hessian) - self._hessian: torch.Tensor + self._hessian: torch.Tensor = None self.hessian_batch = 0 # weight and input adaptive @@ -50,17 +48,38 @@ def format_input(self, input: torch.Tensor): @property def hessian(self): """hessian always return float.""" - return self._hessian + if dist.is_initialized(): + if dist.get_rank() == 0: + assert self._hessian is not None, 'hessian is not initialized.' + hessian = self._hessian.to(self.weight_matrix.device) + else: + hessian = torch.zeros( + self.columns, + self.columns, + device=self.weight_matrix.device) + dist.broadcast(hessian, 0) + return hessian + else: + return self._hessian @hessian.setter def hessian(self, value: torch.Tensor): with torch.no_grad(): - self._hessian.data.copy_(value.data) + if dist.is_initialized(): + if dist.get_rank() == 0: + assert self._hessian is not None, 'hessian is not initialized.' # noqa + self._hessian.data.copy_( + value.data.to(self._hessian.device)) + else: + self._hessian = None + else: + self._hessian.data.copy_(value.data.to(self._hessian.device)) @torch.no_grad() def update_hessian(self, input: torch.Tensor): - input = self.format_input(input).float() + H_save = self.hessian + H_save = H_save.to(input.device) assert len(input.shape) == 3 B = input.shape[0] # B N C @@ -71,8 +90,8 @@ def update_hessian(self, input: torch.Tensor): if dist.is_initialized(): dist.all_reduce(H) B *= dist.get_world_size() - self.hessian = (self.hessian * self.hessian_batch + H) / ( - self.hessian_batch + B) + H_save = (H_save * self.hessian_batch + H) / (self.hessian_batch + B) + self.hessian = H_save self.hessian_batch = self.hessian_batch + B def start_init_hessian(self): @@ -89,8 +108,18 @@ def end_init_hessian(self): for h in self.sparse_gpt_handles: h.remove() - def keep_hessian_in_float(self): - self._hessian = self._hessian.float() + def init_hessian(self, device=None): + if dist.is_initialized(): + if dist.get_rank() == 0: + self._hessian = torch.zeros([self.columns, self.columns], + device=device, + dtype=torch.float) + else: + self._hessian = None + else: + self._hessian = torch.zeros([self.columns, self.columns], + device=device, + dtype=torch.float) # prune @@ -102,7 +131,7 @@ def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01): assert self.hessian is not None W: torch.Tensor = self.weight_matrix.float() # out in - H = self.hessian.float() + H = self.hessian.float().to(W.device) dead = torch.diag(H) == 0 H[dead, dead] = 1 diff --git a/projects/mmrazor_large/examples/model_examples/ResNet/sparse_gpt/resnet18_sparse_gpt.py b/projects/mmrazor_large/examples/model_examples/ResNet/sparse_gpt/resnet18_sparse_gpt.py index bd3cf3cb0..be703f2a5 100644 --- a/projects/mmrazor_large/examples/model_examples/ResNet/sparse_gpt/resnet18_sparse_gpt.py +++ b/projects/mmrazor_large/examples/model_examples/ResNet/sparse_gpt/resnet18_sparse_gpt.py @@ -116,11 +116,15 @@ def infer(model: nn.Module, num_samples = args.num_samples batch_size = args.batch_size - model = torchvision.models.resnet18(pretrained=True).cuda() + model = torchvision.models.resnet18(pretrained=True) train_loader, test_loader = get_dataloaders(batch_size, 4, data_path) mutator = sparse_gpt.SparseGptMutator() mutator.prepare_from_supernet(model) + + model.cuda() + + mutator.init_hessian() mutator.start_init_hessian() infer(model, test_loader, num_samples=num_samples) mutator.end_init_hessian() diff --git a/projects/mmrazor_large/examples/model_examples/language_models/Llama/datautils.py b/projects/mmrazor_large/examples/model_examples/language_models/Llama/datautils.py index 4066d92b1..04697d560 100755 --- a/projects/mmrazor_large/examples/model_examples/language_models/Llama/datautils.py +++ b/projects/mmrazor_large/examples/model_examples/language_models/Llama/datautils.py @@ -1,6 +1,8 @@ -# Copyright (c) OpenMMLab. All rights reserved. import numpy as np import torch +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as TorchDataset +from torch.utils.data import DistributedSampler def set_seed(seed): @@ -106,3 +108,45 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''): return get_ptb(nsamples, seed, seqlen, model) if 'c4' in name: return get_c4(nsamples, seed, seqlen, model) + + +def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048): + # tokens: 1 N + N = tokens.shape[1] + num_drop = N % batch_seq_len + if num_drop != 0: + tokens = tokens[:, :-num_drop] + tokens = tokens.reshape([-1, batch_seq_len]) # B N + return tokens + + +class LanguageDataset(TorchDataset): + + def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None: + super().__init__() + # seq: 1, N + self.seq_len = seq_len + + self.seq = fold_tokens(seq) # B N + + def __len__(self) -> int: + return self.seq.shape[0] + + def __getitem__(self, index): + return self.seq[index] + + +def build_language_loader(testloader, world_size, rank, model, batch_size=128): + val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen) + distributed_sampler = DistributedSampler( + val_dataset, num_replicas=world_size, rank=rank, shuffle=False) + batch_size = min(len(val_dataset) // world_size, batch_size) + val_dataloader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + drop_last=True, + sampler=distributed_sampler) + return val_dataloader diff --git a/projects/mmrazor_large/examples/model_examples/language_models/Llama/llama_sparse_gpt_fsdp.py b/projects/mmrazor_large/examples/model_examples/language_models/Llama/llama_sparse_gpt_fsdp.py new file mode 100644 index 000000000..7728bce3d --- /dev/null +++ b/projects/mmrazor_large/examples/model_examples/language_models/Llama/llama_sparse_gpt_fsdp.py @@ -0,0 +1,198 @@ +import functools +import os + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from datautils import build_language_loader, get_loaders +from llama_sparse_gpt import get_model +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import ShardingStrategy +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload +from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy +from utils import init_on_meta, opt_eval_fsdp, opt_infer_fsdp + +from mmrazor.implementations.pruning import sparse_gpt +from mmrazor.utils import print_log + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12356' + + dist.init_process_group('nccl', rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + print_log(f'init {rank}/{world_size}', only_rank0=False) + + +def init_fn_wrapper(model: nn.Module, model_copy: nn.Module): + + def find_module_in_model_copy(module: nn.Module): + name2module = dict(model.named_modules()) + module2name = dict([(v, k) for k, v in name2module.items()]) + + name = module2name[module] + return dict(model_copy.named_modules())[name] + + def _materialize_meta_module(module: nn.Module, ): + + def meta_to_empty(p: torch.Tensor): + if p.device == torch.device('meta'): + return p.new_empty(p.shape, device='cpu') + else: + return p + + module._apply(meta_to_empty) + if dist.get_rank() == 0: + assert model_copy is not None + module_copy = find_module_in_model_copy(module) + + name2p = dict(module_copy.named_parameters(remove_duplicate=False)) + for n, p in module.named_parameters(): + if '_flat_param' not in n: + n = n.replace('_fsdp_wrapped_module.', '') + try: + p.data.copy_(name2p[n]) + except Exception: + pass + name2p = dict(module_copy.named_buffers(remove_duplicate=False)) + for n, p in module.named_buffers(): + if '_flat_param' not in n: + n = n.replace('_fsdp_wrapped_module.', '') + try: + p.data.copy_(name2p[n]) + except Exception: + pass + + return _materialize_meta_module + + +def main(rank, world_size=8, args=None): + setup(rank, world_size) + + model_name = args.model + batch_size = args.batch_size + + def build(): + model = get_model(model_name) + + # init mutator + mutator = sparse_gpt.SparseGptMutator() + mutator.prepare_from_supernet(model.model.layers) + return model, mutator + + with init_on_meta(enable=True): + model, mutator = build() + + if rank == 0: + model_copy, _ = build() # init on cpu + else: + model_copy = None + + # init fsdp + size_based_auto_wrap_policy_x = functools.partial( + size_based_auto_wrap_policy, min_num_params=int(1e8)) + + model = FSDP( + model, + auto_wrap_policy=size_based_auto_wrap_policy_x, + cpu_offload=CPUOffload(True), + sharding_strategy=ShardingStrategy.FULL_SHARD, + device_id=rank, + param_init_fn=init_fn_wrapper(model, model_copy), + sync_module_states=True) + print_log(model) + + # init hessian + + mutator.init_hessian(device='cuda') + mutator.start_init_hessian() + + _, testloader = get_loaders( + args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen) + testloader = build_language_loader( + testloader, world_size, rank, model, batch_size=batch_size) + opt_infer_fsdp(model, testloader) + + mutator.end_init_hessian() + + # prune + name2module = dict(model.named_modules()) + module2name = {} + module2name = dict([(v, k) for k, v in name2module.items()]) + + with torch.no_grad(): + for fsdp in FSDP.fsdp_modules(model): + fsdp._reset_lazy_init() + with FSDP.summon_full_params(fsdp, recurse=False): + fsdp_name = module2name[fsdp] + for name, op in fsdp.named_modules(): + if name.count('_fsdp_wrapped_module') <= 1: + if isinstance(op, sparse_gpt.SparseGptMixIn): + try: + op.prune(0.5, prunen=2, prunem=4) + print_log( + f'prune {fsdp_name}.{name} successfully.', # noqa + only_rank0=True) + except Exception as e: + print_log( + f'prune {fsdp_name}.{name} failed, as {e}', # noqa + only_rank0=True) + fsdp._reset_lazy_init() + + # save + if args.save: + print_log(f'save model in {args.save}') + model._reset_lazy_init() + with FSDP.summon_full_params(model, rank0_only=True, writeback=False): + if dist.get_rank() == 0: + model.save_pretrained(args.save) + + # val + torch.cuda.empty_cache() + model._reset_lazy_init() + for dataset in ['wikitext2', 'ptb', 'c4']: + _, testloader = get_loaders( + dataset, seed=args.seed, model=model_name, seqlen=model.seqlen) + testloader = build_language_loader( + testloader, world_size, rank, model, batch_size=batch_size) + print_log(dataset) + opt_eval_fsdp(model, testloader, torch.device('cuda')) + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument( + 'model', type=str, help='OPT model to load; pass `facebook/opt-X`.') + parser.add_argument( + 'dataset', + type=str, + choices=['wikitext2', 'ptb', 'c4'], + help='Where to extract calibration data from.') + parser.add_argument( + '--seed', + type=int, + default=0, + help='Seed for sampling the calibration data.') + parser.add_argument( + '--nsamples', + type=int, + default=128, + help='Number of calibration data samples.') + parser.add_argument( + '--batch_size', + type=int, + default=64, + help='Batchsize for calibration and evaluation.') + + parser.add_argument( + '--save', type=str, default='', help='Path to saved model.') + parser.add_argument( + '--world_size', type=int, default=1, help='Number of GPUs to use.') + args = parser.parse_args() + + WORLD_SIZE = args.world_size + mp.spawn(main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) diff --git a/projects/mmrazor_large/examples/model_examples/language_models/Llama/utils.py b/projects/mmrazor_large/examples/model_examples/language_models/Llama/utils.py index 8d3beb2ae..a728a2268 100644 --- a/projects/mmrazor_large/examples/model_examples/language_models/Llama/utils.py +++ b/projects/mmrazor_large/examples/model_examples/language_models/Llama/utils.py @@ -2,6 +2,8 @@ # Example for opt is converted from https://github.com/ist-daslab/sparsegpt import torch import torch.nn as nn +from torch import distributed as dist +from torch.utils.data import DataLoader from transformers import OPTForCausalLM from mmrazor.utils import print_log @@ -78,3 +80,92 @@ def opt_infer( if (i + 1) * batch_size >= num_samples: break + + +class init_on_meta: + + def __init__(self, enable=True) -> None: + self.enable = enable + self.default_device = torch.ones([]).device + + def __enter__(self): + if self.enable: + torch.set_default_device('meta') + + def __exit__(self, exc_type, exc_value, traceback): + if self.enable: + torch.set_default_device(self.default_device) + + +@torch.no_grad() +def opt_eval_fsdp( + model: nn.Module, + dataloader: DataLoader, + dev=torch.device('cuda:0'), +): + print_log('Evaluating ...') + + use_cache = model.config.use_cache + model.config.use_cache = False + loss_sum = torch.zeros([1], device=dev) + total_seq_len = torch.zeros([1], device=dev, dtype=torch.long) + + for i, batch in enumerate(dataloader): + B, seq_len = batch.shape[:2] + + batch = batch.to(dev) + out: torch.Tensor = model(batch)[0] # 1 + + shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C + shift_labels = batch[:, 1:].flatten() # (B N) + + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + neg_log_likelihood = loss.float() * seq_len * B + total_seq_len += seq_len * B + loss_sum += neg_log_likelihood + + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + infered_batch = (i + 1) * B * world_size + + print_log(f'{infered_batch} / {len(dataloader.dataset)}') + + if dist.is_initialized(): + dist.all_reduce(loss_sum) + dist.all_reduce(total_seq_len) + + ppl = torch.exp(loss_sum / total_seq_len) + print_log(f'Perplexity: {ppl.item():3f}') + model.config.use_cache = use_cache + + +@torch.no_grad() +def opt_infer_fsdp( + model: nn.Module, + dataloader: DataLoader, + dev=torch.device('cuda:0'), + num_samples=128, +): + print_log('Infering ...') + + model.config.use_cache = False + + for i, batch in enumerate(dataloader): + B = batch.shape[0] + + batch = batch.to(dev) + model(batch)[0] # 1 + + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + infered_batch = (i + 1) * B * world_size + + print_log(f'{infered_batch} / {len(dataloader.dataset)}') + if infered_batch >= num_samples: + break diff --git a/projects/mmrazor_large/examples/model_examples/language_models/OPT/opt_sparse_gpt_fsdp.py b/projects/mmrazor_large/examples/model_examples/language_models/OPT/opt_sparse_gpt_fsdp.py index dbac11146..95dc772c2 100644 --- a/projects/mmrazor_large/examples/model_examples/language_models/OPT/opt_sparse_gpt_fsdp.py +++ b/projects/mmrazor_large/examples/model_examples/language_models/OPT/opt_sparse_gpt_fsdp.py @@ -11,8 +11,7 @@ from torch.distributed.fsdp.api import ShardingStrategy from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy -from torch.utils.data import DataLoader -from utils import init_on_meta +from utils import init_on_meta, opt_eval_fsdp, opt_infer_fsdp from mmrazor.implementations.pruning import sparse_gpt from mmrazor.utils import print_log @@ -27,73 +26,46 @@ def setup(rank, world_size): print_log(f'init {rank}/{world_size}', only_rank0=False) -@torch.no_grad() -def opt_eval( - model: nn.Module, - dataloader: DataLoader, - dev=torch.device('cuda:0'), -): - print_log('Evaluating ...') +def init_fn_wrapper(model: nn.Module, model_copy: nn.Module): - use_cache = model.config.use_cache - model.config.use_cache = False - loss_sum = torch.zeros([1], device=dev) - total_seq_len = torch.zeros([1], device=dev, dtype=torch.long) + def find_module_in_model_copy(module: nn.Module): + name2module = dict(model.named_modules()) + module2name = dict([(v, k) for k, v in name2module.items()]) - for i, batch in enumerate(dataloader): - B, seq_len = batch.shape[:2] + name = module2name[module] + return dict(model_copy.named_modules())[name] - batch = batch.to(dev) - out: torch.Tensor = model(batch)[0] # 1 + def _materialize_meta_module(module: nn.Module, ): - shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C - shift_labels = batch[:, 1:].flatten() # (B N) + def meta_to_empty(p: torch.Tensor): + if p.device == torch.device('meta'): + return p.new_empty(p.shape, device='cpu') + else: + return p - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct(shift_logits, shift_labels) + module._apply(meta_to_empty) + if dist.get_rank() == 0: + assert model_copy is not None + module_copy = find_module_in_model_copy(module) - neg_log_likelihood = loss.float() * seq_len * B - total_seq_len += seq_len * B - loss_sum += neg_log_likelihood + name2p = dict(module_copy.named_parameters(remove_duplicate=False)) + for n, p in module.named_parameters(): + if '_flat_param' not in n: + n = n.replace('_fsdp_wrapped_module.', '') + try: + p.data.copy_(name2p[n]) + except Exception: + pass + name2p = dict(module_copy.named_buffers(remove_duplicate=False)) + for n, p in module.named_buffers(): + if '_flat_param' not in n: + n = n.replace('_fsdp_wrapped_module.', '') + try: + p.data.copy_(name2p[n]) + except Exception: + pass - print_log(f'{(i+1)*B} / {len(dataloader.dataset)}', only_rank0=False) - - if dist.is_initialized(): - dist.all_reduce(loss_sum) - dist.all_reduce(total_seq_len) - - ppl = torch.exp(loss_sum / total_seq_len) - print_log(f'Perplexity: {ppl.item():3f}') - model.config.use_cache = use_cache - - -@torch.no_grad() -def opt_infer( - model: nn.Module, - dataloader: DataLoader, - dev=torch.device('cuda:0'), -): - print_log('Infering ...') - - model.config.use_cache = False - - for i, batch in enumerate(dataloader): - B, seq_len = batch.shape[:2] - - batch = batch.to(dev) - model(batch)[0] # 1 - - print_log(f'{(i+1)*B} / {len(dataloader.dataset)}') - - -def _materialize_meta_module(module: nn.Module, ): - # Run default meta device initialization - - module.to_empty(device=torch.device('cpu')) - for p in module.parameters(): - p.data.fill_(0) - for p in module.buffers(): - p.data.fill_(0) + return _materialize_meta_module def main(rank, world_size=8, args=None): @@ -102,20 +74,25 @@ def main(rank, world_size=8, args=None): model_name = args.model batch_size = args.batch_size - with init_on_meta(enable=args.m): - if args.m: - print_log('init on meta') + def build(): model = get_model(model_name) # init mutator mutator = sparse_gpt.SparseGptMutator() mutator.prepare_from_supernet(model.model.decoder) + return model, mutator - mutator.keep_hessian_in_float() + with init_on_meta(enable=True): + model, mutator = build() + + if rank == 0: + model_copy, _ = build() # init on cpu + else: + model_copy = None # init fsdp size_based_auto_wrap_policy_x = functools.partial( - size_based_auto_wrap_policy) + size_based_auto_wrap_policy, min_num_params=int(1e8)) model = FSDP( model, @@ -123,64 +100,65 @@ def main(rank, world_size=8, args=None): cpu_offload=CPUOffload(True), sharding_strategy=ShardingStrategy.FULL_SHARD, device_id=rank, - param_init_fn=_materialize_meta_module) - + param_init_fn=init_fn_wrapper(model, model_copy), + sync_module_states=True) print_log(model) # init hessian + mutator.init_hessian(device='cuda') mutator.start_init_hessian() _, testloader = get_loaders( - 'c4', seed=args.seed, model=model_name, seqlen=model.seqlen) + args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen) testloader = build_language_loader( testloader, world_size, rank, model, batch_size=batch_size) - opt_infer(model, testloader) + opt_infer_fsdp(model, testloader) mutator.end_init_hessian() # prune + name2module = dict(model.named_modules()) + module2name = {} + module2name = dict([(v, k) for k, v in name2module.items()]) with torch.no_grad(): - - total_num_op = 0 for fsdp in FSDP.fsdp_modules(model): - if len(FSDP.fsdp_modules(fsdp)) == 1: - fsdp._reset_lazy_init() - with FSDP.summon_full_params(fsdp): - num_op = 0 - for name, op in fsdp.named_modules(): - if isinstance(op, sparse_gpt.SparseGptMixIn): - if num_op % world_size == rank: - try: - op.prune(0.5, prunen=2, prunem=4) - print_log( - f'prune {name} on rank:{rank} successfully.', # noqa - only_rank0=False) - except Exception as e: - print_log( - f'prune {name} on rank:{rank} failed, as {e}', # noqa - only_rank0=False) - num_op += 1 - num_op = 0 - for name, op in fsdp.named_modules(): + fsdp._reset_lazy_init() + with FSDP.summon_full_params(fsdp, recurse=False): + fsdp_name = module2name[fsdp] + for name, op in fsdp.named_modules(): + if name.count('_fsdp_wrapped_module') <= 1: if isinstance(op, sparse_gpt.SparseGptMixIn): - dist.broadcast(op.weight, num_op % world_size) - num_op += 1 - total_num_op += num_op + try: + op.prune(0.5, prunen=2, prunem=4) + print_log( + f'prune {fsdp_name}.{name} successfully.', # noqa + only_rank0=True) + except Exception as e: + print_log( + f'prune {fsdp_name}.{name} failed, as {e}', # noqa + only_rank0=True) + fsdp._reset_lazy_init() + + # save + if args.save: + print_log(f'save model in {args.save}') + model._reset_lazy_init() + with FSDP.summon_full_params(model, rank0_only=True, writeback=False): + if dist.get_rank() == 0: + model.save_pretrained(args.save) - fsdp._reset_lazy_init() - torch.cuda.empty_cache() # val torch.cuda.empty_cache() model._reset_lazy_init() for dataset in ['wikitext2', 'ptb', 'c4']: _, testloader = get_loaders( - dataset, seed=1000, model=model_name, seqlen=model.seqlen) + dataset, seed=args.seed, model=model_name, seqlen=model.seqlen) testloader = build_language_loader( testloader, world_size, rank, model, batch_size=batch_size) print_log(dataset) - opt_eval(model, testloader, torch.device('cuda')) + opt_eval_fsdp(model, testloader, torch.device('cuda')) if __name__ == '__main__': @@ -209,11 +187,7 @@ def main(rank, world_size=8, args=None): type=int, default=64, help='Batchsize for calibration and evaluation.') - parser.add_argument( - '-m', - type=bool, - default=False, - help='Init on meta device to save memory.') + parser.add_argument( '--save', type=str, default='', help='Path to saved model.') parser.add_argument( diff --git a/projects/mmrazor_large/examples/model_examples/language_models/OPT/utils.py b/projects/mmrazor_large/examples/model_examples/language_models/OPT/utils.py index fc494214f..a728a2268 100644 --- a/projects/mmrazor_large/examples/model_examples/language_models/OPT/utils.py +++ b/projects/mmrazor_large/examples/model_examples/language_models/OPT/utils.py @@ -2,6 +2,8 @@ # Example for opt is converted from https://github.com/ist-daslab/sparsegpt import torch import torch.nn as nn +from torch import distributed as dist +from torch.utils.data import DataLoader from transformers import OPTForCausalLM from mmrazor.utils import print_log @@ -93,3 +95,77 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): if self.enable: torch.set_default_device(self.default_device) + + +@torch.no_grad() +def opt_eval_fsdp( + model: nn.Module, + dataloader: DataLoader, + dev=torch.device('cuda:0'), +): + print_log('Evaluating ...') + + use_cache = model.config.use_cache + model.config.use_cache = False + loss_sum = torch.zeros([1], device=dev) + total_seq_len = torch.zeros([1], device=dev, dtype=torch.long) + + for i, batch in enumerate(dataloader): + B, seq_len = batch.shape[:2] + + batch = batch.to(dev) + out: torch.Tensor = model(batch)[0] # 1 + + shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C + shift_labels = batch[:, 1:].flatten() # (B N) + + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(shift_logits, shift_labels) + + neg_log_likelihood = loss.float() * seq_len * B + total_seq_len += seq_len * B + loss_sum += neg_log_likelihood + + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + infered_batch = (i + 1) * B * world_size + + print_log(f'{infered_batch} / {len(dataloader.dataset)}') + + if dist.is_initialized(): + dist.all_reduce(loss_sum) + dist.all_reduce(total_seq_len) + + ppl = torch.exp(loss_sum / total_seq_len) + print_log(f'Perplexity: {ppl.item():3f}') + model.config.use_cache = use_cache + + +@torch.no_grad() +def opt_infer_fsdp( + model: nn.Module, + dataloader: DataLoader, + dev=torch.device('cuda:0'), + num_samples=128, +): + print_log('Infering ...') + + model.config.use_cache = False + + for i, batch in enumerate(dataloader): + B = batch.shape[0] + + batch = batch.to(dev) + model(batch)[0] # 1 + + if dist.is_initialized(): + world_size = dist.get_world_size() + else: + world_size = 1 + infered_batch = (i + 1) * B * world_size + + print_log(f'{infered_batch} / {len(dataloader.dataset)}') + if infered_batch >= num_samples: + break