In [241]:
import torch 
from torch import nn
from torch import FloatTensor
from torch.autograd import Variable
from torch import optim

from numpy.linalg import matrix_rank

from sklearn.datasets import make_spd_matrix

In [47]:
temp = torch.nn.Linear(5, 3)
b = Variable(torch.rand((5, 5)))
a = Variable(torch.rand((2, 5, 5)))

In [50]:
b

Variable containing:
 0.5040  0.6785  0.2569  0.6513  0.8617
 0.4109  0.0840  0.2289  0.8198  0.9596
 0.9519  0.0509  0.2460  0.5015  0.9868
 0.0956  0.9169  0.4598  0.3063  0.6225
 0.6994  0.1210  0.9268  0.7742  0.5865
[torch.FloatTensor of size 5x5]

In [39]:
a[0]

Variable containing:
 0.3181  0.3285  0.7137  0.8696  0.6389
 0.2745  0.7231  0.4505  0.7706  0.4151
 0.1746  0.9885  0.1559  0.9463  0.3336
 0.2589  0.9943  0.8707  0.6326  0.0786
 0.3552  0.6404  0.1191  0.3398  0.8872
[torch.FloatTensor of size 5x5]

In [209]:
class Linear2D(nn.Module):
    
    def __init__(self, in_features, out_features):
        super(Linear2D, self).__init__()
        self.W = nn.Parameter(torch.rand((out_features, in_features)))
        
    def forward(self, X):
        XW = self._matmul(X, self.W.transpose(0 ,1))
        WXW = self._matmul(XW, self.W, kind='left')
        return WXW
        
    def _matmul(self, X, Y, kind='right'):
        results = [] 
        for i in range(X.size(0)):
            if kind == 'right':
                result = torch.mm(X[i], Y)
            elif kind == 'left':
                result = torch.mm(Y, X[i])
            results.append(result.unsqueeze(0))
        return torch.cat(results, 0)
    
    def check_rank(self):
        return min(self.W.data.size()) == matrix_rank(self.W.data.numpy())

In [210]:
class SymmetricallyCleanLayer(nn.Module):
    def __init__(self):
        super(SymmetricallyCleanLayer, self).__init__()
        self.relu = nn.ReLU()
        
    def forward(self, X):
        return self.relu(X) 

In [211]:
class NonLinearityBlock(nn.Module):
    def __init__(self, activation):
        super(NonLinearityBlock, self).__init__()
        self.activation = activation
        
    def forward(self, X):
        return self.activation(X)

In [218]:
class ShrinkBlock(nn.Module):
    def __init__(self, features):
        super(ShrinkBlock, self).__init__()
        self.weights_matrix = nn.Parameter(torch.zeros((features, features)).normal_())
        self.weights_matrix_square = torch.mm(self.weights_matrix.transpose(0, 1), self.weights_matrix)
        self.inverse_eye = Variable(torch.ones((features, features)) - torch.eye(features))
    
    def forward(self, X):
        return X - self.weights_matrix_square * self.inverse_eye

In [231]:
class MatrixEncoder(nn.Module):
    def __init__(self, n_features):
        super(MatrixEncoder, self).__init__()
        self.n_features = n_features
        
        self.encoder = nn.Sequential(
            Linear2D(n_features, 5),
            ShrinkBlock(5),
            SymmetricallyCleanLayer(),
            Linear2D(5, 3),
            ShrinkBlock(3)
        )
        
        self.decoder = nn.Sequential(
            Linear2D(3, 5),
            ShrinkBlock(5),
            SymmetricallyCleanLayer(),
            Linear2D(5, n_features),
            ShrinkBlock(n_features)
        )
        
    def forward(self, X):
        return self.decoder(self.encoder(X))

In [245]:
from torch.nn.modules.loss import MSELoss

In [279]:
coder = MatrixEncoder(6)
matrix = Variable(FloatTensor(make_spd_matrix(6))).unsqueeze(0)
optimizer = optim.Adam(coder.parameters(), lr=0.01)
criterion = MSELoss()

for epoch in range(4000):
    optimizer.zero_grad()
    outputs = coder(matrix)
    loss = criterion(outputs, matrix)
    loss.backward(retain_graph=True)
    if epoch % 100 == 0:
        print(loss.data[0])
    optimizer.step()

85001.234375
164.30828857421875
85.11470031738281
52.89616012573242
36.56184387207031
27.230751037597656
21.408578872680664
17.52837562561035
14.807859420776367
12.82198715209961
11.323862075805664
10.197795867919922
9.350672721862793
8.643263816833496
8.04478931427002
7.532347679138184
7.088569164276123
6.700085639953613
6.356471538543701
6.04948616027832
5.772524356842041
5.520229816436768
5.28825044631958
5.075875759124756
4.876154899597168
4.686581611633301
4.505087375640869
4.329866409301758
4.159339427947998
3.992133140563965
3.8274006843566895
3.6646995544433594
3.50199556350708
3.338881254196167
3.1752147674560547
3.01114559173584
2.8471343517303467
2.6839442253112793
2.5226051807403564
2.3643441200256348


In [280]:
coder(matrix)

Variable containing:
(0 ,.,.) = 
  0.2774 -0.3904 -0.3473  0.5940 -0.0180  0.9289
 -0.3904  0.3194 -0.6425  1.2896  0.8844 -2.9324
 -0.3473 -0.6425  0.2600  0.5464 -0.5799  3.3653
  0.5940  1.2896  0.5464  3.3272  0.0102  0.8790
 -0.0180  0.8844 -0.5799  0.0102  3.7061  0.2997
  0.9289 -2.9324  3.3653  0.8790  0.2997  2.4377
[torch.FloatTensor of size 1x6x6]

In [281]:
matrix

Variable containing:
(0 ,.,.) = 
  1.4477 -0.1077 -0.3719  0.3169 -1.4884 -1.1962
 -0.1077  0.6834  0.1723 -0.1011  0.6562  0.3868
 -0.3719  0.1723  0.6402 -0.1156  0.6623  0.6781
  0.3169 -0.1011 -0.1156  0.7181 -0.5229 -0.3761
 -1.4884  0.6562  0.6623 -0.5229  3.7567  2.2983
 -1.1962  0.3868  0.6781 -0.3761  2.2983  2.4977
[torch.FloatTensor of size 1x6x6]

In [265]:
coder2 = MatrixEncoder(6)
coder2(matrix)

Variable containing:
(0 ,.,.) = 
  403.3969  403.4521  328.4248  433.2410  286.2404  272.7939
  403.4521  411.7397  332.8844  438.6746  293.6486  275.8667
  328.4248  332.8845  266.6563  349.3683  237.2738  229.8308
  433.2410  438.6745  349.3683  466.1306  314.3354  298.2518
  286.2404  293.6486  237.2738  314.3354  212.1189  196.8404
  272.7939  275.8667  229.8308  298.2518  196.8404  191.2636
[torch.FloatTensor of size 1x6x6]

In [None]:
coder(mat)

In [283]:
import SPD

In [284]:
SPD.run_test()

38152.85546875
112.74282836914062
52.5504264831543
32.5604362487793
23.447019577026367
18.202831268310547
14.932071685791016
12.75731372833252
11.233356475830078
10.116395950317383
9.264239311218262
8.58940315246582
8.035533905029297
7.5648345947265625
7.1509881019592285
6.775049686431885
6.422967910766602
6.08411979675293
5.750424385070801
5.4159111976623535
5.076608657836914
4.730676651000977
4.378704071044922
4.024061679840088
3.6730358600616455
3.334488868713379
3.0187008380889893
2.735435724258423
2.4916892051696777
2.290128707885742
2.128908395767212
2.002943515777588
1.9057472944259644
1.8309561014175415
1.773125171661377
1.7279243469238281
1.6920229196548462
1.6628886461257935
1.6386020183563232
1.6177077293395996


In [235]:
list(block.parameters())

[Parameter containing:
  0.2013  0.0128  0.7827  0.3468  0.1750  0.1051  0.1103  0.4556
  0.8672  0.3314  0.1549  0.9517  0.3401  0.7927  0.9596  0.3282
 [torch.FloatTensor of size 2x8]]

In [236]:
list(block3.parameters())

[Parameter containing:
  0.7057  1.4323
 -0.5567  0.7995
 [torch.FloatTensor of size 2x2]]

In [237]:
list(block4.parameters())

[Parameter containing:
  0.0850  0.3646
  0.6269  0.2221
  0.8985  0.4029
  0.4628  0.0708
  0.1214  0.6744
  0.4403  0.0930
  0.5214  0.4614
  0.1011  0.9890
 [torch.FloatTensor of size 8x2]]