# SVM

Here I use an off-the-shelf lib in [scikil-learn](http://sklearn.lzjqsdd.com/modules/classes.html#module-sklearn.svm)

In [5]:
from sklearn.svm import SVC


In [12]:
class SVM():
    def __init__(self, params={}):
        self.clf = SVC(
            kernel = self.set_params('rbf', 'kernel', params),
            gamma = self.set_params('scale', 'gamma', params),
            verbose = self.set_params(False, 'verbose', params)
                      )
        
    def set_params(self, default, label, params):
        return default if label not in params.keys() else params[label]

    
    def fit(self, X, y):
        if X.shape[1]==len(y):
            X = X.transpose()  # num, dim
        self.dim = X.shape[1]
        self.clf.fit(X,y)
        
    def get_acc(self, X, y):
        if X.shape[1]==len(y):
            X = X.transpose()  # num, dim
        return self.clf.score(X,y)
    
    def predict(self, X):
        if X.shape[1] != self.dim:
            X = X.transpose()
        return self.clf.predict(X)
    
    def get_score(self, X):
        if X.shape[1] != self.dim:
            X = X.transpose()
        return self.clf.decision_function(X)
    
    def get_SV(self):
        return self.clf.support_

## Testing

In [13]:
from sklearn.datasets import make_classification

In [14]:
X,y = make_classification(n_features=4)
y[y==0] = -1
y

array([-1,  1,  1, -1, -1,  1,  1,  1, -1,  1,  1, -1,  1, -1, -1, -1, -1,
        1, -1,  1, -1,  1,  1, -1,  1, -1,  1,  1,  1, -1, -1, -1, -1,  1,
       -1,  1,  1,  1, -1,  1,  1,  1, -1,  1, -1,  1,  1,  1, -1,  1,  1,
       -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1,  1, -1,  1, -1, -1, -1,
       -1,  1,  1,  1, -1, -1,  1, -1, -1,  1, -1,  1,  1,  1,  1, -1,  1,
        1, -1,  1,  1,  1, -1,  1, -1, -1, -1,  1,  1, -1, -1,  1])

In [15]:
M = SVM()

M.fit(X,y)

M.get_acc(X,y)

0.95

In [16]:
M.predict(X[:10,:])

array([-1,  1,  1, -1, -1,  1,  1,  1, -1,  1])

In [17]:
M.get_score(X[:10,:])

array([-1.45771553,  1.57786638,  1.43511928, -0.47695879, -0.63750515,
        0.18763603,  0.8008186 ,  1.59067511, -0.84210904,  1.46913523])

In [18]:
M.get_SV()

array([ 3,  4,  8, 14, 16, 20, 51, 52, 54, 55, 56, 63, 68, 73, 75, 78, 83,
       92, 98,  5,  6, 10, 12, 21, 22, 40, 47, 50, 62, 64, 69, 70, 77, 80,
       87, 91, 95, 96, 99], dtype=int32)