# FactorVAE

In [None]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 128
epochs = 50
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [None]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [None]:
from pixyz.distributions import Normal, Bernoulli, Deterministic
from pixyz.losses import KullbackLeibler, CrossEntropy, AdversarialKullbackLeibler
from pixyz.models import Model

In [None]:
x_dim = 784
z_dim = 8


# inference model q(z|x)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x"], var=["z"], name="q")

        self.fc1 = nn.Linear(x_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc31 = nn.Linear(512, z_dim)
        self.fc32 = nn.Linear(512, z_dim)

    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc31(h), "scale": F.softplus(self.fc32(h))}
    
# generative model p(x|z)    
class Generator(Bernoulli):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        self.fc1 = nn.Linear(z_dim, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, x_dim)

    def forward(self, z):
        h = F.relu(self.fc1(z))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}
    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

p = Generator()
q = Inference()

In [None]:
class InferenceShuffleDim(Deterministic):
    def __init__(self):
        super(InferenceShuffleDim, self).__init__(cond_var=["x_"], var=["z"], name="q_shuffle")

    def forward(self, x_):
        z = q.sample({"x": x_}, return_all=False)["z"]
        return {"z": z[:,torch.randperm(z.shape[1])]}
    
q_shuffle = InferenceShuffleDim()

In [None]:
class Discriminator(Deterministic):
    def __init__(self):
        super(Discriminator, self).__init__(cond_var=["z"], var=["t"], name="d")

        self.model = nn.Sequential(
            nn.Linear(z_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        t = self.model(z)
        return {"t": t}

In [None]:
p = Generator()
q = Inference()
d = Discriminator()

p.to(device)
q.to(device)
d.to(device)

print(p)
print(q)
print(q_shuffle)
print(d)

In [None]:
reconst = CrossEntropy(q, p)
kl = KullbackLeibler(q, prior)
tc = AdversarialKullbackLeibler(q, q_shuffle, discriminator=d, optimizer=optim.Adam, optimizer_params={"lr":1e-3})
loss_cls = reconst.mean() + kl.mean() + 10*tc
print(loss_cls)

In [None]:
model = Model(loss_cls, distributions=[p, q], optimizer=optim.Adam, optimizer_params={"lr":1e-3})
print(model)

In [None]:
def train(epoch):
    train_loss = 0
    train_d_loss = 0    
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        len_x = x.shape[0]//2
        loss = model.train({"x": x[:len_x], "x_": x[len_x:]})
        d_loss = tc.train({"x": x[:len_x], "x_": x[len_x:]})
        train_loss += loss
        train_d_loss += d_loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    train_d_loss = train_d_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}, {:.4f}'.format(epoch, train_loss.item(), train_d_loss.item()))
    return train_loss

In [None]:
def test(epoch):
    test_loss = 0
    test_d_loss = 0    
    for x, _ in tqdm(test_loader):
        x = x.to(device)
        len_x = x.shape[0]//2
        loss = model.test({"x": x[:len_x], "x_": x[len_x:]})
        d_loss = tc.test({"x": x[:len_x], "x_": x[len_x:]})
        test_loss += loss
        test_d_loss += d_loss
 
    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    test_d_loss = test_d_loss * test_loader.batch_size / len(test_loader.dataset)
    print('Test loss: {:.4f}, {:.4f}'.format(test_loss.item(), test_d_loss.item()))
    return test_loss

In [None]:
def plot_reconstrunction(x):
    with torch.no_grad():
        z = q.sample({"x": x}, return_all=False)
        recon_batch = p.sample_mean(z).view(-1, 1, 28, 28)
    
        comparison = torch.cat([x.view(-1, 1, 28, 28), recon_batch]).cpu()
        return comparison
    
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p.sample_mean({"z": z_sample}).view(-1, 1, 28, 28).cpu()
        return sample

In [None]:
writer = SummaryWriter()

plot_dim = 8

z_sample = []
for i in range(plot_dim):
    z_batch = torch.zeros(plot_dim, z_dim)
    z_batch[:, i] = (torch.arange(plot_dim,dtype=torch.float32)*2.)/(plot_dim-1.)-1
    z_sample.append(z_batch)
z_sample = torch.cat(z_sample, dim=0).to(device)
#z_sample = 0.5 * torch.randn(64, z_dim).to(device)
_x, _ = iter(test_loader).next()
_x = _x.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    recon = plot_reconstrunction(_x[:8])
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_image('Image_from_latent', sample, epoch)
    writer.add_image('Image_reconstrunction', recon, epoch)
    
writer.close()