In [None]:
import mxnet as mx
from mxnet import gluon as g
from mxnet import nd
import numpy as np
from mxnet import autograd as ag
from matplotlib import pyplot as plt
import os
import time

In [None]:
ctx = mx.gpu()

In [None]:
batch_size = 128
height = 16
width = 7*height

In [None]:
data_iter = mx.io.ImageRecordIter(path_imgrec=os.path.join('dataset','train.rec'), data_shape=(3,height,width),
                                           shuffle=True, mean_r=123.68, mean_g=116.28, mean_b=103.53,
                                           std_r=58.395, std_g=57.12, std_b=57.375,
                                           batch_size=batch_size)

data_iter_val = mx.io.ImageRecordIter(path_imgrec=os.path.join('dataset','val.rec'), data_shape=(3,height,width),
                                           shuffle=True, mean_r=123.68, mean_g=116.28, mean_b=103.53,
                                           std_r=58.395, std_g=57.12, std_b=57.375,
                                           batch_size=batch_size)

In [None]:
class autoencoder(g.nn.HybridBlock):
    def __init__(self):
        super(autoencoder, self).__init__()
        with self.name_scope():
            self.encoder = g.nn.HybridSequential('encoder_')
            with self.encoder.name_scope():
                self.encoder.add(g.nn.Conv2D(32, 3, padding=1, activation='relu'))
                self.encoder.add(g.nn.MaxPool2D(2, 2))
                self.encoder.add(g.nn.Conv2D(32, 3, padding=1, activation='relu'))
                self.encoder.add(g.nn.MaxPool2D(2, 2))

            self.decoder = g.nn.HybridSequential('decoder_')
            with self.decoder.name_scope():
                self.decoder.add(g.nn.Conv2D(32, 3, padding=1, activation='relu'))
                self.decoder.add(g.nn.Conv2D(32, 3, padding=1, activation='relu'))
                self.decoder.add(g.nn.Conv2D(3, 3,  padding=1,  activation='tanh'))

    def forward(self, x):
        x = self.encoder(x)
        for i in range(len(self.decoder)):
            x = self.decoder[i](x)
            if i < 2:
                x = mx.nd.UpSampling(x,scale=2,sample_type='nearest')
        return x*4

In [None]:
model = autoencoder()
model.hybridize()
model.collect_params().initialize(mx.init.Xavier(magnitude=2), ctx=ctx)
for batch in data_iter:
    batch
    break
print(model(batch.data[0].as_in_context(ctx)).shape)
model.save_parameters("ae_init.params")

In [None]:
criterion = g.loss.L2Loss()

### Optimization

In [None]:
train_metric = mx.metric.MAE()
val_metric = mx.metric.MAE()
#for opt in sorted({'Adam','RMSProp','SGD'}):
for opt in sorted({'SGD'}):
#    for lr in sorted({0.1,0.01,0.001},reverse=True):
    for lr in sorted({0.1},reverse=True):
        start_epoch = 0
        epochs = 100
        model.load_parameters("ae_init.params")
        optimizer = g.Trainer(model.collect_params(), opt, {'learning_rate': lr, 'wd': 1e-5})
        print('### Optimizer: %s ### Learning Rate: %.3f' % (opt,lr))
        for epoch in range(start_epoch, epochs):
            # train
            train_loss = 0
            train_metric.reset()
            data_iter.reset()
            tic = time.time()
            counter = 0
            for batch in data_iter:
                with ag.record():
                    x = batch.data[0].as_in_context(ctx)
                    y = model(x)
                    loss = criterion(x,y)
                    loss.backward()
                    train_loss += mx.nd.sum(loss).asscalar()
                optimizer.step(batch_size, ignore_stale_grad=True)
                train_metric.update(x, y)
                counter += 1
            toc = time.time()
            name_train, val_train = train_metric.get()
            # validation
            val_loss = 0
            val_metric.reset()
            data_iter_val.reset()
            for batch in data_iter_val:
                x = batch.data[0].as_in_context(ctx)
                y = model(x)
                loss = criterion(x,y)
                val_loss += mx.nd.sum(loss).asscalar()
                val_metric.update(x, y)
            name_val, val_val = val_metric.get()
            print('epoch:%3d;\t train:%.6e;%.6e;val:%.6e;%.6e;\t Speed:%d'
                  %(epoch, train_loss/(counter*batch_size), val_train, val_loss/(counter*batch_size), val_val, (counter*batch_size)/(toc-tic)))
            model.save_parameters('process/ae_%s_%.3f_%d.params' % (opt, lr, epoch))

### Visualization

In [None]:
for batch in data_iter_val:
    x = batch.data[0].as_in_context(ctx)
    y = model(x)
    val_image = y[0,:,:,:].as_in_context(mx.cpu())
    val_image = val_image.transpose((1, 2, 0)) * nd.array((0.229, 0.224, 0.225)) + nd.array((0.485, 0.456, 0.406))
    val_image = (val_image * 255).clip(0, 255)
    val_image = val_image.asnumpy()
    val_image = val_image.astype(np.uint8)
    plt.imshow(val_image)
    plt.show()
    break