In [None]:
from jaxl.constants import *
from jaxl.datasets import get_dataset
from jaxl.learning_utils import get_learner
from jaxl.plot_utils import set_size
from jaxl.utils import parse_dict

import copy
import cvxpy as cp
import jax
import jax.numpy as jnp
import jax.random as jrandom
import json
import math
import matplotlib.pyplot as plt
import numpy as np
import optax
import os

from orbax.checkpoint import PyTreeCheckpointer, CheckpointManager

plt.style.use("seaborn")

In [None]:
config_path = "/Users/chanb/research/personal/jaxl/configs/icl/gpt-linear_classification-query_pred_only.json"
seed = 0

with open(config_path, "r") as f:
    config_dict = json.load(f)
    config = parse_dict(config_dict)

dataset = get_dataset(config.learner_config.dataset_config, seed=seed)

In [None]:
train_x, one_hot_y, test_x, test_y = dataset[10]
train_y = np.argmax(one_hot_y, axis=-1)
train_y[train_y == 0] = -1
print(train_x.shape, train_y.shape)

# Scikit-Learn Solution

In [None]:
from sklearn.pipeline import make_pipeline
from sklearn.svm import LinearSVC


def make_svm(inputs, outputs, reg_coef):
    svm = make_pipeline(
        LinearSVC(C=reg_coef, max_iter=2000, loss="hinge", fit_intercept=True),
    )
    svm.fit(inputs, np.argmax(outputs, axis=1))
    return svm


svm = make_svm(train_x, one_hot_y, 10000)

# SolveQP

$$
\begin{align*}
    \min_x & \frac{1}{2} x^\top P x + q^\top x \\
    \text{s.t. } & Gx \leq h\\
    & Ax = b\\
    & lb \leq x \leq ub
\end{align*}
$$

Example usage:
```
# Import packages.
import cvxpy as cp
import numpy as np

# Generate a random non-trivial quadratic program.
m = 15
n = 10
p = 5
np.random.seed(1)
P = np.random.randn(n, n)
P = P.T @ P
q = np.random.randn(n)
G = np.random.randn(m, n)
h = G @ np.random.randn(n)
A = np.random.randn(p, n)
b = np.random.randn(p)

# Define and solve the CVXPY problem.
x = cp.Variable(n)
prob = cp.Problem(cp.Minimize((1/2)*cp.quad_form(x, P) + q.T @ x),
                 [G @ x <= h,
                  A @ x == b])
prob.solve()

# Print result.
print("\nThe optimal value is", prob.value)
print("A solution x is")
print(x.value)
print("A dual solution corresponding to the inequality constraints is")
print(prob.constraints[0].dual_value)
```

# Primal SVM
$$
\begin{align*}
    \min_{w, b} & \frac{1}{2} w^\top w \\
    \text{s.t. } & (w^\top x_j + b) y_j \geq 1, \forall j \in [N]
\end{align*}
$$

In [None]:
padded_x = np.concatenate((train_x, np.ones((8, 1))), axis=-1)
P = np.eye(3)
P[2] = 0
q = np.zeros(3)
G = padded_x * -train_y[:, None]
h = -np.ones(8)

primal_var = cp.Variable(3)
prob = cp.Problem(
    cp.Minimize((1 / 2) * cp.quad_form(primal_var, P) + q.T @ primal_var),
    [G @ primal_var <= h],
)
print(prob.solve())
params = primal_var.value

For given parameters, we can optimize for the dual parameters:
$$
\begin{align*}
    \min_{\alpha} & \frac{1}{2} \lVert w \rVert^2 - \sum_{i}^N \alpha_i y_i (x_i \cdot w + b) + \alpha_i \\
    \text{s.t. } & \alpha_i \geq 0
\end{align*}
$$

# Dual SVM
$$
\begin{align*}
    \max_{\alpha} & \sum_{i}^N \alpha_i - \frac{1}{2} \sum_{i, j} \alpha_i \alpha_j y_i y_j (x_i \cdot x_j) \\
    \text{s.t. } & \sum_{i} \alpha_i y_i = 0\\
    & \alpha_i \geq 0
\end{align*}
$$

In [None]:
# Kernel
K = train_x @ train_x.T
P = (train_y[:, None] @ train_y[None]) * K
q = -np.ones(len(train_x))
A = train_y[None]
b = np.zeros(1)
G = -np.eye(len(train_x))
h = np.zeros(len(train_x))

dual_var = cp.Variable(len(train_x))
prob = cp.Problem(
    cp.Minimize((1 / 2) * cp.quad_form(dual_var, P) + q.T @ dual_var),
    [G @ dual_var <= h, A @ dual_var == b],
)
print(prob.solve())
alphas = dual_var.value

In [None]:
print(params, svm[0].coef_, np.sum((alphas * train_y)[:, None] * train_x, axis=0))
print(alphas)

In [None]:
decision_function = svm.decision_function(train_x)
support_vector_indices = np.where(np.abs(decision_function) <= 1)[0]
print(support_vector_indices)

In [None]:
for label in [-1, 1]:
    idxes = train_y == label
    plt.scatter(train_x[idxes][:, 0], train_x[idxes][:, 1], label=label)
plt.legend()
plt.xlim(-1, 1)
plt.ylim(-1, 1)