Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add confusion matrix to allowed metrics in BaseDataset #1574

Open
wants to merge 4 commits into
base: 0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mmaction/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from mmcv.utils import print_log
from torch.utils.data import Dataset

from ..core import (mean_average_precision, mean_class_accuracy,
mmit_mean_average_precision, top_k_accuracy)
from ..core import (confusion_matrix, mean_average_precision,
mean_class_accuracy, mmit_mean_average_precision,
top_k_accuracy)
from ..utils import visualize_confusion_matrix
from .pipelines import Compose


Expand Down Expand Up @@ -179,7 +181,7 @@ def evaluate(self,
metrics = metrics if isinstance(metrics, (list, tuple)) else [metrics]
allowed_metrics = [
'top_k_accuracy', 'mean_class_accuracy', 'mean_average_precision',
'mmit_mean_average_precision'
'mmit_mean_average_precision', 'confusion_matrix'
]

for metric in metrics:
Expand Down Expand Up @@ -240,6 +242,13 @@ def evaluate(self,
print_log(log_msg, logger=logger)
continue

if metric == 'confusion_matrix':
y_pred = [np.argmax(result) for result in results]
cm = confusion_matrix(y_pred, gt_labels, 'true')
print_log(cm, logger=logger)
visualize_confusion_matrix(cm)
continue

return eval_results

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions mmaction/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
from .collect_env import collect_env
from .gradcam_utils import GradCAM
from .logger import get_root_logger
from .misc import get_random_string, get_shm_dir, get_thread_id
from .misc import (get_random_string, get_shm_dir, get_thread_id,
visualize_confusion_matrix)
from .module_hooks import register_module_hooks
from .precise_bn import PreciseBNHook
from .setup_env import setup_multi_processes

__all__ = [
'get_root_logger', 'collect_env', 'get_random_string', 'get_thread_id',
'get_shm_dir', 'GradCAM', 'PreciseBNHook', 'register_module_hooks',
'setup_multi_processes'
'setup_multi_processes', 'visualize_confusion_matrix'
]
14 changes: 14 additions & 0 deletions mmaction/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,17 @@ def get_thread_id():
def get_shm_dir():
"""Get shm dir for temporary usage."""
return '/dev/shm'


def visualize_confusion_matrix(confusion_matrix):
"""Visualize a confusion matrix.

Args:
confusion_matrix (np.array): the confusion matrix
"""
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(15, 10))
sns.set(font_scale=1.5)
sns.heatmap(confusion_matrix, annot=True, square=True, cbar=True)
plt.show()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to save it as an image file to a specified location.