Skip to content

Commit 8691f55

Browse files
author
User
committed
Merge branch 'master' of github.com:lazyprogrammer/machine_learning_examples
2 parents fa46455 + 05e0c52 commit 8691f55

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

ann_logistic_extra/ann_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def cross_entropy(T, pY):
7777
print("Final train classification_rate:", classification_rate(Ytrain, predict(pYtrain)))
7878
print("Final test classification_rate:", classification_rate(Ytest, predict(pYtest)))
7979

80-
legend1, = plt.plot(train_costs, label='train cost')
81-
legend2, = plt.plot(test_costs, label='test cost')
82-
plt.legend([legend1, legend2])
80+
plt.plot(train_costs, label='train cost')
81+
plt.plot(test_costs, label='test cost')
82+
plt.legend()
8383
plt.show()

ann_logistic_extra/logistic_softmax_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def cross_entropy(T, pY):
7070
print("Final train classification_rate:", classification_rate(Ytrain, predict(pYtrain)))
7171
print("Final test classification_rate:", classification_rate(Ytest, predict(pYtest)))
7272

73-
legend1, = plt.plot(train_costs, label='train cost')
74-
legend2, = plt.plot(test_costs, label='test cost')
75-
plt.legend([legend1, legend2])
73+
plt.plot(train_costs, label='train cost')
74+
plt.plot(test_costs, label='test cost')
75+
plt.legend()
7676
plt.show()

ann_logistic_extra/logistic_train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def cross_entropy(T, pY):
5555
print("Final train classification_rate:", classification_rate(Ytrain, np.round(pYtrain)))
5656
print("Final test classification_rate:", classification_rate(Ytest, np.round(pYtest)))
5757

58-
legend1, = plt.plot(train_costs, label='train cost')
59-
legend2, = plt.plot(test_costs, label='test cost')
60-
plt.legend([legend1, legend2])
58+
plt.plot(train_costs, label='train cost')
59+
plt.plot(test_costs, label='test cost')
60+
plt.legend()
6161
plt.show()
6262

6363

0 commit comments

Comments
 (0)