# **Import Libraries**

In [1]:
import warnings

import torch
import torch.nn as nn
from torch.optim import RMSprop
from torch.optim import Adam
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import FashionMNIST

from tqdm import tqdm

# **Ignoring Any Warning Messages**

In [2]:
warnings.filterwarnings('ignore')

# **Define Important Hyperparameters**

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 5e-5
num_classes = 10
embedding_size = 100
num_epochs = 30
batch_size = 64
img_dim = 64
z_dim = 128
img_channels = 1
features = 64
critic_iterations = 5
weight_clip = 0.01
fixed_noise = torch.rand(32, z_dim, 1, 1).to(device)
lambda_gp = 10

# **Create Discriminator**

In [4]:
class Discriminator(nn.Module):
  def __init__(self, img_channels, features):
    super().__init__()

    self.net = nn.Sequential(
        nn.Conv2d(img_channels, features * 1, 4, 2, 1),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 1, features * 2, 4, 2, 1),
        nn.BatchNorm2d(features * 2),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 2, features * 4, 4, 2, 1),
        nn.BatchNorm2d(features * 4),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 4, features * 8, 4, 2, 1),
        nn.BatchNorm2d(features * 8),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 8,  1, 4, 2, 0),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.net(x)

# **Create Generator**

In [5]:
class Generator(nn.Module):
  def __init__(self, noise_channels, img_channels, features):
    super().__init__()

    self.net = nn.Sequential(
        nn.ConvTranspose2d(noise_channels, features * 16, 4, 1, 0),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 16, features * 8, 4, 2, 1),
        nn.BatchNorm2d(features * 8),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 8, features * 4, 4, 2, 1),
        nn.BatchNorm2d(features * 4),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 4, features * 2, 4, 2, 1),
        nn.BatchNorm2d(features * 2),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 2, img_channels, 4, 2, 1),
        nn.Tanh(),
    )

  def forward(self, x):
    return self.net(x)

# **Define Weights Intialization Function**

In [6]:
def initialize_weights(model):
  for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d,  nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data, 0, .02)

# **Generate Dataloader**

### **Define Transformation**

In [7]:
transforms = transforms.Compose(
    [
        transforms.Resize(img_dim),
        transforms.ToTensor(),
        transforms.Normalize((.5, ), (.5, ))
    ]
)

### **Create Dataloader**

In [8]:
dataloader = DataLoader(FashionMNIST('.', download = True, transform = transforms), batch_size = batch_size, shuffle = True)

100%|██████████| 26.4M/26.4M [00:01<00:00, 14.2MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 270kB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.03MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 11.4MB/s]


# **Training Loop [WGAN]**

In [9]:
gen = Generator(z_dim, img_channels, features).to(device)
initialize_weights(gen)

critic = Discriminator(img_channels, features).to(device)
initialize_weights(critic)

opt_gen  = RMSprop(gen.parameters(),  lr = lr)
opt_critic = RMSprop(critic.parameters(), lr = lr)

for epoch in range(num_epochs):
    gen.train()
    critic.train()

    for real, _ in tqdm(dataloader):
        for _ in range(critic_iterations):
            batch_size = real.shape[0]

            real = real.to(device)

            noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
            fake = gen(noise)

            critic_real = critic(real).view(-1)
            critic_fake = critic(fake.detach()).view(-1)

            lossC = -(torch.mean(critic_real) - torch.mean(critic_fake))

            critic.zero_grad()
            lossC.backward(retain_graph=True)
            opt_critic.step()

            for p in critic.parameters():
                p.data.clamp_(-weight_clip, weight_clip)

        output = critic(fake)
        lossG = -torch.mean(output)

        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

    print(f"epoch{epoch + 1}:  lossC:{-lossC:.4f}   lossG:{-lossG:.4f}")

    with torch.no_grad():
        gen.eval()
        critic.eval()

        fake = gen(fixed_noise).reshape(-1, 1, 64, 64)

        real, _ = next(iter(dataloader))
        real = real.to(device)

        img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
        img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)

        torchvision.utils.save_image(img_grid_fake, "fake_grid_WGAN.png")
        torchvision.utils.save_image(img_grid_real, "real_grid_WGAN.png")

100%|██████████| 938/938 [07:04<00:00,  2.21it/s]


epoch1:  lossC:0.3678   lossG:0.3206


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch2:  lossC:0.3602   lossG:0.3244


100%|██████████| 938/938 [07:17<00:00,  2.14it/s]


epoch3:  lossC:0.2760   lossG:0.3619


100%|██████████| 938/938 [07:18<00:00,  2.14it/s]


epoch4:  lossC:0.2666   lossG:0.3775


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch5:  lossC:0.2721   lossG:0.3470


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch6:  lossC:0.3075   lossG:0.3485


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch7:  lossC:0.1795   lossG:0.4934


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch8:  lossC:0.3087   lossG:0.3492


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch9:  lossC:0.3164   lossG:0.3432


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch10:  lossC:0.1569   lossG:0.3567


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch11:  lossC:0.2094   lossG:0.3961


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch12:  lossC:0.2408   lossG:0.4062


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch13:  lossC:0.2541   lossG:0.3644


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch14:  lossC:0.2593   lossG:0.3723


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch15:  lossC:0.2643   lossG:0.3659


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch16:  lossC:0.2401   lossG:0.3595


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch17:  lossC:0.2473   lossG:0.3725


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch18:  lossC:0.2346   lossG:0.3662


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch19:  lossC:0.1805   lossG:0.4241


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch20:  lossC:0.2222   lossG:0.3647


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch21:  lossC:0.1741   lossG:0.3672


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch22:  lossC:0.2101   lossG:0.4040


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch23:  lossC:0.1929   lossG:0.3761


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch24:  lossC:0.2037   lossG:0.3799


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch25:  lossC:0.2104   lossG:0.4167


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch26:  lossC:0.1974   lossG:0.4136


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch27:  lossC:0.1965   lossG:0.4260


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]


epoch28:  lossC:0.1982   lossG:0.3677


100%|██████████| 938/938 [07:15<00:00,  2.15it/s]


epoch29:  lossC:0.1760   lossG:0.4348


100%|██████████| 938/938 [07:16<00:00,  2.15it/s]

epoch30:  lossC:0.1767   lossG:0.4037





# **Define Gradient Penalty**

In [10]:
def gradient_penalty(critic, fake, real, device = 'cpu'):
  B, C, H, W = real.shape

  alpha = torch.rand((B, 1, 1, 1)).repeat((1, C, H, W)).to(device)
  interpolated_imgs = real * alpha + fake * (1 - alpha)
  mixed_scores = critic(interpolated_imgs)

  gradient = torch.autograd.grad(
      inputs = interpolated_imgs,
      outputs = mixed_scores,
      grad_outputs = torch.ones_like(mixed_scores),
      create_graph = True,
      retain_graph = True
  )[0]

  gradient = gradient.view(gradient.shape[0], -1)
  gradient_norm = gradient.norm(2, dim = 1)
  gradient_penalty = torch.mean((gradient_norm - 1) ** 2)

  return gradient_penalty

# **Training Loop [WGAN-GP]**

In [11]:
gen_gp = Generator(z_dim, img_channels, features).to(device)
initialize_weights(gen_gp)

critic_gp = Discriminator(img_channels, features).to(device)
initialize_weights(critic_gp)

opt_gen_gp = Adam(gen_gp.parameters(), lr=lr, betas=(0.0, 0.9))
opt_critic_gp = Adam(critic_gp.parameters(), lr=lr, betas=(0.0, 0.9))

for epoch in range(num_epochs):
    gen_gp.train()
    critic_gp.train()

    for real, _ in tqdm(dataloader):
        batch_size = real.shape[0]
        real = real.to(device)

        for _ in range(critic_iterations):
            noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
            fake = gen_gp(noise)

            critic_real = critic_gp(real).view(-1)
            critic_fake = critic_gp(fake.detach()).view(-1)

            gp = gradient_penalty(critic_gp, fake, real, device)
            lossC_gp = (-(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp)

            critic_gp.zero_grad()
            lossC_gp.backward(retain_graph=True)
            opt_critic_gp.step()

        output = critic_gp(fake)
        lossG_gp = -torch.mean(output)

        gen_gp.zero_grad()
        lossG_gp.backward()
        opt_gen_gp.step()

    print(f"epoch {epoch + 1}:  lossC: {-lossC_gp:.4f}   lossG: {-lossG_gp:.4f}")

    with torch.no_grad():
        gen_gp.eval()
        critic_gp.eval()

        fake = gen_gp(fixed_noise).reshape(-1, 1, 64, 64)

        real, _ = next(iter(dataloader))
        real = real.to(device)

        img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
        img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)

        torchvision.utils.save_image(img_grid_fake, "fake_grid_WGAN_GP.png")
        torchvision.utils.save_image(img_grid_real, "real_grid_WGAN_GP.png")

100%|██████████| 938/938 [13:53<00:00,  1.13it/s]


epoch 1:  lossC: 0.7835   lossG: 0.0027


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 2:  lossC: 0.8791   lossG: 0.0004


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 3:  lossC: 0.2450   lossG: 0.5839


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 4:  lossC: 0.1681   lossG: 0.7050


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 5:  lossC: 0.8821   lossG: 0.0216


100%|██████████| 938/938 [13:53<00:00,  1.13it/s]


epoch 6:  lossC: 0.7766   lossG: 0.1382


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 7:  lossC: 0.7539   lossG: 0.1539


100%|██████████| 938/938 [13:53<00:00,  1.12it/s]


epoch 8:  lossC: 0.8573   lossG: 0.0939


100%|██████████| 938/938 [13:52<00:00,  1.13it/s]


epoch 9:  lossC: 0.8840   lossG: 0.0564


100%|██████████| 938/938 [13:52<00:00,  1.13it/s]


epoch 10:  lossC: 0.7842   lossG: 0.1067


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 11:  lossC: 0.8172   lossG: 0.1216


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 12:  lossC: 0.7684   lossG: 0.1590


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 13:  lossC: 0.7618   lossG: 0.1311


100%|██████████| 938/938 [13:50<00:00,  1.13it/s]


epoch 14:  lossC: 0.7964   lossG: 0.1643


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 15:  lossC: 0.7509   lossG: 0.0665


100%|██████████| 938/938 [13:50<00:00,  1.13it/s]


epoch 16:  lossC: 0.6520   lossG: 0.1012


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 17:  lossC: 0.6873   lossG: 0.0778


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 18:  lossC: 0.8539   lossG: 0.0874


100%|██████████| 938/938 [13:50<00:00,  1.13it/s]


epoch 19:  lossC: 0.7866   lossG: 0.1783


100%|██████████| 938/938 [13:51<00:00,  1.13it/s]


epoch 20:  lossC: 0.8971   lossG: 0.0441


100%|██████████| 938/938 [13:52<00:00,  1.13it/s]


epoch 21:  lossC: 0.6360   lossG: 0.1781


100%|██████████| 938/938 [13:50<00:00,  1.13it/s]


epoch 22:  lossC: 0.6500   lossG: 0.1951


100%|██████████| 938/938 [13:50<00:00,  1.13it/s]


epoch 23:  lossC: 0.5433   lossG: 0.1798


100%|██████████| 938/938 [13:50<00:00,  1.13it/s]


epoch 24:  lossC: 0.8423   lossG: 0.1246


100%|██████████| 938/938 [13:48<00:00,  1.13it/s]


epoch 25:  lossC: 0.8225   lossG: 0.0489


100%|██████████| 938/938 [13:50<00:00,  1.13it/s]


epoch 26:  lossC: 0.6887   lossG: 0.2071


100%|██████████| 938/938 [13:49<00:00,  1.13it/s]


epoch 27:  lossC: 0.8622   lossG: 0.0996


100%|██████████| 938/938 [13:48<00:00,  1.13it/s]


epoch 28:  lossC: 0.8612   lossG: 0.0962


100%|██████████| 938/938 [13:50<00:00,  1.13it/s]


epoch 29:  lossC: 0.8407   lossG: 0.0789


100%|██████████| 938/938 [13:49<00:00,  1.13it/s]

epoch 30:  lossC: 0.6569   lossG: 0.1509





# **Create Conditional Discriminator**

In [12]:
class ConditionalDiscriminator(nn.Module):
  def __init__(self, img_channels, features, num_classes, img_size):
    super().__init__()

    self.img_size = img_size

    self.embed = nn.Embedding(num_classes, img_size * img_size)

    self.net = nn.Sequential(
        nn.Conv2d(img_channels + 1, features * 1, 4, 2, 1),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 1, features * 2, 4, 2, 1),
        nn.BatchNorm2d(features * 2),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 2, features * 4, 4, 2, 1),
        nn.BatchNorm2d(features * 4),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 4, features * 8, 4, 2, 1),
        nn.BatchNorm2d(features * 8),
        nn.LeakyReLU(.2),

        nn.Conv2d(features * 8,  1, 4, 2, 0),
        nn.Sigmoid(),
    )

  def forward(self, x, labels):
    embedding = self.embed(labels).view(labels.shape[0], 1, self.img_size, self.img_size)
    x = torch.cat([embedding, x], dim = 1)
    return self.net(x)

# **Create Conditional Generator**

In [13]:
class ConditionalGenerator(nn.Module):
  def __init__(self, noise_channels, img_channels, features, num_classes, embedding_size):
    super().__init__()

    self.embed = nn.Embedding(num_classes, embedding_size)

    self.net = nn.Sequential(
        nn.ConvTranspose2d(noise_channels + embedding_size, features * 16, 4, 1, 0),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 16, features * 8, 4, 2, 1),
        nn.BatchNorm2d(features * 8),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 8, features * 4, 4, 2, 1),
        nn.BatchNorm2d(features * 4),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 4, features * 2, 4, 2, 1),
        nn.BatchNorm2d(features * 2),
        nn.ReLU(),

        nn.ConvTranspose2d(features * 2, img_channels, 4, 2, 1),
        nn.Tanh(),
    )

  def forward(self, x, labels):
    embedding = self.embed(labels).unsqueeze(2).unsqueeze(3)
    x = torch.cat([embedding, x], dim = 1)
    return self.net(x)

# **Define Conditional Gradient Penalty**

In [14]:
def conditional_gradient_penalty(critic, fake, real, labels, device = 'cpu'):
  B, C, H, W = real.shape

  alpha = torch.rand((B, 1, 1, 1)).repeat((1, C, H, W)).to(device)
  interpolated_imgs = real * alpha + fake * (1 - alpha)
  mixed_scores = critic(interpolated_imgs, labels)

  gradient = torch.autograd.grad(
      inputs = interpolated_imgs,
      outputs = mixed_scores,
      grad_outputs = torch.ones_like(mixed_scores),
      create_graph = True,
      retain_graph = True
  )[0]

  gradient = gradient.view(gradient.shape[0], -1)
  gradient_norm = gradient.norm(2, dim = 1)
  gradient_penalty = torch.mean((gradient_norm - 1) ** 2)

  return gradient_penalty

# **Training Loop [Conditional WGAN]**

In [15]:
con_gen_gp = ConditionalGenerator(z_dim, img_channels, features, num_classes, embedding_size).to(device)
initialize_weights(con_gen_gp)

con_critic_gp = ConditionalDiscriminator(img_channels, features, num_classes, img_dim).to(device)
initialize_weights(con_critic_gp)

opt_con_gen_gp = Adam(con_gen_gp.parameters(), lr=lr, betas=(0.0, 0.9))
opt_con_critic_gp = Adam(con_critic_gp.parameters(), lr=lr, betas=(0.0, 0.9))

for epoch in range(num_epochs):
    con_gen_gp.train()
    con_critic_gp.train()

    for real, labels in tqdm(dataloader):
        batch_size = real.shape[0]
        real = real.to(device)
        labels = labels.to(device)

        for _ in range(critic_iterations):
            noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)
            fake = con_gen_gp(noise, labels)

            critic_real = con_critic_gp(real, labels).view(-1)
            critic_fake = con_critic_gp(fake.detach(), labels).view(-1)

            gp = conditional_gradient_penalty(con_critic_gp, fake, real, labels, device)
            lossC_con_gp = -(torch.mean(critic_real) - torch.mean(critic_fake)) + lambda_gp * gp

            con_critic_gp.zero_grad()
            lossC_con_gp.backward(retain_graph=True)
            opt_con_critic_gp.step()

        output = con_critic_gp(fake, labels)
        lossG_con_gp = -torch.mean(output)

        con_gen_gp.zero_grad()
        lossG_con_gp.backward()
        opt_con_gen_gp.step()

    print(f"epoch {epoch + 1}:  lossC: {-lossC_con_gp:.4f}   lossG: {-lossG_con_gp:.4f}")

    with torch.no_grad():
        con_gen_gp.eval()
        con_critic_gp.eval()

        fake = con_gen_gp(fixed_noise, labels).reshape(-1, 1, 64, 64)

        real, _ = next(iter(dataloader))
        real = real.to(device)

        img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
        img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)

        torchvision.utils.save_image(img_grid_fake, "fake_grid_Conditional_WGAN_GP.png")
        torchvision.utils.save_image(img_grid_real, "real_grid_Conditional_WGAN_GP.png")

100%|██████████| 938/938 [13:56<00:00,  1.12it/s]


epoch 1:  lossC: 0.5547   lossG: 0.0452


100%|██████████| 938/938 [14:01<00:00,  1.11it/s]


epoch 2:  lossC: -5.1338   lossG: 0.3657


100%|██████████| 938/938 [14:01<00:00,  1.12it/s]


epoch 3:  lossC: 0.3243   lossG: 0.3450


100%|██████████| 938/938 [14:01<00:00,  1.12it/s]


epoch 4:  lossC: 0.6667   lossG: 0.0911


100%|██████████| 938/938 [14:01<00:00,  1.11it/s]


epoch 5:  lossC: 0.5337   lossG: 0.2338


100%|██████████| 938/938 [14:00<00:00,  1.12it/s]


epoch 6:  lossC: 0.6592   lossG: 0.2765


100%|██████████| 938/938 [14:00<00:00,  1.12it/s]


epoch 7:  lossC: 0.5669   lossG: 0.3723


100%|██████████| 938/938 [13:59<00:00,  1.12it/s]


epoch 8:  lossC: 0.6495   lossG: 0.2753


100%|██████████| 938/938 [13:58<00:00,  1.12it/s]


epoch 9:  lossC: 0.4197   lossG: 0.2523


100%|██████████| 938/938 [13:59<00:00,  1.12it/s]


epoch 10:  lossC: 0.7646   lossG: 0.1885


100%|██████████| 938/938 [13:59<00:00,  1.12it/s]


epoch 11:  lossC: 0.8082   lossG: 0.1700


100%|██████████| 938/938 [13:59<00:00,  1.12it/s]


epoch 12:  lossC: 0.7093   lossG: 0.2445


100%|██████████| 938/938 [13:58<00:00,  1.12it/s]


epoch 13:  lossC: 0.7576   lossG: 0.1714


100%|██████████| 938/938 [13:58<00:00,  1.12it/s]


epoch 14:  lossC: 0.7343   lossG: 0.2230


100%|██████████| 938/938 [13:58<00:00,  1.12it/s]


epoch 15:  lossC: 0.7228   lossG: 0.1655


100%|██████████| 938/938 [13:58<00:00,  1.12it/s]


epoch 16:  lossC: 0.7654   lossG: 0.1998


100%|██████████| 938/938 [13:58<00:00,  1.12it/s]


epoch 17:  lossC: 0.8102   lossG: 0.1529


100%|██████████| 938/938 [13:58<00:00,  1.12it/s]


epoch 18:  lossC: 0.7829   lossG: 0.1890


100%|██████████| 938/938 [13:57<00:00,  1.12it/s]


epoch 19:  lossC: 0.8396   lossG: 0.1296


100%|██████████| 938/938 [13:58<00:00,  1.12it/s]


epoch 20:  lossC: 0.7830   lossG: 0.2005


100%|██████████| 938/938 [13:57<00:00,  1.12it/s]


epoch 21:  lossC: 0.5951   lossG: 0.3248


100%|██████████| 938/938 [13:58<00:00,  1.12it/s]


epoch 22:  lossC: 0.4740   lossG: 0.1950


100%|██████████| 938/938 [13:57<00:00,  1.12it/s]


epoch 23:  lossC: 0.5153   lossG: 0.3646


100%|██████████| 938/938 [13:56<00:00,  1.12it/s]


epoch 24:  lossC: 0.7187   lossG: 0.1412


100%|██████████| 938/938 [13:57<00:00,  1.12it/s]

epoch 25:  lossC: 0.8447   lossG: 0.0354



