In [1]:
from optimization import optimize
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from oracle import Oracle, make_oracle
import seaborn as sns
from typing import List
from tabulate import tabulate

In [2]:
def run(w, oracle, optimization_methods, linesearch_methods, title=None):
    to_df = []
    for opt in optimization_methods:
        for ls in linesearch_methods:
            if ls == "armijo":
                c1 = 0.25
                c2 = None
            elif ls == "wolfe":
                c1 = 1e-4
                c2 = 0.9
            elif ls == "nesterov":
                c1 = c2 = 2
            else:
                c1 = 0.25
                c2 = None
            _, _, log = optimize(w, oracle, opt, ls, max_iter=max_iter, output_log=True, c1=c1, c2=c2)
            to_df.append(log.best)

    df = pd.concat(to_df, ignore_index=True)
    table = [df.columns.values.tolist()] + df.values.tolist()
    print("-", title, "-")
    print(tabulate(table, headers="firstrow", tablefmt="github", floatfmt=["", "", ".0e", ".2e", ".2e", "", ".2e", "", ".4f", ".1e"]))

In [3]:
max_iter = 100
optimization_methods = ["newton", "conjugate_gradient"]
linesearch_methods = ["golden_section", "brent", "dbrent", "armijo", "wolfe"]
a1a = make_oracle("a1a.libsvm")

In [4]:
w1 = np.ones((a1a.m, 1))
run(w1, a1a, optimization_methods, linesearch_methods, "ones")

- ones -
| OptMethod          | LineSearch     |   tol |       c1 |       c2 | hf_criterion   |   entropy |   num_iter |   oracle_calls |    time |          rk |
|--------------------|----------------|-------|----------|----------|----------------|-----------|------------|----------------|---------|-------------|
| newton             | golden_section | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |         19 |            108 | 1.1e-01 | 6.44242e-09 |
| newton             | brent          | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |         19 |             80 | 1.0e-01 | 6.44242e-09 |
| newton             | dbrent         | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |         19 |             75 | 1.1e-01 | 6.44242e-09 |
| newton             | armijo         | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |         18 |             67 | 9.6e-02 | 9.75609e-09 |
| newton             | wolfe          | 1e-08 | 1.00e-04 | 9.00e-01 | num

In [5]:
for _ in range(10):
    w_uni = np.random.uniform(-1, 1, size=(a1a.m, 1))
    run(w_uni, a1a, optimization_methods, linesearch_methods, "uniform")


- uniform -
| OptMethod          | LineSearch     |   tol |       c1 |       c2 | hf_criterion   |   entropy |   num_iter |   oracle_calls |    time |          rk |
|--------------------|----------------|-------|----------|----------|----------------|-----------|------------|----------------|---------|-------------|
| newton             | golden_section | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |          9 |             70 | 4.9e-02 | 1.70472e-09 |
| newton             | brent          | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |          9 |             42 | 4.6e-02 | 1.70472e-09 |
| newton             | dbrent         | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |          9 |             38 | 5.0e-02 | 1.70472e-09 |
| newton             | armijo         | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |          9 |             31 | 4.4e-02 | 1.46745e-09 |
| newton             | wolfe          | 1e-08 | 1.00e-04 | 9.00e-01 | 

In [6]:
for _ in range(10):
    w_gauss = np.random.normal(size=(a1a.m, 1))
    run(w_gauss, a1a, optimization_methods, linesearch_methods, "gauss")


- gauss -
| OptMethod          | LineSearch     |   tol |       c1 |       c2 | hf_criterion   |   entropy |   num_iter |   oracle_calls |    time |          rk |
|--------------------|----------------|-------|----------|----------|----------------|-----------|------------|----------------|---------|-------------|
| newton             | golden_section | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |         10 |             75 | 5.6e-02 | 4.64796e-09 |
| newton             | brent          | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |         10 |             45 | 5.1e-02 | 4.64796e-09 |
| newton             | dbrent         | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |         10 |             40 | 5.0e-02 | 4.64796e-09 |
| newton             | armijo         | 1e-08 | 2.50e-01 |          | num2006        |  2.98e-01 |         10 |             34 | 4.7e-02 | 4.65498e-09 |
| newton             | wolfe          | 1e-08 | 1.00e-04 | 9.00e-01 | nu