In [1]:
import torch

In [39]:
class SC(torch.nn.Module):
  
  def __init__(self, N, P, T):
    """
    h: p x 1
    L: N x T
    delta: T x 1
    gamma: N x 1
    """
    super(SC, self).__init__()
    
    self.h = torch.zeros(P, 1, requires_grad=True)
    self.L = torch.zeros(N, T, requires_grad=True)
    self.delta = torch.zeros(T, 1, requires_grad=True)
    self.gamma = torch.zeros(N, 1, requires_grad=True)
    
    self.N = N
    self.P = P
    self.T = T
    
    self.lambda_L = 1
    self.lambda_h = 1
        
  def forward(self, X, Y):
    """
    X: N x P
    """
    
    Xh = torch.matmul(X, self.h)  # N x 1
    gamma1 = torch.matmul(self.gamma, torch.ones(1, self.T)) # N x T
    delta1 = torch.matmul(torch.ones(self.N, 1), self.delta.T)
    
    loss = torch.norm(Y - self.L - Xh - gamma1 - delta1, 2) / (self.N * self.T)
    
    loss_penalty = loss + self.lambda_L * torch.norm(self.L, 'nuc') + self.lambda_h * torch.sum(torch.abs(self.h))    

    return loss_penalty

In [40]:
N = 1000
P = 10
T = 100

X = torch.ones(N, P)
Y = torch.ones(N, T)

In [41]:
model = SC(N, P, T)

In [44]:
optimizer = torch.optim.SGD([model.h, model.L, model.delta, model.gamma],
                            lr=0.01, momentum=0.9)

In [50]:
for _ in range(10000):
  optimizer.zero_grad()
  loss = model.forward(X+1, Y)
  loss.backward()
  optimizer.step()

In [52]:
model.L

tensor([[-5.1445e-03,  3.2464e-05,  3.2493e-05,  ...,  3.2471e-05,
          3.2464e-05,  3.2598e-05],
        [ 3.2501e-05, -5.1444e-03,  3.2492e-05,  ...,  3.2579e-05,
          3.2556e-05,  3.2560e-05],
        [ 3.2497e-05,  3.2555e-05, -5.1442e-03,  ...,  3.2544e-05,
          3.2513e-05,  3.2610e-05],
        ...,
        [-1.5553e-06, -1.5721e-06, -1.5653e-06,  ..., -1.5649e-06,
         -1.5641e-06, -1.5676e-06],
        [-1.5553e-06, -1.5721e-06, -1.5653e-06,  ..., -1.5649e-06,
         -1.5641e-06, -1.5676e-06],
        [-1.5553e-06, -1.5721e-06, -1.5653e-06,  ..., -1.5649e-06,
         -1.5641e-06, -1.5676e-06]], requires_grad=True)