Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Docs] Translate CN 'train_a_gan' to EN #860

Merged
merged 3 commits into from
Jan 6, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
301 changes: 300 additions & 1 deletion docs/en/examples/train_a_gan.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,302 @@
# Train a GAN

Coming soon. Please refer to [chinese documentation](https://mmengine.readthedocs.io/zh_CN/latest/examples/train_a_gan.html).
Generative Adversarial Network (GAN) can be used to generate data such as images and videos. This tutorial will show you how to train a GAN with MMEngine step by step!

It will be divided into the following steps:

> - [Train Generative Adversarial Network](#train-a-gan)
> - [Build a Data Loader](#building-a-data-loader)
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
> - [Build a Dataset](#building-a-dataset)
> - [Build a Generator Network and a Discriminator Network](#build-a-generator-network-and-a-discriminator-network)
> - [Build a Generative Adversarial Network Model](#build-a-generative-adversarial-network-model)
> - [Build an Optimizer](#building-an-optimizer)
> - [Train with Runner](#training-with-runner)

## Building a Data Loader
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

### Building a Dataset

First, we will build a dataset class `MNISTDataset` for the MNIST dataset, inheriting from the base dataset class [BaseDataset](mmengine.dataset.BaseDataset), and overload the `load_data_list` function of the base dataset class to ensure that the return value is a `list[dict]`, where each `dict` represents a data sample.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
More details about using datasets in MMEngine, refer to [the Dataset tutorial](../tutorials/basedataset.md).

```python
import numpy as np
from mmcv.transforms import to_tensor
from torch.utils.data import random_split
from torchvision.datasets import MNIST

from mmengine.dataset import BaseDataset


class MNISTDataset(BaseDataset):

def __init__(self, data_root, pipeline, test_mode=False):
# Download MNIST Dataset
if test_mode:
mnist_full = MNIST(data_root, train=True, download=True)
self.mnist_dataset, _ = random_split(mnist_full, [55000, 5000])
else:
self.mnist_dataset = MNIST(data_root, train=False, download=True)

super().__init__(
data_root=data_root, pipeline=pipeline, test_mode=test_mode)

@staticmethod
def totensor(img):
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
return to_tensor(img)

def load_data_list(self):
return [
dict(inputs=self.totensor(np.array(x[0]))) for x in self.mnist_dataset
]


dataset = MNISTDataset("./data", [])

```

Use the function 'build_dataloader' in Runner to build the data loader.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved

```python
import os
import torch
from mmengine.runner import Runner

NUM_WORKERS = int(os.cpu_count() / 2)
BATCH_SIZE = 256 if torch.cuda.is_available() else 64

train_dataloader = dict(
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dataset)
train_dataloader = Runner.build_dataloader(train_dataloader)
```

## Build a Generator Network and a Discriminator Network

The following code builds and instantiates a Generator and a Discriminator.

```python
import torch.nn as nn

class Generator(nn.Module):
def __init__(self, noise_size, img_shape):
super().__init__()
self.img_shape = img_shape
self.noise_size = noise_size

def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers

self.model = nn.Sequential(
*block(noise_size, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh(),
)

def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
```

```python
class Discriminator(nn.Module):
def __init__(self, img_shape):
super().__init__()

self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)

def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)

return validity
```

```python
generator = Generator(100, (1, 28, 28))
discriminator = Discriminator((1, 28, 28))
```

## Build a Generative Adversarial Network Model

In MMEngine, we use [ImgDataPreprocessor](mmengine.model.ImgDataPreprocessor) to normalize the data and convert the color channels.

```python
from mmengine.model import ImgDataPreprocessor

data_preprocessor = ImgDataPreprocessor(mean=([127.5]), std=([127.5]))
```

The following code implements the basic algorithm of GAN. To implement the algorithm class using MMEngine, you need to inherit from the [BaseModel](mmengine.model.BaseModel) base class and implement the training process in the train_step. GAN requires alternating training of the generator and discriminator, which are implemented by train_discriminator and train_generator and implement disc_loss and gen_loss to calculate the discriminator loss function and generator loss function.
zhouzaida marked this conversation as resolved.
Show resolved Hide resolved
More details about BaseModel, refer to [Model tutorial](../tutorials/model.md).

```python
import torch.nn.functional as F
from mmengine.model import BaseModel

class GAN(BaseModel):

def __init__(self, generator, discriminator, noise_size,
data_preprocessor):
super().__init__(data_preprocessor=data_preprocessor)
assert generator.noise_size == noise_size
self.generator = generator
self.discriminator = discriminator
self.noise_size = noise_size

def train_step(self, data, optim_wrapper):
# Acquiring and preprocessing data
inputs_dict = self.data_preprocessor(data, True)
# Training the discriminator
disc_optimizer_wrapper = optim_wrapper['discriminator']
with disc_optimizer_wrapper.optim_context(self.discriminator):
log_vars = self.train_discriminator(inputs_dict,
disc_optimizer_wrapper)

# Training the generator
set_requires_grad(self.discriminator, False)
gen_optimizer_wrapper = optim_wrapper['generator']
with gen_optimizer_wrapper.optim_context(self.generator):
log_vars_gen = self.train_generator(inputs_dict,
gen_optimizer_wrapper)

set_requires_grad(self.discriminator, True)
log_vars.update(log_vars_gen)

return log_vars

def forward(self, batch_inputs, data_samples=None, mode=None):
return self.generator(batch_inputs)

def disc_loss(self, disc_pred_fake, disc_pred_real):
losses_dict = dict()
losses_dict['loss_disc_fake'] = F.binary_cross_entropy(
disc_pred_fake, 0. * torch.ones_like(disc_pred_fake))
losses_dict['loss_disc_real'] = F.binary_cross_entropy(
disc_pred_real, 1. * torch.ones_like(disc_pred_real))

loss, log_var = self.parse_losses(losses_dict)
return loss, log_var

def gen_loss(self, disc_pred_fake):
losses_dict = dict()
losses_dict['loss_gen'] = F.binary_cross_entropy(
disc_pred_fake, 1. * torch.ones_like(disc_pred_fake))
loss, log_var = self.parse_losses(losses_dict)
return loss, log_var

def train_discriminator(self, inputs, optimizer_wrapper):
real_imgs = inputs['inputs']
z = torch.randn(
(real_imgs.shape[0], self.noise_size)).type_as(real_imgs)
with torch.no_grad():
fake_imgs = self.generator(z)

disc_pred_fake = self.discriminator(fake_imgs)
disc_pred_real = self.discriminator(real_imgs)

parsed_losses, log_vars = self.disc_loss(disc_pred_fake,
disc_pred_real)
optimizer_wrapper.update_params(parsed_losses)
return log_vars

def train_generator(self, inputs, optimizer_wrapper):
real_imgs = inputs['inputs']
z = torch.randn(real_imgs.shape[0], self.noise_size).type_as(real_imgs)

fake_imgs = self.generator(z)

disc_pred_fake = self.discriminator(fake_imgs)
parsed_loss, log_vars = self.gen_loss(disc_pred_fake)

optimizer_wrapper.update_params(parsed_loss)
return log_vars
```

The function, set_requires_grad, is used to lock the weights of the discriminator when training the generator.

```python
def set_requires_grad(nets, requires_grad=False):
"""Set requires_grad for all the networks.

Args:
nets (nn.Module | list[nn.Module]): A list of networks or a single
network.
requires_grad (bool): Whether the networks require gradients or not.
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
```

```python

model = GAN(generator, discriminator, 100, data_preprocessor)

```

## Building an Optimizer

MMEngine uses [OptimWrapper](mmengine.optim.OptimWrapper) to wrap optimizers. For multiple optimizers, we use [OptimWrapperDict](mmengine.optim.OptimWrapperDict) to further wrap OptimWrapper.
More details about optimizers, refer to the [Optimizer tutorial](../tutorials/optimizer.md).

```python
from mmengine.optim import OptimWrapper, OptimWrapperDict

opt_g = torch.optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_g_wrapper = OptimWrapper(opt_g)

opt_d = torch.optim.Adam(
discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
opt_d_wrapper = OptimWrapper(opt_d)

opt_wrapper_dict = OptimWrapperDict(
generator=opt_g_wrapper, discriminator=opt_d_wrapper)

```

## Training with Runner

The following code demonstrates how to use Runner for model training.
More details about Runner, please refer to the [Runner tutorial](../tutorials/runner.md).

```python
train_cfg = dict(by_epoch=True, max_epochs=220)
runner = Runner(
model,
work_dir='runs/gan/',
train_dataloader=train_dataloader,
train_cfg=train_cfg,
optim_wrapper=opt_wrapper_dict)
runner.train()
```

Till now, we have completed an example of training a GAN. The following code can be used to view the results generated by the GAN we just trained.

![GAN generate an image](https://user-images.githubusercontent.com/22982797/186811532-1517a0f7-5452-4a39-b6d0-6c685e4545e2.png)

If you want to learn more about using MMEngine to implement GAN and generative models, we highly recommend you try the generative framework [MMGeneration](https://github.com/open-mmlab/mmgeneration/tree/dev-1.x) based on MMEngine.