In [1]:
from config import *

def load(model, model_file):
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    return model

In [2]:
adversarial_loss_sigmoid = nn.BCEWithLogitsLoss()
aux_loss = nn.CrossEntropyLoss()
G, D = create_model()
# optimizer
g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])
d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])
# fixed input for debugging
fixed_z = tensor2var(torch.randn(batch_size, z_dim), ctx=ctx)  # (*, 100)
fixed_labels = tensor2var(torch.randint(0, n_classes, (batch_size,),
                                        dtype=torch.long), ctx=ctx)

def load_model(epoch, G=G, D=D):
    if epoch < 0:
        return G, D
    else:
        G = load(G, f'results/G{epoch}.pt')
        D = load(D, f'results/D{epoch}.pt')
        return G, D

In [3]:
load_epoch = -1
G, D = load_model(epoch=load_epoch)
import torch
torch.cuda.empty_cache() # 清空 GPU 缓存

In [4]:
epochs = 15000
if 1:
    for epoch in range(load_epoch+1, epochs+load_epoch+1):
        # start time
        start_time = time.time()
        data_loader = data_iter()
        for i, (real_images, labels) in enumerate(data_loader):
            # configure input
            real_images = tensor2var(real_images, ctx=ctx)
            labels = tensor2var(labels, ctx=ctx) #- 1
            # adversarial ground truths
            valid = torch.full((real_images.size(0),), 0.9)
            valid = tensor2var(valid, ctx=ctx)  # (*, )
            fake = torch.full((real_images.size(0),), 0.0)
            fake = tensor2var(fake, ctx=ctx)  # (*, )

            D = D.to(f'cuda:{ctx}')
            G = G.to(f'cuda:{ctx}').to(f'cuda:{ctx}')
            adversarial_loss_sigmoid.to(f'cuda:{ctx}')
            aux_loss

            # ==================== Train D ==================
            D.train()
            G.train()

            D.zero_grad()
            # compute loss with real images
            dis_out_real, aux_out_real = D(real_images)

            adversarial_loss_sigmoid = nn.BCEWithLogitsLoss().to(f'cuda:{ctx}')
            aux_loss = nn.CrossEntropyLoss().to(f'cuda:{ctx}')
            d_loss_real = adversarial_loss_sigmoid(dis_out_real, valid) + aux_loss(aux_out_real, labels)

            # noise z for generator
            z = tensor2var(torch.randn(real_images.size(0), z_dim), ctx=ctx)  # *, 100
            gen_labels = torch.randint(0, n_classes, (real_images.size(0),), dtype=torch.long)
            gen_labels = tensor2var(gen_labels, ctx=ctx)

            fake_images = G(z, gen_labels)  # (*, c, 64, 64)
            dis_out_fake, aux_out_fake = D(fake_images)  # (*,)

            d_loss_fake = adversarial_loss_sigmoid(
                dis_out_fake, fake) + aux_loss(aux_out_fake, gen_labels)

            # total d loss
            d_loss = d_loss_real + d_loss_fake

            d_loss.backward()
            # update D
            d_optimizer.step()

            # calculate dis accuracy
            d_acc = compute_acc(aux_out_real, aux_out_fake, labels, gen_labels)
            # train the generator every 5 steps
            if i % g_num == 0:
                # =================== Train G and gumbel =====================
                G.zero_grad()
                # create random noise 
                fake_images = G(z, gen_labels)

                # compute loss with fake images 
                dis_out_fake, aux_out_fake = D(fake_images) # batch x n

                g_loss_fake = adversarial_loss_sigmoid(dis_out_fake, valid) + \
                    aux_loss(aux_out_fake, gen_labels)

                g_loss_fake.backward()
                # update G
                g_optimizer.step()
        # log to the tensorboard
        logger.add_scalar('d_loss', d_loss.data, epoch)
        logger.add_scalar('g_loss_fake', g_loss_fake.data, epoch)
        # end one epoch

        # print out log info
        if (epoch) % log_step == 0:
            elapsed = time.time() - start_time
            elapsed = str(datetime.timedelta(seconds=elapsed))
            print("Elapsed [{}], G_step [{}/{}], D_step[{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, Acc: {:.4f}"
                        .format(elapsed, epoch, epochs, epoch,
                                epochs, d_loss.item(), g_loss_fake.item(), d_acc))
        # sample images 
        if (epoch) % sample_step == 0:
            G.eval()
            # save real image
            save_sample(sample_path + '/real_images/', real_images, epoch)
            with torch.no_grad():
                fake_images = G(fixed_z, fixed_labels)
                # save fake image 
                save_sample(sample_path + '/fake_images/', fake_images, epoch)
            # sample sample one images
            save_sample_one_image(sample_path, real_images, fake_images, epoch)
            torch.save(G.state_dict(), f"results/G{epoch}.pt")
            torch.save(D.state_dict(), f"results/D{epoch}.pt")

Elapsed [0:00:19.587910], G_step [0/15000], D_step[0/15000], d_loss: 5.5627, g_loss: 5.4476, Acc: 0.0938
Elapsed [0:00:17.325533], G_step [10/15000], D_step[10/15000], d_loss: 5.3165, g_loss: 7.9280, Acc: 0.1562
Elapsed [0:00:16.998439], G_step [20/15000], D_step[20/15000], d_loss: 5.5939, g_loss: 7.7659, Acc: 0.0312
Elapsed [0:00:17.582046], G_step [30/15000], D_step[30/15000], d_loss: 5.2078, g_loss: 8.0218, Acc: 0.1875
Elapsed [0:00:17.495244], G_step [40/15000], D_step[40/15000], d_loss: 5.2776, g_loss: 8.1110, Acc: 0.1875
Elapsed [0:00:17.365253], G_step [50/15000], D_step[50/15000], d_loss: 5.3861, g_loss: 8.3730, Acc: 0.1250
Elapsed [0:00:17.674589], G_step [60/15000], D_step[60/15000], d_loss: 5.2897, g_loss: 7.0865, Acc: 0.1875
Elapsed [0:00:17.889655], G_step [70/15000], D_step[70/15000], d_loss: 5.4459, g_loss: 9.3411, Acc: 0.1562
Elapsed [0:00:17.411195], G_step [80/15000], D_step[80/15000], d_loss: 5.1701, g_loss: 7.5020, Acc: 0.2500
Elapsed [0:00:17.585452], G_step [90/15