Skip to content

Commit

Permalink
refine sparse gpt (#526)
Browse files Browse the repository at this point in the history
* save cpu memory

* update

* update

* update

* update

* refine

* update

* update

---------

Co-authored-by: Your Name <you@example.com>
  • Loading branch information
LKJacky and Your Name committed Apr 26, 2023
1 parent 595849b commit 1fbd872
Show file tree
Hide file tree
Showing 8 changed files with 535 additions and 119 deletions.
4 changes: 2 additions & 2 deletions mmrazor/implementations/pruning/sparse_gpt/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def end_init_hessian(self):
for module in self.sparse_ops:
module.end_init_hessian()

def keep_hessian_in_float(self):
def init_hessian(self, device=None):
for op in self.sparse_ops:
op.keep_hessian_in_float()
op.init_hessian(device=device)

# prune
def prune(self,
Expand Down
51 changes: 40 additions & 11 deletions mmrazor/implementations/pruning/sparse_gpt/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def _sparse_gpt_mix_in_init(self):
self.rows = self.weight_matrix.shape[0]
self.columns = self.weight_matrix.shape[1]

_hessian = torch.zeros([self.columns, self.columns])
self.register_buffer('_hessian', _hessian)
self._hessian: torch.Tensor
self._hessian: torch.Tensor = None
self.hessian_batch = 0

# weight and input adaptive
Expand Down Expand Up @@ -50,17 +48,38 @@ def format_input(self, input: torch.Tensor):
@property
def hessian(self):
"""hessian always return float."""
return self._hessian
if dist.is_initialized():
if dist.get_rank() == 0:
assert self._hessian is not None, 'hessian is not initialized.'
hessian = self._hessian.to(self.weight_matrix.device)
else:
hessian = torch.zeros(
self.columns,
self.columns,
device=self.weight_matrix.device)
dist.broadcast(hessian, 0)
return hessian
else:
return self._hessian

@hessian.setter
def hessian(self, value: torch.Tensor):
with torch.no_grad():
self._hessian.data.copy_(value.data)
if dist.is_initialized():
if dist.get_rank() == 0:
assert self._hessian is not None, 'hessian is not initialized.' # noqa
self._hessian.data.copy_(
value.data.to(self._hessian.device))
else:
self._hessian = None
else:
self._hessian.data.copy_(value.data.to(self._hessian.device))

@torch.no_grad()
def update_hessian(self, input: torch.Tensor):

input = self.format_input(input).float()
H_save = self.hessian
H_save = H_save.to(input.device)

assert len(input.shape) == 3
B = input.shape[0] # B N C
Expand All @@ -71,8 +90,8 @@ def update_hessian(self, input: torch.Tensor):
if dist.is_initialized():
dist.all_reduce(H)
B *= dist.get_world_size()
self.hessian = (self.hessian * self.hessian_batch + H) / (
self.hessian_batch + B)
H_save = (H_save * self.hessian_batch + H) / (self.hessian_batch + B)
self.hessian = H_save
self.hessian_batch = self.hessian_batch + B

def start_init_hessian(self):
Expand All @@ -89,8 +108,18 @@ def end_init_hessian(self):
for h in self.sparse_gpt_handles:
h.remove()

def keep_hessian_in_float(self):
self._hessian = self._hessian.float()
def init_hessian(self, device=None):
if dist.is_initialized():
if dist.get_rank() == 0:
self._hessian = torch.zeros([self.columns, self.columns],
device=device,
dtype=torch.float)
else:
self._hessian = None
else:
self._hessian = torch.zeros([self.columns, self.columns],
device=device,
dtype=torch.float)

# prune

Expand All @@ -102,7 +131,7 @@ def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):
assert self.hessian is not None
W: torch.Tensor = self.weight_matrix.float() # out in

H = self.hessian.float()
H = self.hessian.float().to(W.device)

dead = torch.diag(H) == 0
H[dead, dead] = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,15 @@ def infer(model: nn.Module,
num_samples = args.num_samples
batch_size = args.batch_size

model = torchvision.models.resnet18(pretrained=True).cuda()
model = torchvision.models.resnet18(pretrained=True)
train_loader, test_loader = get_dataloaders(batch_size, 4, data_path)

mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model)

model.cuda()

mutator.init_hessian()
mutator.start_init_hessian()
infer(model, test_loader, num_samples=num_samples)
mutator.end_init_hessian()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DistributedSampler


def set_seed(seed):
Expand Down Expand Up @@ -106,3 +108,45 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
return get_ptb(nsamples, seed, seqlen, model)
if 'c4' in name:
return get_c4(nsamples, seed, seqlen, model)


def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens


class LanguageDataset(TorchDataset):

def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
super().__init__()
# seq: 1, N
self.seq_len = seq_len

self.seq = fold_tokens(seq) # B N

def __len__(self) -> int:
return self.seq.shape[0]

def __getitem__(self, index):
return self.seq[index]


def build_language_loader(testloader, world_size, rank, model, batch_size=128):
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
distributed_sampler = DistributedSampler(
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
batch_size = min(len(val_dataset) // world_size, batch_size)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=True,
sampler=distributed_sampler)
return val_dataloader
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import functools
import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from datautils import build_language_loader, get_loaders
from llama_sparse_gpt import get_model
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from utils import init_on_meta, opt_eval_fsdp, opt_infer_fsdp

from mmrazor.implementations.pruning import sparse_gpt
from mmrazor.utils import print_log


def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12356'

dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
print_log(f'init {rank}/{world_size}', only_rank0=False)


def init_fn_wrapper(model: nn.Module, model_copy: nn.Module):

def find_module_in_model_copy(module: nn.Module):
name2module = dict(model.named_modules())
module2name = dict([(v, k) for k, v in name2module.items()])

name = module2name[module]
return dict(model_copy.named_modules())[name]

def _materialize_meta_module(module: nn.Module, ):

def meta_to_empty(p: torch.Tensor):
if p.device == torch.device('meta'):
return p.new_empty(p.shape, device='cpu')
else:
return p

module._apply(meta_to_empty)
if dist.get_rank() == 0:
assert model_copy is not None
module_copy = find_module_in_model_copy(module)

name2p = dict(module_copy.named_parameters(remove_duplicate=False))
for n, p in module.named_parameters():
if '_flat_param' not in n:
n = n.replace('_fsdp_wrapped_module.', '')
try:
p.data.copy_(name2p[n])
except Exception:
pass
name2p = dict(module_copy.named_buffers(remove_duplicate=False))
for n, p in module.named_buffers():
if '_flat_param' not in n:
n = n.replace('_fsdp_wrapped_module.', '')
try:
p.data.copy_(name2p[n])
except Exception:
pass

return _materialize_meta_module


def main(rank, world_size=8, args=None):
setup(rank, world_size)

model_name = args.model
batch_size = args.batch_size

def build():
model = get_model(model_name)

# init mutator
mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model.model.layers)
return model, mutator

with init_on_meta(enable=True):
model, mutator = build()

if rank == 0:
model_copy, _ = build() # init on cpu
else:
model_copy = None

# init fsdp
size_based_auto_wrap_policy_x = functools.partial(
size_based_auto_wrap_policy, min_num_params=int(1e8))

model = FSDP(
model,
auto_wrap_policy=size_based_auto_wrap_policy_x,
cpu_offload=CPUOffload(True),
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=rank,
param_init_fn=init_fn_wrapper(model, model_copy),
sync_module_states=True)
print_log(model)

# init hessian

mutator.init_hessian(device='cuda')
mutator.start_init_hessian()

_, testloader = get_loaders(
args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
testloader = build_language_loader(
testloader, world_size, rank, model, batch_size=batch_size)
opt_infer_fsdp(model, testloader)

mutator.end_init_hessian()

# prune
name2module = dict(model.named_modules())
module2name = {}
module2name = dict([(v, k) for k, v in name2module.items()])

with torch.no_grad():
for fsdp in FSDP.fsdp_modules(model):
fsdp._reset_lazy_init()
with FSDP.summon_full_params(fsdp, recurse=False):
fsdp_name = module2name[fsdp]
for name, op in fsdp.named_modules():
if name.count('_fsdp_wrapped_module') <= 1:
if isinstance(op, sparse_gpt.SparseGptMixIn):
try:
op.prune(0.5, prunen=2, prunem=4)
print_log(
f'prune {fsdp_name}.{name} successfully.', # noqa
only_rank0=True)
except Exception as e:
print_log(
f'prune {fsdp_name}.{name} failed, as {e}', # noqa
only_rank0=True)
fsdp._reset_lazy_init()

# save
if args.save:
print_log(f'save model in {args.save}')
model._reset_lazy_init()
with FSDP.summon_full_params(model, rank0_only=True, writeback=False):
if dist.get_rank() == 0:
model.save_pretrained(args.save)

# val
torch.cuda.empty_cache()
model._reset_lazy_init()
for dataset in ['wikitext2', 'ptb', 'c4']:
_, testloader = get_loaders(
dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
testloader = build_language_loader(
testloader, world_size, rank, model, batch_size=batch_size)
print_log(dataset)
opt_eval_fsdp(model, testloader, torch.device('cuda'))


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()

parser.add_argument(
'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
parser.add_argument(
'dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
parser.add_argument(
'--batch_size',
type=int,
default=64,
help='Batchsize for calibration and evaluation.')

parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--world_size', type=int, default=1, help='Number of GPUs to use.')
args = parser.parse_args()

WORLD_SIZE = args.world_size
mp.spawn(main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
Loading

0 comments on commit 1fbd872

Please sign in to comment.