In [2]:
import torch.nn as nn
import torch 

In [3]:
print(torch.__version__)

2.4.1+cu124


In [4]:
print(torch.cuda.is_available())

True


In [5]:
import torchvision
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as dataloader
from torch.utils.tensorboard import SummaryWriter

In [20]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ),
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d * 2, 4, 2, 1),
            self._block(features_d * 2, features_d * 4, 4, 2, 1),
            self._block(features_d * 4, features_d * 8, 4, 2, 1),
            self._block(features_d * 8, features_d * 16, 4, 2, 1),
            
            nn.Conv2d(features_d * 16, 1, kernel_size=4, stride=2, padding=0),
        )
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False,
            ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forward(self, x):
        return self.disc(x)
    


In [6]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            self._block(channels_noise, features_g * 16, 4, 1, 0), # 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1), # 8x8 
            self._block(features_g * 8, features_g * 4, 4, 2, 1), # 16 x 16
            self._block(features_g * 4, features_g * 2, 4, 2, 1), # 32 x 32
            self._block(features_g * 2, features_g, 4, 2, 1), # 64 x 64
            nn.ConvTranspose2d(
                features_g * 1, channels_img, kernel_size=4, stride=2, padding=1
            ),
            # 128 x 128
            nn.Tanh(),
        )
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.net(x)

In [28]:
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal(m.weight.data, 0.0, 0.02)

In [11]:
from torchsummary import summary

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMG_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_CLASSES = 5
FEATURES_CRITIC = 64
FEATURES_GEN = 64
CRITIC_ITERATIONS = 5
NUM_EPOCHS = 5
LAMBDA_GP = 10

cuda


In [21]:
def test_classes():
    N, in_channels, H, W = 8, 3, 128, 128
    noise_dim = 100
    x = torch.randn((N, in_channels, H, W)).to(device)
    disc = Discriminator(in_channels, 8).to(device)
    # assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"
    print(summary(disc, input_size=(in_channels, H, H)))
    gen = Generator(channels_noise=noise_dim, channels_img=in_channels, features_g=8).to(device)
    z = torch.randn((N, noise_dim, 1, 1)).to(device)
    # assert gen(z).shape == (N, in_channels, H, W), " Generator test failed"
    print(summary(gen, input_size=(noise_dim, 1, 1)))
    
test_classes()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 64, 64]             392
         LeakyReLU-2            [-1, 8, 64, 64]               0
            Conv2d-3           [-1, 16, 32, 32]           2,048
    InstanceNorm2d-4           [-1, 16, 32, 32]               0
         LeakyReLU-5           [-1, 16, 32, 32]               0
            Conv2d-6           [-1, 32, 16, 16]           8,192
    InstanceNorm2d-7           [-1, 32, 16, 16]               0
         LeakyReLU-8           [-1, 32, 16, 16]               0
            Conv2d-9             [-1, 64, 8, 8]          32,768
   InstanceNorm2d-10             [-1, 64, 8, 8]               0
        LeakyReLU-11             [-1, 64, 8, 8]               0
           Conv2d-12            [-1, 128, 4, 4]         131,072
   InstanceNorm2d-13            [-1, 128, 4, 4]               0
        LeakyReLU-14            [-1, 12

In [44]:
transforms = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

In [49]:
def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    epsilon = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * epsilon + fake * (1 - epsilon)
    
    
    # calcuated mixed scores
    mixed_scores = critic(interpolated_images)
    
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    
    gradient = gradient.view(gradient.shape[0], -1) # flattening
    gradient_norm = gradient.norm(2, dim=1) # taking norm of flattened dim
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [17]:
def save_checkpoint(state, filename='celeba_wgan_gp'):
    print("-> saving checkpoint")
    torch.save(state, filename)
    
def load_checkpoint(filename, gen, disc):
    print("-> loading checkpoint")
    gen.load_state_dict(torch.load(filename))
    disc.load_state_dict(torch.load(filename))

In [18]:
from torch.utils.data import DataLoader

In [45]:
celeba_dataset = datasets.CelebA(root='data',
                                 split='train',
                                 transform=transforms,
                                 download=True)
celeba_loader = DataLoader(dataset=celeba_dataset, batch_size=BATCH_SIZE, shuffle=True)

Files already downloaded and verified


In [29]:
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)   
disc = Discriminator(CHANNELS_IMG, FEATURES_CRITIC).to(device)

initialize_weights(gen)
initialize_weights(disc)

  nn.init.normal(m.weight.data, 0.0, 0.02)


In [30]:
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE)
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

In [35]:
fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/WPGAN_CELEBA/base/real")
writer_fake = SummaryWriter(f"logs/WPGAN_CELEBA/base/fake")

step = 0
img_idx = 0
gen.train()
disc.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Conv2d(512, 1, kernel_size=(4, 4), stride=(2, 2))
  )
)

In [33]:
from tqdm import tqdm

In [34]:
from torchvision.transforms import ToPILImage

In [51]:
# main training loop
for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(tqdm(celeba_loader)):
        # print(real.shape)
        real = real.to(device)
        cur_batch_size = real.shape[0]
        
        # Train Discriminator
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise)
            # print("fake shape", fake.shape)
            # print("real shape", real.shape)
            critical_real = disc(real).reshape(-1)
            critical_fake = disc(fake).reshape(-1)
            gp = gradient_penalty(disc, real, fake, device)
            loss_disc = (
                -(torch.mean(critical_real) - torch.mean(critical_fake)) + LAMBDA_GP * gp
            )
            disc.zero_grad()
            loss_disc.backward(retain_graph=True)
            opt_disc.step()
            
        # Train Generator
        gen_fake = disc(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        if batch_idx % 100 == 0 and batch_idx != 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(celeba_loader)} \
                loss D: {loss_disc:.4f}, loss G {loss_gen:.4f}"
            )
            
            with torch.no_grad():
                fake = gen(fixed_noise)
                
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
                
                to_pil = ToPILImage()
                
                img_fake = to_pil(img_grid_fake)
                img_fake.save(f" ../images/base/fake/fake_images_grid_{img_idx}.png"
                             )
            step += 1

  0%|          | 7/2544 [00:39<3:59:12,  5.66s/it]


KeyboardInterrupt: 