In [1]:
import numpy as np
import statsmodels.api as sm
from sklearn.model_selection import cross_val_score, cross_validate


In [2]:
def gen_delta_grid(X, c, n):
    """ Returns list of n random matrices such the the l2 norm of the jth colum is equal to c[j]
    """
    dim = X.shape
    matrix_grid = []
    for _ in range(n):
        # generate random nxp matrix
        rand_mat = np.random.randn(*dim)
        # append list with matrix with normalized columns = cj 
        matrix_grid.append((rand_mat / np.sqrt(np.square(rand_mat).sum(axis=0))[None, :])*c[None, :])
    return matrix_grid

Testing delta grid function

In [3]:
X = np.random.rand(10,5)

In [4]:
c=np.random.randint(1, 20, size=5)

In [5]:
vals  =  gen_delta_grid(X, c=c, n=15)

In [6]:
len(vals)

15

In [7]:
c

array([11,  5,  9, 11,  9])

In [8]:
for i in vals:
    print(np.sqrt(np.square(i).sum(axis=0)))

[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]
[11.  5.  9. 11.  9.]


### Alternative robustness

In case we need to implement alternative where c a vector of bounds we optimize over

Thought process: 
 - each c is the same bound for all the column l2 norms (this is the case when robust equals lasso)
 - In this case the original "delta" matrix can vary since we are optimizing over c, so we generate it once before cross validation.

In [9]:
def robust(X, c, mat):
    "Returns matrix such that "
    delta = (mat / np.sqrt(np.square(mat).sum(axis=0))[None, :])*c
    return delta + X

In [10]:
X = np.random.rand(10,5)

In [11]:
mat =  np.random.rand(10,5)

In [12]:
c=4

In [13]:
delta = (mat / np.sqrt(np.square(mat).sum(axis=0))[None, :])*c

In [14]:
np.sqrt(np.square(delta).sum(axis=0))

array([4., 4., 4., 4., 4.])

In [15]:
robust(X,c,mat)

array([[1.55638977, 1.85806391, 1.24029493, 0.21505146, 0.9228908 ],
       [0.95692248, 1.15784485, 2.22916218, 1.24780699, 2.16761251],
       [2.31322771, 2.10847926, 1.88937537, 2.08335792, 0.67251758],
       [2.43670609, 1.03208406, 1.58506442, 0.50702534, 2.73457836],
       [2.24596614, 2.05336586, 0.93634453, 2.34992408, 0.56218328],
       [1.20037504, 1.32584986, 2.4559286 , 1.9469776 , 1.97071833],
       [1.0966167 , 0.86715091, 1.41927029, 1.8786771 , 0.24614737],
       [1.03706903, 2.2512734 , 1.93088597, 1.0593049 , 2.95791712],
       [2.22783101, 1.9679335 , 1.94754847, 1.22041397, 1.99666461],
       [2.0320349 , 1.12553208, 1.68083529, 1.95077616, 0.35247605]])