In [1]:
import  torch
from    torch import nn, optim, autograd
import  numpy as np
import  visdom
from    torch.nn import functional as F
from    matplotlib import pyplot as plt
import  random
import torchvision.models as model
import torchvision.datasets as datasets
import torch.utils.data as data
import torchvision.transforms as transforms

In [2]:
import time

In [9]:
viz=visdom.Visdom()
viz.images(np.random.randn(16, 3, 64, 64),win='X',
    opts=dict(title='fake image', caption='fake image'))
while True:
    viz.images(np.random.randn(16, 3, 64, 64),win='X',
    opts=dict(title='fake image', caption='fake image'), update='append')
    time.sleep(1)

Setting up a new session...


TypeError: images() got an unexpected keyword argument 'update'

In [2]:
batch_size=16
epochs = 1000
real_data = datasets.CIFAR10('./realdata/',transform=transforms.ToTensor(),download=True)
real_loader = data.DataLoader(real_data,batch_size=batch_size,shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Files already downloaded and verified


In [3]:
#网络的输入为[batch_size,100]----每个样本100维
#网络的输出为[batch_size，3，32，32]----每个样本为3通道的32×32图片
class Generator(nn.Module):
    def __init__(self,dp=False):
        super(Generator,self).__init__()
        self.linear = nn.Linear(100,256)#32×32/4=256,先转成1/4大小单通道，再上采样
        self.upsample=nn.Sequential(
        nn.ConvTranspose2d(1,3,kernel_size=2,stride=2,bias=False),
        nn.BatchNorm2d(3),
        nn.ReLU(inplace=True),
        #nn.Dropout(0.15,inplace=True)
        )
        self.conv1 = nn.Sequential(
        nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1,bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        #nn.Dropout(0.15,inplace=True)
        )
        self.conv2 = nn.Sequential(
        nn.Conv2d(64,128,kernel_size=1,stride=1,padding=0,bias=False),
        nn.BatchNorm2d(128),
        nn.ReLU(inplace=True),
        #nn.Dropout(0.15,inplace=True)
        )
        self.conv3 = nn.Sequential(
        nn.Conv2d(128,64,kernel_size=1,stride=1,padding=0,bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        #nn.Dropout(0.15,inplace=True)
        )
        self.conv4 = nn.Sequential(
        nn.Conv2d(64,3,kernel_size=3,stride=1,padding=1,bias=False),
        nn.BatchNorm2d(3),
        nn.ReLU(inplace=True)
        )
        
    def forward(self,x):
        x = F.relu(self.linear(x))
        #print(x.shape)
        x = x.view(x.size(0),1,16,16)
        x = self.upsample(x)
        x=self.conv4(self.conv3(self.conv2(self.conv1(x))))
        return x

In [5]:
#基于resnet18修改得到
class Discriminator(nn.Module):
    def __init__(self,model):
        super(Discriminator,self).__init__()
        self.res = nn.Sequential(*list(resnet18.children())[:-1])
        self.fc = nn.Linear(512,1)
        self.sig = nn.Sigmoid()
    def forward(self,x):
        x = self.res(x)
        #print(x.size(),"d")
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        x = self.sig(x)
        return x.view(-1)

In [47]:
G = Generator().to(device)
resnet18 = model.resnet18(pretrained=True)
D = Discriminator(resnet18).to(device)
optim_d = optim.Adam(D.parameters(),lr=0.001,betas=(0.5,0.9))
optim_g = optim.Adam(G.parameters(),lr=0.001,betas=(0.5,0.9))

In [48]:
def gradient_penalty(x_real,x_fake):
    lambada = 0.3
    x_real = x_real.detach()
    x_fake = x_fake.detach()
    alpha = torch.rand(batch_size,1,1,1).expand_as(x_real).to(device)
    interpolates = alpha*x_real+(1-alpha)*x_fake
    interpolates.requires_grad=True
    D_interpolates = D(interpolates)
    gradient = autograd.grad(outputs=D_interpolates,inputs=interpolates,
                             grad_outputs=torch.ones_like(D_interpolates),
                             create_graph=True,retain_graph=True,only_inputs=True)[0]
    norm = []
    for i in gradient:
        norm.append(torch.norm(i))
    norm = torch.tensor(norm).to(device)
    return ((norm-1)**2).mean()*lambada

In [49]:
def train_discriminator(x_real):
    x_real = x_real.to(device)
    #print(x_real.shape)
    pred_real = D(x_real)
    z = torch.rand(batch_size,100).to(device)
    x_fake = G(z).detach()
    pred_fake = D(x_fake)
    #print(pred_real)
    loss_real = -(pred_real.mean())
    loss_fake = pred_fake.mean()
    gp = gradient_penalty(x_real,x_fake)
    #del x_real,x_fake
    loss = loss_real+loss_fake+gp
    optim_d.zero_grad()
    loss.backward()
    optim_d.step()
    return loss

In [50]:
def train_generator():
    l = 0
    nums = 0
    for i in range(10):
        z = torch.rand(batch_size,100).to(device)
        x_fake = G(z)
        #del z
        pred_f = D(x_fake)
        loss = -(pred_f.mean())
        optim_g.zero_grad()
        #print(pred_f)
        #print(loss)
        loss.backward()
        optim_g.step()
        #nums+=len(x_real)
        l+=(loss*batch_size)
    return l/10*batch_size,x_fake

In [89]:
def main():
    D.train()
    G.train()
    viz = visdom.Visdom()
    viz.line([[0,0]], [0], win='loss', opts=dict(title='loss',legend=['D', 'G']))
    viz.images(np.random.randn(batch_size, 3, 64, 64),
    opts=dict(title='fake image', caption='fake image'))
    for epoch in range(epochs):
        torch.manual_seed(32)#设置一个随机数种子，设置后每次随机取值都会从同样的值开始迭代取值
        for i,(x_real,_) in enumerate(real_loader):
            loss_d,num = 0,0
            num+=len(x_real)
            loss_d += (train_discriminator(x_real).item())*len(x_real)
            if (i+1)%10==0:
                loss_g,img = train_generator()
                print('epoch: %d,batch: %d,discriminator loss: %f,generator loss: %f'%(epoch,i+1,loss_d/num,loss_g))
                #print(loss_d/num,loss_d/num)
                #print(type(loss_d))
                #print(type(loss_g))
                viz.line([[loss_d/num,loss_g.item()]], [epoch*len(real_loader)//10+(i+1)//10], win='loss', update='append')
                loss_d,num = 0,0

In [90]:
main()

Setting up a new session...


epoch: 0,batch: 10,discriminator loss: 0.225868,generator loss: -26.615234
epoch: 0,batch: 20,discriminator loss: 0.816112,generator loss: -28.149494
epoch: 0,batch: 30,discriminator loss: 0.328266,generator loss: -25.621403
epoch: 0,batch: 40,discriminator loss: -0.517473,generator loss: -25.615309
epoch: 0,batch: 50,discriminator loss: 1.135155,generator loss: -34.984409
epoch: 0,batch: 60,discriminator loss: -0.512183,generator loss: -33.452843
epoch: 0,batch: 70,discriminator loss: -0.387489,generator loss: -32.088509
epoch: 0,batch: 80,discriminator loss: -0.480134,generator loss: -25.656391
epoch: 0,batch: 90,discriminator loss: -0.613337,generator loss: -26.898777
epoch: 0,batch: 100,discriminator loss: -0.512195,generator loss: -26.274153
epoch: 0,batch: 110,discriminator loss: -0.508509,generator loss: -21.867231
epoch: 0,batch: 120,discriminator loss: 0.762613,generator loss: -22.861303
epoch: 0,batch: 130,discriminator loss: 0.362146,generator loss: -36.459007
epoch: 0,batch

KeyboardInterrupt: 