In [None]:
from rcwc import *
import numpy as np
import cvxpy as cp
from matplotlib import pyplot as plt

In [None]:
def build_data(k):
  '''
  Given k, build the selective prediction instance corresponding to n=2^k.
  '''
  n = 2**k
  m = n*(k-2) + k + 2
  A = np.ones((m,n))
  a_sp = np.zeros((m,n))
  b = np.zeros((m,n))
  weights = np.zeros(m)
  ctr = 0
  for i in range(k):
    w = 2**i
    for t in range(w, n-w+1):
      A[ctr,t::] = 0
      b[ctr,t:t+w] = 1
      a_sp[ctr,t-w:t] = 1
      weights[ctr] = 1 / (k*(n + 1 - 2**(i+1)))
      ctr += 1
  b = b / np.sum(b, axis=1, keepdims=True)
  a_sp = a_sp / np.sum(a_sp, axis=1, keepdims=True)
  
  return A, b, weights, a_sp

In [None]:
max_k = 6
weights_rcwc = []
weights_sp = []
rcwc_errors = np.zeros(max_k)
sp_errors = np.zeros(max_k)

for k in range(1, max_k+1):
  print()
  print(k)
  A, b, weights, a_sp = build_data(k)
  a_rcwc = rcwc(A, b, weights=weights, is_verbose=True)
  weights_rcwc.append(a_rcwc)
  weights_sp.append(a_sp)
  rcwc_errors[k-1] = evaluate_weights_grothendieck(a_rcwc, b, weights=weights)
  sp_errors[k-1] = evaluate_weights_grothendieck(a_sp, b, weights=weights)

In [None]:
# Plotting results
from matplotlib import cm
import matplotlib
from matplotlib.ticker import ScalarFormatter
import seaborn as sns

In [None]:
matplotlib.rcParams.update({'font.size': 30})
plt.figure(figsize=(8,7))
plt.plot(2 ** np.arange(2,max_k+1), sp_errors[1:], color='blue', linewidth=4, marker='o', markersize=10, alpha=0.8)
plt.plot(2 ** np.arange(2,max_k+1), rcwc_errors[1:], color='green', linewidth=4, marker='D', markersize=10, alpha=0.8)
plt.yscale('log')
plt.xscale('log', basex=2)
ax = plt.gca()
ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.2f'))
ax.set_yticks([0.30, 1.00, 3.00])
ax.set_ylim([0.15,3.5])
ax.xaxis.set_major_formatter(ScalarFormatter())
ax.set_xticks([4, 8, 16, 32, 64])
plt.legend(('Selective Prediction', 'RCWC'), frameon=False, loc=(0,0))
plt.ylabel('Prediction Error (Log Scale)')
plt.xlabel('n (Log Scale)')
plt.tight_layout()