From ad04feaf1a8660aa94e22043868fc5723a24ea4b Mon Sep 17 00:00:00 2001 From: liukai Date: Fri, 14 Apr 2023 15:18:31 +0800 Subject: [PATCH 1/2] update --- .../pruning/sparse_gpt/mutator.py | 47 +++-- .../implementations/pruning/sparse_gpt/ops.py | 160 +++++++++--------- .../pruning/sparse_gpt/utils.py | 29 +++- .../test_pruning/test_sparse_gpt/test_op.py | 17 ++ 4 files changed, 159 insertions(+), 94 deletions(-) diff --git a/mmrazor/implementations/pruning/sparse_gpt/mutator.py b/mmrazor/implementations/pruning/sparse_gpt/mutator.py index bbff9894a..d15362725 100644 --- a/mmrazor/implementations/pruning/sparse_gpt/mutator.py +++ b/mmrazor/implementations/pruning/sparse_gpt/mutator.py @@ -7,10 +7,38 @@ from .utils import replace_with_dynamic_ops +def to_static_model(model: nn.Module): + from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet, + load_fix_subnet) + fix_subnet = export_fix_subnet(model)[0] + load_fix_subnet(model, fix_subnet) + return model + + class SparseGptMutator(): - def __init__(self, sparse_model: nn.Module) -> None: - self.model = sparse_model + # init + + def __init__(self) -> None: + self.model: nn.Module = None + + def prepare_from_supernet(self, + model: nn.Module, + prune_conv=True, + prune_linear=True) -> None: + self.model = model + prune_modules: dict = {} + if prune_conv: + prune_modules[nn.Conv2d] = SparseGptConv2d + if prune_linear: + prune_modules[nn.Linear] = SparseGptLinear + replace_with_dynamic_ops(model, prune_modules) + + @classmethod + def to_static_model(cls, model): + return to_static_model(model) + + # hessian def start_init_hessian(self): for module in self.sparse_ops: @@ -20,6 +48,8 @@ def end_init_hessian(self): for module in self.sparse_ops: module.end_init_hessian() + # prune + def prune_24(self, device=torch.device('cuda:0')): for name, module in self.named_sparse_ops: try: @@ -28,11 +58,15 @@ def prune_24(self, device=torch.device('cuda:0')): error = module.prune_24() print_log(f'prune {name} success \t error = {error}') module.to(original_device) + torch.cuda.empty_cache() except Exception as e: print_log(f'prune {name} failed as {e}') + # ops + @property def sparse_ops(self): + assert self.model is not None for module in self.model.modules(): if isinstance(module, SparseGptMixIn): yield module @@ -42,12 +76,3 @@ def named_sparse_ops(self): for name, module in self.model.named_modules(): if isinstance(module, SparseGptMixIn): yield name, module - - @classmethod - def init_from_a_model(cls, model: nn.Module): - replace_with_dynamic_ops(model, { - nn.Linear: SparseGptLinear, - nn.Conv2d: SparseGptConv2d - }) - mutator = cls(model) - return mutator diff --git a/mmrazor/implementations/pruning/sparse_gpt/ops.py b/mmrazor/implementations/pruning/sparse_gpt/ops.py index e39c1cc7b..14b397593 100644 --- a/mmrazor/implementations/pruning/sparse_gpt/ops.py +++ b/mmrazor/implementations/pruning/sparse_gpt/ops.py @@ -7,7 +7,7 @@ from mmrazor.models.architectures.dynamic_ops import (DynamicConv2d, DynamicLinear) -from .utils import ModuleProtocol +from .utils import ModuleProtocol, torch_setting class SparseGptMixIn(ModuleProtocol): @@ -89,95 +89,97 @@ def end_init_hessian(self): @torch.no_grad() def prune_24(self): - # Converted from https://github.com/ist-daslab/sparsegpt - percdamp = 0.01 - blocksize = 128 - prunem = 4 - prunen = 2 - sparsity = 0.5 - - assert self.hessian is not None - W: torch.Tensor = self.weight_matrix.float() # out in - - H = self.hessian - - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - W[:, dead] = 0 - - Losses = torch.zeros(self.rows, device=W.device) - - damp = percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(self.columns, device=W.device) - H[diag, diag] += damp - H = torch.linalg.cholesky(H) - H = torch.cholesky_inverse(H) - H = torch.linalg.cholesky(H, upper=True) - Hinv = H - - mask = None - - for i1 in range(0, self.columns, blocksize): - i2 = min(i1 + blocksize, self.columns) - count = i2 - i1 - - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] - - if prunen == 0: - if mask is not None: - mask1 = mask[:, i1:i2] + with torch_setting(dtype=torch.float): + # Converted from https://github.com/ist-daslab/sparsegpt + percdamp = 0.01 + blocksize = 128 + prunem = 4 + prunen = 2 + sparsity = 0.5 + + assert self.hessian is not None + W: torch.Tensor = self.weight_matrix.float() # out in + + H = self.hessian + + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros(self.rows, device=W.device) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=W.device) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + mask = None + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if prunen == 0: + if mask is not None: + mask1 = mask[:, i1:i2] + else: + tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1)))**2 + thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * + sparsity)] + mask1 = tmp <= thresh else: - tmp = W1**2 / (torch.diag(Hinv1).reshape((1, -1)))**2 - thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * - sparsity)] - mask1 = tmp <= thresh - else: - mask1 = torch.zeros_like(W1) == 1 + mask1 = torch.zeros_like(W1) == 1 - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] - if prunen != 0 and i % prunem == 0: - tmp = W1[:, i:(i + prunem)]**2 / (torch.diag(Hinv1)[i:( - i + prunem)].reshape((1, -1)))**2 - mask1.scatter_( - 1, - i + torch.topk(tmp, prunen, dim=1, largest=False)[1], - True) + if prunen != 0 and i % prunem == 0: + tmp = W1[:, i:(i + prunem)]**2 / (torch.diag(Hinv1)[i:( + i + prunem)].reshape((1, -1)))**2 + mask1.scatter_( + 1, i + + torch.topk(tmp, prunen, dim=1, largest=False)[1], + True) - q = w.clone() - q[mask1[:, i]] = 0 + q = w.clone() + q[mask1[:, i]] = 0 - Q1[:, i] = q - Losses1[:, i] = (w - q)**2 / d**2 + Q1[:, i] = q + Losses1[:, i] = (w - q)**2 / d**2 - err1 = (w - q) / d - W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, - i:].unsqueeze(0)) - Err1[:, i] = err1 + err1 = (w - q) / d + W1[:, + i:] -= err1.unsqueeze(1).matmul(Hinv1[i, + i:].unsqueeze(0)) + Err1[:, i] = err1 - W[:, i1:i2] = Q1 - Losses += torch.sum(Losses1, 1) / 2 + W[:, i1:i2] = Q1 + Losses += torch.sum(Losses1, 1) / 2 - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - torch.cuda.synchronize() - from .sparse24_utils import is_weight_sparse_24 - assert is_weight_sparse_24( - W, -1), f'Weight dose not satisfy 24 with shape {W.shape}' - error = torch.sum(Losses) + torch.cuda.synchronize() + from .sparse24_utils import is_weight_sparse_24 + assert is_weight_sparse_24( + W, -1), f'Weight dose not satisfy 24 with shape {W.shape}' + error = torch.sum(Losses) - if torch.isnan(error).any(): - raise Exception('get nan error') - else: - self.weight_matrix = W.data + if torch.isnan(error).any(): + raise Exception('get nan error') + else: + self.weight_matrix = W.data - return error + return error # SparseGpt Ops for Linear and Conv2d diff --git a/mmrazor/implementations/pruning/sparse_gpt/utils.py b/mmrazor/implementations/pruning/sparse_gpt/utils.py index 59c1bbf8e..95c4e51b6 100644 --- a/mmrazor/implementations/pruning/sparse_gpt/utils.py +++ b/mmrazor/implementations/pruning/sparse_gpt/utils.py @@ -84,18 +84,39 @@ class memory_efficient_forward: def __init__(self, model: nn.Module, + enabled=True, device=torch.device('cuda:0'), wrap_modules=[]) -> None: self.model = model self.device = device self.wrap_modules = wrap_modules + self.enabled = enabled + self.handlers: list = [] + + if not enabled: + model.to(device) def __enter__(self, ): - handles, blocks = enable_efficient_forward(self.model, self.device, - self.wrap_modules) - print_log(f'enable memory efficient forward for {blocks}') - self.handlers = handles + if self.enabled: + handles, blocks = enable_efficient_forward(self.model, self.device, + self.wrap_modules) + print_log(f'enable memory efficient forward for {blocks}') + self.handlers = handles def __exit__(self, exc_type, exc_value, exc_traceback): for h in self.handlers: h.remove() + + +class torch_setting(): + + def __init__(self, dtype=None) -> None: + self.origianl_dtype = torch.get_default_dtype() + self.dtype = dtype + + def __enter__(self): + if self.dtype is not None: + torch.set_default_dtype(self.dtype) + + def __exit__(self, exc_type, exc_value, exc_traceback): + torch.set_default_dtype(self.origianl_dtype) diff --git a/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py b/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py index b8afd2c0a..092b0780c 100644 --- a/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py +++ b/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py @@ -49,3 +49,20 @@ def infer(model, dataset): print('norm:', linear(data_0).norm(2)) print('distance:', get_loss(linear, sparse_linear, data_0)) + + @torch.no_grad() + def test_model(self): + import torchvision + model = torchvision.models.resnet18() + + mutator = sparse_gpt.SparseGptMutator() + mutator.prepare_from_supernet(model) + + x = torch.rand(10, 3, 224, 224) + mutator.start_init_hessian() + model(x) + mutator.end_init_hessian() + mutator.prune_24() + + model = mutator.to_static_model(model) + assert type(model.conv1) is nn.Conv2d From c8634831353b65ec4e5567f20903387c7887fd2f Mon Sep 17 00:00:00 2001 From: liukai Date: Mon, 17 Apr 2023 10:51:50 +0800 Subject: [PATCH 2/2] update --- .../pruning/sparse_gpt/mutator.py | 22 +++++++++++++++---- .../implementations/pruning/sparse_gpt/ops.py | 12 ++++------ .../test_pruning/test_sparse_gpt/test_op.py | 2 +- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/mmrazor/implementations/pruning/sparse_gpt/mutator.py b/mmrazor/implementations/pruning/sparse_gpt/mutator.py index d15362725..e790f9b1c 100644 --- a/mmrazor/implementations/pruning/sparse_gpt/mutator.py +++ b/mmrazor/implementations/pruning/sparse_gpt/mutator.py @@ -49,19 +49,33 @@ def end_init_hessian(self): module.end_init_hessian() # prune - - def prune_24(self, device=torch.device('cuda:0')): + def prune(self, + sparsity, + prunen=0, + prunem=0, + blocksize=128, + percdamp=.01, + device=torch.device('cuda:0')): for name, module in self.named_sparse_ops: try: original_device = next(module.parameters()).device - module = module.to(device) - error = module.prune_24() + module: SparseGptMixIn = module.to(device) + error = module.prune( + sparsity=sparsity, + prunen=prunen, + prunem=prunem, + blocksize=blocksize, + percdamp=percdamp, + ) print_log(f'prune {name} success \t error = {error}') module.to(original_device) torch.cuda.empty_cache() except Exception as e: print_log(f'prune {name} failed as {e}') + def prune_24(self, device=torch.device('cuda:0')): + self.prune(0.5, prunen=2, prunem=4, device=device) + # ops @property diff --git a/mmrazor/implementations/pruning/sparse_gpt/ops.py b/mmrazor/implementations/pruning/sparse_gpt/ops.py index 14b397593..f689f3190 100644 --- a/mmrazor/implementations/pruning/sparse_gpt/ops.py +++ b/mmrazor/implementations/pruning/sparse_gpt/ops.py @@ -88,14 +88,9 @@ def end_init_hessian(self): # prune @torch.no_grad() - def prune_24(self): + def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01): with torch_setting(dtype=torch.float): # Converted from https://github.com/ist-daslab/sparsegpt - percdamp = 0.01 - blocksize = 128 - prunem = 4 - prunen = 2 - sparsity = 0.5 assert self.hessian is not None W: torch.Tensor = self.weight_matrix.float() # out in @@ -170,8 +165,9 @@ def prune_24(self): torch.cuda.synchronize() from .sparse24_utils import is_weight_sparse_24 - assert is_weight_sparse_24( - W, -1), f'Weight dose not satisfy 24 with shape {W.shape}' + if prunen == 2 and prunem == 4: + assert is_weight_sparse_24( + W, -1), f'Weight dose not satisfy 24 with shape {W.shape}' error = torch.sum(Losses) if torch.isnan(error).any(): diff --git a/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py b/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py index 092b0780c..3e9b45c36 100644 --- a/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py +++ b/tests/test_impl/test_pruning/test_sparse_gpt/test_op.py @@ -43,7 +43,7 @@ def infer(model, dataset): infer(sparse_linear, random_data) sparse_linear.end_init_hessian() - sparse_linear.prune_24() + sparse_linear.prune() # compare