In [5]:
import numpy as np
#import pandas as pd
#import bz2file as bz2
import os
from typing import Tuple, Optional


## READ gisette dataset

each line is like this: 
-1 1:-1 2:-1 3:0.913914 4:-1 5:-1 6:0.4530 ...
the first number is either 1 or -1 (label y)
 and it is followed by 5000 pairs of the form integer_index. the floats are the x values


In [6]:
from src.utils import read_gisette_data

In [8]:
MAX_LINES = 5000
file_path_train = os.path.join("..","data","gisette_scale.bz2")
file_path_test = os.path.join("..","data","gisette_scale.t.bz2")


y_train, X_train = read_gisette_data(file_path_train, max_lines=MAX_LINES)
y_test, X_test = read_gisette_data(file_path_test, max_lines=MAX_LINES)

## SVM problem definition

* the optimization problem we should solve it the following one:
$$
\begin{equation}
\begin{aligned}
& \min \quad \frac{1}{2} \|\mathbf{w}\|^2 + C \sum_{i=1}^m \xi_i \\
& \text{subject to} \quad y_i (\mathbf{w} \cdot \mathbf{x}_i + b) \geq 1 - \xi_i, \quad \xi_i \geq 0, \quad i = 1, \dots, m
\end{aligned}
\end{equation}
$$

According to Platt's algorithm [put reference here] it is preferrable to solve the dual, which is the following:

$$
\begin{equation}
\begin{aligned}
& \max_{\boldsymbol{\alpha}} \quad \sum_{i=1}^n \alpha_i - \frac{1}{2} \sum_{i=1}^n \sum_{j=1}^n \alpha_i \alpha_j y_i y_j \mathbf{x}_i \cdot \mathbf{x}_j \\
& \text{subject to} \sum_{i=1}^n \alpha_i y_i = 0 \\
& \quad \quad \quad \quad 0 \leq \alpha_i \leq C, \quad i = 1, \dots, m
\end{aligned}
\end{equation}
$$
     

In [21]:
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, f1_score, precision_score

# Create an SVM classifier
svm = SVC(kernel='linear', C=100)

# Train the SVM classifier
svm.fit(X_train, y_train)


In [22]:
# Predict on the test data
y_pred = svm.predict(X_test)

# Calculate the accuracy
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)

print(f'Accuracy of SVM on test set: {accuracy:.2f}')
print(f'F1 of SVM on test set: {f1:.2f}')
print(f'Precision of SVM on test set: {precision:.2f}')



Accuracy of SVM on test set: 0.97
F1 of SVM on test set: 0.97
Precision of SVM on test set: 0.98
