In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import torch
import torchvision
from torchvision import transforms
from torch.utils import data

In [4]:
from model import get_Generator, get_Discriminator

generator = get_Generator(from_old_model=False, model_path=None, device='cpu', G_type="L")
discriminator = get_Discriminator(from_old_model=False, model_path=None, device='cpu', D_type="L")
print(generator)
print(discriminator)

Generator_Linear(
  (gen): Sequential(
    (0): Linear(in_features=100, out_features=256, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): ReLU(inplace=True)
    (6): Linear(in_features=1024, out_features=784, bias=True)
    (7): Tanh()
  )
)
Discriminator_Linear(
  (dis): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Linear(in_features=256, out_features=1, bias=True)
    (7): Sigmoid()
  )
)


In [5]:
generator = get_Generator(from_old_model=False, model_path=None, device='cpu', G_type="C")
discriminator = get_Discriminator(from_old_model=False, model_path=None, device='cpu', D_type="C")
print(generator)
print(discriminator)

Generator_Conv(
  (expand): Sequential(
    (0): Linear(in_features=100, out_features=256, bias=True)
    (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Dropout(p=0.3, inplace=False)
    (3): LeakyReLU(negative_slope=True)
    (4): Linear(in_features=256, out_features=484, bias=True)
    (5): BatchNorm1d(484, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Dropout(p=0.5, inplace=False)
    (7): LeakyReLU(negative_slope=True)
  )
  (gen): Sequential(
    (0): ConvTranspose2d(1, 4, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=True)
    (3): ConvTranspose2d(4, 8, kernel_size=(3, 3), stride=(1, 1))
    (4): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=True)
    (6): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(1, 1))
    (7): BatchNorm2

In [2]:
output_path = 'data'

img_transform = transforms.Compose([
    # 将像素值从 [0, 255] 转换到 [0, 1]
    transforms.ToTensor(),

    # 将像素值从 [0, 1] 转换到 [-1, 1], 将输入的图像按照通道进行标准化
    # MINIST 数据集是灰度图像，只有一个通道，所以 mean 和 std 都是一个数
    transforms.Normalize(mean=(0.5,), std=(0.5,)),
])

target_transform = transforms.Lambda(lambda y: torch.tensor(y, dtype=torch.float32))

if not os.path.exists(output_path):
    os.makedirs(output_path)

# len = 60000, each element is a tuple of (image, label), image's shape (1, 28, 28)
minist_train = torchvision.datasets.MNIST(
    root=output_path,
    train=True,
    transform=img_transform,
    target_transform=target_transform,
    download=False
)

# len = 10000, each element is a tuple of (image, label), image's shape (1, 28, 28)
minist_test = torchvision.datasets.MNIST(
    root=output_path,
    train=False,
    transform=img_transform,
    target_transform=target_transform,
    download=False
)

In [3]:
batch_size = 64

train_data_loader = data.DataLoader(
    dataset=minist_train,
    batch_size=batch_size,
    shuffle=True,
)

test_data_loader = data.DataLoader(
    dataset=minist_test,
    batch_size=batch_size,
    shuffle=False,
)

for X, y in train_data_loader:
    print(X.shape, y.shape)
    break

torch.Size([64, 1, 28, 28]) torch.Size([64])


In [15]:
import torch.nn as nn
import time
from torchvision.utils import save_image
import random
from torch.autograd import Variable
from model import *
from torch.optim import AdamW

from_old_model = False

img_seed_dim = 256

G_model_path = 'model/G_model.pth'
D_model_path = 'model/D_model.pth'

G_type = 'Linear'

criterion = nn.BCELoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'
print('device:', device)

img_output_path = f'output_images/{G_type}'

if not os.path.exists(img_output_path):
    os.makedirs(img_output_path)

device: cuda


In [18]:
G_model = get_Generator(
    from_old_model=from_old_model, model_path=G_model_path, device=device, G_type=G_type
    )

D_model = get_Discriminator(
    from_old_model=from_old_model, model_path=D_model_path, device=device
    )

G_optimizer = AdamW(G_model.parameters(), lr=1e-4, weight_decay=1e-6)
D_optimizer = AdamW(D_model.parameters(), lr=1e-4, weight_decay=1e-6)

In [20]:
train_start = time.time()

epochs = 5

for epoch in range(epochs):
    print(f'Epoch {epoch + 1}/{epochs}')
    batch_num = len(train_data_loader)
    D_loss_sum = torch.Tensor([0.0]).to(device)
    G_loss_sum = torch.Tensor([0.0]).to(device)
    count = torch.Tensor([0]).to(device)

    for index, (images, _) in enumerate(train_data_loader):
        count += 1

        real_images = images.to(device)
        real_labels = (1 - torch.rand(batch_size)/10).to(device)

        fake_images = G_model(torch.randn(batch_size, img_seed_dim).to(device))
        fake_labels = Variable(torch.zeros(batch_size)).to(device)

        D_optimizer.zero_grad()
        real_output = D_model(real_images)
        D_loss_real = criterion(real_output, real_labels)
        fake_output = D_model(fake_images)
        D_loss_fake = criterion(fake_output, fake_labels)

        D_loss = D_loss_real + D_loss_fake
        D_loss_sum += D_loss.item()

        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        fake_images = G_model(torch.randn(batch_size, img_seed_dim).to(device))
        fake_output = D_model(fake_images)

        G_loss = criterion(fake_output, real_labels)
        G_loss_sum += G_loss.item()

        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        print(f'[epoch: {epoch} batch: {index + 1}/{batch_num}] D_loss: {D_loss.item():.6f}, G_loss: {G_loss.item():.6f}')
        break
    
    torch.save(G_model.state_dict(), G_model_path)
    torch.save(D_model.state_dict(), D_model_path)

    fake_images = G_model(torch.randn(64, img_seed_dim).to(device)).cpu().detach()

    fake_images = (fake_images + 1) * 0.5
    fake_images = fake_images.clamp(0, 1)

    fake_images = fake_images.view(-1, 1, 28, 28)
    save_image(fake_images, f'{img_output_path}/epoch_{epoch}.png')

    print(f'Epoch {epoch + 1}/{epochs} D_loss: {(D_loss_sum / count).item():.6f}, G_loss: {(G_loss_sum / count).item():.6f}')
    current_time = time.time()
    print(f'Time: {(current_time - train_start):.6f}')

print('Finished Training')

Epoch 1/5
[epoch: 0 batch: 1/938] D_loss: 0.591601, G_loss: 3.395355
Epoch 1/5 D_loss: 0.591601, G_loss: 3.395355
Time: 0.701381
Epoch 2/5
[epoch: 1 batch: 1/938] D_loss: 0.272522, G_loss: 5.340329
Epoch 2/5 D_loss: 0.272522, G_loss: 5.340329
Time: 1.073398
Epoch 3/5
[epoch: 2 batch: 1/938] D_loss: 0.266933, G_loss: 6.585814
Epoch 3/5 D_loss: 0.266933, G_loss: 6.585814
Time: 1.438436
Epoch 4/5
[epoch: 3 batch: 1/938] D_loss: 0.234293, G_loss: 7.945290
Epoch 4/5 D_loss: 0.234293, G_loss: 7.945290
Time: 1.802925
Epoch 5/5
[epoch: 4 batch: 1/938] D_loss: 0.287280, G_loss: 9.170646
Epoch 5/5 D_loss: 0.287280, G_loss: 9.170646
Time: 2.160757
Finished Training


['G_L_D_L_G_model.pth', 'G_L_D_L_D_model.pth', 'output_images/G_L_D_L']