In [1]:
device = 'mps'


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt
import tqdm as tqdm


In [3]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize((32, 32)),
])

train_dataset = torchvision.datasets.QMNIST('data', train=True, download=True, transform=transforms)
# test_dataset = torchvision.datasets.QMNIST('data', train=False, download=True, transform=transforms)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
# test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

print(f"Training with {len(train_dataset)} samples")
# print(f"Testing with {len(test_dataset)} samples")


Training with 60000 samples


In [93]:
class VariationalEncoder(nn.Module):
    def __init__(self, d, latent_dim):
        super(VariationalEncoder, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 128 // d, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128 // d, 256 // d, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256 // d, 512 // d, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512 // d,latent_dim, 4, 1, 0),
            nn.Flatten(),
        )
        
        self.mu = nn.Linear(latent_dim, latent_dim)
        self.logvar = nn.Linear(latent_dim, latent_dim)
        
    def forward(self, x):
        x = self.layers(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        return mu, logvar
    
    def kl_divergence(self, mu, logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    def sample(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
class VariationalDecoder(nn.Module):
    def __init__(self, d, latent_dim):
        super(VariationalDecoder, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(latent_dim, 1024 // d),
            nn.ReLU(),
            nn.Unflatten(1, (1024 // d, 1, 1)),
            nn.ConvTranspose2d(1024 // d, 512 // d, 4, 1, 0),
            nn.ReLU(),
            nn.ConvTranspose2d(512 // d, 256 // d, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(256 // d, 128 // d, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(128 // d, 1, 4, 2, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        return self.layers(x)
    
class Discriminator(nn.Module):
    def __init__(self, d):
        super(Discriminator, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(1, 128 // d, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128 // d, 256 // d, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256 // d, 512 // d, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(1024, 1),            
        )
        
    def forward(self, x):
        return self.layers(x)
    
    def gradient_penalty(self, real, fake):
        alpha = torch.rand(real.size(0), 1, 1, 1, device=real.device)
        interpolates = alpha * real + (1 - alpha) * fake
        d_interpolates = self(interpolates)
        gradients = torch.autograd.grad(outputs=d_interpolates, inputs=interpolates,
                                        grad_outputs=torch.ones(d_interpolates.size(), device=real.device),
                                        create_graph=True, retain_graph=True, only_inputs=True)[0]
        gradients = gradients.view(gradients.size(0), -1)
        return ((gradients.norm(2, dim=1) - 1) ** 2).mean()


In [94]:
encoder = VariationalEncoder(8, 32).to(device)
decoder = VariationalDecoder(8, 32).to(device)
discriminator = Discriminator(8).to(device)


In [95]:
# VAEGAN
encoder_optimizer = optim.Adam(encoder.parameters(), lr=1e-4)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=1e-4)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)


In [103]:
for epoch in range(10):
    with tqdm.tqdm(enumerate(train_loader), desc=f"Epoch {epoch}") as pbar:
        for i, (x, _) in pbar:
            x = x.to(device)
            
            # Forward pass through encoder and decoder
            mu, logvar = encoder(x)
            z = encoder.sample(mu, logvar)
            x_tilde = decoder(z)
            
            # Discriminator outputs
            real_preds = discriminator(x)
            fake_preds = discriminator(x_tilde.detach())  # Detach to avoid backpropagation through decoder
            
            # Encoder loss: KL Divergence + Discriminator likelihood
            encoder_optimizer.zero_grad()
            loss_prior = encoder.kl_divergence(mu, logvar)  # Your encoder should have this method
            loss_real = F.binary_cross_entropy_with_logits(real_preds, torch.ones_like(real_preds))
            loss_fake_for_encoder = F.binary_cross_entropy_with_logits(fake_preds, torch.zeros_like(fake_preds))
            encoder_loss = loss_prior + (loss_real + loss_fake_for_encoder)
            encoder_loss.backward()
            encoder_optimizer.step()
            
            # Decoder loss: Discriminator likelihood - GAN loss
            decoder_optimizer.zero_grad()
            # Re-sample z to avoid using the same values as in encoder loss backprop
            z = encoder.sample(mu, logvar)
            x_tilde = decoder(z)
            fake_preds_for_decoder = discriminator(x_tilde)
            loss_gan = -fake_preds_for_decoder.mean()  # Minimize -E[log(D(G(z)))]
            loss_fake_for_encoder = F.binary_cross_entropy_with_logits(fake_preds, torch.zeros_like(fake_preds))
            decoder_loss = loss_fake_for_encoder - loss_gan
            decoder_loss.backward()
            decoder_optimizer.step()
            
            # Discriminator loss: BCE with logits + Gradient penalty
            discriminator_optimizer.zero_grad()
            # Compute the loss for real and fake images
            loss_real = F.binary_cross_entropy_with_logits(real_preds, torch.ones_like(real_preds))
            loss_fake_for_discriminator = F.binary_cross_entropy_with_logits(fake_preds_for_decoder, torch.zeros_like(fake_preds_for_decoder))
            loss_discriminator = loss_real + loss_fake_for_discriminator
            # Calculate gradient penalty
            loss_gradient_penalty = discriminator.gradient_penalty(x, x_tilde.detach())  # Your discriminator should have this method
            discriminator_loss = loss_discriminator + loss_gradient_penalty * 10  # Adjust lambda for gradient penalty if necessary
            discriminator_loss.backward()
            discriminator_optimizer.step()
            
            # Progress bar update
            pbar.set_postfix_str(f"Encoder Loss: {encoder_loss.item()}, Decoder Loss: {decoder_loss.item()}, Discriminator Loss: {discriminator_loss.item()}")


Epoch 0: 0it [00:00, ?it/s]

Epoch 0: 0it [00:01, ?it/s]


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.