In [1]:
import sys
sys.path.append('../src')

import importlib
import numpy as np
from sklearn.covariance import empirical_covariance

import policies 

# Covariance Constrained Policy Testing

Testing the covariance constrained policy. Initially, this was giving a lot of violations. Initially thought that the problem might just not be feasible but after inspection it was. I switched the solver being used by cvxpy to Gurobi, which resolved the issue. 

The initial line in this, %%capture, redirects stdout from printing out the output. This is much nicer than redirecting the stdout manually, which doesn't work as anticipated inside of jupyter notebooks. Note: if your code is multithreaded, this no longer works for some reason. 

In [21]:
%%capture

importlib.reload(policies)

n = 100
n_features = 6
pred_dim = 6

np.random.seed(30)

# Gaussian xs, with true labels having a different linear relationship with xs in each coordinate
xs = np.random.normal(size=(n, n_features))
slopes = np.random.uniform(size = n_features)
ys = np.multiply(xs, slopes)

def meta_model(coord, slopes):
    def model(xs):
        preds = np.random.normal(size=(n, pred_dim))
        true_ys = np.multiply(xs, slopes)
        preds[:,coord] = true_ys[:,coord]
        return preds
    return model

model = meta_model(0, slopes)

alpha = 0.9
policy = policies.VarianceConstrained(pred_dim, model, 0.1, alpha, ys)
out = policy.run_given_preds(ys);

In [22]:
out

array([[9.60852781e-10, 9.99999997e-01, 2.55313001e-11, 9.00783916e-10,
        3.07819780e-10, 5.00705747e-10],
       [7.54686652e-12, 8.77117848e-08, 8.77443274e-12, 7.50539889e-12,
        9.99999912e-01, 8.20039187e-12],
       [1.91751937e-10, 2.53293148e-12, 9.99999999e-01, 7.38496930e-10,
        4.77996752e-11, 4.52196993e-10],
       ...,
       [1.04077839e-08, 1.53514710e-09, 9.99999949e-01, 1.41299553e-08,
        1.33785896e-08, 1.17086987e-08],
       [1.29970266e-07, 1.69285051e-08, 2.31114975e-08, 1.23471479e-07,
        9.99999693e-01, 1.31613885e-08],
       [2.12553512e-09, 9.99999671e-01, 8.34060607e-10, 1.59162518e-08,
        3.10161634e-07, 2.80261627e-10]])

Sanity check: 

1. Does each row sum to 1? Yes! Or at least, up to a tolerance, since there will be small floating point errors.

In [13]:
tolerance = 1e-3

print(np.sum(out, axis=1))
print("Violations: ", n-sum((np.isclose(np.sum(out, axis=1), np.ones(50), atol=tolerance))))

2. Is each constraint bounded by 0 and 1?

In [52]:
print(f"Number of allocations greater than 1: {np.sum(out - tolerance > 1)}")
print(f"Number of allocations less than 0: {np.sum(out+tolerance < 0)}")

3. Are the variance conditions being approx satisfied? I wrote yes here before but it seems like no?

In [53]:
cov = empirical_covariance(ys, assume_centered=False)

viol = np.zeros(len(out))
for i in range(len(out)):
    viol[i] = np.matmul(np.matmul(out[i], cov), np.transpose(out[i]))

print(f"Max violation: {alpha}")
print(f"Number of variance constraints which violate the max allowed variance: {sum(viol > alpha+0.1)}")

3. Is the covariance matrix actually measuring the correct thing? A sanity check that the diagonal of the matrix is equal to empirical variance.


In [54]:
print(f"Variances of each coordinate of the ys: \n {np.var(ys, axis=0)}")

print(f"Diagonal of the covariance matrix: \n {np.diagonal(cov)}")

4. Are all the constraints convex and are things feasible?

If C is positive semi-definite, then xCx^T <= val is a convex constraint on R^n. Maybe something went wrong in calculation of C? 

If the eigenvalues of C are positive, then C must be positive semi-definite. Here we see this is the case:

In [55]:
np.linalg.eigvals(cov) > 0

array([ True,  True,  True,  True,  True,  True])

If the only constraint was covariance, xCx^T will be bounded by the eigenvalues of covariance matrix, so this is a good spot check for how feasible things are

In [40]:
np.linalg.eigvals(cov)

array([1.04541420e+00, 8.10321387e-01, 1.04229047e-01, 2.73097313e-03,
       8.41962782e-04, 1.57097101e-03])

In [4]:
%%capture
print('hi, stdout')
print('hi, stderr', file=sys.stderr)

In [5]:
print(5)

5
