<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Solution-of-master-problem" data-toc-modified-id="Solution-of-master-problem-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Solution of master problem</a></span><ul class="toc-item"><li><span><a href="#Local-solution" data-toc-modified-id="Local-solution-1.1"><span class="toc-item-num">1.1&nbsp;&nbsp;</span>Local solution</a></span></li></ul></li></ul></div>

Here we try to apply Mert's idea of locally computing the derivative and try to see how well this works in practice

We will assume that each of the sub functions is a quad form, purely and simply

**Note** All this could/should? be done in jax-autodiff-pytorch

----


In [None]:
import cvxpy as cp
import numpy as np
import matplotlib.pyplot as plt

In [None]:
dim_x = 2
dim_y = 2

In [None]:
dim_tot = dim_x+dim_y

In [None]:
class problem_data():
    def __init__(self, dim_x, dim_y):
        self.dim_x, self.dim_y = dim_x, dim_y
        self.dim_tot = dim_x+dim_y
        dA1_theta = np.random.randn(dim_tot, dim_tot)
        dA2_theta = np.random.randn(dim_tot, dim_tot)
        self.dA2_theta = dA2_theta@dA2_theta.T
        self.dA1_theta = dA1_theta@dA1_theta.T
        self.db1_theta = np.random.randn(dim_tot)
        self.db2_theta = np.random.randn(dim_tot)

        A1_ = np.random.randn(dim_tot, dim_tot)
        A2_ = np.random.randn(dim_tot, dim_tot)
        self.A1_ = A1_@A1_.T
        self.A2_ = A2_@A2_.T
        self.b1_ = np.random.randn(dim_tot)+10*np.ones(dim_tot)
        self.b2_ = np.random.randn(dim_tot)+20*np.ones(dim_tot)
        
        self.B1 = np.concatenate((np.zeros((dim_y, dim_x)), np.eye(dim_y)), axis = 1)
        self.B2 = np.concatenate((np.zeros((dim_y, dim_x)), -np.eye(dim_y)), axis = 1)
        
    def get_data(self, theta):
        A1 = self.A1_ + self.dA1_theta*theta
        A2 = self.A2_ + self.dA2_theta*theta
        b1 = self.b1_ + theta*self.db1_theta
        b2 = self.b2_ + theta*self.db2_theta
        
        return A1, A2, b1, b2, self.B1, self.B2
    
    def solve_exact(self, theta):
        
        A1, A2, b1, b2, B1, B2 = self.get_data(theta)
        
        z1, z2 = cp.Variable(self.dim_tot), cp.Variable(self.dim_tot)
        constraints = [z1[self.dim_x:] == z2[self.dim_x:]]
        cost = cp.quad_form(z1, A1) + b1@z1 + cp.quad_form(z2, A2) + b2@z2

        prob = cp.Problem(cp.Minimize(cost), constraints)

        prob.solve()
        lam_ = constraints[0].dual_value
        
        #D_theta_g
        self.dg_1 = self.dA1_theta@z1.value + self.db1_theta
        self.dg_2 = self.dA2_theta@z2.value + self.db2_theta
        dg_theta = np.concatenate((dg_1, dg_2, np.zeros((self.dim_y,))))
        
        return z1.value, z2.value, lam_, dg_theta
    
    def compute_J(self, theta):
        dim_tot = self.dim_tot
        dim_y = self.dim_y
        A1, A2, b1, b2, B1, B2 = self.get_data(theta)
        
        J = np.zeros((2*dim_tot + dim_y, 2*dim_tot+dim_y))

        J[:dim_tot, :dim_tot] = A1

        J[dim_tot:-dim_y, dim_tot:-dim_y] = A2

        J[2*dim_tot:, :dim_tot] = B1
        J[2*dim_tot:, dim_tot:2*dim_tot] = B2
        J[:dim_tot, 2*dim_tot:] = B1.T
        J[dim_tot:2*dim_tot, 2*dim_tot:] = B2.T
        
        return J

In [None]:
theta0 = np.random.rand()

# Solution of master problem


In [None]:
data = problem_data(dim_x, dim_y)

In [None]:
z1, z2, lam_, dg_theta = data.solve_exact(theta0)

J = data.compute_J(theta0)

In [None]:
x1_0 = z1[0]#reference solution

In [None]:
x1_0

In [None]:
dzdtheta = - np.linalg.inv(J)@dg_theta

J1 = J[:dim_tot, :dim_tot]

dz1dtheta = - np.linalg.inv(J1)@data.dg_1

In [None]:
dzdtheta[0]

In [None]:
dz1dtheta[0]

In [None]:
x1_vals = []
dtheta_vec = np.logspace(-6, -2, 10)
for dtheta in dtheta_vec:
    theta_ = theta0+dtheta
    z1, z2, lam_, dg_theta = data.solve_exact(theta_)
    x1_vals.append(z1[0])

In [None]:
x1_pred = x1_0 + dzdtheta[0]*dtheta_vec
x1_pred_local = x1_0+dz1dtheta[0]*dtheta_vec

In [None]:
plt.plot(dtheta_vec, x1_vals)

In [None]:
plt.plot(dtheta_vec, x1_vals, label='exact')
plt.plot(dtheta_vec, x1_pred, label='diff')
plt.plot(dtheta_vec, x1_pred_local, label = 'local_diff')
# plt.xscale('log')
plt.grid()
plt.xlabel('dtheta')
plt.ylabel('x1')
plt.legend()


In [None]:
e_diff = np.abs(np.divide(x1_vals-x1_pred, x1_vals))
e_local_diff = np.abs(np.divide(x1_vals-x1_pred_local, x1_vals))

In [None]:
plt.plot(dtheta_vec, e_diff, label='rel_err_diff')
plt.plot(dtheta_vec, e_local_diff, label='rel_err_local_diff')
plt.grid()
plt.xlabel('dtheta')
plt.ylabel('relative error')
# plt.xscale('log')
plt.legend()