Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
56 lines (45 sloc) 1.76 KB
#!/usr/bin/env python3
from binary_classification import SVM
from kernel import Kernel
import logging
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import itertools
import argh
def example(num_samples=100, num_features=2, grid_size=200, filename="svm.pdf"):
samples = np.array(np.random.normal(size=num_samples * num_features)
.reshape(num_samples, num_features))
labels = 2 * (samples.sum(axis=1) > 0) - 1.0
clf = SVM(Kernel.rbf(0.1), 0.1)
clf.fit(samples, labels)
plot(clf, samples, labels, grid_size, "svm1.pdf")
clf = SVM(Kernel.linear(), 0.1)
clf.fit(samples, labels)
plot(clf, samples, labels, grid_size, "svm2.pdf")
def plot(predictor, X, y, grid_size, filename):
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.linspace(x_min, x_max, grid_size),
np.linspace(y_min, y_max, grid_size),
indexing='ij')
flatten = lambda m: np.array(m).reshape(-1,)
result = []
for (i, j) in itertools.product(range(grid_size), range(grid_size)):
point = np.array([xx[i, j], yy[i, j]]).reshape(1, 2)
result.append(predictor.predict(point))
Z = np.array(result).reshape(xx.shape)
plt.clf()
plt.contourf(xx, yy, Z,
cmap=cm.Paired,
levels=[-0.001, 0.001],
extend='both',
alpha=0.8)
plt.scatter(flatten(X[:, 0]), flatten(X[:, 1]),
c=flatten(y), cmap=cm.Paired)
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.savefig(filename)
if __name__ == "__main__":
logging.basicConfig(level=logging.ERROR)
argh.dispatch_command(example)