-
Notifications
You must be signed in to change notification settings - Fork 1
/
svm.py
87 lines (65 loc) · 2.08 KB
/
svm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fake news detection
The SVM model
"""
#import all the required libraries
from getEmbeddings import getEmbeddings
import numpy as np
from sklearn.svm import LinearSVC
import matplotlib.pyplot as plt
import scikitplot as skplt
import os
import pickle
from sklearn import metrics
def plot_cmat(yte, ypred):
"""
Plotting confusion matrix
"""
skplt.metrics.plot_confusion_matrix(yte,ypred)
plt.show()
def svm_model():
"""
In this function the support vector machine classified is built
"""
'''
Read the data from all the .npy file if file exist,
and if not then call the getEmbeddings function to
create the .npy files.
more about getEmbeddings is in the getEmbeddings.py
NOTE: .npy stands for numpy array
'''
if not os.path.isfile('./xtr.npy') or \
not os.path.isfile('./xte.npy') or \
not os.path.isfile('./ytr.npy') or \
not os.path.isfile('./yte.npy'):
xtr,xte,ytr,yte = getEmbeddings("datasets/train.csv")
np.save('./xtr', xtr)
np.save('./xte', xte)
np.save('./ytr', ytr)
np.save('./yte', yte)
#Load the files to local variables.
xtr = np.load('./xtr.npy')
xte = np.load('./xte.npy')
ytr = np.load('./ytr.npy')
yte = np.load('./yte.npy')
# Use the built-in SVM for classification
'''
creating the classiier LinearSVC()
fitting the model with xte(xtranning) and ytr(ytranning)
'''
clf = LinearSVC()
clf.fit(xtr, ytr)
#Saving the models in the svm_model.sav file so that we can use pretranied model
model_file = 'svm_model.sav'
pickle.dump(clf,open(model_file,'wb'))
#Prediction the y_pred values for xte(xtest)
y_pred = clf.predict(xte)
#Plotting the confussion matrix
print("Accuracy = " + format(metrics.accuracy_score(yte,y_pred)*100, '.2f') + "%")
# Draw the confusion matrix
plot_cmat(yte, y_pred)
#The main function
if __name__ == '__main__':
svm_model()