## 1. Key code

In [None]:
#The generator of WGAN
generator = nn.Sequential(
    nn.ConvTranspose2d(latent_dim, 512, kernel_size=3, stride=2),
    nn.BatchNorm2d(512),
    nn.ReLU(),
    nn.ConvTranspose2d(512, 128, kernel_size=3, stride=2),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.ConvTranspose2d(64, channels, kernel_size=2, stride=2),
    nn.Tanh()
).to(device)

#The generator of GAN
generator2 = nn.Sequential(
    # nn.ConvTranspose2d can be seen as the inverse operation
    # of Conv2d, where after convolution we arrive at an
    # upscaled image.
    nn.ConvTranspose2d(latent_dim, 256, kernel_size=3, stride=2),
    nn.BatchNorm2d(256),
    nn.ReLU(),
    nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.ConvTranspose2d(64, channels, kernel_size=2, stride=2),
    nn.Sigmoid() # Image intensities are in [0, 1]
).to(device)

In [None]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

#The discriminator of WGAN
discriminator = nn.Sequential(
    nn.Conv2d(channels, 64, kernel_size=ks, stride=st, padding = pa),
    nn.Dropout(p=0.5),
    nn.LayerNorm(13),
    nn.ReLU(),
    nn.Conv2d(64, 128, kernel_size=3, stride=2, padding = 1),
    nn.LayerNorm(7),
    nn.ReLU(),
    nn.Conv2d(128, 256, kernel_size=3, stride=2, padding = pa),
    nn.LayerNorm(3),
    nn.ReLU(),
    nn.Conv2d(256, 512, kernel_size=3, stride=1, padding = pa),
    nn.ReLU(),
    Flatten(),
    nn.Linear(512, 1),
).to(device)

#The discriminator of GAN
discriminator2 = nn.Sequential(
    nn.Conv2d(channels, 64, kernel_size=ks, stride=st, padding = pa),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2),
    nn.Conv2d(64, 128, kernel_size=ks, stride=st, padding = pa),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2),
    nn.Conv2d(128, 512, kernel_size=ks, stride=st, padding = pa),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2),
    Flatten(),
    nn.Linear(512, 1),
    nn.Sigmoid()
).to(device)

In [None]:
#spectral normalization
class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

In [None]:
#Training
for epoch in range(num_epochs):
        batch_d_loss, batch_g_loss = [], []
        dt = torch.zeros((2,batch_size,channels,image_size,image_size))
        k = 0
    
        for x, _ in train_loader:
            if x.size(0) == batch_size:
                # True data is given label 1, while fake data is given label 0
                true_label = torch.ones(batch_size, 1).to(device)
                fake_label = torch.zeros(batch_size, 1).to(device)
                
                if LOSS == 'BCEWLS':
                    discriminator2.zero_grad()
                    generator2.zero_grad()
                    
                    # Step 1. Send real data through discriminator
                    #         and backpropagate its errors.
                    x_true = Variable(x).to(device)        
                    output = discriminator2(x_true)
                    
                    error_true = loss(output, true_label)                    
                    error_true.backward()
                    
                    # Step 2. Generate fake data G(z), where z ~ N(0, 1)
                    #         is a latent code.
                    z = torch.randn(batch_size, latent_dim, 1, 1)
                    z = Variable(z, requires_grad=False).to(device)
                    
                    x_fake = generator2(z)
                        
                    # Step 3. Send fake data through discriminator
                    #         propagate error and update D weights.
                    # --------------------------------------------
                    # Note: detach() is used to avoid compounding generator gradients
                    output = discriminator2(x_fake.detach()) 
                    error_fake = loss(output, fake_label)                  
                    error_fake.backward()
                    discriminator_optim2.step()
                    
                    # Step 4. Send fake data through discriminator _again_
                    #         propagate the error of the generator and
                    #         update G weights.
                    output = discriminator2(x_fake)
                    error_generator = loss(output, true_label)              
                    error_generator.backward()
                    generator_optim2.step()
                    
                    batch_d_loss.append((error_true/(error_true + error_fake)).item())
                    batch_g_loss.append(error_generator.item())
                    batches_done = epoch * len(train_loader) + x.size(0)
                    
                if LOSS == 'wasserstein':
                    
                   
                    discriminator.zero_grad()              
                    generator.zero_grad()
                    
                    z = torch.randn(batch_size, latent_dim, 1, 1)
                    z= Variable(z, requires_grad=False).to(device)  

                    x_true = Variable(x).to(device) 
                    x_fake = generator(z)   

                    error_true = -discriminator(x_true).mean()+discriminator(x_fake).mean() 
                    error_true.backward()
                    

                    discriminator_optim.step()  

                    for i in range(disc_iters):
                        discriminator.zero_grad()
                        generator.zero_grad()
                        z = torch.randn(batch_size, latent_dim, 1, 1)
                        z = Variable(z, requires_grad=False).to(device)                
                        x_fake = generator(z)
                        error_fake = - discriminator(generator(z)).mean()
                        error_fake.backward()
                        generator_optim.step()
                    
                    
                    batch_d_loss.append(error_true.item())
                    batch_g_loss.append(error_fake.item())
                    batches_done = epoch * len(train_loader) + x.size(0)

## 2.Experimental results
### 2.1 Results on MNIST

Real images of MNIST:

<img src="images/true_mnist.png" />

Loss and fake images of the standard GAN:

<img src="images/fake_mnist_gan.png" />


Loss and fake images of the WGAN with gradient clipping:

<img src="images/fake_mnist_gc.png" />

Loss and fake images of the WGAN with spectral normalization:

<img src="images/fake_mnist_sn.png" />

### 2.2 Results on CIFAR10

Real images of CIFAR10:

<img src="images/true_cifar.png" />

Fake images of the standard GAN:

<img src="images/fake_cifar_gan.png" />


Fake images of the WGAN with gradient clipping:

<img src="images/fake_cifar_gc.png" />

Fake images of the WGAN with spectral normalization:

<img src="images/fake_cifar_sn.png" />

### 2.3 Quantitative results

Every model(GAN, WGAN-GC and WGAN-SN) is computed 3 times on each dataset(MNIST and CIFAR10). 

Average Fid score:

<img src="images/fid_bar.png" />

Confidence Interval:


<img src="images/CI.png" />