Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine sparse gpt #526

Merged
merged 8 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mmrazor/implementations/pruning/sparse_gpt/mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def end_init_hessian(self):
for module in self.sparse_ops:
module.end_init_hessian()

def keep_hessian_in_float(self):
def init_hessian(self, device=None):
for op in self.sparse_ops:
op.keep_hessian_in_float()
op.init_hessian(device=device)

# prune
def prune(self,
Expand Down
51 changes: 40 additions & 11 deletions mmrazor/implementations/pruning/sparse_gpt/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def _sparse_gpt_mix_in_init(self):
self.rows = self.weight_matrix.shape[0]
self.columns = self.weight_matrix.shape[1]

_hessian = torch.zeros([self.columns, self.columns])
self.register_buffer('_hessian', _hessian)
self._hessian: torch.Tensor
self._hessian: torch.Tensor = None
self.hessian_batch = 0

# weight and input adaptive
Expand Down Expand Up @@ -50,17 +48,38 @@ def format_input(self, input: torch.Tensor):
@property
def hessian(self):
"""hessian always return float."""
return self._hessian
if dist.is_initialized():
if dist.get_rank() == 0:
assert self._hessian is not None, 'hessian is not initialized.'
hessian = self._hessian.to(self.weight_matrix.device)
else:
hessian = torch.zeros(
self.columns,
self.columns,
device=self.weight_matrix.device)
dist.broadcast(hessian, 0)
return hessian
else:
return self._hessian

@hessian.setter
def hessian(self, value: torch.Tensor):
with torch.no_grad():
self._hessian.data.copy_(value.data)
if dist.is_initialized():
if dist.get_rank() == 0:
assert self._hessian is not None, 'hessian is not initialized.' # noqa
self._hessian.data.copy_(
value.data.to(self._hessian.device))
else:
self._hessian = None
else:
self._hessian.data.copy_(value.data.to(self._hessian.device))

@torch.no_grad()
def update_hessian(self, input: torch.Tensor):

input = self.format_input(input).float()
H_save = self.hessian
H_save = H_save.to(input.device)

assert len(input.shape) == 3
B = input.shape[0] # B N C
Expand All @@ -71,8 +90,8 @@ def update_hessian(self, input: torch.Tensor):
if dist.is_initialized():
dist.all_reduce(H)
B *= dist.get_world_size()
self.hessian = (self.hessian * self.hessian_batch + H) / (
self.hessian_batch + B)
H_save = (H_save * self.hessian_batch + H) / (self.hessian_batch + B)
self.hessian = H_save
self.hessian_batch = self.hessian_batch + B

def start_init_hessian(self):
Expand All @@ -89,8 +108,18 @@ def end_init_hessian(self):
for h in self.sparse_gpt_handles:
h.remove()

def keep_hessian_in_float(self):
self._hessian = self._hessian.float()
def init_hessian(self, device=None):
if dist.is_initialized():
if dist.get_rank() == 0:
self._hessian = torch.zeros([self.columns, self.columns],
device=device,
dtype=torch.float)
else:
self._hessian = None
else:
self._hessian = torch.zeros([self.columns, self.columns],
device=device,
dtype=torch.float)

# prune

Expand All @@ -102,7 +131,7 @@ def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):
assert self.hessian is not None
W: torch.Tensor = self.weight_matrix.float() # out in

H = self.hessian.float()
H = self.hessian.float().to(W.device)

dead = torch.diag(H) == 0
H[dead, dead] = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,15 @@ def infer(model: nn.Module,
num_samples = args.num_samples
batch_size = args.batch_size

model = torchvision.models.resnet18(pretrained=True).cuda()
model = torchvision.models.resnet18(pretrained=True)
train_loader, test_loader = get_dataloaders(batch_size, 4, data_path)

mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model)

model.cuda()

mutator.init_hessian()
mutator.start_init_hessian()
infer(model, test_loader, num_samples=num_samples)
mutator.end_init_hessian()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from torch.utils.data import DistributedSampler


def set_seed(seed):
Expand Down Expand Up @@ -106,3 +108,45 @@ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
return get_ptb(nsamples, seed, seqlen, model)
if 'c4' in name:
return get_c4(nsamples, seed, seqlen, model)


def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens


class LanguageDataset(TorchDataset):

def __init__(self, seq: torch.Tensor, seq_len: int = 2048) -> None:
super().__init__()
# seq: 1, N
self.seq_len = seq_len

self.seq = fold_tokens(seq) # B N

def __len__(self) -> int:
return self.seq.shape[0]

def __getitem__(self, index):
return self.seq[index]


def build_language_loader(testloader, world_size, rank, model, batch_size=128):
val_dataset = LanguageDataset(testloader.input_ids, seq_len=model.seqlen)
distributed_sampler = DistributedSampler(
val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
batch_size = min(len(val_dataset) // world_size, batch_size)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=0,
pin_memory=True,
drop_last=True,
sampler=distributed_sampler)
return val_dataloader
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import functools
import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from datautils import build_language_loader, get_loaders
from llama_sparse_gpt import get_model
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from utils import init_on_meta, opt_eval_fsdp, opt_infer_fsdp

from mmrazor.implementations.pruning import sparse_gpt
from mmrazor.utils import print_log


def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12356'

dist.init_process_group('nccl', rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
print_log(f'init {rank}/{world_size}', only_rank0=False)


def init_fn_wrapper(model: nn.Module, model_copy: nn.Module):

def find_module_in_model_copy(module: nn.Module):
name2module = dict(model.named_modules())
module2name = dict([(v, k) for k, v in name2module.items()])

name = module2name[module]
return dict(model_copy.named_modules())[name]

def _materialize_meta_module(module: nn.Module, ):

def meta_to_empty(p: torch.Tensor):
if p.device == torch.device('meta'):
return p.new_empty(p.shape, device='cpu')
else:
return p

module._apply(meta_to_empty)
if dist.get_rank() == 0:
assert model_copy is not None
module_copy = find_module_in_model_copy(module)

name2p = dict(module_copy.named_parameters(remove_duplicate=False))
for n, p in module.named_parameters():
if '_flat_param' not in n:
n = n.replace('_fsdp_wrapped_module.', '')
try:
p.data.copy_(name2p[n])
except Exception:
pass
name2p = dict(module_copy.named_buffers(remove_duplicate=False))
for n, p in module.named_buffers():
if '_flat_param' not in n:
n = n.replace('_fsdp_wrapped_module.', '')
try:
p.data.copy_(name2p[n])
except Exception:
pass

return _materialize_meta_module


def main(rank, world_size=8, args=None):
setup(rank, world_size)

model_name = args.model
batch_size = args.batch_size

def build():
model = get_model(model_name)

# init mutator
mutator = sparse_gpt.SparseGptMutator()
mutator.prepare_from_supernet(model.model.layers)
return model, mutator

with init_on_meta(enable=True):
model, mutator = build()

if rank == 0:
model_copy, _ = build() # init on cpu
else:
model_copy = None

# init fsdp
size_based_auto_wrap_policy_x = functools.partial(
size_based_auto_wrap_policy, min_num_params=int(1e8))

model = FSDP(
model,
auto_wrap_policy=size_based_auto_wrap_policy_x,
cpu_offload=CPUOffload(True),
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=rank,
param_init_fn=init_fn_wrapper(model, model_copy),
sync_module_states=True)
print_log(model)

# init hessian

mutator.init_hessian(device='cuda')
mutator.start_init_hessian()

_, testloader = get_loaders(
args.dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
testloader = build_language_loader(
testloader, world_size, rank, model, batch_size=batch_size)
opt_infer_fsdp(model, testloader)

mutator.end_init_hessian()

# prune
name2module = dict(model.named_modules())
module2name = {}
module2name = dict([(v, k) for k, v in name2module.items()])

with torch.no_grad():
for fsdp in FSDP.fsdp_modules(model):
fsdp._reset_lazy_init()
with FSDP.summon_full_params(fsdp, recurse=False):
fsdp_name = module2name[fsdp]
for name, op in fsdp.named_modules():
if name.count('_fsdp_wrapped_module') <= 1:
if isinstance(op, sparse_gpt.SparseGptMixIn):
try:
op.prune(0.5, prunen=2, prunem=4)
print_log(
f'prune {fsdp_name}.{name} successfully.', # noqa
only_rank0=True)
except Exception as e:
print_log(
f'prune {fsdp_name}.{name} failed, as {e}', # noqa
only_rank0=True)
fsdp._reset_lazy_init()

# save
if args.save:
print_log(f'save model in {args.save}')
model._reset_lazy_init()
with FSDP.summon_full_params(model, rank0_only=True, writeback=False):
if dist.get_rank() == 0:
model.save_pretrained(args.save)

# val
torch.cuda.empty_cache()
model._reset_lazy_init()
for dataset in ['wikitext2', 'ptb', 'c4']:
_, testloader = get_loaders(
dataset, seed=args.seed, model=model_name, seqlen=model.seqlen)
testloader = build_language_loader(
testloader, world_size, rank, model, batch_size=batch_size)
print_log(dataset)
opt_eval_fsdp(model, testloader, torch.device('cuda'))


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()

parser.add_argument(
'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
parser.add_argument(
'dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
parser.add_argument(
'--seed',
type=int,
default=0,
help='Seed for sampling the calibration data.')
parser.add_argument(
'--nsamples',
type=int,
default=128,
help='Number of calibration data samples.')
parser.add_argument(
'--batch_size',
type=int,
default=64,
help='Batchsize for calibration and evaluation.')

parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--world_size', type=int, default=1, help='Number of GPUs to use.')
args = parser.parse_args()

WORLD_SIZE = args.world_size
mp.spawn(main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)
Loading