# Build GAN (Generative Adversarial Networks) with PyTorch and SageMaker

There are many public datasets on the Internet, which are very helpful for machine learning engineering and scientific research, such as algorithm study and evaluation. We will use MNIST dataset, which is a handwritten digits dataset, we will use it to train a GAN model, and eventually generate some fake "handwritten" digits.


### Environment setup
Upgrade packages

In [None]:
!pip install --upgrade pip 
!pip install --upgrade sagemaker awscli boto3 pandas Pillow==7.1.2

Following commands for ```SageMaker Studio``` only

In [None]:
!pip uninstall -y tqdm

In [None]:
%cd /root/ml-on-aws/byos-pytorch-gan

### Data preparation

PyTorch framework has a torchvision.datasets package, which provides MNIST dataset, you may use the following commands to download MNIST dataset to local storage, for late use.


In [None]:
from torchvision import datasets

dataroot = './data'

trainset = datasets.MNIST(root=dataroot, train=True, download=True)
testset = datasets.MNIST(root=dataroot, train=False, download=True)

SageMaker SDK will create a default Amazon S3 bucket for you to access various files and data, that you may need in the machine learning engineering lifecycle. We can get the name of this bucket through the default_bucket method of the sagemaker.session.Session class in the SageMaker SDK.

In [None]:
from sagemaker.session import Session

sess = Session()

# S3 bucket for saving code and model artifacts.
# Feel free to specify a different bucket here if you wish.
bucket = sess.default_bucket()

# Location to save your custom code in tar.gz format.
s3_custom_code_upload_location = f's3://{bucket}/customcode/byos-pytorch-gan'

# Location where results of model training are saved.
s3_model_artifacts_location = f's3://{bucket}/artifacts/'

The SageMaker SDK provides tools for operating AWS services. For example, the S3Downloader class is used to download objects in S3, and the S3Uploader is used to upload local files to S3. You will upload the dataset files to Amazon S3 for model training. During model training, we do not download data from the Internet to avoid network latency caused by fetching data from the Internet, and at the same time avoiding possible security risks due to direct access to the Internet.


In [None]:
import os
from sagemaker.s3 import S3Uploader as s3up

s3_data_location = s3up.upload(os.path.join(dataroot, "MNIST"), f"s3://{bucket}/data/mnist")

### Training




Per sagemaker.get_execution_role() method, the notebook can get the role pre-assigned to the notebook instance. This role will be used to obtain training resources, such as downloading training framework images, allocating Amazon EC2 instances, and so on.

In [None]:
from sagemaker import get_execution_role

# IAM execution role that gives SageMaker access to resources in your AWS account.
# We can use the SageMaker Python SDK to get the role from our notebook environment. 
role = get_execution_role()

The hyperparameters, that used in the model training tasks, can be defined in the notebook so that it is separated from the algorithm and training code. The hyperparameters are passed in when the training task is created and dynamically combined with the training task.

In [None]:
import json

hps = {
         'seed': 0,
         'learning-rate': 0.0002,
         'epochs': 18,
         'dataset': 'mnist',
         'pin-memory': 1,
         'beta1': 0.5,
         'nc': 1,
         'nz': 100,
         'ngf': 28,
         'ndf': 28,
         'batch-size': 128,
         'sample-interval': 100,
         'log-interval': 20,
     }


str_hps = json.dumps(hps, indent = 4)
print(str_hps)

```PyTorch``` class from sagemaker.pytorch package, is an estimator for PyTorch framework, it can be used to create and execute training tasks, as well as to deploy trained models. In the parameter list, ``train_instance_type`` is used to specify the instance type, such as CPU or GPU instances. The directory containing training script and the model code are specified by ``source_dir``, and the training script file name must be clearly defined by ``entry_point``. These parameters will be passed to the training task along with other parameters, and they determine the environment settings of the training task.

In [None]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(role=role,
                        entry_point='train.py',
                        source_dir='networks/DCGAN',
                        output_path=s3_model_artifacts_location,
                        code_location=s3_custom_code_upload_location,
                        train_instance_count=1,
                        train_instance_type='ml.p3.2xlarge',
                        train_use_spot_instances=True,
                        train_max_wait=86400,
                        framework_version='1.5.0',
                        py_version='py3',
                        hyperparameters=hps,
                   )

请特别注意 ``train_use_spot_instances`` 参数，``True`` 值代表您希望优先使用 SPOT 实例。由于机器学习训练工作通常需要大量计算资源长时间运行，善用 SPOT 可以帮助您实现有效的成本控制，SPOT 实例价格可能是按需实例价格的 20% 到 60%，依据选择实例类型、区域、时间不同实际价格有所不同。 

您已经创建了 PyTorch 对象，下面可以用它来拟合预先存在 Amazon S3 上的数据了。下面的指令将执行训练任务，训练数据将以名为 **MNIST** 的输入通道的方式导入训练环境。训练开始执行过程中，Amazon S3 上的训练数据将被下载到模型训练环境的本地文件系统，训练脚本 ```train.py``` 将从本地磁盘加载数据进行训练。

In [None]:
# Start training
estimator.fit({"MNIST": s3_data_location}, wait=False)

根据您选择的训练实例不同，训练过程中可能持续几十分钟到几个小时不等。建议设置 ``wait`` 参数为 ``False`` ，这个选项将使笔记本与训练任务分离，在训练时间长、训练日志多的场景下，可以避免笔记本上下文因为网络中断或者会话超时而丢失。训练任务脱离笔记本后，输出将暂时不可见，可以执行如下代码，笔记本将获取并载入此前的训练回话，

In [None]:
%%time
from sagemaker.estimator import Estimator

# Attaching previous training session
training_job_name = estimator.latest_training_job.name
attached_estimator = Estimator.attach(training_job_name)

由于的模型设计考虑到了GPU对训练加速的能力，所以用GPU实例训练会比CPU实例快一些，例如，p3.2xlarge 实例大概需要15分钟左右，而 c5.xlarge 实例则可能需要6小时以上。目前模型不支持分布、并行训练，所以多实例、多CPU/GPU并不会带来更多的训练速度提升。

训练完成后，模型将被上传到 Amazon S3 里，上传位置由创建 `PyTorch` 对象时提供的 `output_path` 参数指定。

### 模型的验证

您将从 Amazon S3 下载经过训练的模型到笔记本所在实例的本地文件系统，下面的代码将载入模型，然后输入一个随机数，获得推理结果，以图片形式展现出来。


In [None]:
from helper import *

last_artifact_location = s3_model_artifacts_location + training_job_name

last_model_url = get_object_path_by_filename(last_artifact_location, 'model.tar.gz')
last_output_url = get_object_path_by_filename(last_artifact_location, 'output.tar.gz')

print(last_model_url)
print(last_output_url)

In [None]:
from sagemaker.s3 import S3Downloader as s3down

!rm -rf ./tmp/* ./model/*
s3down.download(last_model_url, './tmp')
s3down.download(last_output_url, './tmp')

In [None]:
!tar -zxf tmp/model.tar.gz -C ./tmp
!tar -zxf tmp/output.tar.gz -C ./tmp

执行如下指令加载训练好的模型，并通过这个模型产生一组『手写』数字字体。

In [None]:
import helper
import matplotlib.pyplot as plt
import numpy as np
import torch
from networks.DCGAN.model import Generator
from networks.DCGAN.model_tools import generate_fake_handwriting

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

params = {'nz': hps['nz'], 'nc': hps['nc'], 'ngf': hps['ngf']}
model = helper.load_model("./tmp/generator_state.pth", model_cls=Generator, params=params, device=device, strict=False)
img = generate_fake_handwriting(model, num_images=64, nz=hps['nz'], device=device)

plt.imshow(np.asarray(img))

### Training loss tracking (Optional)

In [None]:
from IPython.display import HTML

HTML("""
<div align="middle">
<img align=left src="tmp/loss_tracking.png" type="image/png" width=600>
</div>""")

### Fake image samples looping (Optional)

In [None]:
from PIL import Image
import os

fake_files = []
for root, dirs, files in os.walk("tmp", topdown=False):
    for name in files:
        if not root == "tmp":
            continue
        if not name[:4] == "fake":
            continue
        if not name[-9:] == "b0000.png":
            continue
        fake_files.append(name)
    for name in dirs:
        continue
        print(os.path.join(root, name))

fake_files.sort()

images = []
for file in fake_files:
    im = Image.open(f'tmp/{file}')
    images.append(im)

images[0].save('tmp/gan.gif',
               save_all=True, append_images=images[1:], optimize=False, duration=500, loop=0)

In [None]:
from IPython.display import HTML

HTML("""
<div align="middle">
<img align=left src="tmp/gan.gif" type="image/gif" width=300>
</div>""")