In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torch.nn.utils import weight_norm

from itertools import combinations
import numpy as np
import os
from scipy.sparse import csr_matrix
import secrets

In [2]:
def get_weights(X, gauss_coef=1.0, neighbors=None):
    ''' 
    Construct convext clustering weights according to approaches from Hocking et al. 2011 
    and Chi and Lange, 2013. X is the original data matrix with observations in the rows. 
    ''' 
    from scipy.spatial.distance import pdist, squareform
    dist_vec = pdist(X) / X.shape[0]
    w = np.exp(-1*gauss_coef*(dist_vec)**2)
    if neighbors is not None:
        from sklearn.neighbors import NearestNeighbors
        nbrs = NearestNeighbors(n_neighbors=neighbors, algorithm='ball_tree').fit(X)
        _, indices = nbrs.kneighbors(X)
        comb_list = list(combinations(range(X.shape[0]),2))
        neighbors_indicator = []
        for comb in comb_list:
            nn_i = indices[comb[0],:]
            if comb[1] in nn_i:
                neighbors_indicator.append(1.0)
            else:
                neighbors_indicator.append(0.0)
        w *= np.array(neighbors_indicator)
    return w


def sparse_D(n,p):
    '''
    Construct a sparse matrix, that when applied to a vector containing concatenated vectors
    of coefficients b = [b_1 b_2 ... b_n] where each b_i is p=num_var long and there are
    n = num_vec of them. Differences are taken between conformal elements (e.g. b_11 and b_21)
    across all unique pairwise combinations of vectors.
    '''
    comb_list = list(combinations(range(n),2))
    combs_arr = np.array(comb_list)
    num_combs = combs_arr.shape[0]
    data = np.ones_like(combs_arr)
    data[:,1] *= -1
    row = np.repeat(range(num_combs),2)
    col = combs_arr.flatten()
    D = csr_matrix((data.flatten(), (row, col)), shape=(num_combs, n))
    return sparse_mx_to_torch_sparse_tensor(D)

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    '''
    Convert a scipy sparse matrix to a torch sparse tensor.
    '''
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def convclust_penalty(recons, weights, wasserstein=True, q=2):
    '''
    Computes the differences between all rows of the output.
    :return: (Tensor)
    '''
    import torch
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    n,p = recons.shape        
    D = sparse_D(n,p).to(device)
    if wasserstein:
        recons = torch.sort(recons)[0]
    diffs = torch.norm(D.matmul(recons), q, dim=1)
    return torch.norm(torch.mul(weights, diffs), 1)


In [3]:
class LinearAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, cc_lambda=0, init_scale=0.001, weight_reg_type=None):
        super(LinearAE, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.encoder = nn.Linear(input_dim, hidden_dim, bias=False)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)
        
        self.encoder.weight.data.normal_(0.0, init_scale)
        self.decoder.weight.data.normal_(0.0, init_scale)
        
        self.cc_lambda = cc_lambda
        self.weight_reg_type = weight_reg_type

    def set_weights(self,x):
        x = x.to(device) #cuda()
        weights = torch.from_numpy(get_weights(x.detach().cpu().numpy().reshape(x.shape[0], \
                                                                             np.prod(x.shape[1::])))).to(device) 
        self.convclust_weights = weights
            
    def forward(self, x):
        return self.get_reconstruction_loss(x) + self._get_reg_loss() #+ self.cc_lambda*self.get_convclust_loss(x)
        
    def compute_trace_norm(self):
        """
        Computes the trace norm of the autoencoder, as well as decoder and encoder individually
        :return: trace_norm(W2W1), trace_norm(W1), trace_norm(W2)
        """
        return torch.matmul(self.decoder.weight, self.encoder.weight).norm(p='nuc'), \
               self.encoder.weight.norm(p='nuc'), \
               self.decoder.weight.norm(p='nuc'),

    def get_reconstruction_loss(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        # MSE loss with mean reduction
        recon_loss = torch.sum((x - recon) ** 2) / len(x) 
        return recon_loss
    
    def get_convclust_loss(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)
        weights = self.convclust_weights
        # Convex cluster penalty loss
        convclust_loss = convclust_penalty(recon, weights, wasserstein=False)
        return convclust_loss

    def _get_reg_loss(self):
        # Standard L2 regularization, applied to W2W1 (product loss)
        if self.weight_reg_type == 'uniform_product':
            return self.l2_reg_scalar * (torch.norm(torch.matmul(self.decoder.weight, self.encoder.weight)) ** 2)

        # Standard L2 regularization for encoder and decoder separately (sum loss)
        elif self.weight_reg_type == 'uniform_sum':
            # regularize both encoder and decoder
            return self.l2_reg_scalar * (torch.norm(self.encoder.weight) ** 2 + torch.norm(self.decoder.weight) ** 2)

        # non-uniform sum
        elif self.weight_reg_type == 'non_uniform_sum':
            return torch.norm(self.diag_weights @ self.encoder.weight) ** 2 \
                   + torch.norm(self.decoder.weight @ self.diag_weights) ** 2

        # Do not apply regularization
        elif self.weight_reg_type is None:
            return 0.0

        else:
            raise ValueError("weight_reg_type should be one of (uniform_product, uniform_sum, non_uniform_sum, None)")


In [4]:
for x in data_loader:
#     print(x[100,:])
    x_cuda = x.to(device)
    lae_model = LinearAE(input_dim=28*28, cc_lambda=0, hidden_dim=20, init_scale=0.0001).to(device)
#     lae_model.set_weights(x)
    

NameError: name 'data_loader' is not defined

In [None]:
x.shape

In [5]:
lae_model.set_weights(x)

lae_model.convclust_weights
lae_model.get_convclust_loss(x)
# z = lae_model.encoder(x)
# z.shape
# recon = lae_model.decoder(z)

NameError: name 'lae_model' is not defined

In [21]:
def train_model(data, data_loader, input_dim, cc_lambda=0, hidden_dim=2, weight_reg_type=None, lr=0.0001, init_scale = 0.0001, train_itr=100):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    train_loss = 0
    for train_i in range(train_itr):
        losses = {}
        for batch_idx, x  in enumerate(data_loader):
            x_cuda = x.to(device)

    #         # ---- Log weights ----
    #         if (train_i == 0) or (train_i + 1) % 100 == 0:
    #             weight_history.log_weights(epoch=train_i)

            # ---- Optimize ----
            lae_model = LinearAE(input_dim=input_dim, cc_lambda=cc_lambda, hidden_dim=hidden_dim, init_scale=init_scale).to(device)
            #lae_model.set_weights(x)
            optimizer = torch.optim.SGD(lae_model.parameters(), lr=lr, momentum=0.9, nesterov=True)
            optimizer.zero_grad()

            loss = lae_model(x_cuda)

            loss.backward()

            # Rotation (https://github.com/XuchanBao/linear-ae/)
            y = lae_model.encoder.weight @ x_cuda.T
            yy_t_norm = y @ y.T / float(len(x))
            yy_t_upper = yy_t_norm - yy_t_norm.tril()
            gamma = 0.5 * (yy_t_upper - yy_t_upper.T)
            lae_model.encoder.weight.grad -= gamma @ lae_model.encoder.weight
            lae_model.decoder.weight.grad -= lae_model.decoder.weight @ gamma.T

            optimizer.step()

#                 losses = loss.item()
            train_loss += loss.item()

#             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                         train_i, batch_idx * len(x), len(data_loader.dataset),
#                         100. * batch_idx / len(data_loader), loss.item() / len(x)))
            
            if (train_i + 1) % 10 == 0:
                print('====> Epoch:',train_i,'Alignment:',metric_alignment(lae_model, data.eigvectors))
    #             print('Train Epoch:',train_i,'Subspace metric:',metric_subspace(lae_model, data.eigvectors, data.eigs))

        print('====> Epoch: {} Convclust lambda: {}  Average loss: {:.4f}'.format(train_i, cc_lambda,\
                                                                               train_loss / len(data_loader.dataset)))

    return lae_model
            

In [7]:
from torch.utils.data import Dataset

import numpy as np
from scipy.stats import ortho_group

class DataGeneratorPCA(Dataset):
    def __init__(self, dims, hdims, min_sv=0.11, max_sv=5.0, total=10000, sv_list=None,
                 load_data=None):
        self.dims = dims
        self.hdims = hdims

        if load_data is None:
            if isinstance(sv_list, list):
                assert len(sv_list) == dims
                self.full_eigs = np.array(sorted(sv_list, reverse=True))
            else:
                self.full_eigs = min_sv + (max_sv - min_sv) * np.linspace(1, 0, dims)
            self.eigs = self.full_eigs[:hdims]

            self.full_svs = np.sqrt(self.full_eigs)

            self.full_eigvectors = ortho_group.rvs(dims)
            self.eigvectors = self.full_eigvectors[:, :hdims]

            self.total = total

            self.full_z_sample = np.random.normal(size=(total, self.dims))
            self.x_sample = (self.full_eigvectors @ np.diag(self.full_svs) @ self.full_z_sample.T).T.astype(np.float32)

        else:
            self.x_sample = load_data
            u, s, vh = np.linalg.svd(self.x_sample.T, full_matrices=False)
            self.eigs = s[:self.hdims]
            self.eigvectors = u[:, :self.hdims]
            self.total = len(self.x_sample)

    def __getitem__(self, i):
        return self.x_sample[i]

    def __len__(self):
        return self.total

    @property
    def shape(self):
        return self.x_sample.shape

In [8]:
def get_weight_tensor_from_seq(weight_seq):
    if isinstance(weight_seq, nn.Linear):
        return weight_seq.weight.detach()
    elif isinstance(weight_seq, nn.Sequential):
        weight_tensor = None
        for layer in weight_seq:
            if isinstance(layer, nn.Linear):
                layer_weight = layer.weight.detach()
                if weight_tensor is None:
                    weight_tensor = layer_weight
                else:
                    weight_tensor = layer_weight @ weight_tensor
            elif isinstance(layer, nn.BatchNorm1d):
                bn_weight = layer.weight.detach()

                # ignore bias

                if weight_tensor is None:
                    weight_tensor = torch.diag(bn_weight)
                else:
                    weight_tensor = torch.diag(bn_weight) @ weight_tensor
            else:
                raise ValueError("Layer type {} not supported!".format(type(layer)))
        return weight_tensor

def metric_alignment(model, gt_eigvectors):
    """
    Metric for alignment of decoder columns to ground truth eigenvectors
    :param model: Linear AE model
    :param gt_eigvectors: ground truth eigenvectors (input_dims,hidden_dims)
    :return: sum_i (1 - max_j (cos(eigvector_i, normalized_decoder column_j)))
    """
    decoder_weight = get_weight_tensor_from_seq(model.decoder)
    decoder_np = decoder_weight.detach().cpu().numpy()

    # normalize columns of gt_eigvectors
    norm_gt_eigvectors = gt_eigvectors / np.linalg.norm(gt_eigvectors, axis=0)
    # normalize columns of decoder
    norm_decoder = decoder_np / (np.linalg.norm(decoder_np, axis=0) + 1e-8)

    total_angles = 0.0
    for eig_i in range(gt_eigvectors.shape[1]):
        eigvector = norm_gt_eigvectors[:, eig_i]
        total_angles += 1. - np.max(np.abs(norm_decoder.T @ eigvector)) ** 2

    return total_angles / float(model.hidden_dim)


def metric_subspace(model, gt_eigvectors, gt_eigs):
    decoder_weight = get_weight_tensor_from_seq(model.decoder)
    decoder_np = decoder_weight.detach().cpu().numpy()

    # k - tr(UU^T WW^T), where W is left singular vector matrix of decoder
    u, s, vh = np.linalg.svd(decoder_np, full_matrices=False)
    return 1 - np.trace(gt_eigvectors @ gt_eigvectors.T @ u @ u.T) / float(model.hidden_dim)


In [9]:
import torchvision
# Get MNIST data
input_dim = 28 * 28
hidden_dim = 20

mnist_data = torchvision.datasets.MNIST(root='./data', train=True, download=True,
                                  transform=torchvision.transforms.Compose([
                                      torchvision.transforms.ToTensor()
                                  ]))
# full batch
batch_size = len(mnist_data)

mnist_loader = torch.utils.data.DataLoader(mnist_data, batch_size=batch_size, shuffle=True)

_, (raw_data, __) = next(enumerate(mnist_loader))
raw_data = torch.squeeze(raw_data.view(-1, input_dim))
raw_data = raw_data#[0:200,:]
labels = mnist_data.targets#[0:200]

# Center the data, and find ground truth principle directions
data_mean = torch.mean(raw_data, dim=0)
centered_data = raw_data - data_mean

data = DataGeneratorPCA(input_dim, hidden_dim, load_data=centered_data.numpy())
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False)


In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
out_model = train_model(data, data_loader, input_dim, cc_lambda=0, hidden_dim=hidden_dim, lr=0.0001, init_scale = 0.0001, train_itr=1000)



====> Epoch: 0 Alignment: 0.9940118028664487
====> Epoch: 0 Convclust lambda: 0  Average loss: 0.0009
====> Epoch: 1 Alignment: 0.9926550648679822
====> Epoch: 1 Convclust lambda: 0  Average loss: 0.0018
====> Epoch: 2 Alignment: 0.9938786906066399
====> Epoch: 2 Convclust lambda: 0  Average loss: 0.0026
====> Epoch: 3 Alignment: 0.9940317102809997
====> Epoch: 3 Convclust lambda: 0  Average loss: 0.0035
====> Epoch: 4 Alignment: 0.9924410944538294
====> Epoch: 4 Convclust lambda: 0  Average loss: 0.0044
====> Epoch: 5 Alignment: 0.99479541308992
====> Epoch: 5 Convclust lambda: 0  Average loss: 0.0053
====> Epoch: 6 Alignment: 0.9937208910095855
====> Epoch: 6 Convclust lambda: 0  Average loss: 0.0062
====> Epoch: 7 Alignment: 0.9931135122298956
====> Epoch: 7 Convclust lambda: 0  Average loss: 0.0070
====> Epoch: 8 Alignment: 0.9937744880561373
====> Epoch: 8 Convclust lambda: 0  Average loss: 0.0079
====> Epoch: 9 Alignment: 0.9936334156401474
====> Epoch: 9 Convclust lambda: 0  Ave

====> Epoch: 80 Alignment: 0.9936434815852293
====> Epoch: 80 Convclust lambda: 0  Average loss: 0.0712
====> Epoch: 81 Alignment: 0.9939889932372361
====> Epoch: 81 Convclust lambda: 0  Average loss: 0.0721
====> Epoch: 82 Alignment: 0.9935717368562758
====> Epoch: 82 Convclust lambda: 0  Average loss: 0.0729
====> Epoch: 83 Alignment: 0.9939237062755627
====> Epoch: 83 Convclust lambda: 0  Average loss: 0.0738
====> Epoch: 84 Alignment: 0.9941976979148542
====> Epoch: 84 Convclust lambda: 0  Average loss: 0.0747
====> Epoch: 85 Alignment: 0.994064012998734
====> Epoch: 85 Convclust lambda: 0  Average loss: 0.0756
====> Epoch: 86 Alignment: 0.9936997741656846
====> Epoch: 86 Convclust lambda: 0  Average loss: 0.0765
====> Epoch: 87 Alignment: 0.9938919150458064
====> Epoch: 87 Convclust lambda: 0  Average loss: 0.0773
====> Epoch: 88 Alignment: 0.9937770030573896
====> Epoch: 88 Convclust lambda: 0  Average loss: 0.0782
====> Epoch: 89 Alignment: 0.9934561874196064
====> Epoch: 89 Con

====> Epoch: 158 Alignment: 0.993635806503175
====> Epoch: 158 Convclust lambda: 0  Average loss: 0.1397
====> Epoch: 159 Alignment: 0.9939917253115755
====> Epoch: 159 Convclust lambda: 0  Average loss: 0.1406
====> Epoch: 160 Alignment: 0.9938457169630815
====> Epoch: 160 Convclust lambda: 0  Average loss: 0.1415
====> Epoch: 161 Alignment: 0.9934106248787481
====> Epoch: 161 Convclust lambda: 0  Average loss: 0.1424
====> Epoch: 162 Alignment: 0.9928807023070334
====> Epoch: 162 Convclust lambda: 0  Average loss: 0.1432
====> Epoch: 163 Alignment: 0.9947784761600182
====> Epoch: 163 Convclust lambda: 0  Average loss: 0.1441
====> Epoch: 164 Alignment: 0.9921518183744837
====> Epoch: 164 Convclust lambda: 0  Average loss: 0.1450
====> Epoch: 165 Alignment: 0.9940095510706053
====> Epoch: 165 Convclust lambda: 0  Average loss: 0.1459
====> Epoch: 166 Alignment: 0.99283204535298
====> Epoch: 166 Convclust lambda: 0  Average loss: 0.1468
====> Epoch: 167 Alignment: 0.9924002009135162
==

====> Epoch: 236 Alignment: 0.9944929700631011
====> Epoch: 236 Convclust lambda: 0  Average loss: 0.2083
====> Epoch: 237 Alignment: 0.9940420802797695
====> Epoch: 237 Convclust lambda: 0  Average loss: 0.2091
====> Epoch: 238 Alignment: 0.9936647965512428
====> Epoch: 238 Convclust lambda: 0  Average loss: 0.2100
====> Epoch: 239 Alignment: 0.9937563652030974
====> Epoch: 239 Convclust lambda: 0  Average loss: 0.2109
====> Epoch: 240 Alignment: 0.9941953641847864
====> Epoch: 240 Convclust lambda: 0  Average loss: 0.2118
====> Epoch: 241 Alignment: 0.9942969194377278
====> Epoch: 241 Convclust lambda: 0  Average loss: 0.2127
====> Epoch: 242 Alignment: 0.9935711059178184
====> Epoch: 242 Convclust lambda: 0  Average loss: 0.2135
====> Epoch: 243 Alignment: 0.9944952171065788
====> Epoch: 243 Convclust lambda: 0  Average loss: 0.2144
====> Epoch: 244 Alignment: 0.992790591432545
====> Epoch: 244 Convclust lambda: 0  Average loss: 0.2153
====> Epoch: 245 Alignment: 0.9940909481955451


====> Epoch: 314 Alignment: 0.9928680158285417
====> Epoch: 314 Convclust lambda: 0  Average loss: 0.2768
====> Epoch: 315 Alignment: 0.9938561725010823
====> Epoch: 315 Convclust lambda: 0  Average loss: 0.2777
====> Epoch: 316 Alignment: 0.992545597198901
====> Epoch: 316 Convclust lambda: 0  Average loss: 0.2786
====> Epoch: 317 Alignment: 0.9934399914195076
====> Epoch: 317 Convclust lambda: 0  Average loss: 0.2794
====> Epoch: 318 Alignment: 0.9937748183722782
====> Epoch: 318 Convclust lambda: 0  Average loss: 0.2803
====> Epoch: 319 Alignment: 0.9941102271375364
====> Epoch: 319 Convclust lambda: 0  Average loss: 0.2812
====> Epoch: 320 Alignment: 0.9935419721583786
====> Epoch: 320 Convclust lambda: 0  Average loss: 0.2821
====> Epoch: 321 Alignment: 0.9936502215616881
====> Epoch: 321 Convclust lambda: 0  Average loss: 0.2830
====> Epoch: 322 Alignment: 0.9932566706616395
====> Epoch: 322 Convclust lambda: 0  Average loss: 0.2838
====> Epoch: 323 Alignment: 0.9943276434158491


====> Epoch: 392 Alignment: 0.9942551741974421
====> Epoch: 392 Convclust lambda: 0  Average loss: 0.3453
====> Epoch: 393 Alignment: 0.9935868228660748
====> Epoch: 393 Convclust lambda: 0  Average loss: 0.3462
====> Epoch: 394 Alignment: 0.9944568446275632
====> Epoch: 394 Convclust lambda: 0  Average loss: 0.3471
====> Epoch: 395 Alignment: 0.9937445769839297
====> Epoch: 395 Convclust lambda: 0  Average loss: 0.3480
====> Epoch: 396 Alignment: 0.9939190717550751
====> Epoch: 396 Convclust lambda: 0  Average loss: 0.3489
====> Epoch: 397 Alignment: 0.9945026190726625
====> Epoch: 397 Convclust lambda: 0  Average loss: 0.3497
====> Epoch: 398 Alignment: 0.994539074461313
====> Epoch: 398 Convclust lambda: 0  Average loss: 0.3506
====> Epoch: 399 Alignment: 0.9938751185211345
====> Epoch: 399 Convclust lambda: 0  Average loss: 0.3515
====> Epoch: 400 Alignment: 0.9934803489196566
====> Epoch: 400 Convclust lambda: 0  Average loss: 0.3524
====> Epoch: 401 Alignment: 0.9952243638873883


====> Epoch: 470 Alignment: 0.9940533145406247
====> Epoch: 470 Convclust lambda: 0  Average loss: 0.4139
====> Epoch: 471 Alignment: 0.9935975487272548
====> Epoch: 471 Convclust lambda: 0  Average loss: 0.4148
====> Epoch: 472 Alignment: 0.9933190589802606
====> Epoch: 472 Convclust lambda: 0  Average loss: 0.4156
====> Epoch: 473 Alignment: 0.9935986027060295
====> Epoch: 473 Convclust lambda: 0  Average loss: 0.4165
====> Epoch: 474 Alignment: 0.9934297976295069
====> Epoch: 474 Convclust lambda: 0  Average loss: 0.4174
====> Epoch: 475 Alignment: 0.9942889338958496
====> Epoch: 475 Convclust lambda: 0  Average loss: 0.4183
====> Epoch: 476 Alignment: 0.9929721259712133
====> Epoch: 476 Convclust lambda: 0  Average loss: 0.4192
====> Epoch: 477 Alignment: 0.994171326471745
====> Epoch: 477 Convclust lambda: 0  Average loss: 0.4200
====> Epoch: 478 Alignment: 0.9939132137420048
====> Epoch: 478 Convclust lambda: 0  Average loss: 0.4209
====> Epoch: 479 Alignment: 0.9940127750514332


====> Epoch: 548 Alignment: 0.9942087140319767
====> Epoch: 548 Convclust lambda: 0  Average loss: 0.4824
====> Epoch: 549 Alignment: 0.9943831141424992
====> Epoch: 549 Convclust lambda: 0  Average loss: 0.4833
====> Epoch: 550 Alignment: 0.9931845790715441
====> Epoch: 550 Convclust lambda: 0  Average loss: 0.4842
====> Epoch: 551 Alignment: 0.9942511163122012
====> Epoch: 551 Convclust lambda: 0  Average loss: 0.4851
====> Epoch: 552 Alignment: 0.9950009829929114
====> Epoch: 552 Convclust lambda: 0  Average loss: 0.4859
====> Epoch: 553 Alignment: 0.9947007229460816
====> Epoch: 553 Convclust lambda: 0  Average loss: 0.4868
====> Epoch: 554 Alignment: 0.9930559268011315
====> Epoch: 554 Convclust lambda: 0  Average loss: 0.4877
====> Epoch: 555 Alignment: 0.9928559308761373
====> Epoch: 555 Convclust lambda: 0  Average loss: 0.4886
====> Epoch: 556 Alignment: 0.9946269636313858
====> Epoch: 556 Convclust lambda: 0  Average loss: 0.4895
====> Epoch: 557 Alignment: 0.9924142058812275

====> Epoch: 626 Alignment: 0.9944694482916031
====> Epoch: 626 Convclust lambda: 0  Average loss: 0.5510
====> Epoch: 627 Alignment: 0.9943714843430074
====> Epoch: 627 Convclust lambda: 0  Average loss: 0.5519
====> Epoch: 628 Alignment: 0.994286927914847
====> Epoch: 628 Convclust lambda: 0  Average loss: 0.5527
====> Epoch: 629 Alignment: 0.9943963710703869
====> Epoch: 629 Convclust lambda: 0  Average loss: 0.5536
====> Epoch: 630 Alignment: 0.9938633476664218
====> Epoch: 630 Convclust lambda: 0  Average loss: 0.5545
====> Epoch: 631 Alignment: 0.9934862543039831
====> Epoch: 631 Convclust lambda: 0  Average loss: 0.5554
====> Epoch: 632 Alignment: 0.9942025372100807
====> Epoch: 632 Convclust lambda: 0  Average loss: 0.5562
====> Epoch: 633 Alignment: 0.9924513960159752
====> Epoch: 633 Convclust lambda: 0  Average loss: 0.5571
====> Epoch: 634 Alignment: 0.9936925835022619
====> Epoch: 634 Convclust lambda: 0  Average loss: 0.5580
====> Epoch: 635 Alignment: 0.9937059316496875


====> Epoch: 704 Alignment: 0.9940193430006176
====> Epoch: 704 Convclust lambda: 0  Average loss: 0.6195
====> Epoch: 705 Alignment: 0.9940084874772195
====> Epoch: 705 Convclust lambda: 0  Average loss: 0.6204
====> Epoch: 706 Alignment: 0.9935378828107458
====> Epoch: 706 Convclust lambda: 0  Average loss: 0.6213
====> Epoch: 707 Alignment: 0.9938255908905174
====> Epoch: 707 Convclust lambda: 0  Average loss: 0.6222
====> Epoch: 708 Alignment: 0.994009281343309
====> Epoch: 708 Convclust lambda: 0  Average loss: 0.6230
====> Epoch: 709 Alignment: 0.9942598071520703
====> Epoch: 709 Convclust lambda: 0  Average loss: 0.6239
====> Epoch: 710 Alignment: 0.9944424389561067
====> Epoch: 710 Convclust lambda: 0  Average loss: 0.6248
====> Epoch: 711 Alignment: 0.9940439873730064
====> Epoch: 711 Convclust lambda: 0  Average loss: 0.6257
====> Epoch: 712 Alignment: 0.9936895249972304
====> Epoch: 712 Convclust lambda: 0  Average loss: 0.6265
====> Epoch: 713 Alignment: 0.9939018751715208


====> Epoch: 782 Alignment: 0.9937142934324296
====> Epoch: 782 Convclust lambda: 0  Average loss: 0.6881
====> Epoch: 783 Alignment: 0.9930343639923672
====> Epoch: 783 Convclust lambda: 0  Average loss: 0.6889
====> Epoch: 784 Alignment: 0.9938272227048082
====> Epoch: 784 Convclust lambda: 0  Average loss: 0.6898
====> Epoch: 785 Alignment: 0.9934363830291975
====> Epoch: 785 Convclust lambda: 0  Average loss: 0.6907
====> Epoch: 786 Alignment: 0.9940803421216893
====> Epoch: 786 Convclust lambda: 0  Average loss: 0.6916
====> Epoch: 787 Alignment: 0.9934277094519863
====> Epoch: 787 Convclust lambda: 0  Average loss: 0.6925
====> Epoch: 788 Alignment: 0.9944129232373726
====> Epoch: 788 Convclust lambda: 0  Average loss: 0.6933
====> Epoch: 789 Alignment: 0.9943877511855543
====> Epoch: 789 Convclust lambda: 0  Average loss: 0.6942
====> Epoch: 790 Alignment: 0.9936050557364478
====> Epoch: 790 Convclust lambda: 0  Average loss: 0.6951
====> Epoch: 791 Alignment: 0.9935960845842879

====> Epoch: 860 Alignment: 0.9942686180109079
====> Epoch: 860 Convclust lambda: 0  Average loss: 0.7566
====> Epoch: 861 Alignment: 0.994577649675495
====> Epoch: 861 Convclust lambda: 0  Average loss: 0.7575
====> Epoch: 862 Alignment: 0.9930870858379338
====> Epoch: 862 Convclust lambda: 0  Average loss: 0.7584
====> Epoch: 863 Alignment: 0.9945865198139725
====> Epoch: 863 Convclust lambda: 0  Average loss: 0.7592
====> Epoch: 864 Alignment: 0.9945551818892515
====> Epoch: 864 Convclust lambda: 0  Average loss: 0.7601
====> Epoch: 865 Alignment: 0.9946492572768024
====> Epoch: 865 Convclust lambda: 0  Average loss: 0.7610
====> Epoch: 866 Alignment: 0.9945198417057257
====> Epoch: 866 Convclust lambda: 0  Average loss: 0.7619
====> Epoch: 867 Alignment: 0.9941097819818262
====> Epoch: 867 Convclust lambda: 0  Average loss: 0.7628
====> Epoch: 868 Alignment: 0.9934750709536434
====> Epoch: 868 Convclust lambda: 0  Average loss: 0.7636
====> Epoch: 869 Alignment: 0.9943587707041062


====> Epoch: 938 Alignment: 0.9935941128944232
====> Epoch: 938 Convclust lambda: 0  Average loss: 0.8251
====> Epoch: 939 Alignment: 0.9938057034684074
====> Epoch: 939 Convclust lambda: 0  Average loss: 0.8260
====> Epoch: 940 Alignment: 0.9950664376562534
====> Epoch: 940 Convclust lambda: 0  Average loss: 0.8269
====> Epoch: 941 Alignment: 0.9938891762884442
====> Epoch: 941 Convclust lambda: 0  Average loss: 0.8278
====> Epoch: 942 Alignment: 0.993649017395275
====> Epoch: 942 Convclust lambda: 0  Average loss: 0.8287
====> Epoch: 943 Alignment: 0.9920536603001535
====> Epoch: 943 Convclust lambda: 0  Average loss: 0.8295
====> Epoch: 944 Alignment: 0.9937993932861084
====> Epoch: 944 Convclust lambda: 0  Average loss: 0.8304
====> Epoch: 945 Alignment: 0.9936651478959471
====> Epoch: 945 Convclust lambda: 0  Average loss: 0.8313
====> Epoch: 946 Alignment: 0.9939713861740126
====> Epoch: 946 Convclust lambda: 0  Average loss: 0.8322
====> Epoch: 947 Alignment: 0.9938894231194721


In [11]:
print(metric_alignment(out_model, data.eigvectors))
print(metric_subspace(out_model, data.eigvectors, data.eigs))


0.9938783559128727
0.9761425346136093


In [19]:
input_dim = 1000
hidden_dim = 2

n_data = 100
batch_size = n_data

max_sv = float(input_dim) * 0.1
min_sv = 1.0
sigma = 0.5

gt_data = DataGeneratorPCA(input_dim, hidden_dim, min_sv=min_sv, max_sv=max_sv, total=n_data)
data = DataGeneratorPCA(input_dim, hidden_dim, load_data=gt_data.x_sample)

data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False)


In [20]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
out_model = train_model(data, data_loader, input_dim, cc_lambda=0, hidden_dim=hidden_dim, lr=0.0001, init_scale = 0.0001, train_itr=1000)


====> Epoch: 0 Alignment: 0.9971986798229162
====> Epoch: 0 Convclust lambda: 0  Average loss: 500.2075
====> Epoch: 1 Alignment: 0.9963386408138781
====> Epoch: 1 Convclust lambda: 0  Average loss: 1000.4150
====> Epoch: 2 Alignment: 0.9980415229649326
====> Epoch: 2 Convclust lambda: 0  Average loss: 1500.6224
====> Epoch: 3 Alignment: 0.9999887339482273
====> Epoch: 3 Convclust lambda: 0  Average loss: 2000.8299
====> Epoch: 4 Alignment: 0.9993092597335149
====> Epoch: 4 Convclust lambda: 0  Average loss: 2501.0374
====> Epoch: 5 Alignment: 0.9991553323414358
====> Epoch: 5 Convclust lambda: 0  Average loss: 3001.2449
====> Epoch: 6 Alignment: 0.9983851818628737
====> Epoch: 6 Convclust lambda: 0  Average loss: 3501.4524
====> Epoch: 7 Alignment: 0.9960444346518844
====> Epoch: 7 Convclust lambda: 0  Average loss: 4001.6599
====> Epoch: 8 Alignment: 0.998779538521462
====> Epoch: 8 Convclust lambda: 0  Average loss: 4501.8674
====> Epoch: 9 Alignment: 0.9995711175805988
====> Epoch:

====> Epoch: 78 Alignment: 0.99313856268533
====> Epoch: 78 Convclust lambda: 0  Average loss: 39516.3907
====> Epoch: 79 Alignment: 0.995626430238615
====> Epoch: 79 Convclust lambda: 0  Average loss: 40016.5982
====> Epoch: 80 Alignment: 0.9979925586260937
====> Epoch: 80 Convclust lambda: 0  Average loss: 40516.8056
====> Epoch: 81 Alignment: 0.993388570509418
====> Epoch: 81 Convclust lambda: 0  Average loss: 41017.0131
====> Epoch: 82 Alignment: 0.9992373631635597
====> Epoch: 82 Convclust lambda: 0  Average loss: 41517.2205
====> Epoch: 83 Alignment: 0.9980006007835487
====> Epoch: 83 Convclust lambda: 0  Average loss: 42017.4280
====> Epoch: 84 Alignment: 0.9980294715010267
====> Epoch: 84 Convclust lambda: 0  Average loss: 42517.6355
====> Epoch: 85 Alignment: 0.9995660533026482
====> Epoch: 85 Convclust lambda: 0  Average loss: 43017.8430
====> Epoch: 86 Alignment: 0.9995232572168644
====> Epoch: 86 Convclust lambda: 0  Average loss: 43518.0504
====> Epoch: 87 Alignment: 0.997

====> Epoch: 156 Alignment: 0.9988124721790717
====> Epoch: 156 Convclust lambda: 0  Average loss: 78532.5735
====> Epoch: 157 Alignment: 0.9993693822110091
====> Epoch: 157 Convclust lambda: 0  Average loss: 79032.7810
====> Epoch: 158 Alignment: 0.9995056820433441
====> Epoch: 158 Convclust lambda: 0  Average loss: 79532.9884
====> Epoch: 159 Alignment: 0.9956856819612397
====> Epoch: 159 Convclust lambda: 0  Average loss: 80033.1959
====> Epoch: 160 Alignment: 0.9975386791309442
====> Epoch: 160 Convclust lambda: 0  Average loss: 80533.4034
====> Epoch: 161 Alignment: 0.9993356186350137
====> Epoch: 161 Convclust lambda: 0  Average loss: 81033.6109
====> Epoch: 162 Alignment: 0.996740970672381
====> Epoch: 162 Convclust lambda: 0  Average loss: 81533.8184
====> Epoch: 163 Alignment: 0.9980599795889645
====> Epoch: 163 Convclust lambda: 0  Average loss: 82034.0259
====> Epoch: 164 Alignment: 0.997792843846961
====> Epoch: 164 Convclust lambda: 0  Average loss: 82534.2334
====> Epoch:

====> Epoch: 231 Alignment: 0.9988084945277262
====> Epoch: 231 Convclust lambda: 0  Average loss: 116048.1342
====> Epoch: 232 Alignment: 0.9980497773190053
====> Epoch: 232 Convclust lambda: 0  Average loss: 116548.3417
====> Epoch: 233 Alignment: 0.9973725539868257
====> Epoch: 233 Convclust lambda: 0  Average loss: 117048.5491
====> Epoch: 234 Alignment: 0.9983753799418292
====> Epoch: 234 Convclust lambda: 0  Average loss: 117548.7565
====> Epoch: 235 Alignment: 0.9995288096007096
====> Epoch: 235 Convclust lambda: 0  Average loss: 118048.9640
====> Epoch: 236 Alignment: 0.9992793810595744
====> Epoch: 236 Convclust lambda: 0  Average loss: 118549.1715
====> Epoch: 237 Alignment: 0.9989677897895903
====> Epoch: 237 Convclust lambda: 0  Average loss: 119049.3790
====> Epoch: 238 Alignment: 0.9958319750498299
====> Epoch: 238 Convclust lambda: 0  Average loss: 119549.5865
====> Epoch: 239 Alignment: 0.9994750158031096
====> Epoch: 239 Convclust lambda: 0  Average loss: 120049.7939
=

====> Epoch: 306 Alignment: 0.9982042949766785
====> Epoch: 306 Convclust lambda: 0  Average loss: 153563.6944
====> Epoch: 307 Alignment: 0.9956265400652443
====> Epoch: 307 Convclust lambda: 0  Average loss: 154063.9019
====> Epoch: 308 Alignment: 0.9948778572627474
====> Epoch: 308 Convclust lambda: 0  Average loss: 154564.1093
====> Epoch: 309 Alignment: 0.9988054018964543
====> Epoch: 309 Convclust lambda: 0  Average loss: 155064.3168
====> Epoch: 310 Alignment: 0.9958395183153507
====> Epoch: 310 Convclust lambda: 0  Average loss: 155564.5243
====> Epoch: 311 Alignment: 0.9962265929983513
====> Epoch: 311 Convclust lambda: 0  Average loss: 156064.7318
====> Epoch: 312 Alignment: 0.9995005213611077
====> Epoch: 312 Convclust lambda: 0  Average loss: 156564.9393
====> Epoch: 313 Alignment: 0.9981825717868471
====> Epoch: 313 Convclust lambda: 0  Average loss: 157065.1468
====> Epoch: 314 Alignment: 0.9996015240993485
====> Epoch: 314 Convclust lambda: 0  Average loss: 157565.3542
=

====> Epoch: 381 Alignment: 0.9998139689015273
====> Epoch: 381 Convclust lambda: 0  Average loss: 191079.2549
====> Epoch: 382 Alignment: 0.9977757988725608
====> Epoch: 382 Convclust lambda: 0  Average loss: 191579.4623
====> Epoch: 383 Alignment: 0.9982262901920294
====> Epoch: 383 Convclust lambda: 0  Average loss: 192079.6697
====> Epoch: 384 Alignment: 0.9981662748153921
====> Epoch: 384 Convclust lambda: 0  Average loss: 192579.8772
====> Epoch: 385 Alignment: 0.9957399342676765
====> Epoch: 385 Convclust lambda: 0  Average loss: 193080.0847
====> Epoch: 386 Alignment: 0.9920155079719448
====> Epoch: 386 Convclust lambda: 0  Average loss: 193580.2922
====> Epoch: 387 Alignment: 0.9990142418069197
====> Epoch: 387 Convclust lambda: 0  Average loss: 194080.4996
====> Epoch: 388 Alignment: 0.9948698324214943
====> Epoch: 388 Convclust lambda: 0  Average loss: 194580.7071
====> Epoch: 389 Alignment: 0.9995223009379565
====> Epoch: 389 Convclust lambda: 0  Average loss: 195080.9146
=

====> Epoch: 455 Alignment: 0.9992884684969554
====> Epoch: 455 Convclust lambda: 0  Average loss: 228094.6078
====> Epoch: 456 Alignment: 0.9992633163949133
====> Epoch: 456 Convclust lambda: 0  Average loss: 228594.8152
====> Epoch: 457 Alignment: 0.9976006585330535
====> Epoch: 457 Convclust lambda: 0  Average loss: 229095.0227
====> Epoch: 458 Alignment: 0.9968951949873158
====> Epoch: 458 Convclust lambda: 0  Average loss: 229595.2302
====> Epoch: 459 Alignment: 0.9993557598000912
====> Epoch: 459 Convclust lambda: 0  Average loss: 230095.4377
====> Epoch: 460 Alignment: 0.9986333952221894
====> Epoch: 460 Convclust lambda: 0  Average loss: 230595.6452
====> Epoch: 461 Alignment: 0.9959393170526922
====> Epoch: 461 Convclust lambda: 0  Average loss: 231095.8527
====> Epoch: 462 Alignment: 0.9964983405115468
====> Epoch: 462 Convclust lambda: 0  Average loss: 231596.0601
====> Epoch: 463 Alignment: 0.9995981055206705
====> Epoch: 463 Convclust lambda: 0  Average loss: 232096.2675
=

====> Epoch: 531 Alignment: 0.9967177270720997
====> Epoch: 531 Convclust lambda: 0  Average loss: 266110.3754
====> Epoch: 532 Alignment: 0.9991167902381256
====> Epoch: 532 Convclust lambda: 0  Average loss: 266610.5829
====> Epoch: 533 Alignment: 0.9985550774946461
====> Epoch: 533 Convclust lambda: 0  Average loss: 267110.7904
====> Epoch: 534 Alignment: 0.9982314476223731
====> Epoch: 534 Convclust lambda: 0  Average loss: 267610.9979
====> Epoch: 535 Alignment: 0.9994409998378742
====> Epoch: 535 Convclust lambda: 0  Average loss: 268111.2053
====> Epoch: 536 Alignment: 0.9986683027597487
====> Epoch: 536 Convclust lambda: 0  Average loss: 268611.4128
====> Epoch: 537 Alignment: 0.9966921473682104
====> Epoch: 537 Convclust lambda: 0  Average loss: 269111.6202
====> Epoch: 538 Alignment: 0.9988611364749791
====> Epoch: 538 Convclust lambda: 0  Average loss: 269611.8277
====> Epoch: 539 Alignment: 0.9979736106425433
====> Epoch: 539 Convclust lambda: 0  Average loss: 270112.0352
=

====> Epoch: 607 Alignment: 0.9996161171405009
====> Epoch: 607 Convclust lambda: 0  Average loss: 304126.1434
====> Epoch: 608 Alignment: 0.9959551956346796
====> Epoch: 608 Convclust lambda: 0  Average loss: 304626.3509
====> Epoch: 609 Alignment: 0.9988342289319128
====> Epoch: 609 Convclust lambda: 0  Average loss: 305126.5582
====> Epoch: 610 Alignment: 0.9966424396446789
====> Epoch: 610 Convclust lambda: 0  Average loss: 305626.7657
====> Epoch: 611 Alignment: 0.9999095715443764
====> Epoch: 611 Convclust lambda: 0  Average loss: 306126.9731
====> Epoch: 612 Alignment: 0.9983975573933912
====> Epoch: 612 Convclust lambda: 0  Average loss: 306627.1805
====> Epoch: 613 Alignment: 0.9986110887214144
====> Epoch: 613 Convclust lambda: 0  Average loss: 307127.3879
====> Epoch: 614 Alignment: 0.9995381592455267
====> Epoch: 614 Convclust lambda: 0  Average loss: 307627.5954
====> Epoch: 615 Alignment: 0.9988077850194963
====> Epoch: 615 Convclust lambda: 0  Average loss: 308127.8029
=

====> Epoch: 681 Alignment: 0.9991195895750791
====> Epoch: 681 Convclust lambda: 0  Average loss: 341141.4963
====> Epoch: 682 Alignment: 0.9989785590351425
====> Epoch: 682 Convclust lambda: 0  Average loss: 341641.7037
====> Epoch: 683 Alignment: 0.9994221785953772
====> Epoch: 683 Convclust lambda: 0  Average loss: 342141.9112
====> Epoch: 684 Alignment: 0.997494903222828
====> Epoch: 684 Convclust lambda: 0  Average loss: 342642.1187
====> Epoch: 685 Alignment: 0.9990025715199347
====> Epoch: 685 Convclust lambda: 0  Average loss: 343142.3262
====> Epoch: 686 Alignment: 0.997607221708575
====> Epoch: 686 Convclust lambda: 0  Average loss: 343642.5337
====> Epoch: 687 Alignment: 0.9993969240574512
====> Epoch: 687 Convclust lambda: 0  Average loss: 344142.7411
====> Epoch: 688 Alignment: 0.9995101524887532
====> Epoch: 688 Convclust lambda: 0  Average loss: 344642.9485
====> Epoch: 689 Alignment: 0.9990263874159988
====> Epoch: 689 Convclust lambda: 0  Average loss: 345143.1560
===

====> Epoch: 755 Alignment: 0.9998177315283867
====> Epoch: 755 Convclust lambda: 0  Average loss: 378156.8495
====> Epoch: 756 Alignment: 0.9973638322889204
====> Epoch: 756 Convclust lambda: 0  Average loss: 378657.0570
====> Epoch: 757 Alignment: 0.9996475554513786
====> Epoch: 757 Convclust lambda: 0  Average loss: 379157.2645
====> Epoch: 758 Alignment: 0.997126848264452
====> Epoch: 758 Convclust lambda: 0  Average loss: 379657.4719
====> Epoch: 759 Alignment: 0.9960068414861228
====> Epoch: 759 Convclust lambda: 0  Average loss: 380157.6794
====> Epoch: 760 Alignment: 0.998152330179181
====> Epoch: 760 Convclust lambda: 0  Average loss: 380657.8868
====> Epoch: 761 Alignment: 0.9982369277477539
====> Epoch: 761 Convclust lambda: 0  Average loss: 381158.0943
====> Epoch: 762 Alignment: 0.9964619671448938
====> Epoch: 762 Convclust lambda: 0  Average loss: 381658.3018
====> Epoch: 763 Alignment: 0.9985142320424365
====> Epoch: 763 Convclust lambda: 0  Average loss: 382158.5093
===

====> Epoch: 829 Alignment: 0.9953164572626128
====> Epoch: 829 Convclust lambda: 0  Average loss: 415172.2027
====> Epoch: 830 Alignment: 0.9986238204230616
====> Epoch: 830 Convclust lambda: 0  Average loss: 415672.4102
====> Epoch: 831 Alignment: 0.9937877000683005
====> Epoch: 831 Convclust lambda: 0  Average loss: 416172.6177
====> Epoch: 832 Alignment: 0.9945872635564846
====> Epoch: 832 Convclust lambda: 0  Average loss: 416672.8252
====> Epoch: 833 Alignment: 0.9996263459096573
====> Epoch: 833 Convclust lambda: 0  Average loss: 417173.0327
====> Epoch: 834 Alignment: 0.9990891535982827
====> Epoch: 834 Convclust lambda: 0  Average loss: 417673.2402
====> Epoch: 835 Alignment: 0.9983544881380354
====> Epoch: 835 Convclust lambda: 0  Average loss: 418173.4477
====> Epoch: 836 Alignment: 0.9995916109016943
====> Epoch: 836 Convclust lambda: 0  Average loss: 418673.6552
====> Epoch: 837 Alignment: 0.99915096487269
====> Epoch: 837 Convclust lambda: 0  Average loss: 419173.8626
===

====> Epoch: 904 Alignment: 0.9998787084671341
====> Epoch: 904 Convclust lambda: 0  Average loss: 452687.7633
====> Epoch: 905 Alignment: 0.9985045606110412
====> Epoch: 905 Convclust lambda: 0  Average loss: 453187.9708
====> Epoch: 906 Alignment: 0.9988731219132583
====> Epoch: 906 Convclust lambda: 0  Average loss: 453688.1783
====> Epoch: 907 Alignment: 0.9956006380038795
====> Epoch: 907 Convclust lambda: 0  Average loss: 454188.3858
====> Epoch: 908 Alignment: 0.9977922093727697
====> Epoch: 908 Convclust lambda: 0  Average loss: 454688.5932
====> Epoch: 909 Alignment: 0.9983203008968318
====> Epoch: 909 Convclust lambda: 0  Average loss: 455188.8007
====> Epoch: 910 Alignment: 0.9984946711652039
====> Epoch: 910 Convclust lambda: 0  Average loss: 455689.0082
====> Epoch: 911 Alignment: 0.9970055687324386
====> Epoch: 911 Convclust lambda: 0  Average loss: 456189.2157
====> Epoch: 912 Alignment: 0.9995167715988423
====> Epoch: 912 Convclust lambda: 0  Average loss: 456689.4232
=

====> Epoch: 978 Alignment: 0.996009355727471
====> Epoch: 978 Convclust lambda: 0  Average loss: 489703.1159
====> Epoch: 979 Alignment: 0.9992602474493448
====> Epoch: 979 Convclust lambda: 0  Average loss: 490203.3234
====> Epoch: 980 Alignment: 0.9977933704122578
====> Epoch: 980 Convclust lambda: 0  Average loss: 490703.5309
====> Epoch: 981 Alignment: 0.9990572711641068
====> Epoch: 981 Convclust lambda: 0  Average loss: 491203.7384
====> Epoch: 982 Alignment: 0.9977761813936916
====> Epoch: 982 Convclust lambda: 0  Average loss: 491703.9459
====> Epoch: 983 Alignment: 0.9990363796953087
====> Epoch: 983 Convclust lambda: 0  Average loss: 492204.1534
====> Epoch: 984 Alignment: 0.998498951724009
====> Epoch: 984 Convclust lambda: 0  Average loss: 492704.3608
====> Epoch: 985 Alignment: 0.9998721838781786
====> Epoch: 985 Convclust lambda: 0  Average loss: 493204.5683
====> Epoch: 986 Alignment: 0.9980842648641582
====> Epoch: 986 Convclust lambda: 0  Average loss: 493704.7757
===

In [None]:
class LinearAE(nn.Module):
    def __init__(self,
                 input_dim, hidden_dim, init_scale=0.001,
                 weight_reg_type=None, l2_reg_list=None):
        super(LinearAE, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.encoder = nn.Linear(input_dim, hidden_dim, bias=False)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)

        self.weight_reg_type = weight_reg_type
        self.l2_reg_scalar = None
        self.l2_reg_list = l2_reg_list

        self.encoder.weight.data.normal_(0.0, init_scale)
        self.decoder.weight.data.normal_(0.0, init_scale)

        # configure regularization parameters

        assert self.weight_reg_type is None or isinstance(self.l2_reg_list, list), \
            "l2_reg_list must be a list if weight_reg_type is not None"

        assert self.l2_reg_list is None or len(self.l2_reg_list) == hidden_dim, \
            "Length of l2_reg_list must match latent dimension"

        if weight_reg_type in ("uniform_product", "uniform_sum"):
            self.l2_reg_scalar = l2_reg_list[0] ** 2    # more efficient to use scalar than diag_weights

        elif weight_reg_type == "non_uniform_sum":
            self.reg_weights = torch.tensor(
                np.array(self.l2_reg_list).astype(np.float32)
            )
            self.diag_weights = nn.Parameter(torch.diag(self.reg_weights), requires_grad=False)

    def forward(self, x):
        return self.get_reconstruction_loss(x) + self._get_reg_loss()

    def compute_trace_norm(self):
        """
        Computes the trace norm of the autoencoder, as well as decoder and encoder individually
        :return: trace_norm(W2W1), trace_norm(W1), trace_norm(W2)
        """
        return torch.matmul(self.decoder.weight, self.encoder.weight).norm(p='nuc'), \
               self.encoder.weight.norm(p='nuc'), \
               self.decoder.weight.norm(p='nuc'),

    def get_reconstruction_loss(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)

        recon_loss = torch.sum((x - recon) ** 2) / len(x)
        return recon_loss

    def get_reg_weights_np(self):
        if self.weight_reg_type is None:
            return np.zeros(self.hidden_dim)
        return np.array(self.l2_reg_list)

    def _get_reg_loss(self):
        # Standard L2 regularization, applied to W2W1 (product loss)
        if self.weight_reg_type == 'uniform_product':
            return self.l2_reg_scalar * (torch.norm(torch.matmul(self.decoder.weight, self.encoder.weight)) ** 2)

        # Standard L2 regularization for encoder and decoder separately (sum loss)
        elif self.weight_reg_type == 'uniform_sum':
            # regularize both encoder and decoder
            return self.l2_reg_scalar * (torch.norm(self.encoder.weight) ** 2 + torch.norm(self.decoder.weight) ** 2)

        # non-uniform sum
        elif self.weight_reg_type == 'non_uniform_sum':
            return torch.norm(self.diag_weights @ self.encoder.weight) ** 2 \
                   + torch.norm(self.decoder.weight @ self.diag_weights) ** 2

        # Do not apply regularization
        elif self.weight_reg_type is None:
            return 0.0

        else:
            raise ValueError("weight_reg_type should be one of (uniform_product, uniform_sum, non_uniform_sum, None)")