In [None]:
"""
This file will implement ViTs as encoders for GANs 
"""

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
from transformers import ViTModel, ViTConfig


In [4]:
class ViTGenerator(nn.Module):
    """
    This class implements a vision transformer to be used as a generator for GANs.
    """
    def __init__(self, img_size=64, patch_size=16, latent_dim=100, num_channels=3):
        super(ViTGenerator, self).__init__()
        
        # latent serves as the input to the generator network in a GAN
        # it is usually sampled from a gaussian distribution.
        self.latent_dim = latent_dim
        self.img_size = img_size

        # in transformers, the input (image) is divided into patches.
        self.patch_size = patch_size
        
        # Configuration for Vision Transformer
        config = ViTConfig(
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12,
            intermediate_size=3072,
            image_size=img_size,
            patch_size=patch_size,
            num_channels=num_channels
        )
        
        self.vit = ViTModel(config)
        
        # Upsampling layers to scale up the generated image patches
        self.upsample = nn.Sequential(
            nn.ConvTranspose2d(768, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, num_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
        
    def forward(self, z):
        # Generate patches
        patches = self.vit(inputs_embeds=z)[0]
        patches = patches.permute(0, 2, 1).contiguous().view(-1, 768, self.img_size // self.patch_size, self.img_size // self.patch_size)
        
        # Upsample patches to generate full image
        img = self.upsample(patches)
        return img


In [5]:
class Discriminator(nn.Module):
    def __init__(self, img_size=64, num_channels=3):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )
        
    def forward(self, img):
        return self.main(img).view(-1, 1).squeeze(1)


In [7]:
def train_gan(generator, discriminator, dataloader, num_epochs=100, lr=0.0002, latent_dim=100, device='cuda'):
    generator.to(device)
    discriminator.to(device)
    
    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    
    for epoch in range(num_epochs):
        for i, (imgs, _) in enumerate(dataloader):
            batch_size = imgs.size(0)
            valid = torch.ones(batch_size, device=device)
            fake = torch.zeros(batch_size, device=device)
            
            real_imgs = imgs.to(device)
            
            # Train Generator
            optimizer_G.zero_grad()
            
            # here, we are sampling the latent input from a gaussian distribution.
            z = torch.randn(batch_size, latent_dim, 1, 1, device=device)
            gen_imgs = generator(z)
            
            g_loss = criterion(discriminator(gen_imgs), valid)
            
            g_loss.backward()
            optimizer_G.step()
            
            # Train Discriminator
            optimizer_D.zero_grad()
            
            real_loss = criterion(discriminator(real_imgs), valid)
            fake_loss = criterion(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            
            d_loss.backward()
            optimizer_D.step()
            
            if i % 50 == 0:
                print(f"[Epoch {epoch}/{num_epochs}] [Batch {i}/{len(dataloader)}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]")


In [8]:
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.CIFAR10(root='data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Initialize the generator and discriminator
generator = ViTGenerator()
discriminator = Discriminator()

# Train the GAN
train_gan(generator, discriminator, dataloader)


Files already downloaded and verified


TypeError: ViTModel.forward() got an unexpected keyword argument 'inputs_embeds'