In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")
from functools import partial
import torch
from src.frank_wolfe import pseudo_frank_wolfe, term, isosceles_triangular_extreme_points
from src.affine_cond_x import objective_template, objective_gradient
from src.affine_cond_x import obj_matrix_form, grad_matrix_form, opt_step_len, approx_step_len

In [None]:
# Fn check
xs = torch.randn((2,1))
precisions = torch.randn((2,1))
#xs = torch.tensor([[2.0, 3]]).T
#precisions = torch.tensor([[1.0, 2]]).T
sigma_sq_k, sigma_sq_m = 1, 1
full = objective_template(precisions, xs, sigma_sq_k, sigma_sq_m)
mat = obj_matrix_form(precisions, xs, sigma_sq_k, sigma_sq_m)
print(full, mat)

In [None]:
# Grad check
xs = torch.randn((2,1))
precisions = torch.randn((2,1))
xs = torch.tensor([[2.0, 3]]).T
precisions = torch.tensor([[1.0, 2]]).T
sigma_sq_k, sigma_sq_m = 1, 1
full = objective_gradient(precisions, xs, sigma_sq_k, sigma_sq_m)
mat = grad_matrix_form(precisions, xs, sigma_sq_k, sigma_sq_m)
print(full)
print(mat)

In [None]:
num_samples = xs.size(0)
xs_elem_sq = xs ** 2
first_quadr = xs_elem_sq.repeat((1, num_samples))
second_quadr = xs @ xs.T
quadr_mat = first_quadr - second_quadr
linear = num_samples / sigma_sq_m * xs_elem_sq + num_samples / sigma_sq_k * torch.ones(xs.size())

In [None]:
step_len_selector = partial(opt_step_len, quadr_mat=2*quadr_mat, linear=linear)


In [None]:
import numpy as np
import matplotlib.cm as cm
import matplotlib.pyplot as plt


delta = 3
alpha_1 = np.arange(-50.0, 100.0, delta)
alpha_2 = np.arange(-50.0, 100.0, delta)

fn = partial(obj_matrix_form, xs=xs, sigma_sq_k=1, sigma_sq_m=1)
num_samples = alpha_1.shape[0]
Z = torch.empty((num_samples, num_samples))
for (i_x, x) in enumerate(alpha_1):
    for (i_y, y) in enumerate(alpha_2):
        alpha = torch.tensor([x, y], dtype=xs.dtype).reshape((2,1))
        Z[i_x, i_y] = fn(alpha)

In [None]:
budget = 85
alpha_init = torch.tensor([[40.0, 0.0]]).T
grad_fn = partial(objective_gradient, xs=xs, sigma_sq_k=1, sigma_sq_m=1)
extr_point_finder = partial(isosceles_triangular_extreme_points, budget=budget, num_samples=alpha.size(0))
step_len_selector = partial(approx_step_len, obj_fn=fn, num_steps=5)
alpha_star, alpha_store = pseudo_frank_wolfe(None, grad_fn, extr_point_finder, step_len_selector, term, alpha_init, num_iter=100)
alpha_store

In [None]:
def plot_constraint_set(ax, budget):
    x = np.array([0, budget])
    y = np.array([budget, 0])
    ax.plot(x, y, "r-")
    x = np.array([0, 0])
    y = np.array([0, budget])
    ax.plot(x, y, "r-")
    x = np.array([0, budget])
    y = np.array([0, 0])
    ax.plot(x, y, "r-")

# Contour, constraint set and extr points
X, Y = np.meshgrid(alpha_1, alpha_2)
fig, ax = plt.subplots()
CS = ax.contour(X, Y, Z)
ax.clabel(CS, inline=True, fontsize=10)

plot_constraint_set(ax, budget)

extr_points = isosceles_triangular_extreme_points(budget, len(xs))
ax.plot(extr_points[0, :], extr_points[1, :], "rX")

# Traj:
ax.plot(alpha_store[0, :], alpha_store[1, :], "b*-")
# Grad:
grad_star = grad_fn(alpha_star)
g_plot = torch.column_stack((alpha_star, alpha_star + 0.05 * grad_star))
ax.plot(g_plot[0, :], g_plot[1, :], "b-")
lin_star = grad_fn(alpha_star).T @ (extr_points - alpha_star)
print(lin_star)
opt_extr_init = torch.tensor([[0, budget]]).T
search_dir = opt_extr_init - alpha_init