In [None]:
import keras
import keras.backend as K
import numpy as np
import matplotlib.pyplot as plt
import os.path
import csv
from math import ceil
from keras.preprocessing.image import DirectoryIterator, ImageDataGenerator

dataset = 'faces'
z_dims = 32
beta = 20

channels = 1
likelihood = 'bernoulli'

class NormalSampler(keras.layers.Layer):
    def __init__(self):
        super(NormalSampler, self).__init__()

    def call(self, mu_logvar):
        mu, logvar = mu_logvar
        epsilon = K.random_normal(shape=K.shape(mu))
        std = K.exp(logvar / 2)
        return mu + epsilon*std


inputs = keras.Input(shape=(channels, 64, 64))

# encoder
x = keras.layers.Conv2D(filters=32, kernel_size=4, strides=2, padding='same', activation='relu')(inputs)
x = keras.layers.Conv2D(filters=32, kernel_size=4, strides=2, padding='same', activation='relu')(x)
x = keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, padding='same', activation='relu')(x)
x = keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, padding='same', activation='relu')(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(units=256, activation='relu')(x)
mu = keras.layers.Dense(units=z_dims, activation=None)(x)
logvar = keras.layers.Dense(units=z_dims, activation=None)(x)
z = NormalSampler()([mu, logvar])

encoder_mu = keras.Model(inputs=inputs, outputs=mu, name='encoder_mu')
encoder = keras.Model(inputs=inputs, outputs=[z, mu, logvar], name='encoder')


# decoder
d_inputs = keras.Input(shape=(z_dims,))
x = keras.layers.Dense(units=256, activation='relu')(d_inputs)
x = keras.layers.Dense(units=64*4*4, activation='relu')(x)
x = keras.layers.Reshape((64, 4, 4))(x)
x = keras.layers.Conv2DTranspose(filters=64, kernel_size=4, strides=2, padding='same', activation='relu')(x)
x = keras.layers.Conv2DTranspose(filters=32, kernel_size=4, strides=2, padding='same', activation='relu')(x)
x = keras.layers.Conv2DTranspose(filters=32, kernel_size=4, strides=2, padding='same', activation='relu')(x)
x = keras.layers.Conv2DTranspose(filters=channels, kernel_size=4, strides=2, padding='same', activation=None)(x)
decoder = keras.Model(inputs=d_inputs, outputs=x, name='decoder')


# combined
outputs = decoder(encoder(inputs)[0])
model = keras.Model(inputs=inputs, outputs=outputs)


model_path = os.path.join('checkpoints', '{}-{}-{}'.format(dataset, z_dims, beta))
if os.path.exists(model_path):
    model.load_weights(model_path)
    print('Loaded existing checkpoint')

In [None]:
faces = np.load('faces-labelled.npz')
data = faces['images']
factors = faces['factors']

In [None]:
latents = []
for x in data:
    l = encoder_mu.predict(x.reshape(1, 1, 64, 64))[0]
    latents.append(l)

latents = np.stack(latents)

np.save('latents-0.1', latents)

print(latents.shape)
print(data.shape)
print(factors.shape)

# create representation 

a_ = [-50, -40, -30, -20, -10, 0, 10, 20, 30, 40, 50]
b_ = [-20, -16, -12, -8, -4, 0, 4, 8, 12, 16, 20]
c_ = [-90, -72, -54, -36, -18, 0, 18, 36, 54, 72, 90]
d_ = [-40, -20, 0, 20, 40, 60, 80, 100]
e_ = [-3, -2, -1, 0, 1, 2, 3]

latents_labelled = np.empty((11, 11, 11, 8, 7, 32))
for i, factor in enumerate(factors):
    a = a_.index(int(factor[0]))
    b = b_.index(int(factor[1]))
    c = c_.index(int(factor[2]))
    d = d_.index(int(factor[3]))
    e = e_.index(int(factor[4]))
    latents_labelled[a, b, c, d, e, :] = latents[i]

In [None]:
import torch
import torch.nn as nn
import random
import numpy as np

v_shape = [11, 11, 11, 8, 7]
v_dims = 5 # 3
z_dims = 32
batches = 40000
batch = 256

class LinearClassifier(nn.Module):
    def __init__(self):
        super(LinearClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(z_dims, v_dims),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        y = self.classifier(x)
        return y

classifier = LinearClassifier()
optim = torch.optim.Adam(classifier.parameters(), lr=1e-3 * 0.5)
loss = torch.nn.NLLLoss()


def disentanglement_epoch(i):
    z_diffs = []
    for l in range(batch):
        i1 = [[0]]*5
        for j in range(v_dims):
            i1[j] = [random.randrange(v_shape[j])]

        i2 = [[0]]*5
        for j in range(v_dims):
            i2[j] = [random.randrange(v_shape[j])]
        i2[i] = i1[i]

        z1 = latents_labelled[tuple(i1)][0]
        z2 = latents_labelled[tuple(i2)][0]
        
        z_diff = np.abs(z1 - z2)
        z_diffs.append(z_diff)
    
    z_diffs = np.stack(z_diffs)
    z_diffs = torch.Tensor(z_diffs)
    z_diff = z_diffs.mean(dim=0).unsqueeze(0)

    return classifier(z_diff)

def disentanglement():
    for b in range(batches):
        print(b)
        i = random.sample(range(v_dims), 1)[0]
        output = loss(disentanglement_epoch(i), torch.full((1,), i))
        optim.zero_grad()
        output.backward()
        optim.step()

    with torch.no_grad():
        corr_predictions = 0
        tot_predictions = 0

        for b in range(batches):
            i = random.sample(range(v_dims), 1)[0]
            output = disentanglement_epoch(i)
            pred = output.data.max(1, keepdim=True)[1]
            corr_predictions += pred.eq(torch.full((batch, 1), i)).sum()
            tot_predictions += batch

        accuracy = corr_predictions / tot_predictions

        return accuracy.item()

disentanglement()