In [217]:
import numpy as np
import torch
from torch.nn.utils import clip_grad_norm_


def simplex_projection(s):
    """Projection onto the unit simplex."""
    if np.sum(s) <=1 and np.alltrue(s >= 0):
        return s
    # Code taken from https://gist.github.com/daien/1272551
    # get the array of cumulative sums of a sorted (decreasing) copy of v
    u = np.sort(s)[::-1]
    cssv = np.cumsum(u)
    # get the number of > 0 components of the optimal solution
    rho = np.nonzero(u * np.arange(1, len(u)+1) > (cssv - 1))[0][-1]
    # compute the Lagrange multiplier associated to the simplex constraint
    theta = (cssv[rho] - 1) / (rho + 1.0)
    # compute the projection by thresholding v using theta
    return np.maximum(s-theta, 0)

def svd(A, full_matrices=False, print_flag=False):
    """Projection onto nuclear norm ball."""
    U, s, V = np.linalg.svd(A, full_matrices=full_matrices)
    if print_flag:
        print(s)
    return U, s, V
    
def nuclear_projection(A, print_flag=False):
    """Projection onto nuclear norm ball."""
    U, s, V = svd(A, full_matrices=False, print_flag=print_flag)
    s = simplex_projection(s)
    print(s)
    return U.dot(np.diag(s).dot(V))



In [137]:
# A = np.random.randn(5,5)/5
A = torch.randn(5,5)/3


In [201]:
A_ = nuclear_projection(A)
# A_ = np.array(A_)
_,s_A,_ = svd(A_, print_flag=True)
A_ = torch.tensor(A_)
s_A = torch.tensor(s_A)


[0.81445134 0.18554878 0.         0.         0.        ]
[8.1445134e-01 1.8554880e-01 4.0307957e-09 1.9482722e-09 1.4969507e-12]


In [222]:


def net_Matrix(Xs):
    M = Xs[0]
    if len(M.shape)==2:
        for X in Xs[1:]:
            M = M@X
    elif len(M.shape)==1:
        for X in Xs[1:]:
            M = M*X
    else:
        raise ValueError
    return M

def weight_loss(Xs):
    l=0
    for X in Xs:
        l+= X.pow(2).sum()
    return l


def train(epoch, lr, wd, Xs, A):
    optim = torch.optim.SGD(Xs, lr=lr/len(Xs), momentum=0.0)
    for i in range(epoch):
        optim.zero_grad()
        X = net_Matrix(Xs)
        L = (X - A).pow(2).sum() + wd*weight_loss(Xs)
        L.backward()
        clip_grad_norm_(Xs, clip_val)
        optim.step()

In [223]:
A_

tensor([[ 0.0999, -0.2374, -0.0653,  0.0820, -0.0269],
        [-0.1570,  0.3934,  0.0916, -0.1214,  0.0426],
        [-0.0135,  0.0376,  0.0059, -0.0091,  0.0037],
        [-0.2381,  0.3693,  0.2617, -0.2682,  0.0613],
        [-0.0817,  0.2588,  0.0184, -0.0431,  0.0230]])

In [260]:

N=5
depth = 2 #4 #8 

# Xs = [torch.eye(N, requires_grad=True) for _ in range(depth)]
Xs = [torch.ones(N, requires_grad=True) for _ in range(depth)]

clip_val=0.1/2
lr=4
wd=0.01
epoch=200

train(epoch, lr, wd, Xs, s_A)
X = net_Matrix(Xs)
print(X)
# _,_,_ = svd(X.detach().numpy(), print_flag=True)

lr=1
wd=0.00001
epoch=100

train(epoch, lr, wd, Xs, s_A)
X = net_Matrix(Xs)
print(X)
# _,_,_ = svd(X.detach().numpy(), print_flag=True)


tensor([7.2130e-01, 1.7555e-01, 3.6433e-05, 3.6433e-05, 3.6433e-05],
       grad_fn=<MulBackward0>)
tensor([8.1444e-01, 1.8554e-01, 3.6104e-05, 3.6104e-05, 3.6104e-05],
       grad_fn=<MulBackward0>)


In [210]:

N=5
depth = 8 #2 #6

# Xs = [torch.eye(N, requires_grad=True) for _ in range(depth)]
Xs = [torch.ones(N, requires_grad=True) for _ in range(depth)]

    
lr=0.4
wd=0.01
epoch=2000

train(epoch, lr, wd, Xs, s_A)
X = net_Matrix(Xs)
print(X)
# _,_,_ = svd(X.detach().numpy(), print_flag=True)

lr=0.5
wd=0.00001
epoch=400

train(epoch, lr, wd, Xs, s_A)
X = net_Matrix(Xs)
print(X)
# _,_,_ = svd(X.detach().numpy(), print_flag=True)


tensor([8.0266e-01, 1.4242e-01, 7.7902e-09, 7.7902e-09, 7.7902e-09],
       grad_fn=<MulBackward0>)
tensor([8.1444e-01, 1.8551e-01, 7.7596e-09, 7.7596e-09, 7.7596e-09],
       grad_fn=<MulBackward0>)


In [209]:

N=5
depth = 8 #2 #6

Xs = [torch.eye(N, requires_grad=True) for _ in range(depth)]

lr=0.4
wd=0.01
epoch=2000

train(epoch, lr, wd, Xs, A_)
X = net_Matrix(Xs)
# print(X)
_,_,_ = svd(X.detach().numpy(), print_flag=True)

lr=0.5
wd=0.00001
epoch=400

train(epoch, lr, wd, Xs, A_)
X = net_Matrix(Xs)
# print(X)
_,_,_ = svd(X.detach().numpy(), print_flag=True)


[8.0265892e-01 1.4241293e-01 1.6195807e-08 5.8746785e-09 4.1077870e-09]
[8.1443971e-01 1.8551345e-01 1.0652803e-08 6.9684631e-09 2.1435971e-09]


In [205]:
A_

tensor([[ 0.0999, -0.2374, -0.0653,  0.0820, -0.0269],
        [-0.1570,  0.3934,  0.0916, -0.1214,  0.0426],
        [-0.0135,  0.0376,  0.0059, -0.0091,  0.0037],
        [-0.2381,  0.3693,  0.2617, -0.2682,  0.0613],
        [-0.0817,  0.2588,  0.0184, -0.0431,  0.0230]])

In [206]:
X

tensor([[ 0.1002, -0.2364, -0.0663,  0.0829, -0.0270],
        [-0.1575,  0.3915,  0.0937, -0.1231,  0.0428],
        [-0.0136,  0.0371,  0.0064, -0.0095,  0.0038],
        [-0.2373,  0.3724,  0.2583, -0.2654,  0.0610],
        [-0.0823,  0.2563,  0.0211, -0.0453,  0.0232]], grad_fn=<MmBackward0>)