### Import Packages

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

### Config

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
clip_value = 0.01

### Data Loader

In [None]:
transforms = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize([0.5,], [0.5,])
])

dataloader = DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transforms),
    batch_size=64,
    shuffle=True
)

### WGAN

#### Generator

In [None]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, noise_dim=100, img_channels=1, feature_g=64):
        """
        Args:
            noise_dim: Dimension of input noise vector (typically 100)
            img_channels: Number of image channels (1 for grayscale, 3 for RGB)
            feature_g: Base number of feature maps in generator
        """
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        
        self.model = nn.Sequential(
            # Input: noise_dim x 1 x 1
            # Project and reshape: noise_dim -> feature_g*16 feature maps of 4x4
            nn.ConvTranspose2d(noise_dim, feature_g * 16, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_g * 16),
            nn.ReLU(True),
            # State: (feature_g*16) x 4 x 4
            
            nn.ConvTranspose2d(feature_g * 16, feature_g * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g * 8),
            nn.ReLU(True),
            # State: (feature_g*8) x 8 x 8
            
            nn.ConvTranspose2d(feature_g * 8, feature_g * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g * 4),
            nn.ReLU(True),
            # State: (feature_g*4) x 16 x 16
            
            nn.ConvTranspose2d(feature_g * 4, feature_g * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g * 2),
            nn.ReLU(True),
            # State: (feature_g*2) x 32 x 32
            
            nn.ConvTranspose2d(feature_g * 2, img_channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output: img_channels x 64 x 64
        )
    
    def forward(self, x):
        # Reshape input from (batch, noise_dim) to (batch, noise_dim, 1, 1)
        x = x.view(-1, self.noise_dim, 1, 1)
        return self.model(x)

#### Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_channels=1, feature_d=64):
        """
        Args:
            img_channels: Number of image channels (1 for grayscale, 3 for RGB)
            feature_d: Base number of feature maps in discriminator
        """
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            # Input: img_channels x 64 x 64
            nn.Conv2d(img_channels, feature_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # State: feature_d x 32 x 32
            
            nn.Conv2d(feature_d, feature_d * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (feature_d*2) x 16 x 16
            
            nn.Conv2d(feature_d * 2, feature_d * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (feature_d*4) x 8 x 8
            
            nn.Conv2d(feature_d * 4, feature_d * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # State: (feature_d*8) x 4 x 4
            
            nn.Conv2d(feature_d * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # Output: 1 x 1 x 1
        )
    
    def forward(self, x):
        output = self.model(x)
        return output.view(-1, 1)  # Flatten to (batch, 1)

### Create WGAN

In [None]:
noise_dim = 100
img_channel = 1

# models
generator = Generator(noise_dim=noise_dim, img_channels=img_channel).to(device)
discriminator = Discriminator(img_channels=img_channel).to(device)

# optimizers
optimizer_G = optim.RMSprop(generator.parameters(), lr=0.0002)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=0.0002)

### Utils

In [None]:
import os
os.makedirs('WGAN_output_g1_d5_clip_0.01', exist_ok=True)
def show_generated_images(epoch, generator, fixed_noise):
    generator.eval()
    with torch.no_grad():
        fake = generator(fixed_noise).reshape(-1, 1, 64, 64)
        fake = fake *0.5 + 0.5
    
    grid = torchvision.utils.make_grid(fake, nrow=8)
    plt.figure(figsize=(8,8))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.title(f'Generated Images at Epoch {epoch}')
    plt.savefig(f'WGAN_output_g1_d5_clip_0.01/generated_images_epoch_{epoch}.png')
    plt.show()
    generator.train()

### Train

In [None]:
def train_gan(train_dataloader, num_epochs, g_step=1, d_step=1):

    fixed_noise = torch.randn(64, noise_dim).to(device)

    # Store epoch-wise average losses
    d_losses = []
    g_losses = []

    for epoch in range(num_epochs):

        d_loss_epoch = 0.0
        g_loss_epoch = 0.0

        d_count = 0
        g_count = 0

        pbar = tqdm(train_dataloader,
                    desc=f"Epoch {epoch+1}/{num_epochs}",
                    leave=False)

        for i, (real_images, _) in enumerate(pbar):

            batch_size = real_images.size(0)
            real_images = real_images.to(device)

            # ---------------------
            # Train Discriminator
            # ---------------------
            for _ in range(d_step):

                z = torch.randn(batch_size, noise_dim).to(device)
                fake_images = generator(z).detach()

                d_loss = (
                    -torch.mean(discriminator(real_images)) +
                     torch.mean(discriminator(fake_images))
                )

                optimizer_D.zero_grad()
                d_loss.backward()
                optimizer_D.step()

                d_loss_epoch += d_loss.item()
                d_count += 1

                # Weight clipping
                for p in discriminator.parameters():
                    p.data.clamp_(-clip_value, clip_value)

            # -----------------
            # Train Generator
            # -----------------
            for _ in range(g_step):

                z = torch.randn(batch_size, noise_dim).to(device)
                fake_images = generator(z)

                g_loss = -torch.mean(discriminator(fake_images))

                optimizer_G.zero_grad()
                g_loss.backward()
                optimizer_G.step()

                g_loss_epoch += g_loss.item()
                g_count += 1


        # Compute average losses for this epoch
        avg_d_loss = d_loss_epoch / d_count
        avg_g_loss = g_loss_epoch / g_count

        d_losses.append(avg_d_loss)
        g_losses.append(avg_g_loss)

        print(
            f"Epoch [{epoch+1}/{num_epochs}] | "
            f"D Loss: {avg_d_loss:.4f} | "
            f"G Loss: {avg_g_loss:.4f}"
        )

        show_generated_images(epoch+1, generator, fixed_noise)

    return d_losses, g_losses

### Run Training

In [None]:
print("training: 1-step Gen, 5-step Dis")
d_losses, g_losses = train_gan(dataloader, num_epochs=50, g_step=1, d_step=5)

In [None]:
plt.figure()
plt.plot(d_losses, label="Discriminator Loss")
plt.plot(g_losses, label="Generator Loss")

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("WGAN Training Loss")

plt.legend()
plt.savefig("WGAN_output_g1_d5_clip_0.01/wgan_training_loss.png")
plt.show()

### Generate video

In [None]:
import cv2
import os

def images_to_video(image_dir, output_path, fps=9):

    images = sorted([
        img for img in os.listdir(image_dir)
        if img.endswith(".png") or img.endswith(".jpg")
    ])

    # Read first image to get size
    first_image = cv2.imread(os.path.join(image_dir, images[0]))
    height, width, _ = first_image.shape

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    video = cv2.VideoWriter(output_path+".mp4", fourcc, fps, (width, height))

    for image in images:
        img_path = os.path.join(image_dir, image)
        frame = cv2.imread(img_path)
        video.write(frame)

    video.release()
    print("Video saved as:", output_path)


In [None]:
images_to_video(r".", "output_video")