In [2]:
#访问数据集文件问题
import argparse

parser = argparse.ArgumentParser("DCGAN")
parser.add_argument('--dataset_dir', type=str, default="") # dataset directory
parser.add_argument('--result_dir', type=str, default='') # log image directory
parser.add_argument('--batch_size', type=int, default=64) # batch size
parser.add_argument('--n_epoch', type=int, default=20) # epoch size
parser.add_argument('--n_cpu', type=int, default=4) # num of process(for use worker)
parser.add_argument('--log_iter', type=int, default=1000) # print log message and save image per log_iter
parser.add_argument('--nz', type=int, default=100)  # noise dimension
parser.add_argument('--nc', type=int, default=3)    # input and out channel
parser.add_argument('--ndf', type=int, default=64)  # number of Discriminator's feature map dimension
parser.add_argument('--ngf', type=int, default=64)  # number of Generator's feature map dimension
parser.add_argument('--lr', type=float, default=0.0002) # learning rate
parser.add_argument('--device',type=str,default="NPU")#device 
parser.add_argument('--beta', type=float, default=0.5)
parser.add_argument('--criterion', type=str, default='BCE') # BCE / MSE
parser.add_argument('--tanh', action='store_true') # Use tanh end of generator
config, _ = parser.parse_known_args()



构建网络

In [3]:
import mindspore.nn as nn
import os
class Generator(nn.Cell):
    def __init__(self):
        super(Generator, self).__init__()
        main = [
            nn.Conv2dTranspose(in_channels=config.nz,out_channels= config.ngf * 8,kernel_size= 4,stride=1,padding=0,pad_mode="pad", has_bias=False),
            nn.BatchNorm2d(config.ngf * 8),
            nn.ReLU(),
            nn.Conv2dTranspose(in_channels=config.ngf * 8,out_channels= config.ngf * 4,kernel_size= 4,stride= 2,padding= 1,pad_mode="pad", has_bias=False),
            nn.BatchNorm2d(config.ngf * 4),
            nn.ReLU(),
            nn.Conv2dTranspose(in_channels=config.ngf * 4,out_channels= config.ngf * 2,kernel_size=  4,stride= 2,padding= 1,pad_mode="pad", has_bias=False),
            nn.BatchNorm2d(config.ngf * 2),
            nn.ReLU(),
            nn.Conv2dTranspose(in_channels=config.ngf * 2,out_channels= config.ngf,kernel_size=  4,stride= 2,padding= 1,pad_mode="pad", has_bias=False),
            nn.BatchNorm2d(config.ngf),
            nn.ReLU(),
            nn.Conv2dTranspose(in_channels=config.ngf,out_channels= config.nc,kernel_size=  4,stride= 2, padding=1,pad_mode="pad", has_bias=False),
        ]

        self.tanh=nn.Tanh()
        self.main = nn.SequentialCell(*main)

    def construct(self, x):
        out=self.main(x)
        if config.tanh:
            out=self.tanh(out)
        return out


class Discriminator(nn.Cell):
    def __init__(self):
        super(Discriminator, self).__init__()
        main = [
            nn.Conv2d(in_channels=config.nc,out_channels=config.ndf,kernel_size=4,stride= 2,padding= 1, has_bias=False,pad_mode="pad"),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=config.ndf,out_channels= config.ndf * 2,kernel_size= 4,stride= 2,padding= 1,has_bias=False,pad_mode="pad"),
            nn.BatchNorm2d(config.ndf * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=config.ndf * 2,out_channels= config.ndf * 4,kernel_size= 4,stride= 2,padding= 1, has_bias=False,pad_mode="pad"),
            nn.BatchNorm2d(config.ndf * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=config.ndf * 4,out_channels= config.ndf * 8,kernel_size= 4,stride= 2,padding= 1,has_bias=False,pad_mode="pad"),
            nn.BatchNorm2d(config.ndf * 8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(in_channels=config.ndf * 8,out_channels=1,kernel_size= 4,stride= 1,padding= 0,has_bias=False,pad_mode="pad"),
        ]
     

        self.main = nn.SequentialCell(*main)
        self.sigmoid=nn.Sigmoid()

    def construct(self, x):
        out=self.main(x)
        out=out.reshape(-1,1)
        if config.criterion=='BCE':
            out=self.sigmoid(out)
        return out




In [4]:
def show_generated_images(fake_images):
        #同时获得最小值和最大值
        min_val,max_val=ops.aminmax(fake_images)
        #解决StubTensor情况，Tensor->float
        min_val = min_val.asnumpy()
        max_val = max_val.asnumpy()
        min_mat=Tensor(np.full(fake_images.shape,min_val))
        ret_mat=Tensor(np.full(fake_images.shape,255))
        div_mat=Tensor(np.full(fake_images.shape,max_val-min_val))
        # 映射到 [0, 255]
        fake_images=ops.sub(fake_images,min_mat)
        fake_images=ops.div(fake_images,div_mat)
        fake_images=ops.mul(fake_images,ret_mat)
        # 将 Tensor 转换为 numpy 数组
        fake_images_np = fake_images.asnumpy().transpose(0, 2, 3, 1)  # 调整维度以适应 matplotlib
        # 转换为 uint8 类型
        fake_images_np = fake_images_np.astype(np.uint8) 
        # 展示图片
        fig, axes = plt.subplots(1, min(config.batch_size, 10), figsize=(10, 2))
        for ax, img in zip(axes, fake_images_np[:10]):
            ax.imshow(img)
            ax.axis('off')
        plt.show()

In [5]:
import mindspore.experimental.optim as optim
import mindspore as ms
import mindspore.dataset as msda
from PIL import Image
import numpy as np
from mindspore import Tensor,ops,context,Parameter,value_and_grad
import matplotlib.pyplot as plt


class Trainer:
    def __init__(self):
        self.generator = Generator()
        self.discriminator = Discriminator()
        self.loss = nn.MSELoss()
        self.real_best_loss=float('inf')
        self.fake_best_loss=float('inf')

#         if config.device == 'GPU':
#             context.set_context(device_target="GPU")
#         else:
#             context.set_context(device_target="CPU")

        self.optimizer_g = nn.Adam(self.generator.trainable_params(), learning_rate=config.lr, beta1=config.beta,beta2=0.999)
        self.optimizer_d = nn.Adam(self.discriminator.trainable_params(), learning_rate=config.lr, beta1=config.beta,beta2=0.999)
        
    def forward_fn_g(self,data, label):
        loss = self.loss(data, label)
        return loss,()#,logits
    
    def forward_fn_d(self,data1, label1,data2,label2):
        loss = self.loss(data1, label1)+self.loss(data2,label2)
        return loss,()#,logits1,logits2
    
    def grad_fn_g(self,d_fake,label_real):
        fs=value_and_grad(self.forward_fn_g, None,self.optimizer_g.parameters, has_aux=True)
        (loss, _), grads=fs(d_fake,label_real)
        return loss,grads
    
    def grad_fn_d(self,d_real,label_real,d_fake,label_fake):
        fs=value_and_grad(self.forward_fn_d, None,self.optimizer_d.parameters, has_aux=True)
        (loss, _), grads=fs(d_real, label_real,d_fake, label_fake)
        return loss,grads
  
    
    def train_d(self,label_real,real,label_fake,fake):
        d_real = self.discriminator(real)
        d_fake= self.discriminator(ops.stop_gradient(fake))
        d_real=Tensor(d_real)
        d_fake=Tensor(d_fake)
        loss, grads = self.grad_fn_d(d_real, label_real,d_fake, label_fake)
        self.optimizer_d(grads)
        return loss,d_real

    def train_g(self,fake,label_real):
        d_fake = self.discriminator(fake)
        loss, grads = self.grad_fn_g(d_fake, label_real)
        self.optimizer_g(grads)
        return loss,d_fake


    def train(self, dataloader):
         # 创建噪声张量
        noise = Tensor(np.random.randn(config.batch_size, config.nz, 1, 1), dtype=ms.float32)

        # 创建真实标签
        label_real = Tensor(np.ones((config.batch_size, 1)), dtype=ms.float32)

        # 创建假标签
        label_fake = Tensor(np.zeros((config.batch_size, 1)), dtype=ms.float32)
        

        for epoch in range(config.n_epoch):
            for i,(data, _) in enumerate(dataloader,0):
                #训练需要带入数据，否则只是训练噪声
                noise = ops.standard_normal(noise.shape)
                noise=Tensor(noise)
#                 print(noise.shape)
                
                # Train Discriminator

                real=Parameter(Tensor(data,dtype=ms.float32),requires_grad=True)
#                 print(real.shape)
                
                fake = self.generator(noise)
#                 print(fake.shape)
                fake=Tensor(fake)
#                 noise.reshape(fake.shape)
                loss_d,d_real=self.train_d(label_real,real,label_fake,fake)
                loss_g,d_fake=self.train_g(fake,label_real)


                if i % config.log_iter == 0:
                    print("[Epoch {:03d}] ({}/{}) d_real: {}, d_fake: {}".format(epoch, i, len(dataloader),d_real.mean(), d_fake.mean()))
                    if loss_d<self.real_best_loss:
                        self.real_best=loss_d
                        ms.save_checkpoint(self.discriminator,'discriminator_best.ckpt')
                    if loss_g<self.fake_best_loss:
                        self.fake_best=loss_g
                        ms.save_checkpoint(self.generator,'generator_best.ckpt')
                    #保存并打印图像
                    show_generated_images(fake)
                                                             


In [None]:
def main():
    #给出示例文件夹
    config.dataset_dir="CeleA"
    config.result_dir="result"
    transform=msda.transforms.Compose([
        msda.vision.Resize((64,64)),
        msda.vision.ToTensor(),
        msda.vision.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),False)])
    context.set_context()#清理内存
    #dset = dset.map(transform,["image"])
    ms_dataloader = msda.CelebADataset(config.dataset_dir, usage='all', decode=True,shuffle=True,num_parallel_workers=config.n_cpu,num_samples=config.batch_size)
    ms_dataloader = ms_dataloader.map(transform, ["image"])
    ms_dataloader = ms_dataloader.batch(config.batch_size)
    ms.set_context(device_target="Ascend")
    trainer = Trainer()
    trainer.train(ms_dataloader)
    

if __name__ == '__main__':
    ms.ms_memory_recycle()
    main()
    if os.path.exists("generator_best.ckpt"):
        #引用生成器类
        gen=Generator()
        param_dict_g=ms.load_checkpoint("generator_best.ckpt",Generator())
        ms.load_param_into_net(gen, param_dict_g)
        noise=Tensor(np.random.randn(config.batch_size, config.nz, 1, 1), dtype=ms.float32)
        noise = ops.standard_normal(noise.shape)
        noise=Tensor(noise)
        fake_images=gen(noise)
        show_generated_images(fake_images)
    #后续补充绘图函数

0
