Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhence sparsegpt #505

Merged
merged 6 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mmrazor/implementations/pruning/sparse_gpt/mutator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from mmrazor.utils import print_log
Expand All @@ -19,11 +20,14 @@ def end_init_hessian(self):
for module in self.sparse_ops:
module.end_init_hessian()

def prune_24(self):
def prune_24(self, device=torch.device('cuda:0')):
for name, module in self.named_sparse_ops:
try:
original_device = next(module.parameters()).device
module = module.to(device)
error = module.prune_24()
print_log(f'prune {name} success \t error = {error}')
module.to(original_device)
except Exception as e:
print_log(f'prune {name} failed as {e}')

Expand Down
12 changes: 11 additions & 1 deletion mmrazor/implementations/pruning/sparse_gpt/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,14 @@ def __init__(self, *args, **kwargs) -> None:
self._sparse_gpt_mix_in_init()

@classmethod
def convert_from(cls, module: nn.Linear):
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)

device = next(module.parameters()).device
dtype = next(module.parameters()).dtype
new_module = new_module.to(device).to(dtype)

return new_module


Expand All @@ -206,6 +211,11 @@ def __init__(self, *args, **kwargs) -> None:
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)

device = next(module.parameters()).device
dtype = next(module.parameters()).dtype
new_module = new_module.to(device).to(dtype)

return new_module

def format_input(self, input: torch.Tensor):
Expand Down
55 changes: 55 additions & 0 deletions mmrazor/implementations/pruning/sparse_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from mmrazor.models.architectures.dynamic_ops import DynamicMixin
from mmrazor.models.utils import get_module_device
from mmrazor.utils import print_log


class ModuleProtocol(Protocol):
Expand Down Expand Up @@ -44,3 +45,57 @@ def replace_op(model: nn.Module, name: str, module: nn.Module):
new_module = dynamicop_map[type(module)].convert_from(module).to(
get_module_device(module))
replace_op(model, name, new_module)


def register_efficient_forward_hook(module: nn.Module,
device=torch.device('cuda:0')):

def forward_pre_hook(module: nn.Module, input):
module.to(device)

def forward_hook(module: nn.Module, input, output):
module.to('cpu')
torch.cuda.empty_cache()

h1 = module.register_forward_pre_hook(forward_pre_hook)
h2 = module.register_forward_hook(forward_hook)
return [h1, h2]


def enable_efficient_forward(model: nn.Module,
device=torch.device('cuda:0'),
wrap_modules=[]):
handles = []
blocks = []
for name, module in model.named_children():
if type(module) in wrap_modules or len(module._parameters) != 0 or len(
module._buffers) != 0:
handles_ = register_efficient_forward_hook(module, device)
blocks_ = [name]
else:
handles_, blocks_ = enable_efficient_forward(
module, device, wrap_modules)
handles += handles_
blocks += blocks_
return handles, blocks


class memory_efficient_forward:

def __init__(self,
model: nn.Module,
device=torch.device('cuda:0'),
wrap_modules=[]) -> None:
self.model = model
self.device = device
self.wrap_modules = wrap_modules

def __enter__(self, ):
handles, blocks = enable_efficient_forward(self.model, self.device,
self.wrap_modules)
print_log(f'enable memory efficient forward for {blocks}')
self.handlers = handles

def __exit__(self, exc_type, exc_value, exc_traceback):
for h in self.handlers:
h.remove()
87 changes: 52 additions & 35 deletions projects/sparse_gpt/llm/opt_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import torch
import torch.nn as nn
from transformers import OPTForCausalLM
from transformers.models.opt.modeling_opt import OPTDecoderLayer

from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward

has_wandb = False

Expand All @@ -24,35 +28,47 @@ def skip(*args, **kwargs):
return model


def fold_tokens(tokens: torch.Tensor, batch_seq_len=2048):
# tokens: 1 N
N = tokens.shape[1]
num_drop = N % batch_seq_len
if num_drop != 0:
tokens = tokens[:, :-num_drop]
tokens = tokens.reshape([-1, batch_seq_len]) # B N
return tokens


@torch.no_grad()
def opt_eval(model: OPTForCausalLM,
testenc,
dev,
dataset: str,
dev=torch.device('cuda:0'),
batch_size=64,
log_wandb: bool = False):
print('Evaluating ...')

testenc: torch.Tensor = testenc.input_ids # type: ignore
nsamples = testenc.numel() // model.seqlen
seqlen = model.seqlen

testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
testenc = fold_tokens(testenc, seqlen) # B N

use_cache = model.config.use_cache
model.config.use_cache = False
nlls = []

for i in range(nsamples):
batch = testenc[:, (i * model.seqlen):(i + 1) * model.seqlen].to(dev)
out = model(batch)[0] # 1
for batch in torch.split(testenc, batch_size):
B = batch.shape[0]

batch = batch.to(dev)
out: torch.Tensor = model(batch)[0] # 1

shift_logits = out[:, :-1, :].contiguous() # 1 N C
shift_labels = batch[:, 1:] # 1 N
shift_logits = out[:, :-1, :].contiguous().flatten(0, 1) # (B N) C
shift_labels = batch[:, 1:].flatten() # (B N)

loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
neg_log_likelihood = loss.float() * model.seqlen
loss = loss_fct(shift_logits, shift_labels)
neg_log_likelihood = loss.float() * seqlen * B
nlls.append(neg_log_likelihood)
ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
ppl = torch.exp(torch.stack(nlls).sum() / (testenc.numel()))
print(f'Perplexity: {ppl.item():3f}')
model.config.use_cache = use_cache

Expand All @@ -62,22 +78,21 @@ def opt_infer(
model: OPTForCausalLM,
testenc,
dev,
num_samples=128,
batch_size=64,
):
print('Infer ...')

testenc: torch.Tensor = testenc.input_ids # type: ignore
nsamples = testenc.numel() // model.seqlen
seqlen = model.seqlen

testenc: torch.Tensor = testenc.input_ids # type: ignore # 1, N
testenc = fold_tokens(testenc, seqlen) # B N

model.config.use_cache = False

for i in range(nsamples):
batch = testenc[:, (i * model.seqlen):(i + 1) * model.seqlen].to(dev)
for batch in torch.split(testenc, batch_size):
batch = batch.to(dev)
_ = model(batch)[0] # 1

if i > num_samples:
break


if __name__ == '__main__':
import argparse
Expand Down Expand Up @@ -107,23 +122,25 @@ def opt_infer(

model = get_opt(args.model)
model.eval()
model = model.cuda()
print('load model over')
DEV = torch.device('cuda:0')

dataloader, testloader = get_loaders(
'c4', seed=args.seed, model=args.model, seqlen=model.seqlen)

from mmrazor.implementations.pruning import sparse_gpt
mutator = sparse_gpt.SparseGptMutator.init_from_a_model(model)

mutator.start_init_hessian()
opt_infer(model, testloader, DEV, num_samples=128)
mutator.end_init_hessian()
mutator.prune_24()

for dataset in ['wikitext2', 'ptb', 'c4']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print(dataset)
opt_eval(model, testloader, DEV, dataset)
mutator = sparse_gpt.SparseGptMutator.init_from_a_model(
model.model.decoder)

with memory_efficient_forward(model, wrap_modules=[OPTDecoderLayer]):

mutator.start_init_hessian()
opt_infer(model, testloader, DEV)
mutator.end_init_hessian()
mutator.prune_24()

for dataset in ['wikitext2', 'ptb', 'c4']:
dataloader, testloader = get_loaders(
dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print(dataset)
opt_eval(model, testloader, DEV)
32 changes: 0 additions & 32 deletions projects/sparse_gpt/pipe.py

This file was deleted.

40 changes: 30 additions & 10 deletions projects/sparse_gpt/torch_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from pipe import sparse_model
from torch.utils.data import DataLoader

from mmrazor.implementations.pruning import sparse_gpt


def get_dataloaders(batch_size, n_workers, path=''):
normalize = transforms.Normalize(
Expand Down Expand Up @@ -50,13 +51,14 @@ def get_dataloaders(batch_size, n_workers, path=''):
return dataloader_train, dataloader_test


def eval(model: nn.Module, dataloader_test: DataLoader):
@torch.no_grad()
def eval(model: nn.Module,
dataloader_test: DataLoader,
device=torch.device('cuda:0')):

total = 0
correct = 0

device = next(model.parameters()).device

model.eval()
with torch.no_grad():
for x, y in dataloader_test:
Expand All @@ -72,15 +74,33 @@ def eval(model: nn.Module, dataloader_test: DataLoader):
return acc


# sparse_model(model, train_loader, 512)
@torch.no_grad()
def infer(model: nn.Module,
dataloader: torch.utils.data.DataLoader,
num_batchs=256,
device=torch.device('cuda:0')):
model.eval()
with torch.no_grad():
accumulate_batch = 0
for x, _ in dataloader:
x = x.to(device)
model(x)
B = x.shape[0]
accumulate_batch += B
if accumulate_batch > num_batchs:
break


if __name__ == '__main__':
# sparse_model(model, train_loader, 512)
model = torchvision.models.resnet18(pretrained=True)
train_loader, test_loader = get_dataloaders(128, 4, 'data/imagenet_torch')

model = model.cuda()
model = sparse_model(model, test_loader, num_batchs=512)
model = torchvision.models.resnet18(pretrained=True).cuda()
train_loader, test_loader = get_dataloaders(256, 4, 'data/imagenet_torch')

mutator = sparse_gpt.SparseGptMutator.init_from_a_model(model)
mutator.start_init_hessian()
infer(model, test_loader, num_batchs=512)
mutator.end_init_hessian()
mutator.prune_24()

print('start evaluation')
model = model.cuda()
Expand Down