In [None]:
#GAN
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

BATCH_SIZE=64
LR_G=0.0001             #generator生成器
LR_D=0.0001             #discriminater判别器
N_IDEAS=5               #think of this as number of ideas for generating an art work
ART_COMPONENTS=15       #相当于一张画的总点数（这些点组成一幅画）
PAINT_POINTS=np.vstack([np.linspace(-1,1,ART_COMPONENTS)for _ in range(BATCH_SIZE)])

plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
plt.legend(loc='upper right')
plt.show()

def artist_works():
    a=np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
    paintings=a*np.power(PAINT_POINTS,2)+(a-1)
    paintings=torch.from_numpy(paintings).float()
    return paintings

G=nn.Sequential(
    nn.Linear(N_IDEAS,128),
    nn.ReLU(),
    nn.Linear(128,ART_COMPONENTS)
)#用5个idead画一幅画

D=nn.Sequential(
    nn.Linear(ART_COMPONENTS,128),
    nn.ReLU(),
    nn.Linear(128,1),
    nn.Sigmoid()
)#输入一幅画，判断是好作品的概率

opt_D=torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G=torch.optim.Adam(G.parameters(),lr=LR_G)

plt.ion()

for step in range(10000):
    artist_paintings=artist_works()
    G_ideas=torch.randn(BATCH_SIZE,N_IDEAS)#矩阵，NIDEAS列,表示random ideas
    G_paintings=G(G_ideas)
    
    prob_artist0=D(artist_paintings)#artist_paintings是好作品的概率
    prob_artist1=D(G_paintings)#G作品是好作品的概率
    
    D_loss=-torch.mean(torch.log(prob_artist0)+torch.log(1-prob_artist1))
    #增加0的概率，减小1的概率，使D判断更准
    G_loss=torch.mean(torch.log(1-prob_artist1))
    #增加1的概率，使G生成更好的作品
    
    opt_D.zero_grad()
    D_loss.backward(retain_graph=True)#保留参数给下一个反向传播
    opt_D.step()
    
    opt_G.zero_grad()
    G_loss.backward()
    opt_G.step()
    
    if step % 50 == 0:  # plotting
        plt.cla()
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting',)
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        plt.ylim((0, 3));plt.legend(loc='upper right', fontsize=10);plt.draw();plt.pause(0.01)

plt.ioff()
plt.show()