In [None]:
import sys
import os
import math
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

from fastprogress import master_bar, progress_bar

# Data Creation

In [None]:
def get_data(k, d, N, return_rank=False):
    mean_k = np.zeros(k)
    cov_k = np.eye(k)

    x = np.random.multivariate_normal(mean_k, cov_k, N)
    x = torch.tensor(x).float()

    with torch.no_grad():
        mat = QResLayer(k, d)(x)
        
    if return_rank:
        from numpy.linalg import matrix_rank
        rank = matrix_rank(mat.detach().cpu().numpy())
        return mat, rank
    else:
        return mat

# configs
k = 20
N = 5000
d = 100

result_path = './results/multi_gaussian/var_%d/quad/' % k

if not os.path.exists(result_path):
    os.makedirs(result_path)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

# Architecture

In [None]:
# Adapted from: 
# https://github.com/jayroxis/qres/blob/master/QRes/discussion/sin_wave/models.py

class QResLayer(nn.Module):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super(QResLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight_1 = nn.Parameter(torch.Tensor(out_features, in_features))
        self.weight_2 = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight_1, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight_2, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_1)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        h_1 = F.linear(input, self.weight_1, bias=None)
        h_2 = F.linear(input, self.weight_2, bias=None)
        return torch.add(
            torch.mul(h_1, h_2), 
            F.linear(input, self.weight_1, self.bias)
        )

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )
    
nn.QResLayer = QResLayer

In [None]:
class DAM(nn.Module):
    """ Discriminative Amplitude Modulator Layer (1-D) """
    def __init__(self, in_dim):
        super(DAM, self).__init__()
        self.in_dim = in_dim
        
        self.mu = torch.arange(self.in_dim).float() / self.in_dim * 5.0
        self.mu = nn.Parameter(self.mu, requires_grad=False)
        self.beta = nn.Parameter(torch.ones(1) * 5, requires_grad=True)
        self.alpha = nn.Parameter(torch.ones(1), requires_grad=False)
        self.register_parameter('mu', self.mu)
        self.register_parameter('beta', self.beta)
        self.register_parameter('alpha', self.alpha)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return x * self.mask()
    
    def mask(self):
        return self.relu(self.tanh((self.alpha ** 2) * (self.mu + self.beta)))

In [None]:
class AEnc(torch.nn.Module):
    def __init__(self, num_neuron):
        super(AEnc, self).__init__()
        
        self.num_neuron = num_neuron
        self.enc_layer_1 = nn.Linear(self.num_neuron, 256)
        self.enc_layer_2 = nn.Linear(256, 256)
        self.enc_layer_3 = nn.Linear(256, 50)
        
        self.dam_layer = DAM(50)
        self.dec_layer = nn.QResLayer(50, self.num_neuron)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        self.leakyrelu = nn.LeakyReLU()
        self.elu = nn.ELU()
        self.sin = lambda x: torch.sin(x)
        
    def forward(self, x):
        out = self.leakyrelu(self.enc_layer_1(x))
        out = self.leakyrelu(self.enc_layer_2(out))
        out = self.enc_layer_3(out)
        h = self.dam_layer(out)
        x_r = self.dec_layer(h)
        return x_r, h


# $\text{Training Module}$

In [None]:
mb = master_bar(range(1, 6))

for run in mb:
    lambda_r = 0.01
    compact_dim = []
    net = AEnc(d).to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.01, weight_decay=1e-6)

    pb = progress_bar(range(5000), parent=mb)
    mb.names = ['layer Embd']
    x_bounds = [0, len(net.dam_layer.mask())+1]
    y_bounds = [0, 1]
    x_n = np.arange(len(net.dam_layer.mask()))
    y1 = net.dam_layer.mask().detach().cpu().numpy()
    graphs = [[x_n,y1],]
    mb.update_graph(graphs, x_bounds, y_bounds)
    print("[Epoch\tloss\tMSE\tReg\tbeta_1]")

    
    x, rank = get_data(k, d, N, return_rank=True)
    x = x.float().to(device)
    
    for epoch in pb:
        optimizer.zero_grad()
        x_rc, _ = net(x)
        beta_1 = net.dam_layer.beta
        loss_data = criterion(x_rc, x)
        loss = loss_data + lambda_r * beta_1
        loss.backward()

        optimizer.step()

        btl_dim = (net.dam_layer.mask() != 0).sum().item()
        compact_dim.append(btl_dim)

        if epoch % 10 == 0:       
            y1 = net.dam_layer.mask().detach().cpu().numpy()
            graphs = [[x_n,y1],]
            mb.update_graph(graphs, x_bounds, y_bounds)

        sys.stdout.write("\r[%d\t%.5e\t%.5e\t%.3f]" % (epoch, loss.item(), loss_data.item(),  net.dam_layer.beta.item()))
        

    print('\nFinal Embedding dim:', btl_dim)

    torch.save({
        'state_dict': net.state_dict(),
        'btn_dim': compact_dim,
        'rank': rank
    }, os.path.join(result_path, 'run%d.pt' % run))