In [1]:
import numpy as np
import cvxpy as cp

In [84]:
# Data taken from https://github.com/cvxgrp/cvxbook_additional_exercises/blob/main/python/rob_logistic_reg_data.py

np.random.seed(0x364a_23F5)
d = 40
n = 60
# epsilon = 0.5
# epsilon = 0.75
epsilon = 1

true_theta = np.random.randn(d)
true_X = np.random.randn(n, d)
noise = 2 * epsilon * np.random.rand(n, d) - epsilon

X = true_X + noise
y = np.sign(true_X @ true_theta + 0.1 * np.random.rand(n) - 0.05)

true_X_test = np.random.randn(n, d)
noise = 2 * epsilon * np.random.rand(n, d) - epsilon

X_test = true_X_test + noise
y_test = np.sign(true_X_test @ true_theta + 0.1 * np.random.rand(n) - 0.05)

In [85]:
# train a typical logistic classifier
theta = cp.Variable(d)
lse = cp.logistic

loss = cp.sum( [ lse( -y[i] * theta@X[i, :] ) for i in range(n)] )
prob = cp.Problem(cp.Minimize(loss), [])
prob.solve(solver=cp.CLARABEL)

1.9880240659714824e-09

In [86]:
# Train Classification
num_correct = 0
for i in range(X.shape[0]):
    y_hat = cp.sign(theta.value @ X[i, :]).value
    if y_hat == y[i]:
        num_correct += 1
print(num_correct)

60


In [87]:
num_correct = 0
for i in range(X_test.shape[0]):
    y_hat = cp.sign(theta.value @ X_test[i, :]).value
    if y_hat == y_test[i]:
        num_correct += 1
print(num_correct)

39


In [94]:
# train a robust classifier
theta_robust = cp.Variable(d)
u = cp.Variable(n)

lse = cp.logistic

loss = cp.sum( [ lse( u[i] ) for i in range(n)] )

constrs = [ -y[i]*theta_robust@X[i, :] + cp.norm((-epsilon*y[i]*theta_robust), 'inf') <= u[i]
           for i in range(n)]

prob = cp.Problem(cp.Minimize(loss), constrs)
prob.solve(solver=cp.CLARABEL)

4.6770196119973843e-08

In [95]:
# Train Classification
num_correct = 0
for i in range(X.shape[0]):
    y_hat = cp.sign(theta_robust.value @ X[i, :]).value
    if y_hat == y[i]:
        num_correct += 1
print(num_correct)

60


In [96]:
num_correct = 0
for i in range(X_test.shape[0]):
    y_hat = cp.sign(theta_robust.value @ X_test[i, :]).value
    if y_hat == y_test[i]:
        num_correct += 1
print(num_correct)

43
