In [1]:
import numpy as np
from sklearn.datasets import load_wine

from sklearn.model_selection import train_test_split

from sklearn.tree import DecisionTreeClassifier 
from sklearn.ensemble import RandomForestClassifier 
from sklearn import svm
from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import LogisticRegression 

from sklearn.metrics import classification_report
from sklearn.metrics import recall_score

In [2]:
seed = 7

In [3]:
wine = load_wine()
wine_data = wine.data
wine_label = wine.target

In [4]:
np.bincount(wine_label)

array([59, 71, 48])

In [5]:
X_train, X_test, y_train, y_test = train_test_split(wine_data, 
                                                    wine_label, 
                                                    test_size=0.2, 
                                                    random_state=seed) 

In [6]:
X_train.shape,  y_train.shape, X_test.shape, y_test.shape

((142, 13), (142,), (36, 13), (36,))

In [7]:
np.bincount(y_train),np.bincount(y_test)

(array([52, 54, 36]), array([ 7, 17, 12]))

In [8]:
decision_tree = DecisionTreeClassifier(random_state=seed) 
decision_tree.fit(X_train, y_train) 
decision_y = decision_tree.predict(X_test)
print(classification_report(y_test, decision_y))

              precision    recall  f1-score   support

           0       0.88      1.00      0.93         7
           1       0.89      0.94      0.91        17
           2       1.00      0.83      0.91        12

    accuracy                           0.92        36
   macro avg       0.92      0.92      0.92        36
weighted avg       0.92      0.92      0.92        36



In [9]:
random_forest = RandomForestClassifier(random_state=seed) 
random_forest.fit(X_train, y_train)
random_y = random_forest.predict(X_test)
print(classification_report(y_test, random_y))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00         7
           1       1.00      1.00      1.00        17
           2       1.00      1.00      1.00        12

    accuracy                           1.00        36
   macro avg       1.00      1.00      1.00        36
weighted avg       1.00      1.00      1.00        36



In [10]:
svm_model = svm.SVC(random_state=seed,kernel='linear')
svm_model = svm.SVC(random_state=seed)

svm_model.fit(X_train, y_train) 
svm_y = svm_model.predict(X_test)
print(classification_report(y_test, svm_y))

              precision    recall  f1-score   support

           0       0.86      0.86      0.86         7
           1       0.58      0.88      0.70        17
           2       0.33      0.08      0.13        12

    accuracy                           0.61        36
   macro avg       0.59      0.61      0.56        36
weighted avg       0.55      0.61      0.54        36



In [11]:
sgd_model = SGDClassifier(random_state=seed)
sgd_model.fit(X_train, y_train) 
sgd_y = sgd_model.predict(X_test)
print(classification_report(y_test, sgd_y)) 

              precision    recall  f1-score   support

           0       0.86      0.86      0.86         7
           1       0.85      0.65      0.73        17
           2       0.62      0.83      0.71        12

    accuracy                           0.75        36
   macro avg       0.78      0.78      0.77        36
weighted avg       0.77      0.75      0.75        36



In [12]:
logistic_model = LogisticRegression(random_state=seed,max_iter=10000)
logistic_model.fit(X_train, y_train) 
log_y = logistic_model.predict(X_test) 
print(classification_report(y_test, log_y))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00         7
           1       0.94      1.00      0.97        17
           2       1.00      0.92      0.96        12

    accuracy                           0.97        36
   macro avg       0.98      0.97      0.98        36
weighted avg       0.97      0.97      0.97        36



3개의 클래스의 비율이 균등하지 않기에 accuracy로의 비교는 하지 않는다.

하나의 클래스를 찾는게 목표가 아니기 때문에 f1 score로 비교한다

Random Forest가 1.0으로 가장 높은 f1 score를 달성하였다.