Skip to content

Commit

Permalink
small bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuawe committed Nov 10, 2023
1 parent 68181b8 commit dcdf1a2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ fmt: ## Format code using black & isort.

.PHONY: lint
lint: ## Run pep8, black, mypy linters.
@echo "Running linters ..."
@echo "--- Running flake8 ---"
$(ENV_PREFIX)flake8 plotsandgraphs/
@echo "--- Running black ---"
$(ENV_PREFIX)black -l 79 --check plotsandgraphs/
$(ENV_PREFIX)black -l 79 --check tests/
@echo "--- Running mypy ---"
$(ENV_PREFIX)mypy --ignore-missing-imports plotsandgraphs/

.PHONY: test
Expand Down
8 changes: 4 additions & 4 deletions plotsandgraphs/binary_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def plot_accuracy(y_true, y_pred, name='', save_fig_path=None) -> Figure:
path = Path(save_fig_path)
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_fig_path, bbox_inches='tight')
return fig, accuracy
return fig

def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, save_fig_path=None) -> Figure:
import matplotlib.colors as colors
Expand Down Expand Up @@ -136,7 +136,7 @@ def plot_classification_report(y_test: np.ndarray,
mask[:,cols-1] = True

bounds = np.linspace(0, 1, 11)
cmap = plt.cm.get_cmap('YlOrRd', len(bounds)+1)
cmap = plt.cm.get_cmap('YlOrRd', len(bounds)+1) # type: ignore
norm = colors.BoundaryNorm(bounds, cmap.N) # type: ignore[attr-defined]

ax = sns.heatmap(df, mask=mask, annot=False, cmap=cmap, fmt='.3g',
Expand Down Expand Up @@ -428,8 +428,8 @@ def plot_pr_curve(

# Plot Precision-Recall curve
ax.plot(recall, precision, label=label, color=color)
ax.set_xlim([0.0, 1.01])
ax.set_ylim([-0.01, 1.01])
ax.set_xlim((0.0, 1.01))
ax.set_ylim((-0.01, 1.01))
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
if title is not None:
Expand Down

0 comments on commit dcdf1a2

Please sign in to comment.