In [1]:
import os
import time
import torch
import sys 
from pathlib import Path
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
sys.path.append(os.path.abspath(r"../.."))
from model.generator import Generator, Generator_Transpose
from model.discriminator import DiscriminatorResnet, DiscriminatorLinear, DiscriminatorConv
from utils_.utils import weights_init, weight_init
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

## 1. model parameters

In [2]:
input_size = [3, 32, 32]
batch_size = 128
Epoch = 1000
GenEpoch = 1
in_channel = 64

## 2. Result save path

In [3]:
time_now = time.strftime('%Y-%m-%d-%H_%M_%S', time.localtime(time.time()))
log_path = f'./log/{time_now}'
os.makedirs(log_path)
os.makedirs(f'{log_path}/image')
os.makedirs(f'{log_path}/image/image_all')

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'using device: {device}')

using device: cuda


## 3. Loading data

In [6]:
# 加载数据集
cifar_train = datasets.CIFAR10('../../../../data/cifar', True,
                                transform=transforms.Compose([
                                transforms.Resize((32, 32)),
                                transforms.ToTensor()]
                                ),
                                download=True)
cifar_train = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)

cifar_test = datasets.CIFAR10('../../../../data/cifar', False,
                                transform=transforms.Compose([
                                transforms.Resize((32, 32)),
                                transforms.ToTensor()]
                                ),
                                download=True)
cifar_test = DataLoader(cifar_test, batch_size=batch_size, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
# check the shape of data
for data in cifar_train:
    images,targets = data
    print(images.shape,targets.shape)
    break

torch.Size([128, 3, 32, 32]) torch.Size([128])


## 4. Model define

In [8]:
gen = Generator_Transpose(in_channel=in_channel)
dis = DiscriminatorConv(input_size=input_size)
gen.apply(weight_init)
dis.apply(weight_init)
gen.to(device)
dis.to(device)
gen,dis

(Generator_Transpose(
   (up1): Sequential(
     (0): ConvTranspose2d(64, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
     (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): ReLU(inplace=True)
   )
   (up2): Sequential(
     (0): ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)
     (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): ReLU(inplace=True)
   )
   (up3): Sequential(
     (0): ConvTranspose2d(128, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)
     (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): ReLU(inplace=True)
   )
   (up4): Sequential(
     (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), bias=False)
     (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (2): ReLU(inplace=True)
   )
   (out_layer): Sequential(
     (0): ConvTranspose2d(32, 3, ker

## 6. train model

In [9]:
# 设置损失函数
Loss = nn.CrossEntropyLoss()
# 设置优化器
opt_gen = optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_dis = optim.Adam(dis.parameters(), lr=2e-4, betas=(0.5, 0.999))
# 模型训练
gen.train()
dis.train()
gen_loss_list = []       # 生成网路损失
dis_loss_list = []       # 判别网络损失
for epoch in range(Epoch):
    with tqdm(total=cifar_train.__len__(), desc=f'Epoch {epoch + 1}/{Epoch}') as pbar:
        gen_loss_avg = []
        dis_loss_avg = []
        index = 0       # 记录训练了多少个batch
        for batchidx, (img, _) in enumerate(cifar_train):
            # 获取数据
            img = img.to(device)
            # 获取标注向量
            valid = torch.ones(img.size()[0], dtype=torch.int64).to(device)
            fake = torch.zeros(img.size()[0], dtype=torch.int64).to(device)
            # 随机生成一组数据
            G_img = torch.randn([img.size()[0], in_channel, 1, 1], requires_grad=True).to(device)
            # ------------------更新判别器------------------
            # 前向计算
            G_pred_gen = gen(G_img)
            G_pred_dis = dis(G_pred_gen.detach())
            R_pred_dis = dis(img)
            # 计算损失
            G_loss = Loss(G_pred_dis, fake)
            R_loss = Loss(R_pred_dis, valid)
            dis_loss = (R_loss + G_loss) / 2
            dis_loss_avg.append(dis_loss.item())
            # 反向传播
            opt_dis.zero_grad()
            dis_loss.backward()
            opt_dis.step()
            # ------------------更新生成器------------------
            # 前向计算
            G_pred_gen = gen(G_img)
            G_pred_dis = dis(G_pred_gen)
            # 计算损失
            gen_loss = Loss(G_pred_dis, valid)
            gen_loss_avg.append(gen_loss.item())
            # 反向传播
            opt_gen.zero_grad()
            gen_loss.backward()
            opt_gen.step()
            # 保存过程图片
            if index % 100 == 0 or index + 1 == cifar_train.__len__():
                save_image(G_pred_gen, f'{log_path}/image/image_all/epoch-{epoch}-index-{index}.png')
            index += 1
            # ------------------进度条更新------------------
            pbar.set_postfix(**{
                'gen-loss': sum(gen_loss_avg) / len(gen_loss_avg),
                'dis-loss': sum(dis_loss_avg) / len(dis_loss_avg)
            })
            pbar.update(1)
    save_image(G_pred_gen, f'{log_path}/image/epoch-{epoch}.png')
    filename = 'epoch%d-genLoss%.2f-disLoss%.2f' % (epoch, sum(gen_loss_avg) / len(gen_loss_avg), sum(dis_loss_avg) / len(dis_loss_avg))
    torch.save(gen.state_dict(), f'{log_path}/{filename}-gen.pth')
    torch.save(dis.state_dict(), f'{log_path}/{filename}-dis.pth')
    # 记录损失
    gen_loss_list.append(sum(gen_loss_avg) / len(gen_loss_avg))
    dis_loss_list.append(sum(dis_loss_avg) / len(dis_loss_avg))
    # 绘制损失图像并保存
    plt.figure(0)
    plt.plot(range(epoch + 1), gen_loss_list, 'r--', label='gen loss')
    plt.plot(range(epoch + 1), dis_loss_list, 'r--', label='dis loss')
    plt.legend()
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.savefig(f'{log_path}/loss.png', dpi=300)
    plt.close(0)

Epoch 1/1000:  32%|███▏      | 125/391 [00:08<00:17, 14.97it/s, dis-loss=0.61, gen-loss=0.748] 


KeyboardInterrupt: 

## 7. predict

In [10]:
input_size = [3, 32, 32]
in_channel = 64
gen_para_path = '../../log/2023-02-11-17_52_12/epoch999-genLoss1.21-disLoss0.40-gen.pth'
dis_para_path = '../../log/2023-02-11-17_52_12/epoch999-genLoss1.21-disLoss0.40-dis.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator_Transpose(in_channel=in_channel).to(device)
dis = DiscriminatorLinear(input_size=input_size).to(device)
gen.load_state_dict(torch.load(gen_para_path, map_location=device))
gen.eval()
# 随机生成一组数据
G_img = torch.randn([1, in_channel, 1, 1], requires_grad=False).to(device)
# 放入网路
G_pred = gen(G_img)
G_dis = dis(G_pred)
print('generator-dis:', G_dis)
# 图像显示
G_pred = G_pred[0, ...]
G_pred = G_pred.detach().cpu().numpy()
G_pred = np.array(G_pred * 255)
G_pred = np.transpose(G_pred, [1, 2, 0])
G_pred = Image.fromarray(np.uint8(G_pred))
G_pred.show()

FileNotFoundError: [Errno 2] No such file or directory: '../../log/2023-02-11-17_52_12/epoch999-genLoss1.21-disLoss0.40-gen.pth'