# Multiclass SVM 구현

In [44]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

#IRIS 데이터 로드
iris =  sns.load_dataset('iris') 
X= iris.iloc[:,:4] #학습할데이터
y = iris.iloc[:,-1] #타겟
print(y)

0         setosa
1         setosa
2         setosa
3         setosa
4         setosa
         ...    
145    virginica
146    virginica
147    virginica
148    virginica
149    virginica
Name: species, Length: 150, dtype: object


In [45]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=48)

In [46]:
def standardization(train, test):
    scaler = StandardScaler()
    train = scaler.fit_transform(train)
    test = scaler.transform(test)
    return train, test

X_train, X_test = standardization(X_train, X_test)

In [47]:
X_test.shape

(30, 4)

In [96]:
class SVM_OVR:
    def __init__(self, num_classes, kernel, C, gamma):
      
        self.num_classes = num_classes
        self.clfs = [SVC(kernel = kernel, C = C, gamma = gamma) for _ in range(num_classes)]
        self.classes = None

         #clfs 리스트 생성 후 SVC를 필요한 갯수만큼 append해줍니다. OVR이기 때문에 class갯수만큼 해줍니다.
        
    def fit(self, X_train, y_train): #트레이닝하기 위한 메소드입니다.
      
        y_train = pd.get_dummies(y_train) #y_train을 원핫인코딩해줍니다.
        for i in range(self.num_classes): #모든 class에 대하여 svc객체의 fit메소드를 이용해서 학습을 돌려줍니다.
            self.clfs[i].fit(X_train,y_train.iloc[:,i]) 
        self.classes = y_train.columns

    
    def predict(self, X_test): #트레이닝 된 모델을 이용하여 테스트를 진행합니다.

        y_pred = pd.DataFrame([svm.predict(X_test) for svm in self.clfs]) #각 모델에 대해 예측을 수행 
        decisions = np.array([svm.decision_function(X_test) for svm in self.clfs]).T
        print(y_pred) #투표를 통해 class를 결정할 수 없을때 사용
        print(len(y_pred),type(y_pred))
        
        tmp= []
        #정상적으로 분류할경우 
        for i in range(len(y_pred.columns)):
            if sum(y_pred.iloc[:,i]) == 1:
                label = y_pred.iloc[:,i][y_pred.iloc[:,i] == 1].index[0]
                tmp.append(self.classes[label])

                print(tmp)

        #동점이 나올경우 decision 메소드를 사용해서 판별
            else:
                label = np.argmax(decisions[i])
                tmp.append(self.classes[label])
        
      
        return tmp

In [100]:
clf=SVM_OVR(num_classes=3,kernel='rbf',C=10,gamma=10)

In [101]:
clf.fit(X_train,y_train)

In [102]:
pred = clf.predict(X_test)

   0   1   2   3   4   5   6   7   8   9   ...  20  21  22  23  24  25  26  \
0   0   0   0   0   0   0   1   0   0   0  ...   1   0   0   0   0   1   1   
1   1   1   0   0   0   0   0   0   0   1  ...   0   0   0   1   0   0   0   
2   0   0   0   0   0   1   0   1   0   0  ...   0   1   0   0   1   0   0   

   27  28  29  
0   0   0   0  
1   0   0   1  
2   1   1   0  

[3 rows x 30 columns]
3 <class 'pandas.core.frame.DataFrame'>
['versicolor']
['versicolor', 'versicolor']
['versicolor', 'versicolor', 'versicolor', 'virginica', 'virginica', 'virginica']
['versicolor', 'versicolor', 'versicolor', 'virginica', 'virginica', 'virginica', 'setosa']
['versicolor', 'versicolor', 'versicolor', 'virginica', 'virginica', 'virginica', 'setosa', 'virginica']
['versicolor', 'versicolor', 'versicolor', 'virginica', 'virginica', 'virginica', 'setosa', 'virginica', 'setosa', 'versicolor']
['versicolor', 'versicolor', 'versicolor', 'virginica', 'virginica', 'virginica', 'setosa', 'virginica', 'se

In [103]:
for gt, pr in zip(y_test, pred):
    print('%s%20s'%(gt, pr))

versicolor          versicolor
versicolor          versicolor
virginica          versicolor
setosa           virginica
versicolor           virginica
virginica           virginica
setosa              setosa
virginica           virginica
setosa              setosa
versicolor          versicolor
virginica           virginica
setosa              setosa
setosa              setosa
virginica           virginica
versicolor          versicolor
versicolor          versicolor
setosa              setosa
versicolor           virginica
versicolor           virginica
virginica           virginica
setosa              setosa
virginica           virginica
versicolor          versicolor
versicolor          versicolor
virginica           virginica
setosa              setosa
setosa              setosa
virginica           virginica
virginica           virginica
versicolor          versicolor


In [104]:
accuracy_score(y_test, pred)

0.8333333333333334