In [1]:
import torch
from torch import nn

In [2]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()

        self.z_dim = z_dim
        self.generator = nn.Sequential(
            self.build_gen_block(z_dim, hidden_dim * 4),
            self.build_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.build_gen_block(hidden_dim * 2, hidden_dim),
            self.build_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True))

    def build_gen_block(self,input_channels,output_channels,kernel_size=3,stride=2,final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True))
        else:
            return nn.Sequential(nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),nn.Tanh())

    def forward(self, noise):
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.generator(x)


In [3]:
class Critic(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Critic, self).__init__()

        self.critic = nn.Sequential(
            self.make_crit_block(im_chan, hidden_dim),
            self.make_crit_block(hidden_dim, hidden_dim * 2),
            self.make_crit_block(hidden_dim * 2, 1, final_layer=True))

    def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:

            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size, stride))

    def forward(self, image):
        crit_pred = self.critic(image)
        return crit_pred.view(len(crit_pred), -1)

In [4]:
import torch
from torch import nn
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

torch.manual_seed(0)

<torch._C.Generator at 0x7d980535bdd0>

In [5]:
def plot_images_from_tensor(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    img_detached = image_tensor.detach().cpu()
    image_grid = make_grid(img_detached[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [6]:
def make_grad_hook():
    gradients_list = []

    def grad_hook(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            gradients_list.append(m.weight.grad)
    return gradients_list, grad_hook

In [7]:
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

def get_noise(n_samples, z_dim, device="cpu"):
    return torch.randn(n_samples, z_dim, device=device)

In [8]:
def get_gen_loss(critic_fake_prediction):
    gen_loss = -1.0 * torch.mean(critic_fake_prediction)
    return gen_loss

assert torch.isclose(get_gen_loss(torch.tensor(1.0)), torch.tensor(-1.0))
assert torch.isclose(get_gen_loss(torch.rand(10000)), torch.tensor(-0.5), 0.05)
print("Success!")

Success!


In [9]:
def get_crit_loss(critic_fake_prediction, crit_real_pred, gp, c_lambda):
    crit_loss = (torch.mean(critic_fake_prediction) - torch.mean(crit_real_pred) + c_lambda * gp)
    return crit_loss

assert torch.isclose(get_crit_loss(torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0), 0.1),torch.tensor(-0.7))
assert torch.isclose(get_crit_loss(torch.tensor(20.0), torch.tensor(-20.0), torch.tensor(2.0), 10),torch.tensor(60.0))
print("Success!")

Success!


In [10]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [14]:
n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = "cpu"

In [12]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,)),])

dataloader = DataLoader(MNIST("/content", download=True, transform=transform),batch_size=batch_size,shuffle=True,)


100%|██████████| 9.91M/9.91M [00:00<00:00, 59.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.72MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.8MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.71MB/s]


In [15]:
generator = Generator(z_dim).to(device)
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta_1, beta_2))

critic = Critic().to(device)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=lr, betas=(beta_1, beta_2))

generator = generator.apply(weights_init)
critic = critic.apply(weights_init)

In [16]:
def gradient_of_critic_score(critic, real, fake, epsilon):
    interpolated_images = real * epsilon + fake * (1 - epsilon)
    mixed_scores = critic(interpolated_images)
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,)[0]
    return gradient

In [17]:
def test_gradient_of_critic_score(image_shape):
    real = torch.randn(*image_shape, device=device) + 1
    fake = torch.randn(*image_shape, device=device) - 1
    epsilon_shape = [1 for _ in image_shape]
    epsilon_shape[0] = image_shape[0]
    epsilon = torch.rand(epsilon_shape, device=device).requires_grad_()
    gradient = gradient_of_critic_score(critic, real, fake, epsilon)
    assert tuple(gradient.shape) == image_shape
    assert gradient.max() > 0
    assert gradient.min() < 0
    return gradient

gradient = test_gradient_of_critic_score((256, 1, 28, 28))
print("Success!")

Success!


In [18]:
def gradient_penalty_l2_norm(gradient):
    gradient = gradient.view(len(gradient), -1)
    gradient_norm = gradient.norm(2, dim=1)
    penalty = torch.mean((gradient_norm - 1) ** 2)
    return penalty

In [19]:
def test_gradient_penalty_l2_norm(image_shape):
    bad_gradient = torch.zeros(*image_shape)
    print(bad_gradient)
    bad_gradient_penalty = gradient_penalty_l2_norm(bad_gradient)
    assert torch.isclose(bad_gradient_penalty, torch.tensor(1.0))
    image_size = torch.prod(torch.Tensor(image_shape[1:]))

    print("torch.sqrt(image_size) ", torch.sqrt(image_size))
    good_gradient = torch.ones(*image_shape) / torch.sqrt(image_size)
    good_gradient_penalty = gradient_penalty_l2_norm(good_gradient)

    assert torch.isclose(good_gradient_penalty, torch.tensor(0.0))
    random_gradient = test_gradient_of_critic_score(image_shape)
    random_gradient_penalty = gradient_penalty_l2_norm(random_gradient)
    assert torch.abs(random_gradient_penalty - 1) < 0.1



test_gradient_penalty_l2_norm((256, 1, 28, 28))
print("Success!")

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0.

In [None]:
import matplotlib.pyplot as plt

current_step = 0
generator_losses = []
critic_losses_across_critic_repeats = []
for epoch in range(n_epochs):
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_critic_loss_for_this_iteration = 0
        for _ in range(crit_repeats):
            critic_optimizer.zero_grad()

            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = generator(fake_noise)

            critic_fake_prediction = critic(fake.detach())
            crit_real_pred = critic(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = gradient_of_critic_score(critic, real, fake.detach(), epsilon)
            gp = gradient_penalty_l2_norm(gradient)

            crit_loss = get_crit_loss(critic_fake_prediction, crit_real_pred, gp, c_lambda)
            mean_critic_loss_for_this_iteration += crit_loss.item() / crit_repeats


            crit_loss.backward(retain_graph=True)
            critic_optimizer.step()
        critic_losses_across_critic_repeats += [mean_critic_loss_for_this_iteration]


        gen_optimizer.zero_grad()

        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = generator(fake_noise_2)

        critic_fake_prediction = critic(fake_2)
        gen_loss = get_gen_loss(critic_fake_prediction)

        gen_loss.backward()
        gen_optimizer.step()

        generator_losses += [gen_loss.item()]

        if current_step % display_step == 0 and current_step > 0:

            generator_mean_loss_display_step = (sum(generator_losses[-display_step:]) / display_step)

            critic_mean_loss_display_step = (sum(critic_losses_across_critic_repeats[-display_step:]) / display_step)
            print(f"Step {current_step}: Generator loss: {generator_mean_loss_display_step}, critic loss: {critic_mean_loss_display_step}")

            plot_images_from_tensor(fake)
            plot_images_from_tensor(real)

            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins

            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss")
            plt.plot(
                range(num_examples // step_bins),
                torch.Tensor(critic_losses_across_critic_repeats[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss")
            plt.legend()
            plt.show()

        current_step += 1

Output hidden; open in https://colab.research.google.com to view.