diff --git a/src/sparseml/pytorch/models/classification/inception_v3.py b/src/sparseml/pytorch/models/classification/inception_v3.py index 5b6b429a127..6b62b2b52a1 100644 --- a/src/sparseml/pytorch/models/classification/inception_v3.py +++ b/src/sparseml/pytorch/models/classification/inception_v3.py @@ -470,7 +470,7 @@ def forward(self, x_tens: Tensor) -> Tuple[Tensor, ...]: domain="cv", sub_domain="classification", architecture="inception_v3", - sub_architecture="none", + sub_architecture=None, default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=[ diff --git a/src/sparseml/pytorch/models/classification/resnet.py b/src/sparseml/pytorch/models/classification/resnet.py index 5dbf1e30a21..85eb36235ae 100644 --- a/src/sparseml/pytorch/models/classification/resnet.py +++ b/src/sparseml/pytorch/models/classification/resnet.py @@ -825,7 +825,7 @@ def resnetv2_50(num_classes: int = 1000, class_type: str = "single") -> ResNet: domain="cv", sub_domain="classification", architecture="resnet_v1", - sub_architecture="50-2xwidth", + sub_architecture="50_2x", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.fc.weight", "classifier.fc.bias"], @@ -1067,7 +1067,7 @@ def resnetv2_101(num_classes: int = 1000, class_type: str = "single") -> ResNet: domain="cv", sub_domain="classification", architecture="resnet_v1", - sub_architecture="101-2xwidth", + sub_architecture="101_2x", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.fc.weight", "classifier.fc.bias"], diff --git a/src/sparseml/pytorch/models/classification/vgg.py b/src/sparseml/pytorch/models/classification/vgg.py index d71825d25ef..bc4fd269816 100644 --- a/src/sparseml/pytorch/models/classification/vgg.py +++ b/src/sparseml/pytorch/models/classification/vgg.py @@ -238,7 +238,7 @@ def vgg11(num_classes: int = 1000, class_type: str = "single") -> VGG: domain="cv", sub_domain="classification", architecture="vgg", - sub_architecture="11-bn", + sub_architecture="11_bn", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], @@ -324,7 +324,7 @@ def vgg13(num_classes: int = 1000, class_type: str = "single") -> VGG: domain="cv", sub_domain="classification", architecture="vgg", - sub_architecture="13-bn", + sub_architecture="13_bn", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], @@ -410,7 +410,7 @@ def vgg16(num_classes: int = 1000, class_type: str = "single") -> VGG: domain="cv", sub_domain="classification", architecture="vgg", - sub_architecture="16-bn", + sub_architecture="16_bn", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], @@ -496,7 +496,7 @@ def vgg19(num_classes: int = 1000, class_type: str = "single") -> VGG: domain="cv", sub_domain="classification", architecture="vgg", - sub_architecture="19-bn", + sub_architecture="19_bn", default_dataset="imagenet", default_desc="base", def_ignore_error_tensors=["classifier.mlp.6.weight", "classifier.mlp.6.bias"], diff --git a/src/sparseml/pytorch/models/detection/yolo_v3.py b/src/sparseml/pytorch/models/detection/yolo_v3.py index 0fcc144507e..81569248dfb 100644 --- a/src/sparseml/pytorch/models/detection/yolo_v3.py +++ b/src/sparseml/pytorch/models/detection/yolo_v3.py @@ -276,7 +276,7 @@ def forward(self, inp: Tensor): domain="cv", sub_domain="detection", architecture="yolo_v3", - sub_architecture="none", + sub_architecture="spp", default_dataset="coco", default_desc="base", ) diff --git a/src/sparseml/tensorflow_v1/optim/mask_pruning.py b/src/sparseml/tensorflow_v1/optim/mask_pruning.py index 058a788b042..287ca458ac8 100644 --- a/src/sparseml/tensorflow_v1/optim/mask_pruning.py +++ b/src/sparseml/tensorflow_v1/optim/mask_pruning.py @@ -20,8 +20,10 @@ from collections import namedtuple from typing import List, Tuple + try: import tensorflow.contrib.graph_editor as graph_editor + tf_contrib_err = None except Exception as err: graph_editor = None diff --git a/src/sparseml/tensorflow_v1/utils/variable.py b/src/sparseml/tensorflow_v1/utils/variable.py index 503e51823c4..aeda0a39201 100644 --- a/src/sparseml/tensorflow_v1/utils/variable.py +++ b/src/sparseml/tensorflow_v1/utils/variable.py @@ -17,6 +17,7 @@ import numpy + try: import tensorflow.contrib.graph_editor as graph_editor from tensorflow.contrib.graph_editor.util import ListView @@ -238,7 +239,9 @@ def get_ops_and_inputs_by_name_or_regex( nm_ks_consuming_ops_with_input = [ (consuming_op, inp) for output_tens in graph_editor.sgv(op).outputs - for consuming_op in graph_editor.get_consuming_ops(output_tens) + for consuming_op in graph_editor.get_consuming_ops( + output_tens + ) if "_nm_ks" not in consuming_op.name ] prunable_ops_and_inputs += nm_ks_consuming_ops_with_input