In [None]:
import os
import torch
import sys
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import numpy as np

import warnings
warnings.filterwarnings('ignore')

from fastprogress import master_bar, progress_bar

In [None]:
run = 1

In [None]:
if not os.path.exists('./saved_models'):
    os.mkdir('./saved_models')

In [None]:
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

In [None]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [None]:
num_epochs = 100
batch_size = 512
learning_rate = 1e-3

In [None]:
train_dataset = MNIST('../data', train=True, transform=img_transform)
test_dataset = MNIST('../data', train=False, transform=img_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

In [None]:
class DAM(nn.Module):
    """ Discriminative Amplitude Modulator Layer (1-D) """
    def __init__(self, in_dim):
        super(DAM, self).__init__()
        self.in_dim = in_dim
        
        self.mu = torch.arange(0, self.in_dim).float()/self.in_dim * 5
        self.mu = nn.Parameter(self.mu, requires_grad=False)
        self.beta = nn.Parameter(torch.ones(1), requires_grad=True)
        self.alpha = nn.Parameter(torch.ones(1), requires_grad=False)
        self.register_parameter('mu', self.mu)
        self.register_parameter('beta', self.beta)
        self.register_parameter('alpha', self.alpha)
        
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()
        
    def forward(self, x):
        return x * self.mask()
    
    def mask(self):
        return self.relu(self.tanh((self.alpha ** 2) * (self.mu + self.beta)))

In [None]:
class AEnc(nn.Module):
    def __init__(self, init_dim):
        super(AEnc, self).__init__()
        self.enc_layer_1 = nn.Linear(28 * 28, 128)
        self.enc_layer_2 = nn.Linear(128, 64)
        self.enc_layer_3 = nn.Linear(64, 32)
        self.enc_layer_4 = nn.Linear(32, init_dim)
        self.dam_layer = DAM(init_dim)
        self.dec_layer_1 = nn.Linear(init_dim, 32)
        self.dec_layer_2 = nn.Linear(32, 64)
        self.dec_layer_3 = nn.Linear(64, 128)
        self.dec_layer_4 = nn.Linear(128, 28 * 28)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        
        
    def forward(self, x):
        out = self.relu(self.enc_layer_1(x))
        out = self.relu(self.enc_layer_2(out))
        out = self.relu(self.enc_layer_3(out))
        h = self.enc_layer_4(out)
        h = self.dam_layer(h)
        x_r = self.relu(self.dec_layer_1(h))
        x_r = self.relu(self.dec_layer_2(x_r))
        x_r = self.relu(self.dec_layer_3(x_r))
        x_r = self.tanh(self.dec_layer_4(x_r))
        return x_r, h

In [None]:
cuda_ = 'cuda:'+str(run)
device = cuda_ if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
init_dims = [50]
lambda_rs = [0.01, 0.05, 0.1, 0.5, 1., 2., 3., 4.]+np.arange(5., 10., 0.5).tolist()

verbose = []
test_mse = []
btl_dim = []

for init_dim in init_dims:
    for lambda_r in lambda_rs:
        print('####################################################')
        print('Run: %d INIT_DIM: %d \t LAMBDA_REG: %f' %(run, init_dim, lambda_r))
        verbose.append(str(init_dim)+'x'+str(lambda_r))
        print('####################################################')

        net = AEnc(init_dim).to(device)

        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

        mb = master_bar(range(1))
        pb = progress_bar(range(100), parent=mb)
        mb.names = ['Embd dim']

        x_bounds = [0, len(net.dam_layer.mask().detach().cpu().numpy())+1]
        y_bounds = [0, 1]
        x_n = np.arange(len(net.dam_layer.mask().detach().cpu().numpy()))
        y1 = net.dam_layer.mask().detach().cpu().numpy()

        graphs = [[x_n,y1],]
        mb.update_graph(graphs, x_bounds, y_bounds)
        print("[Epoch\tloss\tMSE\tReg\tbeta_1]")


        for _ in mb:
            for epoch in range(num_epochs):
                for data in train_loader:
                    img, _ = data
                    img = img.view(img.size(0), -1)
                    img = img.to(device)

                    # ===================forward=====================
                    output, _ = net(img)

                    beta_1 = net.dam_layer.beta
                    loss_gate = torch.mean(beta_1) 
                    loss_data = criterion(output, img)

                    loss = loss_data + lambda_r * loss_gate

                    # ===================backward====================
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                if epoch % 2 == 0:       
                    y1 = net.dam_layer.mask().detach().cpu().numpy()
                    graphs = [[x_n,y1],]
                    mb.update_graph(graphs, x_bounds, y_bounds)

                    sys.stdout.write("\r[%d\t%.5e\t%.5e\t%.3f]" % (epoch, loss.item(), loss_data.item(), net.dam_layer.beta.item()))

        btl_d = (net.dam_layer.mask().detach().cpu().numpy() > 0).sum()
        btl_dim.append(btl_d)
        
        if btl_d == 0:
            break

        net.eval()
        rec_loss = 0
        count = 0
        for data in test_loader:
            img, _ = data
            img = img.view(img.size(0), -1)
            img = img.to(device)
            output, _ = net(img)
            rec_loss += criterion(output, img)
            count += 1
        test_mse.append((rec_loss/count).item())

        print('MODEL:', verbose[-1])
        print('BTLNK_DIM:', btl_dim[-1])
        print('TEST_MSE:', test_mse[-1])

        torch.save(net.state_dict(), 'saved_models/dam_model_'+verbose[-1]+'_run_'+str(run)+'.pt')

