# Image Generation

## 1.1 Generative adversarial network
#### In this exercise, you will implement a Deep Convolutional Generative Network (DCGAN) to synthesis images by using the provided anime faces dataset.

---


- Construct a <font color=red>DCGAN</font> with GAN objective, you can refer to the [tutorial website](https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html) provided by PyTorch for implementation.
$$
    \begin{equation*} \begin{aligned}
    &\max _{D} \mathcal{L}(D) =\mathbb{E}_{\boldsymbol{x} \sim p_{\text {data }}} \log D(\boldsymbol{x})+\mathbb{E}_{z \sim p_{\boldsymbol{z}}} \log (1-D(G(\boldsymbol{z}))) \\
    &\min _{G} \mathcal{L}(G) =\mathbb{E}_{z \sim p_{x}} \log (1-D(G(\boldsymbol{z}))
    \end{aligned} \end{equation*}
$$
- <font color=red>Draw</font> some samples generated from your generator at <font color=red>different training stages </font>. For example, you may show the results when running at $5^{\text{th}}$ and final epoch 100. (10\%)


<center>
    <img src="https://i.imgur.com/tnRR3tr.png" width="350px" />
    <img src="https://i.imgur.com/g9AnDwN.png" width="350px" />
</center>





In [None]:
# # Downlaod and unzip data
# !gdown 1K1oB7GOUerTCIa68bbxETcGajLeE_5j1
# !unzip resized_64x64.zip

In [None]:
import random
import torch
import torchvision
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np
from IPython.display import HTML
import tqdm as tqdm

### Please write the gan code here

In [None]:
# Please write the gan code here
# Note: In our experience, you can just select around 10000 images for training and get acceptable result.

seed = 42
random.seed(seed)
torch.manual_seed(seed)

data_path = './data/'
workers = 2
batch_size = 128
image_size = 64

num_channels = 3
num_latents = 100
size_feature_map_gen = 64
size_feature_map_disc = 64
train_epochs = 5
lr = 0.0002
beta1 = 0.5
workers = 0

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataset = torchvision.datasets.ImageFolder(
    root=data_path,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_size),
        torchvision.transforms.CenterCrop(image_size),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers
)

In [None]:
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(torchvision.utils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
def init_weights(m):
    if type(m) == torch.nn.Conv2d or type(m) == torch.nn.ConvTranspose2d:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif type(m) == torch.nn.BatchNorm2d:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

In [None]:
class Genterator(torch.nn.Module):
    def __init__(self):
        super(Genterator, self).__init__()
        self.main = torch.nn.Sequential(
            # input is Z, going into a convolution
            torch.nn.ConvTranspose2d(num_latents, size_feature_map_gen * 8, 4, 1, 0, bias=False),
            torch.nn.BatchNorm2d(size_feature_map_gen * 8),
            torch.nn.ReLU(True),
            # state size. (size_feature_map_gen*8) x 4 x 4
            torch.nn.ConvTranspose2d(size_feature_map_gen * 8, size_feature_map_gen * 4, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(size_feature_map_gen * 4),
            torch.nn.ReLU(True),
            # state size. (size_feature_map_gen*4) x 8 x 8
            torch.nn.ConvTranspose2d(size_feature_map_gen * 4, size_feature_map_gen * 2, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(size_feature_map_gen * 2),
            torch.nn.ReLU(True),
            # state size. (size_feature_map_gen*2) x 16 x 16
            torch.nn.ConvTranspose2d(size_feature_map_gen * 2, size_feature_map_gen, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(size_feature_map_gen),
            torch.nn.ReLU(True),
            # state size. (size_feature_map_gen) x 32 x 32
            torch.nn.ConvTranspose2d(size_feature_map_gen, num_channels, 4, 2, 1, bias=False),
            torch.nn.Tanh()
            # state size. (num_channels) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)
    
G_net = Genterator().to(device)
G_net.apply(init_weights)
print(G_net)

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = torch.nn.Sequential(
            # input is (num_channels) x 64 x 64
            torch.nn.Conv2d(num_channels, size_feature_map_disc, 4, 2, 1, bias=False),
            torch.nn.LeakyReLU(0.2, inplace=True),
            # state size. (size_feature_map_disc) x 32 x 32
            torch.nn.Conv2d(size_feature_map_disc, size_feature_map_disc * 2, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(size_feature_map_disc * 2),
            torch.nn.LeakyReLU(0.2, inplace=True),
            # state size. (size_feature_map_disc*2) x 16 x 16
            torch.nn.Conv2d(size_feature_map_disc * 2, size_feature_map_disc * 4, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(size_feature_map_disc * 4),
            torch.nn.LeakyReLU(0.2, inplace=True),
            # state size. (size_feature_map_disc*4) x 8 x 8
            torch.nn.Conv2d(size_feature_map_disc * 4, size_feature_map_disc * 8, 4, 2, 1, bias=False),
            torch.nn.BatchNorm2d(size_feature_map_disc * 8),
            torch.nn.LeakyReLU(0.2, inplace=True),
            # state size. (size_feature_map_disc*8) x 4 x 4
            torch.nn.Conv2d(size_feature_map_disc * 8, 1, 4, 1, 0, bias=False),
            torch.nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)
    
D_net = Discriminator().to(device)
D_net.apply(init_weights)
print(D_net)

In [None]:
criterion = torch.nn.BCELoss()
fixed_noise = torch.randn(64, num_latents, 1, 1, device=device)
real_label = 1.
fake_label = 0.

d_optimizer = torch.optim.Adam(D_net.parameters(), lr=lr, betas=(beta1, 0.999))
g_optimizer = torch.optim.Adam(G_net.parameters(), lr=lr, betas=(beta1, 0.999))

img_list = []
G_losses = []
D_losses = []
iters = 0
temp_epochs, final_epochs = 5, 100

In [None]:
for epoch in range(temp_epochs):
    temp_gloss, temp_dloss = 0, 0
    with tqdm.tqdm(total=len(dataloader)) as pbar:
        for i, data in enumerate(dataloader, 0):

            D_net.zero_grad()

            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

            output = D_net(real_cpu).view(-1)

            errD_real = criterion(output, label)

            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(b_size, num_latents, 1, 1, device=device)

            fake = G_net(noise)
            label.fill_(fake_label)
            
            output = D_net(fake.detach()).view(-1)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            d_optimizer.step()

            G_net.zero_grad()
            label.fill_(real_label)
            output = D_net(fake).view(-1)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            g_optimizer.step()

            temp_gloss += errG.item()
            temp_dloss += errD.item()

            if (iters % 500 == 0) or ((epoch == temp_epochs-1) and (i == len(dataloader)-1)):
                with torch.no_grad():
                    fake = G_net(fixed_noise).detach().cpu()
                img_list.append(torchvision.utils.make_grid(fake, padding=2, normalize=True))

            iters += 1

            pbar.set_description(f"Epoch {epoch}")
            pbar.set_postfix(Loss_D=errD.item(), Loss_G=errG.item(), D_x=D_x, D_G_z1=D_G_z1, D_G_z2=D_G_z2)
            pbar.update(1)

        G_losses.append(temp_gloss/(len(dataloader) * batch_size))
        D_losses.append(temp_dloss/(len(dataloader) * batch_size))

In [None]:
plt.figure(figsize=(5,5))
plt.title("Generator and Discriminator Loss During Training with 5 epochs")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("Epochs")
plt.ylabel("Binary Cross Entropy Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(8,8))
plt.title("Fake Images with 5 Epochs")
plt.axis("off")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.savefig("fake_images_5.png")

In [None]:
# initialize generator and discriminator
G_net.apply(init_weights)
D_net.apply(init_weights)
G_losses = []
D_losses = []
img_list = []

In [None]:
for epoch in range(final_epochs):
    temp_gloss, temp_dloss = 0, 0
    with tqdm.tqdm(total=len(dataloader)) as pbar:
        for i, data in enumerate(dataloader, 0):

            D_net.zero_grad()

            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)

            output = D_net(real_cpu).view(-1)

            errD_real = criterion(output, label)

            errD_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(b_size, num_latents, 1, 1, device=device)

            fake = G_net(noise)
            label.fill_(fake_label)
            
            output = D_net(fake.detach()).view(-1)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            d_optimizer.step()

            G_net.zero_grad()
            label.fill_(real_label)
            output = D_net(fake).view(-1)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            g_optimizer.step()

            temp_gloss += errG.item()
            temp_dloss += errD.item()

            if (iters % 500 == 0) or ((epoch == temp_epochs-1) and (i == len(dataloader)-1)):
                with torch.no_grad():
                    fake = G_net(fixed_noise).detach().cpu()
                img_list.append(torchvision.utils.make_grid(fake, padding=2, normalize=True))

            iters += 1

            pbar.set_description(f"Epoch {epoch}")
            pbar.set_postfix(Loss_D=errD.item(), Loss_G=errG.item(), D_x=D_x, D_G_z1=D_G_z1, D_G_z2=D_G_z2)
            pbar.update(1)

        G_losses.append(temp_gloss/(len(dataloader) * batch_size))
        D_losses.append(temp_dloss/(len(dataloader) * batch_size))

In [None]:
plt.figure(figsize=(5,5))
plt.title("Generator and Discriminator Loss During Training with 100 epochs")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("Epochs")
plt.ylabel("Binary Cross Entropy Loss")
plt.legend()
plt.show()
plt.savefig("losses_100.png")

In [None]:
plt.figure(figsize=(8,8))
plt.title("Fake Images with 100 Epochs")
plt.axis("off")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.savefig("fake_images_100.png")

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
ani.save("animation.gif", writer="imagemagick")

### 1.1.a 
### Draw some samples generated from your generator at different training stages. For example, you may show the results when running at 5th and final epoch 100

<center>
    <img src = "./image/fake_images_5.png" width="49%">
    <img src = "./image/fake_images_100.png" width="49%">
</center>

### 1.1.b
### The Helvetica Scenario often happens during the training procedure of GAN. Please explain why this problem occurs and how to avoid it. We suggest you can read the original paper and do the discuss.


## 1.2 Denoising Diffusion Probabilistic Model (30%)

#### In this exercise, you will implement a <font color=red>Denoising Diffusion Probabilistic Model (DDPM) </font>to generate images by the provided  <font color=red>anime faces dataset</font>. The Figure below is the process of the Diffusion Model. It consists of a forward process, which gradually adds noise, and the reverse process will transform the noise back into a sample from the target distribution. Here is the [link1](https://lilianweng.github.io/posts/2021-07-11-diffusion-models/) and [link2](https://www.youtube.com/watch?v=azBugJzmz-o&t=190s) to the detailed introduction to the diffusion model. 

<center>
  <img src="https://i.imgur.com/BqpRi4v.png"/>
</center>

1. Construct  <font color='blue'>DDPM</font> by fulfilling the <font color='red'>2 TODOs</font> and follow the instruction. Noticed that you are not allowed to directly call library or API to load the model. The total epoch is 10. (20\%)

  (a) **Draw** some generated samples based on diffusion steps $T = 500$ and $T = 1000$. We provide the **pre-trained weights** which are trained with 500 and 1000 steps. Hint: In the paper, the steps start at 1..

  (b) **Discuss** the result based on different diffusion steps.

### Training (You can skip this)

- Notice that becuase the diffusion requires high computational device, Colab may not be suitable. Thus, we provide the code of Training for reference. 

In [None]:
# !gdown 1E8yulcTDMk9dvz2dJ_TniLKdU4n6AFwa
# !gdown 1g_RYSP1A2rXg_ud18ARlWXK8BWhiHdjV

In [None]:
# !pip install torchmetrics

In [None]:
import torch, sys
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm, trange
from model import Unet
from torchmetrics import MeanMetric
from dataloader import get_loader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device) #make sure this is cuda

In [None]:
T = 500
ALPHA = 1-torch.linspace(1e-4, 2e-2, T)
def alpha(t):
    at = torch.prod(ALPHA[:t]).reshape((1, ))
    return torch.sqrt(torch.cat((at, 1-at)))
ALPHA_bar = torch.stack([alpha(t) for t in range(T)]).to(device)

batch_size = 32
update_step = 1
save_step = 2
save_step_ = 20
num_workers = 6
epochs = 100
loss_func = torch.nn.MSELoss()
lr = 5e-4
model = Unet(
    in_channels=3
)
state_dict = torch.load('checkpoint_100epoch_T500.pth')
optimizer = Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

In [None]:
def train(model, data_loader):
    running_loss = MeanMetric(accumulate=True)

    model.train()
    optimizer.zero_grad()

    for epoch in (overall:=trange(1, epochs+1, position=1, desc='[Overall]')):
        running_loss.reset()

        for i, X_0 in enumerate(bar := tqdm(data_loader, position=0, desc=f'[Train {epoch:3d}] lr={scheduler.get_last_lr()[0]:.2e}'), start=1):
            X_0 = X_0.to(device)
            eps = torch.randn(X_0.shape, device=device)
            t = torch.randint(0, T, (X_0.shape[0], ), device=device)

            # print(ALPHA_bar[t, 0].reshape(-1, 1, 1, 1)*X_0)
            with torch.no_grad():
                X_noise = ALPHA_bar[t, 0].reshape(-1, 1, 1, 1)*X_0 + ALPHA_bar[t, 1].reshape(-1, 1, 1, 1)*eps
            # X_noise = X_noise.to(device)
            t = t.to(device)
            
            pred = model(X_noise, t+1)

            loss = loss_func(eps, pred)
            loss.backward()

            if i%update_step == 0 or i == bar.total:
                optimizer.step()
                optimizer.zero_grad()

            running_loss.update(loss.item())
            bar.set_postfix_str(f'loss {running_loss.compute():.2e}')

        scheduler.step()
        tqdm.write('\r\033[K', end='')

        if epoch % save_step == 0:
            save_checkpoint(epoch, model, optimizer, 'checkpoint.pth')
        if epoch % save_step_ == 0:
            save_checkpoint(epoch, model, optimizer, f'checkpoint_{epoch}.pth')    

def save_checkpoint(epoch, model, optimizer, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }, path)
    tqdm.write('Save checkpoint')

In [None]:
# Training start
train_loader = get_loader(
    './data/resized_64x64/',
    batch_size=batch_size, 
    num_workers=num_workers,
)
model = model.to(device)
train(model, train_loader)

### Sampling

In [None]:
# !gdown 1E8yulcTDMk9dvz2dJ_TniLKdU4n6AFwa 
# # dataloader.py
# !gdown 1g_RYSP1A2rXg_ud18ARlWXK8BWhiHdjV 
# # model.py
# !gdown 1n9K-HSY3GJKTS4HkHTCA_AJT0q1cKBZ1 
# # checkpoint_epoch100_T1000.pth
# !gdown 1jPycQFo_f_fPRUg6OuauTrsXbdibTvKI 
# # checkpoint_epoch100_T500.pth

In [None]:
import torch, os
from tqdm import trange, tqdm
from torchvision.utils import save_image
from torchvision.transforms import ColorJitter
import torch, sys
from tqdm import tqdm, trange
from model import Unet

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device) # cuda is recommand
T1 = 500   # 500 or 1000
model1 = Unet(in_channels=3).to(device)
state_dict1 = torch.load('checkpoint_100epoch_T500.pth')
model1.load_state_dict(state_dict1['model_state_dict'])

T2 = 1000   # 500 or 1000
model2 = Unet(in_channels=3).to(device)
state_dict2 = torch.load('checkpoint_100epoch_T1000.pth')
model2.load_state_dict(state_dict2['model_state_dict'])

<center>
    <img src = "./image/DDPM.png">
</center>

In [None]:
ALPHA = (1-torch.linspace(1e-4, 2e-2, T)).reshape((-1, 1)).to(device)

@torch.no_grad()
def generate_and_save(model, gen_N, chan=3, resolu=(28, 28)):
    model.eval()
    L = []
    # Sample gaussian noise X_T (5%) and denoise it
    X_T = torch.randn(gen_N, chan, *resolu, device=device)
    L.append(X_T)

    for t in (bar := trange(T-1, -1, -1)):

        bar.set_description(f'[Denoising] step: {t}')
        # Sampling
        eps = torch.randn(gen_N, chan, *resolu, device=device)
        # Denoising
        X_T = ALPHA[t]*X_T + (1-ALPHA[t])*model(X_T, t+1) + torch.sqrt(1-ALPHA[t]**2)*eps

        if t < 1:
            L.append(X_T)
            break
    
    save_image(torch.cat(L)/2+0.5, 'L.jpg')

### 1.2.a 
#### **Draw** some generated samples based on diffusion steps $T = 500$ and $T = 1000$. We provide the **pre-trained weights** which are trained with 500 and 1000 steps. Hint: In the paper, the steps start at 1..

In [None]:
# gen_N is the setting of output images number
# resolu is the setting of output images resolution, you should not change this.
# This function will automatically save the sample images, what you need to do is to show in here.
generate_and_save(model1, gen_N=64, resolu=(64, 64))

In [None]:
generate_and_save(model2, gen_N=64, resolu=(64, 64))

### 1.2.b
#### **Discuss** the result based on different diffusion steps.

## 1.3 Comparison between GAN and DDPM (10%)
#### (a) Both GAN and DDPM are generative models. The following figures are randomly generated results by using GAN (left) and DDPM (right). Please describe the pros and cons of the two models. (10%)

<center>
    <img src="https://i.imgur.com/pU77cfa.jpg" width="600px"/>
</center>

### 1.3.a
### Both GAN and DDPM are generative models. The figures are randomly generated results by using GAN (left) and DDPM (right). Please describe the pros and cons of the two models based on your observation.