In [None]:
import torch as t
from torch import nn
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import transforms
from torchvision.utils import save_image
from torchvision.datasets import CIFAR10
import numpy as np
from torch  import optim
import torchvision.utils as vutil
from tensorboard_logger import Logger
import torchvision as tv

In [None]:
class Config:
    lr=0.0002
    nz=100# 噪声维度
    image_size=64
    image_size2=64
    nc=3# 图片三通道
    ngf=64 #生成图片
    ndf=64 #判别图片
    gpuids=2
    beta1=0.5
    batch_size=1024
    max_epoch=12# =1 when debug
    workers=4
    clamp_num=0.01# WGAN 截断大小
    
opt=Config()


In [None]:
# 加载数据

dataset=tv.datasets.ImageFolder('/home/x/data/pre/train_new/nouse/',
                transform=transforms.Compose(\
                                             
                                             [transforms.Scale(opt.image_size),
                                              transforms.RandomCrop(opt.image_size) ,
#                                               transforms.RandomSizedCrop()
                                              transforms.ToTensor(),
                                              transforms.Normalize([0.5]*3,[0.5]*3)
                                             ]))
dataloader=t.utils.data.DataLoader(dataset,opt.batch_size,True,num_workers=opt.workers)

In [None]:
# 网络结构

class ModelG(nn.Module):
    def __init__(self,ngpu):
        super(ModelG,self).__init__()
        self.ngpu=ngpu
        self.model=nn.Sequential()
        self.model.add_module('deconv1',nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False))
        self.model.add_module('bnorm1',nn.BatchNorm2d(opt.ngf*8))
        self.model.add_module('relu1',nn.ReLU(True))
        self.model.add_module('deconv2',nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False))
        self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ngf*4))
        self.model.add_module('relu2',nn.ReLU(True))
        self.model.add_module('deconv3',nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False))
        self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ngf*2))
        self.model.add_module('relu3',nn.ReLU(True))
        self.model.add_module('deconv4',nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False))
        self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ngf))
        self.model.add_module('relu4',nn.ReLU(True))
        self.model.add_module('deconv5',nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False))
        self.model.add_module('tanh',nn.Tanh())
    def forward(self,input):
         
        if self.ngpu:
            gpuids=range(self.ngpu)
        return nn.parallel.data_parallel(self.model,input, device_ids=gpuids)

def weight_init(m):
    # 参数初始化。
    class_name=m.__class__.__name__
    if class_name.find('conv')!=-1:
        m.weight.data.normal_(0,0.02)
    if class_name.find('norm')!=-1:
        m.weight.data.normal_(1.0,0.02)
    
class ModelD(nn.Module):
    def __init__(self,ngpu):
        super(ModelD,self).__init__()
        self.ngpu=ngpu
        self.model=nn.Sequential()
        self.model.add_module('conv1',nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False))
        self.model.add_module('relu1',nn.LeakyReLU(0.2,inplace=True))
        
        self.model.add_module('conv2',nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False))
        self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ndf*2))
        self.model.add_module('relu2',nn.LeakyReLU(0.2,inplace=True))
        
        self.model.add_module('conv3',nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False))
        self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ndf*4))
        self.model.add_module('relu3',nn.LeakyReLU(0.2,inplace=True))
        
        self.model.add_module('conv4',nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False))
        self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ndf*8))
        self.model.add_module('relu4',nn.LeakyReLU(0.2,inplace=True))
        
        self.model.add_module('conv5',nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False))
        
    def forward(self,input):    
        if self.ngpu:
            gpuids=range(self.ngpu)
        return nn.parallel.data_parallel(self.model,input, device_ids=gpuids).view(-1,1).mean(0).view(1)#
         ## no loss but score

netg=ModelG(opt.gpuids)
netd=ModelD(opt.gpuids)

netd.cuda()
netg.cuda()

netd.apply(weight_init)
netg.apply(weight_init)


In [None]:
# 定义优化器

optimizerD=optim.RMSprop(netd.parameters(),lr=opt.lr ) #modify ： 不要采用基于动量的优化方法 如Adam
optimizerG=optim.RMSprop(netg.parameters(),lr=opt.lr )  #  

# 定义 D网和G网的输入
input=Variable(t.FloatTensor(opt.batch_size,opt.nc,opt.image_size,opt.image_size2).cuda())
noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).cuda()) 
fixed_noise=Variable(t.cuda.FloatTensor(64,opt.nz,1,1).normal_(0,1)) 


In [None]:
#criterion=nn.BCELoss() # WGAN 不需要log（交叉熵） 
one=t.FloatTensor([1])
mone=-1*one
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
input=Variable(t.FloatTensor(opt.batch_size,opt.nc,opt.image_size,opt.image_size2).cuda())
#开始训练
input=Variable(t.FloatTensor(opt.batch_size,opt.nc,opt.image_size,opt.image_size2).cuda())
for epoch in xrange(150,151):
    try:
     for ii, data in enumerate(dataloader,0):
        #### 训练D网 ####
        print ii
        netd.zero_grad() #有必要
        real,_=data
        input.data.resize_(real.size()).copy_(real.cuda())
      
        output=netd(input)
        output.backward(one)#######for wgan
        D_x=output.data.mean()
        
        noise.data.resize_(input.size()[0],opt.nz,1,1 ).normal_(0,1)
        fake_pic=netg(noise).detach()
        output2=netd(fake_pic)
       
    
        output2.backward(mone) #for wgan
        D_x2=output2.data.mean()        
        optimizerD.step()
        for parm in netd.parameters():parm.data.clamp_(-opt.clamp_num,opt.clamp_num) ### 只有判别器需要 截断参数
        
        #### 训练G网 ########

        if ii%5 :
            netg.zero_grad()
            noise.data.normal_(0,1)
            fake_pic=netg(noise)
            output=netd(fake_pic)
            output.backward(one)
            optimizerG.step()
            #for parm in netg.parameters():parm.data.clamp_(-opt.clamp_num,opt.clamp_num)## 只有判别器需要 生成器不需要
            D_G_z2=output.data.mean()

        if ii%10==0 and ii>0 and False:
            fake_u=netg(fixed_noise)
            vutil.save_image(fake_u.data,'badgen2/fake%s_%s.png'%(epoch,ii))
#             vutil.save_image(real,'wgan/real%s_%s.png'%(epoch,ii)) 
#             break
            print epoch,ii
    except Exception as e:
        print e

![image](image/fake.png)