In [1]:
from __future__ import annotations
import matplotlib.pyplot as plt
from google.colab import drive
from PIL import Image
import numpy as np
import pickle
import os

In [2]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!pip install mxnet-cu101

from typing import Tuple, Union
from mxnet import nd, symbol
from mxnet.gluon.nn import HybridBlock
from mxnet.gluon.parameter import Parameter
from mxnet.initializer import Zero
from mxnet.gluon.nn import Conv2D, HybridSequential, LeakyReLU, Dense
from mxnet import nd, gluon, autograd
import mxnet as mx
from mxnet.io import NDArrayIter



In [4]:
def load_dataset(t, x, batch_size):
    return NDArrayIter({ "x": nd.stack(*x, axis=0) }, { "t": nd.stack(*t, axis=0) }, batch_size, True)

In [5]:
class Linear(HybridSequential):
    def __init__(self, n_in, n_out):
        super(Linear, self).__init__()
        with self.name_scope():
            self.add(Dense(n_out, in_units=n_in))


class Pixelnorm(HybridBlock):
    def __init__(self, epsilon: float = 1e-8) -> None:
        super(Pixelnorm, self).__init__()
        self.epsilon = epsilon

    def hybrid_forward(self, F, x) -> nd:
        return x * F.rsqrt(F.mean(F.square(x), 1, True) + self.epsilon)


class Bias(HybridBlock):
    def __init__(self, shape: Tuple) -> None:
        super(Bias, self).__init__()
        self.shape = shape
        with self.name_scope():
            self.b = self.params.get("b", init=Zero(), shape=shape)

    def hybrid_forward(self, F, x, b) -> nd:
        return F.broadcast_add(x, b[None, :, None, None])


class Block(HybridSequential):
    def __init__(self, channels: int, in_channels: int) -> None:
        super(Block, self).__init__()
        self.channels = channels
        self.in_channels = in_channels
        with self.name_scope():
            self.add(Conv2D(channels, 3, padding=1, in_channels=in_channels))
            self.add(LeakyReLU(0.2))
            self.add(Pixelnorm())
            self.add(Conv2D(channels, 3, padding=1, in_channels=channels))
            self.add(LeakyReLU(0.2))
            self.add(Pixelnorm())

    def hybrid_forward(self, F, x) -> nd:
        x = F.repeat(x, 2, 2)
        x = F.repeat(x, 2, 3)
        for i in range(len(self)):
            x = self[i](x)
        return x

In [6]:
class Generator(HybridSequential):
    def __init__(self) -> None:
        super(Generator, self).__init__()
        with self.name_scope():
            self.add(Pixelnorm())
            self.add(Dense(8192, use_bias=False, in_units=512))
            self.add(Bias((512,)))
            self.add(LeakyReLU(0.2))
            self.add(Pixelnorm())
            self.add(Conv2D(512, 3, padding=1, in_channels=512))
            self.add(LeakyReLU(0.2))
            self.add(Pixelnorm())
            
            self.add(Block(512, 512)) # 8
            self.add(Block(512, 512))
            self.add(Block(512, 512))
            self.add(Block(256, 512))
            self.add(Block(128, 256))
            self.add(Block(64, 128))
            self.add(Block(32, 64))
            self.add(Block(16, 32)) # 15
            self.add(Conv2D(3, 1, in_channels=16))


    def hybrid_forward(self, F: Union(nd, symbol), x: nd, layer: int) -> nd:
        x = F.Reshape(self[1](self[0](x)), (-1, 512, 4, 4))
        for i in range(2, len(self)):
            x = self[i](x)
            if i == layer + 7:
              return x
        return x

In [7]:
max_epoch = 1500
batch_size = 30
n_vox = 4096
n_lat = 512

In [8]:
# Note that we are using gradient descent to fit the weights of the dense layer whereas ordinary least squares would yield a similar
# solution. However, the current setup allows you to experiment and try different things to make more sophisticated models (e.g., predict
intermediate layer activations of PGGAN and include this in your loss function).generator = Generator()
generator.load_parameters("/content/drive/MyDrive/HYPER/data/generator.params")
mean_squared_error = gluon.loss.L2Loss()
for subject in [1, 2]:
      
      # Data
      with open("/content/drive/MyDrive/HYPER/data/data_%i.dat" % subject, 'rb') as f:
          X_tr, T_tr, X_te, T_te = pickle.load(f)
      train = load_dataset(nd.array(T_tr), nd.array(X_tr), batch_size)        
      test =  load_dataset(nd.array(T_te), nd.array(X_te), batch_size=36)  

      # Training
      vox_to_lat = Linear(n_vox, n_lat)
      vox_to_lat.initialize()
      trainer = gluon.Trainer(vox_to_lat.collect_params(), "Adam", {"learning_rate": 0.00001, "wd": 0.01})
      epoch = 0
      results_tr = []
      results_te = []
      while epoch < max_epoch:
          train.reset()
          test.reset()
          loss_tr = 0
          loss_te = 0
          count = 0
          for batch_tr in train:
              with autograd.record():
                  lat_Y = vox_to_lat(batch_tr.data[0])
                  loss = mean_squared_error(lat_Y, batch_tr.label[0])
              loss.backward()
              trainer.step(batch_size)
              loss_tr += loss.mean().asnumpy()
              count += 1
          for batch_te in test:
              lat_Y = vox_to_lat(batch_te.data[0])
              loss = mean_squared_error(lat_Y, batch_te.label[0])
              loss_te += loss.mean().asnumpy()
          loss_tr_normalized = loss_tr / count
          results_tr.append(loss_tr_normalized)
          results_te.append(loss_te)
          epoch += 1
          print("Epoch %i: %.4f / %.4f" % (epoch, loss_tr_normalized, loss_te))

      plt.figure()
      plt.plot(np.linspace(0, epoch, epoch), results_tr)
      plt.plot(np.linspace(0, epoch, epoch), results_te)
      plt.savefig("loss_s%i.png" % subject)

      # Testing and reconstructing
      lat_Y = vox_to_lat(nd.array(X_te))
      dir = '/content/faces_%i' % subject
      if not os.path.exists(dir):
          os.mkdir(dir)
      for i, latent in enumerate(lat_Y):
          face = generator(latent[None], 9).asnumpy()
          face = np.clip(np.rint(127.5 * face + 127.5), 0.0, 255.0)
          face = face.astype("uint8")
          face = face.transpose(0, 2, 3, 1)
          Image.fromarray(face[0], 'RGB').save(dir + "/%d.png" % i)