# Bimodal VAE (+PCA term?)

As observed in the latent space, we see that the when projecting on a direction the images are organized as a sum of two gaussians, one centered on one half of the hyperplane and the other in the second one. This is what we will change the traditional VAE that has an a priori that is one gaussian.


In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
import numpy as np
import generate_nb
import matplotlib.pyplot as plt
import torch
import numpy as np
import dnnlib
import legacy
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from torch.optim import Adam
import projector_nb
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch
import torch.nn as nn
import torch.nn.functional as F


In [8]:
batch_size = 16

In [11]:
network = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with dnnlib.util.open_url(network) as fp:
        G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device)
        
def generate_samples(G, num_samples):
    z_samples = np.random.RandomState().randn(num_samples, 512).astype(np.float32) 
    
    w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None)  # [N, L, C]
    w_samples = w_samples.cpu().numpy()
    w_samples = w_samples[:,0,:]

    return w_samples


w_samples_ae = generate_samples(G, 100)

# Define your data preprocessing pipeline
preprocessing_pipeline = Pipeline([
    ('standard', StandardScaler()),
    ('minmax', MinMaxScaler())
])


w_samples_scaled = preprocessing_pipeline.fit_transform(w_samples_ae)
w_samples_scaled = np.clip(w_samples_scaled, 0, 1)
print(np.min(w_samples_scaled))
print(np.max(w_samples_scaled))
dataset = TensorDataset(torch.tensor(w_samples_scaled, dtype=torch.float32))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

input_dim = 512
hidden_dim = 256
latent_dim = 2
batch_size = 16
epochs = 10
learning_rate = 1e-3

0.0
1.0


In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hidden_dim//2, hidden_dim//4),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hidden_dim//4, latent_dim * 4)  # 4 for mean1, log_var1, mean2, log_var2
        )

    def forward(self, x):
        h = self.encoder(x)
        mean1, log_var1, mean2, log_var2 = torch.chunk(h, 4, dim=1)
        return mean1, log_var1, mean2, log_var2

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim//4),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hidden_dim//4, hidden_dim//2),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hidden_dim//2, hidden_dim),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.decoder(z)

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()

        self.encoder = Encoder(input_dim, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, input_dim)

    def encode(self, x):
        return self.encoder(x)

    def reparameterize(self, mean1, log_var1, mean2, log_var2):
        std1 = torch.exp(0.5 * log_var1)
        eps1 = torch.randn_like(std1)
        z1 = mean1 + eps1 * std1
        
        std2 = torch.exp(0.5 * log_var2)
        eps2 = torch.randn_like(std2)
        z2 = mean2 + eps2 * std2
        
        return z1, z2

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mean1, log_var1, mean2, log_var2 = self.encode(x)
        z1, z2 = self.reparameterize(mean1, log_var1, mean2, log_var2)
        z = z1+z2
        print(z.shape)
        x_recon = self.decode(z)  
        return x_recon, mean1, log_var1, mean2, log_var2,z

def orthogonality_constraint(z):
    z_prod = torch.matmul(z.T,z)
    identity = torch.eye(z_prod.size(0)).to(device)
    ortho_loss = torch.norm(z_prod - identity, p='fro')
    return ortho_loss

def loss_function(recon_x, x,mean1, log_var1, mean2, log_var2,z,beta=200.0,alpha=100.0):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    print("BCE Loss is:", BCE)

    # KL divergence for the sum of two Gaussians
    KLD1 = -0.5 * torch.sum(1 + log_var1 - mean1.pow(2) - log_var1.exp())
    KLD2 = -0.5 * torch.sum(1 + log_var2 - mean2.pow(2) - log_var2.exp())
    KLD = KLD1 + KLD2
    print("beta x KLD is :", beta*KLD)

    # orthopenalty
    
    orthogonality_loss =  orthogonality_constraint(z) 
    print("alpha x orthogonality loss is:", alpha*orthogonality_loss)
    return BCE + beta*KLD + alpha*orthogonality_loss


In [26]:
vae = VAE(input_dim, hidden_dim, latent_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=learning_rate)

In [28]:
for epoch in range(epochs):
    for data in dataloader:
        x = data[0].to(device)
        print(x.shape)
        x_recon, mean1, log_var1, mean2, log_var2,z= vae(x)
    
        loss = loss_function(x_recon, x,mean1, log_var1, mean2, log_var2,z,beta=500.0,alpha=100.0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

print("Training complete.")

torch.save(vae.state_dict(), 'vae_model.pth')

torch.Size([16, 512])
torch.Size([16, 2])
BCE Loss is: tensor(4721.1890, grad_fn=<BinaryCrossEntropyBackward0>)
beta x KLD is : tensor(1637.9420, grad_fn=<MulBackward0>)
alpha x orthogonality loss is: tensor(3581.7268, grad_fn=<MulBackward0>)
torch.Size([16, 512])
torch.Size([16, 2])
BCE Loss is: tensor(4727.5225, grad_fn=<BinaryCrossEntropyBackward0>)
beta x KLD is : tensor(1488.6875, grad_fn=<MulBackward0>)
alpha x orthogonality loss is: tensor(4227.0918, grad_fn=<MulBackward0>)
torch.Size([16, 512])
torch.Size([16, 2])
BCE Loss is: tensor(4713.1870, grad_fn=<BinaryCrossEntropyBackward0>)
beta x KLD is : tensor(1251.2552, grad_fn=<MulBackward0>)
alpha x orthogonality loss is: tensor(3590.6963, grad_fn=<MulBackward0>)
torch.Size([16, 512])
torch.Size([16, 2])
BCE Loss is: tensor(4666.1455, grad_fn=<BinaryCrossEntropyBackward0>)
beta x KLD is : tensor(1046.7148, grad_fn=<MulBackward0>)
alpha x orthogonality loss is: tensor(3716.4023, grad_fn=<MulBackward0>)
torch.Size([16, 512])
torch.