Skip to content

Commit

Permalink
Check if method is supported
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobgil committed Apr 15, 2021
1 parent 4e71657 commit 04d367a
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def get_args():
"""

args = get_args()
methods = \
{"gradcam": GradCAM,
"scorecam": ScoreCAM,
"gradcam++": GradCAMPlusPlus,
"ablationcam": AblationCAM,
"xgradcam": XGradCAM}

if args.method not in list(methods.keys()):
raise Exception(f"method should be one of {list(methods.keys())}")

model = models.resnet50(pretrained=True)

# Choose the target layer you want to compute the visualization for.
Expand All @@ -53,12 +63,6 @@ def get_args():
# You can print the model to help chose the layer
target_layer = model.layer4[-1]

methods = \
{"gradcam": GradCAM,
"scorecam": ScoreCAM,
"gradcam++": GradCAMPlusPlus,
"ablationcam": AblationCAM,
"xgradcam": XGradCAM}

if args.method not in methods:
raise Exception(f"Method {args.method} not implemented")
Expand Down

0 comments on commit 04d367a

Please sign in to comment.