# Efficient Continuous Pareto Exploration in Multi-Task Learning
Source code for ICML submission #640 "Efficient Continuous Pareto Exploration in Multi-Task Learning"

This script generates Figure 3 in the paper.

# Problem setup

In [None]:
from collections import deque

import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
from scipy.sparse.linalg import LinearOperator, minres
import scipy.optimize

from common import *
from zdt2_variant import Zdt2Variant
from pretty_tabular import PrettyTabular

We first define hyperparameters and random seeds:

In [None]:
K = 2               # Number of tangent directions we want to generate.
N = 10              # Number of points to be collected.
s = 0.1             # Step size.
maxiter = 2         # Maximum allowable number of iterations in MINRES.
gamma = 0.9         # The decay factor in line search.
num_exp = 10        # How many times you want to repeat the experiment with a different seed.
# Hyperparameters for line search.
ths = 1e-5          # For detecting convergence. Smaller value takes longer to converge.
eta_init = 1        # Initial step size in line search.
c1 = 0              # For strong Wolfe condition. Cannot be negative. Typically 1e-4. Larger values requires more iterations in line search.
np.random.seed(42)

We then import the ZDT2-variant problem definition:

In [None]:
problem = Zdt2Variant()
n, m = problem.n, problem.m

# Our method
We define a few functions here: a function to solve $\alpha$ as described in Equation (3) of the paper, a function that does line search for MGDA, and our MGDA function. Specifically, `mgda_optimize` will be used to obtain a Pareto optimal solution from any initial guess.

In [None]:
# Solve alpha in MGDA: note that this does not increment the counter and g[0] and g[1] are not necessarily parallel.
def compute_alpha(g):
    g = ndarray(g)
    assert g.shape == (m, n)
    alpha = cp.Variable(m)
    objective = cp.Minimize(cp.sum_squares(alpha.T @ g))
    constraints = [alpha >= 0, cp.sum(alpha) == 1]
    alpha_prob = cp.Problem(objective, constraints)
    optimal_loss = alpha_prob.solve()
    alpha = ndarray(alpha.value).ravel()
    return alpha

In [None]:
# Line search algorithm. Note that this function increments the counter of #f calls.
# x: current location. R^n.
# f: f(x). R^m.
# grad: grad(x). R^{mxn}
# d: a *descent* direction for all f. R^n.
# The goal is to find eta such that f(x + eta * d) <= f(x) + c1 * grad_f(x).dot(d) * eta for every f.
def line_search(x, f, grad, d, eta, c1):
    x = ndarray(x).ravel()
    assert x.size == n
    f = ndarray(f).ravel()
    assert f.size == m
    grad = ndarray(grad)
    assert grad.shape == (m, n)
    d = ndarray(d).ravel()
    assert d.size == n
    while True:
        x_new = x + eta * d
        f_new = problem.f(x_new)
        if np.all([fi_new <= fi + c1 * gradi.dot(d) * eta for fi, gradi, fi_new in zip(f, grad, f_new)]):
            return eta
        eta *= gamma

In [None]:
# Pareto optimization: pushing an unoptimized x to the Pareto set.
# x: current location. R^n.
def mgda_optimize(x):
    x = ndarray(x).ravel()
    assert x.size == n
    x_iter = np.copy(x)
    while True:
        g_iter = problem.grad(x_iter)
        f_iter = problem.f(x_iter)
        alpha_iter = compute_alpha(g_iter)
        # Negative sign here because d must be a *descent* direction.
        d = -ndarray(alpha_iter.T @ g_iter).ravel()
        # Make sure they are indeed descent.
        for gi in g_iter:
            assert gi.dot(d) <= 0 or np.isclose(gi.dot(d), 0)
        # Termination condition 1: gradient is too small.
        if np.linalg.norm(d) < ths:
            return x_iter
        eta = line_search(x_iter, f_iter, g_iter, d, eta_init, c1)
        x_iter += eta * d
        # Termination condition 2: change is too little. Effectively, this means eta is too small.
        if eta * np.linalg.norm(d) < ths:
            return x_iter

We now define our expansion method and a baseline method for comparison. The weighted sum baseline expands the local Pareto set as if it is executing the first iteration of SGD with a perturbed weight combination of two losses.

In [None]:
# Pareto expansion: compare two methods below.
# Method 1 (baseline): perturbed alpha and combine two gradients (weighted_sum_expand).
# Method 2 (MINRES): our method. Use MINRES to solve an approximated tangent after 2 iterations (minres_expand).
# For both methods, we normalize the directions returned so that the step size s has the same effects.
def weighted_sum_expand(x):
    x = ndarray(x).ravel()
    assert x.size == n
    g = problem.grad(x)
    alpha = compute_alpha(g)
    # Perturb alpha a bit.
    alpha_perturbed = alpha * np.random.uniform(0.9, 1.1, size=(K, m))
    # Normalize each row.
    alpha_perturbed /= np.sum(alpha_perturbed, axis=1)[:, None]
    d = alpha_perturbed @ g
    # Normalize d.
    # If we do not normalize di, GD can walk around the Pareto front without too much correction but the step size
    # is extremely small because the gradient direction is actually orthogonal to the Pareto set in ZDT2-variant.
    d_norm = np.sqrt(np.sum(d ** 2, axis=1))[:, None]
    d /= d_norm
    # Minus sign as if we are running gradient-*descent* for one step.
    return ndarray(x - s * d)

def minres_expand(x):
    x = ndarray(x).ravel()
    assert x.size == n
    g = problem.grad(x)
    alpha = compute_alpha(g)

    def H_op(y):
        y = ndarray(y).ravel()
        assert y.size == n
        return problem.hvp(x, alpha, y)

    # Generate K normalized directions.
    vi = []
    for _ in range(K):
        b = np.random.normal(size=m).T @ g
        x_sol, _ = minres(LinearOperator((n, n), matvec=H_op, rmatvec=H_op), b, maxiter=maxiter)
        x_sol = ndarray(x_sol)
        x_sol /= np.linalg.norm(x_sol)
        vi.append(x_sol)
    vi = ndarray(vi)
    return ndarray(x + s * vi)

# Visualization

We now generate Figure 3 and print a table that summarizes the time cost of our method and the baseline. Note that our method spends slightly more Hessian-vector-product calls when determining the expansion directions in exchange for much more efficient optimization afterwards. In particular, solutions (orange circles) explored by our method closely track the analytic Pareto front. In the meantime, the baseline uses a cheap direction to explore new solutions but costs much more to optimize them back to the Pareto front. You are welcome to tune the step size or hyperparameters of line search to improve the baseline. Hopefully, the advantage of our method over this baseline should be observed consistently.

If you cannot see figures, try replacing `%matploatlib tk` with `%matplotlib inline`.

In [None]:
# Comment out '%matplotlib tk' and uncomment '%matplotlib inline' if this cell does not generate figures on you computer.
%matplotlib tk
#%matplotlib inline

# Simplifying the counter.
class Counter:
    def __init__(self):
        self.f_cnt = 0
        self.g_cnt = 0
        self.h_cnt = 0

    def add(self, f_cnt, g_cnt, h_cnt):
        self.f_cnt += f_cnt
        self.g_cnt += g_cnt
        self.h_cnt += h_cnt

# Set up the figure.
fig = plt.figure(figsize=(10, 5))
for idx, (name, expand) in enumerate(zip(('MINRES', 'MGDA'), (minres_expand, weighted_sum_expand))):
    # Repeat the same experiments with 10 random seeds.
    pareto_fronts = []
    explored = []

    # Counter starts here.
    expand_counter = Counter()
    optimize_counter = Counter()
    head = { 'method ({})'.format(name): '{:>20}', 'eval_f_cnt': '{:4d}', 'eval_g_cnt': '{:4d}', 'eval_hvp_cnt': '{:4d}' }
    tabular = PrettyTabular(head)
    print_info(tabular.head_string())
    for seed in range(num_exp):
        np.random.seed(seed)
        # Generate the initial Pareto optimal point.
        x0 = problem.sample_pareto_set()
        f0 = problem.f(x0)
        g0 = problem.grad(x0)
        alpha0 = compute_alpha(g0)
        # BFS starts here.
        q = deque()
        q.append((x0, f0))
        pareto_front = [f0]

        while q:
            xi, fi = q.popleft()
            # Expand.
            problem.reset_count()
            x1 = expand(xi)
            # Update counter.
            expand_counter.add(problem.eval_f_cnt, problem.eval_grad_cnt, problem.eval_hvp_cnt)
    
            f1 = ndarray([problem.f(x1i) for x1i in x1])
            explored.append(f1)

            # Optimize.
            problem.reset_count()
            x2 = [mgda_optimize(x1i) for x1i in x1]
            # Update counter.
            optimize_counter.add(problem.eval_f_cnt, problem.eval_grad_cnt, problem.eval_hvp_cnt)

            f2 = ndarray([problem.f(x2i) for x2i in x2])
            # Add back to the queue.
            for x2i in x2:
                f2i = problem.f(x2i)
                q.append((x2i, f2i))
                pareto_front.append(f2i)
            problem.reset_count()

            # Terminate if we reach the limit.
            if len(pareto_front) > N:
                break
        pareto_fronts.append(pareto_front)
    # Print the counter information.
    row_data = { 'method ({})'.format(name): 'expand', 'eval_f_cnt': expand_counter.f_cnt,
                 'eval_g_cnt': expand_counter.g_cnt, 'eval_hvp_cnt': expand_counter.h_cnt }
    print(tabular.row_string(row_data))
    row_data = { 'method ({})'.format(name): 'optimize', 'eval_f_cnt': optimize_counter.f_cnt,
                 'eval_g_cnt': optimize_counter.g_cnt, 'eval_hvp_cnt': optimize_counter.h_cnt }
    print(tabular.row_string(row_data))

    pareto_fronts = ndarray(np.vstack(pareto_fronts))
    explored = ndarray(np.vstack(explored))

    ax = fig.add_subplot(1, 2, idx + 1)
    # Plot the Pareto front.
    problem.plot_pareto_front(ax, label='Pareto front')
    ax.scatter(explored[:, 0], explored[:, 1], c='tab:orange', s=45, alpha=0.7, label='$x_i$')
    ax.scatter(pareto_fronts[:, 0], pareto_fronts[:, 1], c='tab:red', s=15, label='$f(x^*_i)$')
    ax.legend()
    ax.set_xlim([-0.1, 1.1])
    ax.set_ylim([-0.1, 1.4])
    ax.set_xticks(np.linspace(0, 1, 6))
    ax.set_yticks(np.linspace(0, 1.2, 7))
    ax.set_title(name)
    plt.show()